ChatPaper.aiChatPaper

Scalify:用於高效低精度LLM訓練的尺度傳播

Scalify: scale propagation for efficient low-precision LLM training

July 24, 2024
作者: Paul Balança, Sam Hosegood, Carlo Luschi, Andrew Fitzgibbon
cs.AI

摘要

為了提高大型語言模型的訓練和推論的計算效率,機器學習加速硬體引入了像是float8這樣的低精度格式。然而,由於需要複雜且有時脆弱的技術來匹配更高精度的訓練準確性,這種格式的採用速度被機器學習社群放緩。在這項工作中,我們提出了Scalify,一種端對端的尺度傳播範式,用於計算圖,通用化和正式化現有的張量縮放方法。實驗結果顯示,Scalify支持開箱即用的float8矩陣乘法和梯度表示,以及float16優化器狀態存儲。我們針對Scalify的JAX實現已在https://github.com/graphcore-research/jax-scalify上開源。
English
Low-precision formats such as float8 have been introduced in machine learning accelerated hardware to improve computational efficiency for large language models training and inference. Nevertheless, adoption by the ML community has been slowed down by the complex, and sometimes brittle, techniques required to match higher precision training accuracy. In this work, we present Scalify, a end-to-end scale propagation paradigm for computational graphs, generalizing and formalizing existing tensor scaling methods. Experiment results show that Scalify supports out-of-the-box float8 matrix multiplication and gradients representation, as well as float16 optimizer state storage. Our JAX implementation of Scalify is open-sourced at https://github.com/graphcore-research/jax-scalify

Summary

AI-Generated Summary

PDF132November 28, 2024