ChatPaper.aiChatPaper

FlashDecoding++:GPUにおける大規模言語モデル推論の高速化

FlashDecoding++: Faster Large Language Model Inference on GPUs

November 2, 2023
著者: Ke Hong, Guohao Dai, Jiaming Xu, Qiuli Mao, Xiuhong Li, Jun Liu, Kangdi Chen, Hanyu Dong, Yu Wang
cs.AI

要旨

大規模言語モデル(LLM)がさまざまな分野で重要性を増す中、LLM推論の高速化において以下の課題が未解決のまま残されている。(1) 同期化された部分的なソフトマックス更新。ソフトマックス演算では、各部分的なソフトマックス結果間で同期化された更新操作が必要であり、これによりLLMの注意計算において約20%のオーバーヘッドが生じる。(2) 平坦なGEMM計算の低効率化。LLM推論におけるGEMM演算の行列形状は平坦であり、これにより計算効率が低下し、従来の設計ではゼロ埋めを行った後に50%以上の性能損失が発生する。(3) 静的なデータフローによる性能損失。LLMにおけるカーネル性能は、入力データの特徴やハードウェア構成などに依存する。単一で静的なデータフローでは、LLM推論における異なる形状のGEMMに対して最大50.25%の性能損失が生じる可能性がある。 本論文では、主流のLLMとハードウェアバックエンドをサポートする高速なLLM推論エンジンであるFlashDecoding++を提案する。上記の課題に対処するため、FlashDecoding++は以下の革新的な手法を提案する。(1) 統一された最大値を持つ非同期ソフトマックス。FlashDecoding++は、異なる部分的なソフトマックス計算に対して統一された最大値技術を導入し、同期化を回避する。(2) ダブルバッファリングを活用した平坦なGEMM最適化。FlashDecoding++は、異なる形状の平坦なGEMMがさまざまなボトルネックに直面することを指摘し、ダブルバッファリングなどの技術を導入する。(3) ハードウェアリソース適応型のヒューリスティックデータフロー。FlashDecoding++は、入力の動的特性を考慮し、異なるハードウェアリソースを使用してデータフローをヒューリスティックに最適化する。FlashDecoding++の最適化手法の汎用性により、NVIDIAおよびAMD GPUにおいて、Hugging Faceの実装と比較して最大4.86倍および2.18倍の高速化を実現する。さらに、主流のLLMにおいて、最先端のLLM推論エンジンと比較して平均1.37倍の高速化を達成する。
English
As the Large Language Model (LLM) becomes increasingly important in various domains. However, the following challenges still remain unsolved in accelerating LLM inference: (1) Synchronized partial softmax update. The softmax operation requires a synchronized update operation among each partial softmax result, leading to ~20% overheads for the attention computation in LLMs. (2) Under-utilized computation of flat GEMM. The shape of matrices performing GEMM in LLM inference is flat, leading to under-utilized computation and >50% performance loss after padding zeros in previous designs. (3) Performance loss due to static dataflow. Kernel performance in LLM depends on varied input data features, hardware configurations, etc. A single and static dataflow may lead to a 50.25% performance loss for GEMMs of different shapes in LLM inference. We present FlashDecoding++, a fast LLM inference engine supporting mainstream LLMs and hardware back-ends. To tackle the above challenges, FlashDecoding++ creatively proposes: (1) Asynchronized softmax with unified max value. FlashDecoding++ introduces a unified max value technique for different partial softmax computations to avoid synchronization. (2) Flat GEMM optimization with double buffering. FlashDecoding++ points out that flat GEMMs with different shapes face varied bottlenecks. Then, techniques like double buffering are introduced. (3) Heuristic dataflow with hardware resource adaptation. FlashDecoding++ heuristically optimizes dataflow using different hardware resource considering input dynamics. Due to the versatility of optimizations in FlashDecoding++, FlashDecoding++ can achieve up to 4.86x and 2.18x speedup on both NVIDIA and AMD GPUs compared to Hugging Face implementations. FlashDecoding++ also achieves an average speedup of 1.37x compared to state-of-the-art LLM inference engines on mainstream LLMs.
PDF373December 15, 2024