ChatPaper.aiChatPaper

FlashDecoding++: GPU에서 더 빠른 대형 언어 모델 추론

FlashDecoding++: Faster Large Language Model Inference on GPUs

November 2, 2023
저자: Ke Hong, Guohao Dai, Jiaming Xu, Qiuli Mao, Xiuhong Li, Jun Liu, Kangdi Chen, Hanyu Dong, Yu Wang
cs.AI

초록

대규모 언어 모델(LLM)이 다양한 분야에서 점점 더 중요해지고 있습니다. 그러나 LLM 추론을 가속화하는 데 있어 다음과 같은 과제들이 여전히 해결되지 않고 있습니다: (1) 동기화된 부분 소프트맥스 업데이트. 소프트맥스 연산은 각 부분 소프트맥스 결과 간의 동기화된 업데이트 연산을 필요로 하여, LLM의 어텐션 계산에서 약 20%의 오버헤드를 발생시킵니다. (2) 플랫 GEMM의 낮은 계산 활용도. LLM 추론에서 수행되는 GEMM의 행렬 형태가 플랫하여, 이전 설계에서 제로 패딩 후 50% 이상의 성능 손실이 발생합니다. (3) 정적 데이터플로우로 인한 성능 손실. LLM의 커널 성능은 다양한 입력 데이터 특성과 하드웨어 구성 등에 따라 달라집니다. 단일하고 정적인 데이터플로우는 LLM 추론에서 다양한 형태의 GEMM에 대해 최대 50.25%의 성능 손실을 초래할 수 있습니다. 이러한 문제를 해결하기 위해, 우리는 주류 LLM과 하드웨어 백엔드를 지원하는 빠른 LLM 추론 엔진인 FlashDecoding++을 제안합니다. FlashDecoding++은 다음과 같은 혁신적인 기법을 제안합니다: (1) 통합 최대값을 사용한 비동기식 소프트맥스. FlashDecoding++은 서로 다른 부분 소프트맥스 계산 간의 동기화를 피하기 위해 통합 최대값 기법을 도입합니다. (2) 더블 버퍼링을 활용한 플랫 GEMM 최적화. FlashDecoding++은 다양한 형태의 플랫 GEMM이 서로 다른 병목 현상을 겪는다는 점을 지적하고, 더블 버퍼링과 같은 기법을 도입합니다. (3) 하드웨어 자원 적응형 휴리스틱 데이터플로우. FlashDecoding++은 입력의 동적 특성을 고려하여 다양한 하드웨어 자원을 활용해 데이터플로우를 휴리스틱하게 최적화합니다. FlashDecoding++의 다재다능한 최적화 덕분에, Hugging Face 구현 대비 NVIDIA와 AMD GPU에서 각각 최대 4.86배와 2.18배의 속도 향상을 달성할 수 있습니다. 또한, FlashDecoding++은 주류 LLM에서 최신 LLM 추론 엔진 대비 평균 1.37배의 속도 향상을 보여줍니다.
English
As the Large Language Model (LLM) becomes increasingly important in various domains. However, the following challenges still remain unsolved in accelerating LLM inference: (1) Synchronized partial softmax update. The softmax operation requires a synchronized update operation among each partial softmax result, leading to ~20% overheads for the attention computation in LLMs. (2) Under-utilized computation of flat GEMM. The shape of matrices performing GEMM in LLM inference is flat, leading to under-utilized computation and >50% performance loss after padding zeros in previous designs. (3) Performance loss due to static dataflow. Kernel performance in LLM depends on varied input data features, hardware configurations, etc. A single and static dataflow may lead to a 50.25% performance loss for GEMMs of different shapes in LLM inference. We present FlashDecoding++, a fast LLM inference engine supporting mainstream LLMs and hardware back-ends. To tackle the above challenges, FlashDecoding++ creatively proposes: (1) Asynchronized softmax with unified max value. FlashDecoding++ introduces a unified max value technique for different partial softmax computations to avoid synchronization. (2) Flat GEMM optimization with double buffering. FlashDecoding++ points out that flat GEMMs with different shapes face varied bottlenecks. Then, techniques like double buffering are introduced. (3) Heuristic dataflow with hardware resource adaptation. FlashDecoding++ heuristically optimizes dataflow using different hardware resource considering input dynamics. Due to the versatility of optimizations in FlashDecoding++, FlashDecoding++ can achieve up to 4.86x and 2.18x speedup on both NVIDIA and AMD GPUs compared to Hugging Face implementations. FlashDecoding++ also achieves an average speedup of 1.37x compared to state-of-the-art LLM inference engines on mainstream LLMs.
PDF373December 15, 2024