FlashDecoding++: Schnellere Inferenz großer Sprachmodelle auf GPUs
FlashDecoding++: Faster Large Language Model Inference on GPUs
November 2, 2023
Autoren: Ke Hong, Guohao Dai, Jiaming Xu, Qiuli Mao, Xiuhong Li, Jun Liu, Kangdi Chen, Hanyu Dong, Yu Wang
cs.AI
Zusammenfassung
Da Large Language Models (LLMs) in verschiedenen Bereichen zunehmend an Bedeutung gewinnen, bleiben dennoch folgende Herausforderungen bei der Beschleunigung der LLM-Inferenz ungelöst: (1) Synchronisierte partielle Softmax-Aktualisierung. Die Softmax-Operation erfordert eine synchronisierte Aktualisierung zwischen jedem partiellen Softmax-Ergebnis, was zu einem Overhead von ~20 % für die Aufmerksamkeitsberechnung in LLMs führt. (2) Unterausgelastete Berechnung von flachen GEMMs. Die Form der Matrizen, die GEMM in der LLM-Inferenz durchführen, ist flach, was zu einer unterausgelasteten Berechnung und einem Leistungsverlust von >50 % nach dem Auffüllen mit Nullen in früheren Designs führt. (3) Leistungsverlust durch statischen Datenfluss. Die Kernel-Leistung in LLMs hängt von verschiedenen Eingabedatenmerkmalen, Hardware-Konfigurationen usw. ab. Ein einzelner und statischer Datenfluss kann zu einem Leistungsverlust von 50,25 % für GEMMs unterschiedlicher Formen in der LLM-Inferenz führen.
Wir präsentieren FlashDecoding++, eine schnelle LLM-Inferenz-Engine, die Mainstream-LLMs und Hardware-Backends unterstützt. Um die oben genannten Herausforderungen zu bewältigen, schlägt FlashDecoding++ kreativ vor: (1) Asynchronisierte Softmax mit einheitlichem Maximalwert. FlashDecoding++ führt eine Technik des einheitlichen Maximalwerts für verschiedene partielle Softmax-Berechnungen ein, um Synchronisation zu vermeiden. (2) Optimierung von flachen GEMMs mit Double Buffering. FlashDecoding++ weist darauf hin, dass flache GEMMs unterschiedlicher Formen auf verschiedene Engpässe stoßen. Anschließend werden Techniken wie Double Buffering eingeführt. (3) Heuristischer Datenfluss mit Hardware-Ressourcenanpassung. FlashDecoding++ optimiert den Datenfluss heuristisch unter Berücksichtigung der Dynamik der Eingaben und der verschiedenen Hardware-Ressourcen. Aufgrund der Vielseitigkeit der Optimierungen in FlashDecoding++ kann FlashDecoding++ eine Beschleunigung von bis zu 4,86x und 2,18x auf NVIDIA- und AMD-GPUs im Vergleich zu Hugging-Face-Implementierungen erreichen. FlashDecoding++ erzielt auch eine durchschnittliche Beschleunigung von 1,37x im Vergleich zu state-of-the-art LLM-Inferenz-Engines auf Mainstream-LLMs.
English
As the Large Language Model (LLM) becomes increasingly important in various
domains. However, the following challenges still remain unsolved in
accelerating LLM inference: (1) Synchronized partial softmax update. The
softmax operation requires a synchronized update operation among each partial
softmax result, leading to ~20% overheads for the attention computation in
LLMs. (2) Under-utilized computation of flat GEMM. The shape of matrices
performing GEMM in LLM inference is flat, leading to under-utilized computation
and >50% performance loss after padding zeros in previous designs. (3)
Performance loss due to static dataflow. Kernel performance in LLM depends on
varied input data features, hardware configurations, etc. A single and static
dataflow may lead to a 50.25% performance loss for GEMMs of different shapes in
LLM inference.
We present FlashDecoding++, a fast LLM inference engine supporting mainstream
LLMs and hardware back-ends. To tackle the above challenges, FlashDecoding++
creatively proposes: (1) Asynchronized softmax with unified max value.
FlashDecoding++ introduces a unified max value technique for different partial
softmax computations to avoid synchronization. (2) Flat GEMM optimization with
double buffering. FlashDecoding++ points out that flat GEMMs with different
shapes face varied bottlenecks. Then, techniques like double buffering are
introduced. (3) Heuristic dataflow with hardware resource adaptation.
FlashDecoding++ heuristically optimizes dataflow using different hardware
resource considering input dynamics. Due to the versatility of optimizations in
FlashDecoding++, FlashDecoding++ can achieve up to 4.86x and 2.18x speedup on
both NVIDIA and AMD GPUs compared to Hugging Face implementations.
FlashDecoding++ also achieves an average speedup of 1.37x compared to
state-of-the-art LLM inference engines on mainstream LLMs.