Scalify: schaalpropagatie voor efficiënte training van low-precision LLM's
Scalify: scale propagation for efficient low-precision LLM training
July 24, 2024
Auteurs: Paul Balança, Sam Hosegood, Carlo Luschi, Andrew Fitzgibbon
cs.AI
Samenvatting
Laag-precisieformaten zoals float8 zijn geïntroduceerd in hardware voor machine learning-versnelling om de rekenkundige efficiëntie te verbeteren voor het trainen en inferentie van grote taalmodellen. Desalniettemin is de adoptie door de ML-gemeenschap vertraagd door de complexe en soms kwetsbare technieken die nodig zijn om de nauwkeurigheid van training met hogere precisie te evenaren. In dit werk presenteren we Scalify, een end-to-end schaalpropagatieparadigma voor computationele grafieken, dat bestaande tensorschalingsmethoden generaliseert en formaliseert. Experimentele resultaten tonen aan dat Scalify out-of-the-box float8 matrixvermenigvuldiging en gradiëntrepresentatie ondersteunt, evenals float16 optimizer state-opslag. Onze JAX-implementatie van Scalify is open-source beschikbaar op https://github.com/graphcore-research/jax-scalify.
English
Low-precision formats such as float8 have been introduced in machine learning
accelerated hardware to improve computational efficiency for large language
models training and inference. Nevertheless, adoption by the ML community has
been slowed down by the complex, and sometimes brittle, techniques required to
match higher precision training accuracy. In this work, we present Scalify, a
end-to-end scale propagation paradigm for computational graphs, generalizing
and formalizing existing tensor scaling methods. Experiment results show that
Scalify supports out-of-the-box float8 matrix multiplication and gradients
representation, as well as float16 optimizer state storage. Our JAX
implementation of Scalify is open-sourced at
https://github.com/graphcore-research/jax-scalify