Scalify:用于高效低精度LLM训练的尺度传播
Scalify: scale propagation for efficient low-precision LLM training
July 24, 2024
作者: Paul Balança, Sam Hosegood, Carlo Luschi, Andrew Fitzgibbon
cs.AI
摘要
为了提高大型语言模型训练和推断的计算效率,机器学习加速硬件引入了诸如float8之类的低精度格式。然而,由于需要复杂且有时脆弱的技术来匹配更高精度的训练准确性,这些格式在机器学习社区中的采用速度较慢。在这项工作中,我们提出了Scalify,一种端到端的规模传播范式,用于计算图,对现有的张量缩放方法进行泛化和形式化。实验结果表明,Scalify支持开箱即用的float8矩阵乘法和梯度表示,以及float16优化器状态存储。我们基于JAX的Scalify实现已在https://github.com/graphcore-research/jax-scalify上开源。
English
Low-precision formats such as float8 have been introduced in machine learning
accelerated hardware to improve computational efficiency for large language
models training and inference. Nevertheless, adoption by the ML community has
been slowed down by the complex, and sometimes brittle, techniques required to
match higher precision training accuracy. In this work, we present Scalify, a
end-to-end scale propagation paradigm for computational graphs, generalizing
and formalizing existing tensor scaling methods. Experiment results show that
Scalify supports out-of-the-box float8 matrix multiplication and gradients
representation, as well as float16 optimizer state storage. Our JAX
implementation of Scalify is open-sourced at
https://github.com/graphcore-research/jax-scalifySummary
AI-Generated Summary