ChatPaper.aiChatPaper

FlashDecoding++:在GPU上加速大型語言模型推論

FlashDecoding++: Faster Large Language Model Inference on GPUs

November 2, 2023
作者: Ke Hong, Guohao Dai, Jiaming Xu, Qiuli Mao, Xiuhong Li, Jun Liu, Kangdi Chen, Hanyu Dong, Yu Wang
cs.AI

摘要

隨著大型語言模型(LLM)在各個領域變得越來越重要。然而,在加速LLM推論方面仍存在以下挑戰尚未解決:(1)同步部分softmax更新。softmax操作需要在每個部分softmax結果之間進行同步更新操作,導致LLM中的注意力計算產生約20%的開銷。 (2)平坦GEMM計算的低效利用。在LLM推論中執行GEMM的矩陣形狀是平坦的,導致計算被低效利用,在先前設計中填充零後會導致超過50%的性能損失。 (3)由於靜態數據流而導致的性能損失。LLM中的核心性能取決於不同的輸入數據特徵、硬件配置等。單一且靜態的數據流可能導致LLM推論中不同形狀的GEMM產生50.25%的性能損失。 我們提出了FlashDecoding++,一個快速的LLM推論引擎,支持主流的LLM和硬件後端。為應對上述挑戰,FlashDecoding++創新地提出了:(1)具有統一最大值的非同步softmax。FlashDecoding++引入了一種統一的最大值技術,用於不同部分softmax計算,以避免同步。 (2)具有雙緩衝的平坦GEMM優化。FlashDecoding++指出,具有不同形狀的平坦GEMM面臨不同的瓶頸。然後,引入了雙緩衝等技術。 (3)具有硬件資源適應的啟發式數據流。FlashDecoding++通過考慮輸入動態,啟發式地優化數據流,使用不同的硬件資源。由於FlashDecoding++中優化的多功能性,FlashDecoding++可以實現與Hugging Face實現相比,NVIDIA和AMD GPU的速度提升分別高達4.86倍和2.18倍。FlashDecoding++還實現了與主流LLM上最先進的LLM推論引擎相比的平均加速比為1.37倍。
English
As the Large Language Model (LLM) becomes increasingly important in various domains. However, the following challenges still remain unsolved in accelerating LLM inference: (1) Synchronized partial softmax update. The softmax operation requires a synchronized update operation among each partial softmax result, leading to ~20% overheads for the attention computation in LLMs. (2) Under-utilized computation of flat GEMM. The shape of matrices performing GEMM in LLM inference is flat, leading to under-utilized computation and >50% performance loss after padding zeros in previous designs. (3) Performance loss due to static dataflow. Kernel performance in LLM depends on varied input data features, hardware configurations, etc. A single and static dataflow may lead to a 50.25% performance loss for GEMMs of different shapes in LLM inference. We present FlashDecoding++, a fast LLM inference engine supporting mainstream LLMs and hardware back-ends. To tackle the above challenges, FlashDecoding++ creatively proposes: (1) Asynchronized softmax with unified max value. FlashDecoding++ introduces a unified max value technique for different partial softmax computations to avoid synchronization. (2) Flat GEMM optimization with double buffering. FlashDecoding++ points out that flat GEMMs with different shapes face varied bottlenecks. Then, techniques like double buffering are introduced. (3) Heuristic dataflow with hardware resource adaptation. FlashDecoding++ heuristically optimizes dataflow using different hardware resource considering input dynamics. Due to the versatility of optimizations in FlashDecoding++, FlashDecoding++ can achieve up to 4.86x and 2.18x speedup on both NVIDIA and AMD GPUs compared to Hugging Face implementations. FlashDecoding++ also achieves an average speedup of 1.37x compared to state-of-the-art LLM inference engines on mainstream LLMs.
PDF373December 15, 2024