Deja Vu : Sparsité contextuelle pour des LLM efficaces lors de l'inférence
Deja Vu: Contextual Sparsity for Efficient LLMs at Inference Time
October 26, 2023
Auteurs: Zichang Liu, Jue Wang, Tri Dao, Tianyi Zhou, Binhang Yuan, Zhao Song, Anshumali Shrivastava, Ce Zhang, Yuandong Tian, Christopher Re, Beidi Chen
cs.AI
Résumé
Les grands modèles de langage (LLM) avec des centaines de milliards de paramètres ont suscité une nouvelle vague d'applications passionnantes en IA. Cependant, ils sont coûteux en calcul au moment de l'inférence. La parcimonie est une approche naturelle pour réduire ce coût, mais les méthodes existantes nécessitent soit un réentraînement coûteux, soit renoncent à la capacité d'apprentissage en contexte des LLM, soit ne permettent pas d'accélérer le temps réel sur le matériel moderne. Nous émettons l'hypothèse que la parcimonie contextuelle, qui consiste en de petits ensembles de têtes d'attention et de paramètres MLP dépendants de l'entrée et produisant approximativement la même sortie que le modèle dense pour une entrée donnée, peut résoudre ces problèmes. Nous montrons que la parcimonie contextuelle existe, qu'elle peut être prédite avec précision, et que nous pouvons l'exploiter pour accélérer l'inférence des LLM en temps réel sans compromettre la qualité du modèle ou sa capacité d'apprentissage en contexte. Sur la base de ces observations, nous proposons DejaVu, un système qui utilise un algorithme peu coûteux pour prédire la parcimonie contextuelle à la volée en fonction des entrées de chaque couche, ainsi qu'une implémentation asynchrone et adaptée au matériel qui accélère l'inférence des LLM. Nous validons que DejaVu peut réduire la latence d'inférence d'OPT-175B de plus de 2 fois par rapport à FasterTransformer, l'état de l'art, et de plus de 6 fois par rapport à l'implémentation largement utilisée de Hugging Face, sans compromettre la qualité du modèle. Le code est disponible à l'adresse https://github.com/FMInference/DejaVu.
English
Large language models (LLMs) with hundreds of billions of parameters have
sparked a new wave of exciting AI applications. However, they are
computationally expensive at inference time. Sparsity is a natural approach to
reduce this cost, but existing methods either require costly retraining, have
to forgo LLM's in-context learning ability, or do not yield wall-clock time
speedup on modern hardware. We hypothesize that contextual sparsity, which are
small, input-dependent sets of attention heads and MLP parameters that yield
approximately the same output as the dense model for a given input, can address
these issues. We show that contextual sparsity exists, that it can be
accurately predicted, and that we can exploit it to speed up LLM inference in
wall-clock time without compromising LLM's quality or in-context learning
ability. Based on these insights, we propose DejaVu, a system that uses a
low-cost algorithm to predict contextual sparsity on the fly given inputs to
each layer, along with an asynchronous and hardware-aware implementation that
speeds up LLM inference. We validate that DejaVu can reduce the inference
latency of OPT-175B by over 2X compared to the state-of-the-art
FasterTransformer, and over 6X compared to the widely used Hugging Face
implementation, without compromising model quality. The code is available at
https://github.com/FMInference/DejaVu.