Los modelos de lenguaje con atención lineal simple equilibran el compromiso entre recuperación y rendimiento.
Simple linear attention language models balance the recall-throughput tradeoff
February 28, 2024
Autores: Simran Arora, Sabri Eyuboglu, Michael Zhang, Aman Timalsina, Silas Alberti, Dylan Zinsley, James Zou, Atri Rudra, Christopher Ré
cs.AI
Resumen
Trabajos recientes han demostrado que los modelos de lenguaje basados en atención sobresalen en la capacidad de recuperación, es decir, la habilidad de fundamentar las generaciones en tokens previamente vistos en el contexto. Sin embargo, la eficiencia de los modelos basados en atención se ve limitada durante la inferencia por el alto consumo de memoria del KV-cache. En este trabajo, exploramos si es posible mejorar la eficiencia de los modelos de lenguaje (por ejemplo, reduciendo el consumo de memoria) sin comprometer la capacidad de recuperación. Aplicando experimentos y teoría a un amplio conjunto de arquitecturas, identificamos un equilibrio clave entre el tamaño del estado del modelo y su capacidad de recuperación. Mostramos que alternativas eficientes a la atención (por ejemplo, H3, Mamba, RWKV) mantienen un estado recurrente de tamaño fijo, pero tienen dificultades en la recuperación. Proponemos BASED, una arquitectura simple que combina atención lineal y atención de ventana deslizante. Al variar el tamaño de la ventana de BASED y la dimensión de las características de la atención lineal, podemos ajustar el tamaño del estado y recorrer la frontera de Pareto de la curva de equilibrio entre recuperación y memoria, recuperando la calidad completa de la atención en un extremo y el pequeño tamaño de estado de las alternativas a la atención en el otro. Entrenamos modelos de lenguaje de hasta 1.3 mil millones de parámetros y demostramos que BASED iguala a los modelos subcuadráticos más fuertes (por ejemplo, Mamba) en perplejidad y los supera en tareas del mundo real intensivas en recuperación por 6.22 puntos de precisión. Las implementaciones de atención lineal suelen ser menos eficientes que las implementaciones optimizadas de atención estándar. Para hacer que BASED sea competitivo, desarrollamos algoritmos conscientes de E/S que permiten un rendimiento 24 veces mayor en la generación de lenguaje que FlashAttention-2, al generar 1024 tokens utilizando modelos de 1.3 mil millones de parámetros. El código de este trabajo se proporciona en: https://github.com/HazyResearch/based.
English
Recent work has shown that attention-based language models excel at recall,
the ability to ground generations in tokens previously seen in context.
However, the efficiency of attention-based models is bottle-necked during
inference by the KV-cache's aggressive memory consumption. In this work, we
explore whether we can improve language model efficiency (e.g. by reducing
memory consumption) without compromising on recall. By applying experiments and
theory to a broad set of architectures, we identify a key tradeoff between a
model's state size and recall ability. We show that efficient alternatives to
attention (e.g. H3, Mamba, RWKV) maintain a fixed-size recurrent state, but
struggle at recall. We propose BASED a simple architecture combining linear and
sliding window attention. By varying BASED window size and linear attention
feature dimension, we can dial the state size and traverse the pareto frontier
of the recall-memory tradeoff curve, recovering the full quality of attention
on one end and the small state size of attention-alternatives on the other. We
train language models up to 1.3b parameters and show that BASED matches the
strongest sub-quadratic models (e.g. Mamba) in perplexity and outperforms them
on real-world recall-intensive tasks by 6.22 accuracy points. Implementations
of linear attention are often less efficient than optimized standard attention
implementations. To make BASED competitive, we develop IO-aware algorithms that
enable 24x higher throughput on language generation than FlashAttention-2, when
generating 1024 tokens using 1.3b parameter models. Code for this work is
provided at: https://github.com/HazyResearch/based.