ChatPaper.aiChatPaper

Scalify: Skalierungsausbreitung für effizientes Training von LLM mit geringer Präzision

Scalify: scale propagation for efficient low-precision LLM training

July 24, 2024
Autoren: Paul Balança, Sam Hosegood, Carlo Luschi, Andrew Fitzgibbon
cs.AI

Zusammenfassung

Niedrigpräzisionsformate wie float8 wurden in der beschleunigten Hardware für maschinelles Lernen eingeführt, um die Rechenleistung bei der Schulung und Inferenz großer Sprachmodelle zu verbessern. Dennoch wurde die Akzeptanz in der ML-Community durch die komplexen und manchmal spröden Techniken, die erforderlich sind, um die Schulungsgenauigkeit höherer Präzision zu erreichen, verlangsamt. In dieser Arbeit stellen wir Scalify vor, ein End-to-End-Skalenpropagationsparadigma für Berechnungsgraphen, das bestehende Tensor-Skalierungsmethoden verallgemeinert und formalisiert. Experimentelle Ergebnisse zeigen, dass Scalify die Out-of-the-Box-Matrixmultiplikation und Gradientendarstellung in float8 unterstützt, sowie die Speicherung des Optimizer-Zustands in float16. Unsere JAX-Implementierung von Scalify ist unter https://github.com/graphcore-research/jax-scalify als Open Source verfügbar.
English
Low-precision formats such as float8 have been introduced in machine learning accelerated hardware to improve computational efficiency for large language models training and inference. Nevertheless, adoption by the ML community has been slowed down by the complex, and sometimes brittle, techniques required to match higher precision training accuracy. In this work, we present Scalify, a end-to-end scale propagation paradigm for computational graphs, generalizing and formalizing existing tensor scaling methods. Experiment results show that Scalify supports out-of-the-box float8 matrix multiplication and gradients representation, as well as float16 optimizer state storage. Our JAX implementation of Scalify is open-sourced at https://github.com/graphcore-research/jax-scalify

Summary

AI-Generated Summary

PDF132November 28, 2024