FlashDecoding++: Snellere Inferentie van Grote Taalmodellen op GPU's
FlashDecoding++: Faster Large Language Model Inference on GPUs
November 2, 2023
Auteurs: Ke Hong, Guohao Dai, Jiaming Xu, Qiuli Mao, Xiuhong Li, Jun Liu, Kangdi Chen, Hanyu Dong, Yu Wang
cs.AI
Samenvatting
Naarmate het Large Language Model (LLM) steeds belangrijker wordt in verschillende domeinen, blijven de volgende uitdagingen onopgelost bij het versnellen van LLM-inferentie: (1) Gesynchroniseerde gedeeltelijke softmax-update. De softmax-operatie vereist een gesynchroniseerde update-operatie tussen elk gedeeltelijk softmax-resultaat, wat leidt tot ~20% overhead voor de aandachtberekening in LLM's. (2) Onderbenutte berekening van platte GEMM. De vorm van matrices die GEMM uitvoeren in LLM-inferentie is plat, wat leidt tot onderbenutte berekening en >50% prestatieverlies na het opvullen met nullen in eerdere ontwerpen. (3) Prestatieverlies door statische dataflow. De kernelprestatie in LLM hangt af van verschillende invoergegevenskenmerken, hardwareconfiguraties, enz. Een enkele en statische dataflow kan leiden tot een prestatieverlies van 50,25% voor GEMM's van verschillende vormen in LLM-inferentie.
We presenteren FlashDecoding++, een snelle LLM-inferentie-engine die mainstream LLM's en hardware-backends ondersteunt. Om de bovenstaande uitdagingen aan te pakken, stelt FlashDecoding++ creatief voor: (1) Asynchrone softmax met geünificeerde maximale waarde. FlashDecoding++ introduceert een geünificeerde maximale waardetechniek voor verschillende gedeeltelijke softmax-berekeningen om synchronisatie te vermijden. (2) Optimalisatie van platte GEMM met dubbele buffering. FlashDecoding++ wijst erop dat platte GEMM's met verschillende vormen verschillende knelpunten hebben. Vervolgens worden technieken zoals dubbele buffering geïntroduceerd. (3) Heuristische dataflow met hardwarebronnenadaptatie. FlashDecoding++ optimaliseert heuristisch de dataflow met behulp van verschillende hardwarebronnen, rekening houdend met de dynamiek van de invoer. Door de veelzijdigheid van de optimalisaties in FlashDecoding++ kan FlashDecoding++ een versnelling tot 4,86x en 2,18x bereiken op zowel NVIDIA- als AMD-GPU's in vergelijking met Hugging Face-implementaties. FlashDecoding++ behaalt ook een gemiddelde versnelling van 1,37x in vergelijking met state-of-the-art LLM-inferentie-engines op mainstream LLM's.
English
As the Large Language Model (LLM) becomes increasingly important in various
domains. However, the following challenges still remain unsolved in
accelerating LLM inference: (1) Synchronized partial softmax update. The
softmax operation requires a synchronized update operation among each partial
softmax result, leading to ~20% overheads for the attention computation in
LLMs. (2) Under-utilized computation of flat GEMM. The shape of matrices
performing GEMM in LLM inference is flat, leading to under-utilized computation
and >50% performance loss after padding zeros in previous designs. (3)
Performance loss due to static dataflow. Kernel performance in LLM depends on
varied input data features, hardware configurations, etc. A single and static
dataflow may lead to a 50.25% performance loss for GEMMs of different shapes in
LLM inference.
We present FlashDecoding++, a fast LLM inference engine supporting mainstream
LLMs and hardware back-ends. To tackle the above challenges, FlashDecoding++
creatively proposes: (1) Asynchronized softmax with unified max value.
FlashDecoding++ introduces a unified max value technique for different partial
softmax computations to avoid synchronization. (2) Flat GEMM optimization with
double buffering. FlashDecoding++ points out that flat GEMMs with different
shapes face varied bottlenecks. Then, techniques like double buffering are
introduced. (3) Heuristic dataflow with hardware resource adaptation.
FlashDecoding++ heuristically optimizes dataflow using different hardware
resource considering input dynamics. Due to the versatility of optimizations in
FlashDecoding++, FlashDecoding++ can achieve up to 4.86x and 2.18x speedup on
both NVIDIA and AMD GPUs compared to Hugging Face implementations.
FlashDecoding++ also achieves an average speedup of 1.37x compared to
state-of-the-art LLM inference engines on mainstream LLMs.