Scalify: propagación de escala para un entrenamiento eficiente de LLM de baja precisión.
Scalify: scale propagation for efficient low-precision LLM training
July 24, 2024
Autores: Paul Balança, Sam Hosegood, Carlo Luschi, Andrew Fitzgibbon
cs.AI
Resumen
Los formatos de baja precisión, como float8, se han introducido en hardware acelerado de aprendizaje automático para mejorar la eficiencia computacional en el entrenamiento e inferencia de grandes modelos de lenguaje. Sin embargo, la adopción por parte de la comunidad de ML se ha visto ralentizada por las técnicas complejas, a veces frágiles, necesarias para igualar la precisión de entrenamiento de mayor precisión. En este trabajo, presentamos Scalify, un paradigma de propagación de escala de extremo a extremo para grafos computacionales, generalizando y formalizando los métodos de escalado de tensores existentes. Los resultados experimentales muestran que Scalify admite la multiplicación de matrices float8 listo para usar y la representación de gradientes, así como el almacenamiento de estado del optimizador float16. Nuestra implementación de Scalify en JAX está disponible como código abierto en https://github.com/graphcore-research/jax-scalify
English
Low-precision formats such as float8 have been introduced in machine learning
accelerated hardware to improve computational efficiency for large language
models training and inference. Nevertheless, adoption by the ML community has
been slowed down by the complex, and sometimes brittle, techniques required to
match higher precision training accuracy. In this work, we present Scalify, a
end-to-end scale propagation paradigm for computational graphs, generalizing
and formalizing existing tensor scaling methods. Experiment results show that
Scalify supports out-of-the-box float8 matrix multiplication and gradients
representation, as well as float16 optimizer state storage. Our JAX
implementation of Scalify is open-sourced at
https://github.com/graphcore-research/jax-scalifySummary
AI-Generated Summary