라마 속의 맘바: 하이브리드 모델 축소 및 가속화
The Mamba in the Llama: Distilling and Accelerating Hybrid Models
August 27, 2024
저자: Junxiong Wang, Daniele Paliotta, Avner May, Alexander M. Rush, Tri Dao
cs.AI
초록
선형 RNN 아키텍처인 Mamba와 같은 모델은 유리한 배포 특성을 가지면서 언어 모델링에서 Transformer 모델과 경쟁력을 갖을 수 있습니다. 대규모 Transformer 모델을 훈련하는 데 초점을 맞춘 상황에서, 사전 훈련된 모델을 배포용으로 변환하는 과제를 고려합니다. 우리는 학술용 GPU 자원을 활용하여 어텐션 레이어에서 선형 프로젝션 가중치를 재사용함으로써 대규모 Transformer를 선형 RNN으로 축소하는 것이 가능함을 증명합니다. 어텐션 레이어의 1/4을 통합한 결과의 하이브리드 모델은 채팅 벤치마크에서 원본 Transformer와 유사한 성능을 달성하며, 오픈 소스 하이브리드 Mamba 모델보다 우수한 성과를 보입니다. 이 모델은 수조 개의 토큰으로 훈련된 오픈 소스 하이브리드 Mamba 모델을 이기는 성과를 채팅 벤치마크와 일반 벤치마크에서 달성합니다. 더불어, Mamba 및 하이브리드 모델의 추론 속도를 가속화하는 하드웨어 인식 추론 알고리즘을 소개합니다. 전체적으로 한정된 계산 자원을 이용하여 원본 어텐션 레이어 중 많은 부분을 제거하고 그 결과 모델을 더 효율적으로 생성할 수 있는 방법을 보여줍니다. Llama3-8B-Instruct에서 축소된 최고 성능의 모델은 AlpacaEval 2에서 GPT-4를 상대로 29.61의 길이 제어 승률을 달성하며, MT-Bench에서는 7.35로, 최고의 명령어 조정된 선형 RNN 모델을 능가합니다.
English
Linear RNN architectures, like Mamba, can be competitive with Transformer
models in language modeling while having advantageous deployment
characteristics. Given the focus on training large-scale Transformer models, we
consider the challenge of converting these pretrained models for deployment. We
demonstrate that it is feasible to distill large Transformers into linear RNNs
by reusing the linear projection weights from attention layers with academic
GPU resources. The resulting hybrid model, which incorporates a quarter of the
attention layers, achieves performance comparable to the original Transformer
in chat benchmarks and outperforms open-source hybrid Mamba models trained from
scratch with trillions of tokens in both chat benchmarks and general
benchmarks. Moreover, we introduce a hardware-aware speculative decoding
algorithm that accelerates the inference speed of Mamba and hybrid models.
Overall we show how, with limited computation resources, we can remove many of
the original attention layers and generate from the resulting model more
efficiently. Our top-performing model, distilled from Llama3-8B-Instruct,
achieves a 29.61 length-controlled win rate on AlpacaEval 2 against GPT-4 and
7.35 on MT-Bench, surpassing the best instruction-tuned linear RNN model.Summary
AI-Generated Summary