FlashDecoding++ : Accélération de l'inférence des grands modèles de langage sur les GPU
FlashDecoding++: Faster Large Language Model Inference on GPUs
November 2, 2023
Auteurs: Ke Hong, Guohao Dai, Jiaming Xu, Qiuli Mao, Xiuhong Li, Jun Liu, Kangdi Chen, Hanyu Dong, Yu Wang
cs.AI
Résumé
Alors que les modèles de langage de grande taille (LLM) prennent une importance croissante dans divers domaines, plusieurs défis restent non résolus pour accélérer l'inférence des LLM : (1) Mise à jour synchronisée du softmax partiel. L'opération de softmax nécessite une mise à jour synchronisée entre chaque résultat partiel de softmax, entraînant une surcharge d'environ 20 % pour le calcul de l'attention dans les LLM. (2) Sous-utilisation du calcul de GEMM plat. La forme des matrices utilisées pour le GEMM dans l'inférence des LLM est plate, ce qui entraîne une sous-utilisation des calculs et une perte de performance de plus de 50 % après le remplissage par des zéros dans les conceptions précédentes. (3) Perte de performance due au flux de données statique. La performance des noyaux dans les LLM dépend de diverses caractéristiques des données d'entrée, des configurations matérielles, etc. Un flux de données unique et statique peut entraîner une perte de performance de 50,25 % pour les GEMM de formes différentes dans l'inférence des LLM.
Nous présentons FlashDecoding++, un moteur d'inférence rapide pour les LLM prenant en charge les LLM grand public et les architectures matérielles. Pour relever ces défis, FlashDecoding++ propose de manière créative : (1) Softmax asynchronisé avec une valeur maximale unifiée. FlashDecoding++ introduit une technique de valeur maximale unifiée pour les différents calculs partiels de softmax afin d'éviter la synchronisation. (2) Optimisation du GEMM plat avec double tamponnage. FlashDecoding++ souligne que les GEMM plats de formes différentes rencontrent des goulots d'étranglement variés. Des techniques comme le double tamponnage sont alors introduites. (3) Flux de données heuristique avec adaptation aux ressources matérielles. FlashDecoding++ optimise heuristiquement le flux de données en utilisant différentes ressources matérielles en tenant compte de la dynamique des entrées. Grâce à la polyvalence des optimisations de FlashDecoding++, ce dernier peut atteindre des accélérations allant jusqu'à 4,86x et 2,18x sur les GPU NVIDIA et AMD par rapport aux implémentations de Hugging Face. FlashDecoding++ obtient également une accélération moyenne de 1,37x par rapport aux moteurs d'inférence LLM de pointe sur les LLM grand public.
English
As the Large Language Model (LLM) becomes increasingly important in various
domains. However, the following challenges still remain unsolved in
accelerating LLM inference: (1) Synchronized partial softmax update. The
softmax operation requires a synchronized update operation among each partial
softmax result, leading to ~20% overheads for the attention computation in
LLMs. (2) Under-utilized computation of flat GEMM. The shape of matrices
performing GEMM in LLM inference is flat, leading to under-utilized computation
and >50% performance loss after padding zeros in previous designs. (3)
Performance loss due to static dataflow. Kernel performance in LLM depends on
varied input data features, hardware configurations, etc. A single and static
dataflow may lead to a 50.25% performance loss for GEMMs of different shapes in
LLM inference.
We present FlashDecoding++, a fast LLM inference engine supporting mainstream
LLMs and hardware back-ends. To tackle the above challenges, FlashDecoding++
creatively proposes: (1) Asynchronized softmax with unified max value.
FlashDecoding++ introduces a unified max value technique for different partial
softmax computations to avoid synchronization. (2) Flat GEMM optimization with
double buffering. FlashDecoding++ points out that flat GEMMs with different
shapes face varied bottlenecks. Then, techniques like double buffering are
introduced. (3) Heuristic dataflow with hardware resource adaptation.
FlashDecoding++ heuristically optimizes dataflow using different hardware
resource considering input dynamics. Due to the versatility of optimizations in
FlashDecoding++, FlashDecoding++ can achieve up to 4.86x and 2.18x speedup on
both NVIDIA and AMD GPUs compared to Hugging Face implementations.
FlashDecoding++ also achieves an average speedup of 1.37x compared to
state-of-the-art LLM inference engines on mainstream LLMs.