Scalify: 低精度LLMトレーニングのための効率的なスケール伝播
Scalify: scale propagation for efficient low-precision LLM training
July 24, 2024
著者: Paul Balança, Sam Hosegood, Carlo Luschi, Andrew Fitzgibbon
cs.AI
要旨
機械学習アクセラレータハードウェアでは、大規模言語モデルの学習と推論における計算効率を向上させるため、float8などの低精度フォーマットが導入されています。しかし、MLコミュニティでの採用は、高精度学習の精度を維持するために必要な複雑で脆弱な技術によって遅れています。本研究では、既存のテンソルスケーリング手法を一般化し形式化した、計算グラフのためのエンドツーエンドのスケール伝播パラダイムであるScalifyを提案します。実験結果から、Scalifyがfloat8行列乗算と勾配表現、およびfloat16オプティマイザ状態の保存をそのままサポートすることが示されています。ScalifyのJAX実装はhttps://github.com/graphcore-research/jax-scalifyでオープンソースとして公開されています。
English
Low-precision formats such as float8 have been introduced in machine learning
accelerated hardware to improve computational efficiency for large language
models training and inference. Nevertheless, adoption by the ML community has
been slowed down by the complex, and sometimes brittle, techniques required to
match higher precision training accuracy. In this work, we present Scalify, a
end-to-end scale propagation paradigm for computational graphs, generalizing
and formalizing existing tensor scaling methods. Experiment results show that
Scalify supports out-of-the-box float8 matrix multiplication and gradients
representation, as well as float16 optimizer state storage. Our JAX
implementation of Scalify is open-sourced at
https://github.com/graphcore-research/jax-scalifySummary
AI-Generated Summary