FlashDecoding++: Inferenza più veloce per modelli linguistici di grandi dimensioni su GPU
FlashDecoding++: Faster Large Language Model Inference on GPUs
November 2, 2023
Autori: Ke Hong, Guohao Dai, Jiaming Xu, Qiuli Mao, Xiuhong Li, Jun Liu, Kangdi Chen, Hanyu Dong, Yu Wang
cs.AI
Abstract
Man mano che i Large Language Model (LLM) acquisiscono un'importanza crescente in vari domini, permangono tuttavia alcune sfide irrisolte nell'accelerazione dell'inferenza degli LLM: (1) Aggiornamento sincronizzato del softmax parziale. L'operazione di softmax richiede un aggiornamento sincronizzato tra ciascun risultato parziale del softmax, causando un sovraccarico di circa il 20% nel calcolo dell'attenzione negli LLM. (2) Sottoutilizzazione del calcolo nel GEMM piatto. La forma delle matrici che eseguono il GEMM nell'inferenza degli LLM è piatta, portando a un calcolo sottoutilizzato e a una perdita di prestazioni superiore al 50% dopo l'aggiunta di zeri nei progetti precedenti. (3) Perdita di prestazioni dovuta al flusso di dati statico. Le prestazioni del kernel negli LLM dipendono da varie caratteristiche dei dati di input, configurazioni hardware, ecc. Un flusso di dati singolo e statico può portare a una perdita di prestazioni del 50,25% per GEMM di forme diverse nell'inferenza degli LLM.
Presentiamo FlashDecoding++, un motore di inferenza veloce per LLM che supporta i principali LLM e backend hardware. Per affrontare le sfide sopra descritte, FlashDecoding++ propone in modo creativo: (1) Softmax asincrono con valore massimo unificato. FlashDecoding++ introduce una tecnica di valore massimo unificato per diversi calcoli parziali del softmax per evitare la sincronizzazione. (2) Ottimizzazione del GEMM piatto con doppio buffering. FlashDecoding++ evidenzia che i GEMM piatti con forme diverse affrontano colli di bottiglia variabili. Successivamente, vengono introdotte tecniche come il doppio buffering. (3) Flusso di dati euristico con adattamento alle risorse hardware. FlashDecoding++ ottimizza euristicamente il flusso di dati utilizzando diverse risorse hardware considerando la dinamicità degli input. Grazie alla versatilità delle ottimizzazioni in FlashDecoding++, è possibile ottenere un miglioramento delle prestazioni fino a 4,86x e 2,18x su GPU NVIDIA e AMD rispetto alle implementazioni di Hugging Face. FlashDecoding++ raggiunge inoltre un miglioramento medio delle prestazioni di 1,37x rispetto ai motori di inferenza LLM all'avanguardia sui principali LLM.
English
As the Large Language Model (LLM) becomes increasingly important in various
domains. However, the following challenges still remain unsolved in
accelerating LLM inference: (1) Synchronized partial softmax update. The
softmax operation requires a synchronized update operation among each partial
softmax result, leading to ~20% overheads for the attention computation in
LLMs. (2) Under-utilized computation of flat GEMM. The shape of matrices
performing GEMM in LLM inference is flat, leading to under-utilized computation
and >50% performance loss after padding zeros in previous designs. (3)
Performance loss due to static dataflow. Kernel performance in LLM depends on
varied input data features, hardware configurations, etc. A single and static
dataflow may lead to a 50.25% performance loss for GEMMs of different shapes in
LLM inference.
We present FlashDecoding++, a fast LLM inference engine supporting mainstream
LLMs and hardware back-ends. To tackle the above challenges, FlashDecoding++
creatively proposes: (1) Asynchronized softmax with unified max value.
FlashDecoding++ introduces a unified max value technique for different partial
softmax computations to avoid synchronization. (2) Flat GEMM optimization with
double buffering. FlashDecoding++ points out that flat GEMMs with different
shapes face varied bottlenecks. Then, techniques like double buffering are
introduced. (3) Heuristic dataflow with hardware resource adaptation.
FlashDecoding++ heuristically optimizes dataflow using different hardware
resource considering input dynamics. Due to the versatility of optimizations in
FlashDecoding++, FlashDecoding++ can achieve up to 4.86x and 2.18x speedup on
both NVIDIA and AMD GPUs compared to Hugging Face implementations.
FlashDecoding++ also achieves an average speedup of 1.37x compared to
state-of-the-art LLM inference engines on mainstream LLMs.