Modelos de linguagem com atenção linear simples equilibram a relação entre recall e taxa de processamento.
Simple linear attention language models balance the recall-throughput tradeoff
February 28, 2024
Autores: Simran Arora, Sabri Eyuboglu, Michael Zhang, Aman Timalsina, Silas Alberti, Dylan Zinsley, James Zou, Atri Rudra, Christopher Ré
cs.AI
Resumo
Trabalhos recentes demonstraram que modelos de linguagem baseados em atenção se destacam na capacidade de recall, ou seja, na habilidade de fundamentar gerações em tokens previamente vistos no contexto. No entanto, a eficiência desses modelos baseados em atenção é limitada durante a inferência pelo consumo agressivo de memória do cache KV. Neste trabalho, exploramos se é possível melhorar a eficiência dos modelos de linguagem (por exemplo, reduzindo o consumo de memória) sem comprometer o recall. Aplicando experimentos e teoria a um amplo conjunto de arquiteturas, identificamos uma troca fundamental entre o tamanho do estado do modelo e sua capacidade de recall. Mostramos que alternativas eficientes à atenção (por exemplo, H3, Mamba, RWKV) mantêm um estado recorrente de tamanho fixo, mas têm dificuldades com o recall. Propomos o BASED, uma arquitetura simples que combina atenção linear e atenção por janela deslizante. Variando o tamanho da janela do BASED e a dimensão das características da atenção linear, podemos ajustar o tamanho do estado e percorrer a fronteira de Pareto da curva de troca entre recall e memória, recuperando a qualidade total da atenção em um extremo e o pequeno tamanho de estado das alternativas à atenção no outro. Treinamos modelos de linguagem com até 1,3 bilhão de parâmetros e mostramos que o BASED iguala os modelos subquadráticos mais fortes (por exemplo, Mamba) em perplexidade e os supera em tarefas do mundo real intensivas em recall por 6,22 pontos de precisão. Implementações de atenção linear costumam ser menos eficientes do que implementações otimizadas de atenção padrão. Para tornar o BASED competitivo, desenvolvemos algoritmos conscientes de E/S que permitem um throughput 24 vezes maior na geração de linguagem do que o FlashAttention-2, ao gerar 1024 tokens usando modelos de 1,3 bilhão de parâmetros. O código deste trabalho está disponível em: https://github.com/HazyResearch/based.
English
Recent work has shown that attention-based language models excel at recall,
the ability to ground generations in tokens previously seen in context.
However, the efficiency of attention-based models is bottle-necked during
inference by the KV-cache's aggressive memory consumption. In this work, we
explore whether we can improve language model efficiency (e.g. by reducing
memory consumption) without compromising on recall. By applying experiments and
theory to a broad set of architectures, we identify a key tradeoff between a
model's state size and recall ability. We show that efficient alternatives to
attention (e.g. H3, Mamba, RWKV) maintain a fixed-size recurrent state, but
struggle at recall. We propose BASED a simple architecture combining linear and
sliding window attention. By varying BASED window size and linear attention
feature dimension, we can dial the state size and traverse the pareto frontier
of the recall-memory tradeoff curve, recovering the full quality of attention
on one end and the small state size of attention-alternatives on the other. We
train language models up to 1.3b parameters and show that BASED matches the
strongest sub-quadratic models (e.g. Mamba) in perplexity and outperforms them
on real-world recall-intensive tasks by 6.22 accuracy points. Implementations
of linear attention are often less efficient than optimized standard attention
implementations. To make BASED competitive, we develop IO-aware algorithms that
enable 24x higher throughput on language generation than FlashAttention-2, when
generating 1024 tokens using 1.3b parameter models. Code for this work is
provided at: https://github.com/HazyResearch/based.