Мамба в ламе: дистилляция и ускорение гибридных моделей
The Mamba in the Llama: Distilling and Accelerating Hybrid Models
August 27, 2024
Авторы: Junxiong Wang, Daniele Paliotta, Avner May, Alexander M. Rush, Tri Dao
cs.AI
Аннотация
Линейные архитектуры RNN, такие как Mamba, могут быть конкурентоспособны с моделями Transformer в языковом моделировании, обладая преимуществами в развертывании. Учитывая акцент на обучении масштабных моделей Transformer, мы рассматриваем задачу преобразования этих предварительно обученных моделей для развертывания. Мы демонстрируем, что возможно дистиллировать большие Transformers в линейные RNN путем повторного использования весов линейной проекции из слоев внимания с помощью академических ресурсов GPU. Полученная гибридная модель, которая включает четверть слоев внимания, достигает производительности, сравнимой с оригинальным Transformer в чат-бенчмарках и превосходит гибридные модели Mamba с открытым исходным кодом, обученные с нуля с триллионами токенов как в чат-бенчмарках, так и в общих бенчмарках. Более того, мы представляем алгоритм спекулятивного декодирования, учитывающий аппаратные средства, который ускоряет скорость вывода Mamba и гибридных моделей. В целом мы показываем, как с ограниченными вычислительными ресурсами мы можем удалить многие из оригинальных слоев внимания и более эффективно генерировать из полученной модели. Наша наиболее эффективная модель, дистиллированная из Llama3-8B-Instruct, достигает победного показателя 29.61 в контролируемой длине на AlpacaEval 2 против GPT-4 и 7.35 на MT-Bench, превосходя лучшую инструкционно настроенную линейную RNN модель.
English
Linear RNN architectures, like Mamba, can be competitive with Transformer
models in language modeling while having advantageous deployment
characteristics. Given the focus on training large-scale Transformer models, we
consider the challenge of converting these pretrained models for deployment. We
demonstrate that it is feasible to distill large Transformers into linear RNNs
by reusing the linear projection weights from attention layers with academic
GPU resources. The resulting hybrid model, which incorporates a quarter of the
attention layers, achieves performance comparable to the original Transformer
in chat benchmarks and outperforms open-source hybrid Mamba models trained from
scratch with trillions of tokens in both chat benchmarks and general
benchmarks. Moreover, we introduce a hardware-aware speculative decoding
algorithm that accelerates the inference speed of Mamba and hybrid models.
Overall we show how, with limited computation resources, we can remove many of
the original attention layers and generate from the resulting model more
efficiently. Our top-performing model, distilled from Llama3-8B-Instruct,
achieves a 29.61 length-controlled win rate on AlpacaEval 2 against GPT-4 and
7.35 on MT-Bench, surpassing the best instruction-tuned linear RNN model.Summary
AI-Generated Summary