ChatPaper.aiChatPaper

GaLore:通过梯度低秩投影实现内存高效的LLM训练

GaLore: Memory-Efficient LLM Training by Gradient Low-Rank Projection

March 6, 2024
作者: Jiawei Zhao, Zhenyu Zhang, Beidi Chen, Zhangyang Wang, Anima Anandkumar, Yuandong Tian
cs.AI

摘要

训练大型语言模型(LLMs)存在着重要的内存挑战,主要是由于权重和优化器状态的不断增大。常见的内存减少方法,如低秩适应(LoRA),在每一层中向冻结的预训练权重添加一个可训练的低秩矩阵,减少可训练参数和优化器状态。然而,这些方法通常在预训练和微调阶段表现不佳,因为它们将参数搜索限制在低秩子空间,并改变训练动态,可能需要完整秩的热启动。在这项工作中,我们提出了梯度低秩投影(GaLore),这是一种允许完全参数学习但比LoRA等常见低秩适应方法更节省内存的训练策略。我们的方法在优化器状态中将内存使用减少了高达65.5%,同时在使用C4数据集的LLaMA 1B和7B架构进行预训练,以及在GLUE任务上对RoBERTa进行微调时,保持了效率和性能。我们的8位GaLore进一步将优化器内存减少了高达82.5%,总训练内存减少了63.3%,与BF16基线相比。值得注意的是,我们首次展示了在具有24GB内存的消费级GPU上(例如NVIDIA RTX 4090)进行7B模型的预训练的可行性,而无需模型并行、检查点或卸载策略。
English
Training Large Language Models (LLMs) presents significant memory challenges, predominantly due to the growing size of weights and optimizer states. Common memory-reduction approaches, such as low-rank adaptation (LoRA), add a trainable low-rank matrix to the frozen pre-trained weight in each layer, reducing trainable parameters and optimizer states. However, such approaches typically underperform training with full-rank weights in both pre-training and fine-tuning stages since they limit the parameter search to a low-rank subspace and alter the training dynamics, and further, may require full-rank warm start. In this work, we propose Gradient Low-Rank Projection (GaLore), a training strategy that allows full-parameter learning but is more memory-efficient than common low-rank adaptation methods such as LoRA. Our approach reduces memory usage by up to 65.5% in optimizer states while maintaining both efficiency and performance for pre-training on LLaMA 1B and 7B architectures with C4 dataset with up to 19.7B tokens, and on fine-tuning RoBERTa on GLUE tasks. Our 8-bit GaLore further reduces optimizer memory by up to 82.5% and total training memory by 63.3%, compared to a BF16 baseline. Notably, we demonstrate, for the first time, the feasibility of pre-training a 7B model on consumer GPUs with 24GB memory (e.g., NVIDIA RTX 4090) without model parallel, checkpointing, or offloading strategies.
PDF18915December 15, 2024