Eenvoudige lineaire aandachtstaalmodellen balanceren de afweging tussen recall en doorvoersnelheid.
Simple linear attention language models balance the recall-throughput tradeoff
February 28, 2024
Auteurs: Simran Arora, Sabri Eyuboglu, Michael Zhang, Aman Timalsina, Silas Alberti, Dylan Zinsley, James Zou, Atri Rudra, Christopher Ré
cs.AI
Samenvatting
Recent onderzoek heeft aangetoond dat attention-gebaseerde taalmodellen uitblinken in recall, het vermogen om generaties te verankeren in tokens die eerder in de context zijn gezien. De efficiëntie van attention-gebaseerde modellen wordt echter beperkt tijdens inferentie door het agressieve geheugengebruik van de KV-cache. In dit werk onderzoeken we of we de efficiëntie van taalmodellen kunnen verbeteren (bijvoorbeeld door het geheugengebruik te verminderen) zonder in te leveren op recall. Door experimenten en theorie toe te passen op een breed scala aan architecturen, identificeren we een belangrijke afweging tussen de grootte van de modelstatus en het recall-vermogen. We laten zien dat efficiënte alternatieven voor attention (bijvoorbeeld H3, Mamba, RWKV) een vaste grootte van de recurrente status behouden, maar moeite hebben met recall. We stellen BASED voor, een eenvoudige architectuur die lineaire en sliding window attention combineert. Door de venstergrootte van BASED en de dimensie van de lineaire attention-feature te variëren, kunnen we de grootte van de modelstatus aanpassen en de pareto-grens van de recall-geheugen afwegingcurve doorlopen, waarbij we aan de ene kant de volledige kwaliteit van attention herstellen en aan de andere kant de kleine modelstatus van attention-alternatieven behouden. We trainen taalmodellen tot 1,3 miljard parameters en laten zien dat BASED de sterkste sub-kwadratische modellen (bijvoorbeeld Mamba) evenaart in perplexiteit en ze overtreft op real-world recall-intensieve taken met 6,22 nauwkeurigheidspunten. Implementaties van lineaire attention zijn vaak minder efficiënt dan geoptimaliseerde standaard attention-implementaties. Om BASED concurrerend te maken, ontwikkelen we IO-bewuste algoritmen die een 24x hogere doorvoer mogelijk maken bij taalgeneratie dan FlashAttention-2, bij het genereren van 1024 tokens met 1,3 miljard parameter modellen. Code voor dit werk is beschikbaar op: https://github.com/HazyResearch/based.
English
Recent work has shown that attention-based language models excel at recall,
the ability to ground generations in tokens previously seen in context.
However, the efficiency of attention-based models is bottle-necked during
inference by the KV-cache's aggressive memory consumption. In this work, we
explore whether we can improve language model efficiency (e.g. by reducing
memory consumption) without compromising on recall. By applying experiments and
theory to a broad set of architectures, we identify a key tradeoff between a
model's state size and recall ability. We show that efficient alternatives to
attention (e.g. H3, Mamba, RWKV) maintain a fixed-size recurrent state, but
struggle at recall. We propose BASED a simple architecture combining linear and
sliding window attention. By varying BASED window size and linear attention
feature dimension, we can dial the state size and traverse the pareto frontier
of the recall-memory tradeoff curve, recovering the full quality of attention
on one end and the small state size of attention-alternatives on the other. We
train language models up to 1.3b parameters and show that BASED matches the
strongest sub-quadratic models (e.g. Mamba) in perplexity and outperforms them
on real-world recall-intensive tasks by 6.22 accuracy points. Implementations
of linear attention are often less efficient than optimized standard attention
implementations. To make BASED competitive, we develop IO-aware algorithms that
enable 24x higher throughput on language generation than FlashAttention-2, when
generating 1024 tokens using 1.3b parameter models. Code for this work is
provided at: https://github.com/HazyResearch/based.