FlashDecoding++:在GPU上加速大型語言模型推論
FlashDecoding++: Faster Large Language Model Inference on GPUs
November 2, 2023
作者: Ke Hong, Guohao Dai, Jiaming Xu, Qiuli Mao, Xiuhong Li, Jun Liu, Kangdi Chen, Hanyu Dong, Yu Wang
cs.AI
摘要
隨著大型語言模型(LLM)在各個領域變得越來越重要。然而,在加速LLM推論方面仍存在以下挑戰尚未解決:(1)同步部分softmax更新。softmax操作需要在每個部分softmax結果之間進行同步更新操作,導致LLM中的注意力計算產生約20%的開銷。 (2)平坦GEMM計算的低效利用。在LLM推論中執行GEMM的矩陣形狀是平坦的,導致計算被低效利用,在先前設計中填充零後會導致超過50%的性能損失。 (3)由於靜態數據流而導致的性能損失。LLM中的核心性能取決於不同的輸入數據特徵、硬件配置等。單一且靜態的數據流可能導致LLM推論中不同形狀的GEMM產生50.25%的性能損失。
我們提出了FlashDecoding++,一個快速的LLM推論引擎,支持主流的LLM和硬件後端。為應對上述挑戰,FlashDecoding++創新地提出了:(1)具有統一最大值的非同步softmax。FlashDecoding++引入了一種統一的最大值技術,用於不同部分softmax計算,以避免同步。 (2)具有雙緩衝的平坦GEMM優化。FlashDecoding++指出,具有不同形狀的平坦GEMM面臨不同的瓶頸。然後,引入了雙緩衝等技術。 (3)具有硬件資源適應的啟發式數據流。FlashDecoding++通過考慮輸入動態,啟發式地優化數據流,使用不同的硬件資源。由於FlashDecoding++中優化的多功能性,FlashDecoding++可以實現與Hugging Face實現相比,NVIDIA和AMD GPU的速度提升分別高達4.86倍和2.18倍。FlashDecoding++還實現了與主流LLM上最先進的LLM推論引擎相比的平均加速比為1.37倍。
English
As the Large Language Model (LLM) becomes increasingly important in various
domains. However, the following challenges still remain unsolved in
accelerating LLM inference: (1) Synchronized partial softmax update. The
softmax operation requires a synchronized update operation among each partial
softmax result, leading to ~20% overheads for the attention computation in
LLMs. (2) Under-utilized computation of flat GEMM. The shape of matrices
performing GEMM in LLM inference is flat, leading to under-utilized computation
and >50% performance loss after padding zeros in previous designs. (3)
Performance loss due to static dataflow. Kernel performance in LLM depends on
varied input data features, hardware configurations, etc. A single and static
dataflow may lead to a 50.25% performance loss for GEMMs of different shapes in
LLM inference.
We present FlashDecoding++, a fast LLM inference engine supporting mainstream
LLMs and hardware back-ends. To tackle the above challenges, FlashDecoding++
creatively proposes: (1) Asynchronized softmax with unified max value.
FlashDecoding++ introduces a unified max value technique for different partial
softmax computations to avoid synchronization. (2) Flat GEMM optimization with
double buffering. FlashDecoding++ points out that flat GEMMs with different
shapes face varied bottlenecks. Then, techniques like double buffering are
introduced. (3) Heuristic dataflow with hardware resource adaptation.
FlashDecoding++ heuristically optimizes dataflow using different hardware
resource considering input dynamics. Due to the versatility of optimizations in
FlashDecoding++, FlashDecoding++ can achieve up to 4.86x and 2.18x speedup on
both NVIDIA and AMD GPUs compared to Hugging Face implementations.
FlashDecoding++ also achieves an average speedup of 1.37x compared to
state-of-the-art LLM inference engines on mainstream LLMs.