ChatPaper.aiChatPaper

Scalify: 低精度LLMトレーニングのための効率的なスケール伝播

Scalify: scale propagation for efficient low-precision LLM training

July 24, 2024
著者: Paul Balança, Sam Hosegood, Carlo Luschi, Andrew Fitzgibbon
cs.AI

要旨

機械学習アクセラレータハードウェアでは、大規模言語モデルの学習と推論における計算効率を向上させるため、float8などの低精度フォーマットが導入されています。しかし、MLコミュニティでの採用は、高精度学習の精度を維持するために必要な複雑で脆弱な技術によって遅れています。本研究では、既存のテンソルスケーリング手法を一般化し形式化した、計算グラフのためのエンドツーエンドのスケール伝播パラダイムであるScalifyを提案します。実験結果から、Scalifyがfloat8行列乗算と勾配表現、およびfloat16オプティマイザ状態の保存をそのままサポートすることが示されています。ScalifyのJAX実装はhttps://github.com/graphcore-research/jax-scalifyでオープンソースとして公開されています。
English
Low-precision formats such as float8 have been introduced in machine learning accelerated hardware to improve computational efficiency for large language models training and inference. Nevertheless, adoption by the ML community has been slowed down by the complex, and sometimes brittle, techniques required to match higher precision training accuracy. In this work, we present Scalify, a end-to-end scale propagation paradigm for computational graphs, generalizing and formalizing existing tensor scaling methods. Experiment results show that Scalify supports out-of-the-box float8 matrix multiplication and gradients representation, as well as float16 optimizer state storage. Our JAX implementation of Scalify is open-sourced at https://github.com/graphcore-research/jax-scalify

Summary

AI-Generated Summary

PDF132November 28, 2024