FlashDecoding++:在GPU上加速大型语言模型推理
FlashDecoding++: Faster Large Language Model Inference on GPUs
November 2, 2023
作者: Ke Hong, Guohao Dai, Jiaming Xu, Qiuli Mao, Xiuhong Li, Jun Liu, Kangdi Chen, Hanyu Dong, Yu Wang
cs.AI
摘要
随着大型语言模型(LLM)在各个领域变得日益重要。然而,在加速LLM推断过程中仍存在以下挑战尚未解决:(1)同步部分softmax更新。softmax操作需要在每个部分softmax结果之间进行同步更新操作,导致LLM中注意力计算的开销约为20%。 (2)扁平GEMM计算的计算资源未充分利用。在LLM推断中执行GEMM的矩阵形状是扁平的,导致计算资源未充分利用,在之前的设计中填充零后性能损失超过50%。 (3)由于静态数据流而导致的性能损失。LLM中的内核性能取决于不同的输入数据特征、硬件配置等。单一和静态数据流可能导致LLM推断中不同形状的GEMM性能损失达50.25%。
我们提出FlashDecoding++,这是一个快速的LLM推断引擎,支持主流LLM和硬件后端。为了解决上述挑战,FlashDecoding++创造性地提出了:(1)带有统一最大值的异步softmax。FlashDecoding++引入了一种统一的最大值技术,用于不同部分softmax计算,以避免同步。 (2)带有双缓冲的扁平GEMM优化。FlashDecoding++指出,具有不同形状的扁平GEMM面临不同的瓶颈。然后,引入了双缓冲等技术。 (3)具有硬件资源适应性的启发式数据流。FlashDecoding++通过考虑输入动态,启发式地优化数据流使用不同的硬件资源。由于FlashDecoding++中优化的多样性,与Hugging Face实现相比,FlashDecoding++在NVIDIA和AMD GPU上可以实现高达4.86倍和2.18倍的加速。FlashDecoding++还在主流LLM上实现了比最先进的LLM推断引擎平均加速1.37倍。
English
As the Large Language Model (LLM) becomes increasingly important in various
domains. However, the following challenges still remain unsolved in
accelerating LLM inference: (1) Synchronized partial softmax update. The
softmax operation requires a synchronized update operation among each partial
softmax result, leading to ~20% overheads for the attention computation in
LLMs. (2) Under-utilized computation of flat GEMM. The shape of matrices
performing GEMM in LLM inference is flat, leading to under-utilized computation
and >50% performance loss after padding zeros in previous designs. (3)
Performance loss due to static dataflow. Kernel performance in LLM depends on
varied input data features, hardware configurations, etc. A single and static
dataflow may lead to a 50.25% performance loss for GEMMs of different shapes in
LLM inference.
We present FlashDecoding++, a fast LLM inference engine supporting mainstream
LLMs and hardware back-ends. To tackle the above challenges, FlashDecoding++
creatively proposes: (1) Asynchronized softmax with unified max value.
FlashDecoding++ introduces a unified max value technique for different partial
softmax computations to avoid synchronization. (2) Flat GEMM optimization with
double buffering. FlashDecoding++ points out that flat GEMMs with different
shapes face varied bottlenecks. Then, techniques like double buffering are
introduced. (3) Heuristic dataflow with hardware resource adaptation.
FlashDecoding++ heuristically optimizes dataflow using different hardware
resource considering input dynamics. Due to the versatility of optimizations in
FlashDecoding++, FlashDecoding++ can achieve up to 4.86x and 2.18x speedup on
both NVIDIA and AMD GPUs compared to Hugging Face implementations.
FlashDecoding++ also achieves an average speedup of 1.37x compared to
state-of-the-art LLM inference engines on mainstream LLMs.