通过跨层注意力减小Transformer键值缓存大小
Reducing Transformer Key-Value Cache Size with Cross-Layer Attention
May 21, 2024
作者: William Brandon, Mayank Mishra, Aniruddha Nrusimha, Rameswar Panda, Jonathan Ragan Kelly
cs.AI
摘要
键-值(KV)缓存在加速基于变压器的自回归大型语言模型(LLMs)解码中起着至关重要的作用。然而,存储KV缓存所需的内存量在长序列长度和大批量大小下可能变得难以承受。自变压器的发明以来,用于减小KV缓存大小的两种最有效的干预措施是多查询注意力(MQA)及其泛化形式,分组查询注意力(GQA)。MQA和GQA都修改了注意力块的设计,使多个查询头可以共享单个键/值头,大幅减少不同键/值头的数量,同时只略微降低准确性。在本文中,我们展示了可以通过在相邻层之间共享键和值头进一步发展多查询注意力,从而产生一种我们称之为跨层注意力(CLA)的新型注意力设计。通过CLA,我们发现可以将KV缓存的大小再减少2倍,同时保持几乎与未修改的MQA相同的准确性。在从头开始训练10亿和30亿参数模型的实验中,我们证明了CLA相对于传统MQA可能的内存/准确性权衡提供了帕累托改进,使推断可以使用比以往更长的序列长度和更大的批量大小。
English
Key-value (KV) caching plays an essential role in accelerating decoding for
transformer-based autoregressive large language models (LLMs). However, the
amount of memory required to store the KV cache can become prohibitive at long
sequence lengths and large batch sizes. Since the invention of the transformer,
two of the most effective interventions discovered for reducing the size of the
KV cache have been Multi-Query Attention (MQA) and its generalization,
Grouped-Query Attention (GQA). MQA and GQA both modify the design of the
attention block so that multiple query heads can share a single key/value head,
reducing the number of distinct key/value heads by a large factor while only
minimally degrading accuracy. In this paper, we show that it is possible to
take Multi-Query Attention a step further by also sharing key and value heads
between adjacent layers, yielding a new attention design we call Cross-Layer
Attention (CLA). With CLA, we find that it is possible to reduce the size of
the KV cache by another 2x while maintaining nearly the same accuracy as
unmodified MQA. In experiments training 1B- and 3B-parameter models from
scratch, we demonstrate that CLA provides a Pareto improvement over the
memory/accuracy tradeoffs which are possible with traditional MQA, enabling
inference with longer sequence lengths and larger batch sizes than would
otherwise be possible