GaLore:通過梯度低秩投影實現記憶效率的LLM訓練
GaLore: Memory-Efficient LLM Training by Gradient Low-Rank Projection
March 6, 2024
作者: Jiawei Zhao, Zhenyu Zhang, Beidi Chen, Zhangyang Wang, Anima Anandkumar, Yuandong Tian
cs.AI
摘要
訓練大型語言模型(LLMs)存在著重大的記憶挑戰,主要是由於權重和優化器狀態的增加導致的。常見的記憶減少方法,如低秩適應(LoRA),在每一層中將一個可訓練的低秩矩陣添加到凍結的預訓練權重中,減少可訓練參數和優化器狀態。然而,這些方法通常在預訓練和微調階段的訓練中表現不佳,因為它們將參數搜索限制在低秩子空間並改變訓練動態,進一步可能需要完整秩的熱啟動。在這項工作中,我們提出了梯度低秩投影(GaLore),這是一種訓練策略,允許完全參數學習,但比LoRA等常見的低秩適應方法更節省記憶體。我們的方法在優化器狀態中將記憶體使用量降低了高達65.5%,同時在使用C4數據集的LLaMA 1B和7B架構進行預訓練以及在GLUE任務上對RoBERTa進行微調時,保持了效率和性能。我們的8位GaLore進一步將優化器記憶體降低了高達82.5%,總訓練記憶體降低了63.3%,與BF16基線相比。值得注意的是,我們首次展示了在具有24GB記憶體的消費級GPU(例如NVIDIA RTX 4090)上,無需模型並行、檢查點或卸載策略即可進行7B模型的預訓練的可行性。
English
Training Large Language Models (LLMs) presents significant memory challenges,
predominantly due to the growing size of weights and optimizer states. Common
memory-reduction approaches, such as low-rank adaptation (LoRA), add a
trainable low-rank matrix to the frozen pre-trained weight in each layer,
reducing trainable parameters and optimizer states. However, such approaches
typically underperform training with full-rank weights in both pre-training and
fine-tuning stages since they limit the parameter search to a low-rank subspace
and alter the training dynamics, and further, may require full-rank warm start.
In this work, we propose Gradient Low-Rank Projection (GaLore), a training
strategy that allows full-parameter learning but is more memory-efficient than
common low-rank adaptation methods such as LoRA. Our approach reduces memory
usage by up to 65.5% in optimizer states while maintaining both efficiency and
performance for pre-training on LLaMA 1B and 7B architectures with C4 dataset
with up to 19.7B tokens, and on fine-tuning RoBERTa on GLUE tasks. Our 8-bit
GaLore further reduces optimizer memory by up to 82.5% and total training
memory by 63.3%, compared to a BF16 baseline. Notably, we demonstrate, for the
first time, the feasibility of pre-training a 7B model on consumer GPUs with
24GB memory (e.g., NVIDIA RTX 4090) without model parallel, checkpointing, or
offloading strategies.