ATLAS : Apprentissage de la mémorisation optimale du contexte au moment du test
ATLAS: Learning to Optimally Memorize the Context at Test Time
May 29, 2025
Auteurs: Ali Behrouz, Zeman Li, Praneeth Kacham, Majid Daliri, Yuan Deng, Peilin Zhong, Meisam Razaviyayn, Vahab Mirrokni
cs.AI
Résumé
Les Transformers se sont imposés comme les architectures de référence pour la modélisation de séquences, principalement en raison de leur efficacité dans les tâches de récupération en contexte et de leur capacité à apprendre à grande échelle. Cependant, leur complexité quadratique en mémoire et en temps limite leur applicabilité pour les séquences longues, ce qui a incité les chercheurs à explorer des architectures alternatives efficaces, telles que les réseaux de neurones récurrents modernes (également appelés modules de mémoire récurrente à long terme). Malgré leur récent succès dans diverses tâches en aval, ces modèles peinent dans les tâches nécessitant une compréhension de contexte étendu et une extrapolation à des séquences plus longues. Nous observons que ces lacunes proviennent de trois aspects disjoints dans leur conception : (1) une capacité mémoire limitée, contrainte par l'architecture de la mémoire et la cartographie des caractéristiques de l'entrée ; (2) une nature en ligne des mises à jour, c'est-à-dire l'optimisation de la mémoire uniquement par rapport à la dernière entrée ; et (3) une gestion peu expressive de leur mémoire de taille fixe. Pour améliorer ces trois aspects, nous présentons ATLAS, un module de mémoire à long terme de haute capacité qui apprend à mémoriser le contexte en optimisant la mémoire sur la base des tokens actuels et passés, surmontant ainsi la nature en ligne des modèles de mémoire à long terme. Sur la base de cette idée, nous introduisons une nouvelle famille d'architectures de type Transformer, appelées DeepTransformers, qui constituent des généralisations strictes de l'architecture Transformer originale. Nos résultats expérimentaux sur des tâches de modélisation du langage, de raisonnement de bon sens, de rappel intensif et de compréhension de contexte long montrent qu'ATLAS surpasse les performances des Transformers et des modèles récurrents linéaires récents. ATLAS améliore également les performances en contexte long des Titans, atteignant une précision de +80\% pour une longueur de contexte de 10M dans le benchmark BABILong.
English
Transformers have been established as the most popular backbones in sequence
modeling, mainly due to their effectiveness in in-context retrieval tasks and
the ability to learn at scale. Their quadratic memory and time complexity,
however, bound their applicability in longer sequences and so has motivated
researchers to explore effective alternative architectures such as modern
recurrent neural networks (a.k.a long-term recurrent memory module). Despite
their recent success in diverse downstream tasks, they struggle in tasks that
requires long context understanding and extrapolation to longer sequences. We
observe that these shortcomings come from three disjoint aspects in their
design: (1) limited memory capacity that is bounded by the architecture of
memory and feature mapping of the input; (2) online nature of update, i.e.,
optimizing the memory only with respect to the last input; and (3) less
expressive management of their fixed-size memory. To enhance all these three
aspects, we present ATLAS, a long-term memory module with high capacity that
learns to memorize the context by optimizing the memory based on the current
and past tokens, overcoming the online nature of long-term memory models.
Building on this insight, we present a new family of Transformer-like
architectures, called DeepTransformers, that are strict generalizations of the
original Transformer architecture. Our experimental results on language
modeling, common-sense reasoning, recall-intensive, and long-context
understanding tasks show that ATLAS surpasses the performance of Transformers
and recent linear recurrent models. ATLAS further improves the long context
performance of Titans, achieving +80\% accuracy in 10M context length of
BABILong benchmark.Summary
AI-Generated Summary