ChatPaper.aiChatPaper

FlashDecoding++: Snellere Inferentie van Grote Taalmodellen op GPU's

FlashDecoding++: Faster Large Language Model Inference on GPUs

November 2, 2023
Auteurs: Ke Hong, Guohao Dai, Jiaming Xu, Qiuli Mao, Xiuhong Li, Jun Liu, Kangdi Chen, Hanyu Dong, Yu Wang
cs.AI

Samenvatting

Naarmate het Large Language Model (LLM) steeds belangrijker wordt in verschillende domeinen, blijven de volgende uitdagingen onopgelost bij het versnellen van LLM-inferentie: (1) Gesynchroniseerde gedeeltelijke softmax-update. De softmax-operatie vereist een gesynchroniseerde update-operatie tussen elk gedeeltelijk softmax-resultaat, wat leidt tot ~20% overhead voor de aandachtberekening in LLM's. (2) Onderbenutte berekening van platte GEMM. De vorm van matrices die GEMM uitvoeren in LLM-inferentie is plat, wat leidt tot onderbenutte berekening en >50% prestatieverlies na het opvullen met nullen in eerdere ontwerpen. (3) Prestatieverlies door statische dataflow. De kernelprestatie in LLM hangt af van verschillende invoergegevenskenmerken, hardwareconfiguraties, enz. Een enkele en statische dataflow kan leiden tot een prestatieverlies van 50,25% voor GEMM's van verschillende vormen in LLM-inferentie. We presenteren FlashDecoding++, een snelle LLM-inferentie-engine die mainstream LLM's en hardware-backends ondersteunt. Om de bovenstaande uitdagingen aan te pakken, stelt FlashDecoding++ creatief voor: (1) Asynchrone softmax met geünificeerde maximale waarde. FlashDecoding++ introduceert een geünificeerde maximale waardetechniek voor verschillende gedeeltelijke softmax-berekeningen om synchronisatie te vermijden. (2) Optimalisatie van platte GEMM met dubbele buffering. FlashDecoding++ wijst erop dat platte GEMM's met verschillende vormen verschillende knelpunten hebben. Vervolgens worden technieken zoals dubbele buffering geïntroduceerd. (3) Heuristische dataflow met hardwarebronnenadaptatie. FlashDecoding++ optimaliseert heuristisch de dataflow met behulp van verschillende hardwarebronnen, rekening houdend met de dynamiek van de invoer. Door de veelzijdigheid van de optimalisaties in FlashDecoding++ kan FlashDecoding++ een versnelling tot 4,86x en 2,18x bereiken op zowel NVIDIA- als AMD-GPU's in vergelijking met Hugging Face-implementaties. FlashDecoding++ behaalt ook een gemiddelde versnelling van 1,37x in vergelijking met state-of-the-art LLM-inferentie-engines op mainstream LLM's.
English
As the Large Language Model (LLM) becomes increasingly important in various domains. However, the following challenges still remain unsolved in accelerating LLM inference: (1) Synchronized partial softmax update. The softmax operation requires a synchronized update operation among each partial softmax result, leading to ~20% overheads for the attention computation in LLMs. (2) Under-utilized computation of flat GEMM. The shape of matrices performing GEMM in LLM inference is flat, leading to under-utilized computation and >50% performance loss after padding zeros in previous designs. (3) Performance loss due to static dataflow. Kernel performance in LLM depends on varied input data features, hardware configurations, etc. A single and static dataflow may lead to a 50.25% performance loss for GEMMs of different shapes in LLM inference. We present FlashDecoding++, a fast LLM inference engine supporting mainstream LLMs and hardware back-ends. To tackle the above challenges, FlashDecoding++ creatively proposes: (1) Asynchronized softmax with unified max value. FlashDecoding++ introduces a unified max value technique for different partial softmax computations to avoid synchronization. (2) Flat GEMM optimization with double buffering. FlashDecoding++ points out that flat GEMMs with different shapes face varied bottlenecks. Then, techniques like double buffering are introduced. (3) Heuristic dataflow with hardware resource adaptation. FlashDecoding++ heuristically optimizes dataflow using different hardware resource considering input dynamics. Due to the versatility of optimizations in FlashDecoding++, FlashDecoding++ can achieve up to 4.86x and 2.18x speedup on both NVIDIA and AMD GPUs compared to Hugging Face implementations. FlashDecoding++ also achieves an average speedup of 1.37x compared to state-of-the-art LLM inference engines on mainstream LLMs.
PDF373February 7, 2026