Scalify: 저정밀도 LLM 훈련을 위한 효율적인 스케일 전파 기법
Scalify: scale propagation for efficient low-precision LLM training
July 24, 2024
저자: Paul Balança, Sam Hosegood, Carlo Luschi, Andrew Fitzgibbon
cs.AI
초록
float8과 같은 저정밀도 형식은 대규모 언어 모델의 학습 및 추론을 위한 계산 효율성을 향상시키기 위해 머신러닝 가속 하드웨어에 도입되었습니다. 그러나 ML 커뮤니티의 채택은 더 높은 정밀도의 학습 정확도를 맞추기 위해 필요한 복잡하고 때로는 취약한 기술들로 인해 더딘 상태입니다. 본 연구에서는 기존의 텐서 스케일링 방법을 일반화하고 공식화한, 계산 그래프를 위한 종단 간 스케일 전파 패러다임인 Scalify를 소개합니다. 실험 결과는 Scalify가 float8 행렬 곱셈 및 그래디언트 표현, 그리고 float16 옵티마이저 상태 저장을 즉시 지원함을 보여줍니다. Scalify의 JAX 구현은 https://github.com/graphcore-research/jax-scalify에서 오픈소스로 제공됩니다.
English
Low-precision formats such as float8 have been introduced in machine learning
accelerated hardware to improve computational efficiency for large language
models training and inference. Nevertheless, adoption by the ML community has
been slowed down by the complex, and sometimes brittle, techniques required to
match higher precision training accuracy. In this work, we present Scalify, a
end-to-end scale propagation paradigm for computational graphs, generalizing
and formalizing existing tensor scaling methods. Experiment results show that
Scalify supports out-of-the-box float8 matrix multiplication and gradients
representation, as well as float16 optimizer state storage. Our JAX
implementation of Scalify is open-sourced at
https://github.com/graphcore-research/jax-scalifySummary
AI-Generated Summary