ChatPaper.aiChatPaper

FlashDecoding++: Inferencia más rápida de modelos de lenguaje grande en GPUs

FlashDecoding++: Faster Large Language Model Inference on GPUs

November 2, 2023
Autores: Ke Hong, Guohao Dai, Jiaming Xu, Qiuli Mao, Xiuhong Li, Jun Liu, Kangdi Chen, Hanyu Dong, Yu Wang
cs.AI

Resumen

A medida que los Modelos de Lenguaje de Gran Escala (LLM, por sus siglas en inglés) adquieren una importancia creciente en diversos dominios, persisten desafíos sin resolver en la aceleración de la inferencia de LLM: (1) Actualización sincronizada del softmax parcial. La operación de softmax requiere una actualización sincronizada entre cada resultado parcial de softmax, lo que genera un sobrecosto de ~20% en el cálculo de atención en los LLM. (2) Subutilización del cálculo en GEMM plano. La forma de las matrices que realizan GEMM en la inferencia de LLM es plana, lo que resulta en una subutilización del cálculo y una pérdida de rendimiento >50% tras rellenar con ceros en diseños previos. (3) Pérdida de rendimiento debido al flujo de datos estático. El rendimiento del kernel en LLM depende de características variadas de los datos de entrada, configuraciones de hardware, etc. Un flujo de datos único y estático puede generar una pérdida de rendimiento del 50.25% en GEMM de diferentes formas durante la inferencia de LLM. Presentamos FlashDecoding++, un motor de inferencia de LLM rápido que soporta modelos principales de LLM y back-ends de hardware. Para abordar los desafíos mencionados, FlashDecoding++ propone de manera creativa: (1) Softmax asincronizado con valor máximo unificado. FlashDecoding++ introduce una técnica de valor máximo unificado para diferentes cálculos parciales de softmax, evitando la sincronización. (2) Optimización de GEMM plano con doble buffer. FlashDecoding++ señala que los GEMM planos con diferentes formas enfrentan cuellos de botella variados. Luego, se introducen técnicas como el doble buffer. (3) Flujo de datos heurístico con adaptación a recursos de hardware. FlashDecoding++ optimiza heurísticamente el flujo de datos utilizando diferentes recursos de hardware, considerando la dinámica de la entrada. Gracias a la versatilidad de las optimizaciones en FlashDecoding++, este puede lograr una aceleración de hasta 4.86x y 2.18x en GPUs de NVIDIA y AMD, respectivamente, en comparación con las implementaciones de Hugging Face. Además, FlashDecoding++ alcanza una aceleración promedio de 1.37x frente a los motores de inferencia de LLM más avanzados en modelos principales de LLM.
English
As the Large Language Model (LLM) becomes increasingly important in various domains. However, the following challenges still remain unsolved in accelerating LLM inference: (1) Synchronized partial softmax update. The softmax operation requires a synchronized update operation among each partial softmax result, leading to ~20% overheads for the attention computation in LLMs. (2) Under-utilized computation of flat GEMM. The shape of matrices performing GEMM in LLM inference is flat, leading to under-utilized computation and >50% performance loss after padding zeros in previous designs. (3) Performance loss due to static dataflow. Kernel performance in LLM depends on varied input data features, hardware configurations, etc. A single and static dataflow may lead to a 50.25% performance loss for GEMMs of different shapes in LLM inference. We present FlashDecoding++, a fast LLM inference engine supporting mainstream LLMs and hardware back-ends. To tackle the above challenges, FlashDecoding++ creatively proposes: (1) Asynchronized softmax with unified max value. FlashDecoding++ introduces a unified max value technique for different partial softmax computations to avoid synchronization. (2) Flat GEMM optimization with double buffering. FlashDecoding++ points out that flat GEMMs with different shapes face varied bottlenecks. Then, techniques like double buffering are introduced. (3) Heuristic dataflow with hardware resource adaptation. FlashDecoding++ heuristically optimizes dataflow using different hardware resource considering input dynamics. Due to the versatility of optimizations in FlashDecoding++, FlashDecoding++ can achieve up to 4.86x and 2.18x speedup on both NVIDIA and AMD GPUs compared to Hugging Face implementations. FlashDecoding++ also achieves an average speedup of 1.37x compared to state-of-the-art LLM inference engines on mainstream LLMs.
PDF373December 15, 2024