MaskLLM:用于大型语言模型的可学习半结构稀疏化
MaskLLM: Learnable Semi-Structured Sparsity for Large Language Models
September 26, 2024
作者: Gongfan Fang, Hongxu Yin, Saurav Muralidharan, Greg Heinrich, Jeff Pool, Jan Kautz, Pavlo Molchanov, Xinchao Wang
cs.AI
摘要
大型语言模型(LLMs)以其庞大的参数数量而著称,通常会导致显著的冗余。本文介绍了MaskLLM,这是一种可学习的修剪方法,旨在在LLMs中建立半结构化(或“N:M”)稀疏性,以减少推理过程中的计算开销。MaskLLM并未开发新的重要性标准,而是通过Gumbel Softmax采样明确地将N:M模式建模为可学习的分布。这种方法有助于在大规模数据集上进行端到端训练,并具有两个显著优势:1)高质量的蒙版 - 我们的方法能够有效扩展到大型数据集并学习准确的蒙版;2)可迁移性 - 蒙版分布的概率建模使得稀疏性能够在不同领域或任务之间进行迁移学习。我们使用2:4的稀疏性评估了MaskLLM在各种LLMs上的效果,包括LLaMA-2、Nemotron-4和GPT-3,这些模型的参数范围从843M到15B不等。我们的实证结果显示,与最先进的方法相比,MaskLLM取得了显著的改进。例如,领先的方法在Wikitext上的困惑度(PPL)达到10或更高,而密集模型的PPL为5.12,但MaskLLM仅通过学习带有冻结权重的蒙版就实现了显著较低的6.72 PPL。此外,MaskLLM的可学习性使得可以为下游任务或领域定制蒙版,实现无损应用2:4稀疏性。代码可在https://github.com/NVlabs/MaskLLM找到。
English
Large Language Models (LLMs) are distinguished by their massive parameter
counts, which typically result in significant redundancy. This work introduces
MaskLLM, a learnable pruning method that establishes Semi-structured (or
``N:M'') Sparsity in LLMs, aimed at reducing computational overhead during
inference. Instead of developing a new importance criterion, MaskLLM explicitly
models N:M patterns as a learnable distribution through Gumbel Softmax
sampling. This approach facilitates end-to-end training on large-scale datasets
and offers two notable advantages: 1) High-quality Masks - our method
effectively scales to large datasets and learns accurate masks; 2)
Transferability - the probabilistic modeling of mask distribution enables the
transfer learning of sparsity across domains or tasks. We assessed MaskLLM
using 2:4 sparsity on various LLMs, including LLaMA-2, Nemotron-4, and GPT-3,
with sizes ranging from 843M to 15B parameters, and our empirical results show
substantial improvements over state-of-the-art methods. For instance, leading
approaches achieve a perplexity (PPL) of 10 or greater on Wikitext compared to
the dense model's 5.12 PPL, but MaskLLM achieves a significantly lower 6.72 PPL
solely by learning the masks with frozen weights. Furthermore, MaskLLM's
learnable nature allows customized masks for lossless application of 2:4
sparsity to downstream tasks or domains. Code is available at
https://github.com/NVlabs/MaskLLM.