Het verkleinen van de Transformer Key-Value Cache-grootte met Cross-Layer Attention
Reducing Transformer Key-Value Cache Size with Cross-Layer Attention
May 21, 2024
Auteurs: William Brandon, Mayank Mishra, Aniruddha Nrusimha, Rameswar Panda, Jonathan Ragan Kelly
cs.AI
Samenvatting
Key-value (KV) caching speelt een essentiële rol bij het versnellen van het decoderen voor transformer-gebaseerde autoregressieve grote taalmodellen (LLMs). Echter, de hoeveelheid geheugen die nodig is om de KV-cache op te slaan, kan onhoudbaar worden bij lange sequentielengtes en grote batchgroottes. Sinds de uitvinding van de transformer zijn Multi-Query Attention (MQA) en de generalisatie daarvan, Grouped-Query Attention (GQA), twee van de meest effectieve interventies ontdekt om de grootte van de KV-cache te verminderen. MQA en GQA passen beide het ontwerp van het attention-blok aan zodat meerdere query-heads een enkele key/value-head kunnen delen, waardoor het aantal afzonderlijke key/value-heads aanzienlijk wordt verminderd terwijl de nauwkeurigheid slechts minimaal afneemt. In dit artikel laten we zien dat het mogelijk is om Multi-Query Attention een stap verder te brengen door ook key- en value-heads tussen aangrenzende lagen te delen, wat resulteert in een nieuw attention-ontwerp dat we Cross-Layer Attention (CLA) noemen. Met CLA ontdekken we dat het mogelijk is om de grootte van de KV-cache nog eens te halveren terwijl de nauwkeurigheid bijna hetzelfde blijft als bij ongewijzigde MQA. In experimenten waarbij we 1B- en 3B-parameter modellen vanaf nul trainen, demonstreren we dat CLA een Pareto-verbetering biedt ten opzichte van de geheugen/nauwkeurigheid-afwegingen die mogelijk zijn met traditionele MQA, waardoor inferentie met langere sequentielengtes en grotere batchgroottes mogelijk wordt dan anders het geval zou zijn.
English
Key-value (KV) caching plays an essential role in accelerating decoding for
transformer-based autoregressive large language models (LLMs). However, the
amount of memory required to store the KV cache can become prohibitive at long
sequence lengths and large batch sizes. Since the invention of the transformer,
two of the most effective interventions discovered for reducing the size of the
KV cache have been Multi-Query Attention (MQA) and its generalization,
Grouped-Query Attention (GQA). MQA and GQA both modify the design of the
attention block so that multiple query heads can share a single key/value head,
reducing the number of distinct key/value heads by a large factor while only
minimally degrading accuracy. In this paper, we show that it is possible to
take Multi-Query Attention a step further by also sharing key and value heads
between adjacent layers, yielding a new attention design we call Cross-Layer
Attention (CLA). With CLA, we find that it is possible to reduce the size of
the KV cache by another 2x while maintaining nearly the same accuracy as
unmodified MQA. In experiments training 1B- and 3B-parameter models from
scratch, we demonstrate that CLA provides a Pareto improvement over the
memory/accuracy tradeoffs which are possible with traditional MQA, enabling
inference with longer sequence lengths and larger batch sizes than would
otherwise be possible