MaskLLM: 大規模言語モデルのための学習可能な半構造化スパース性
MaskLLM: Learnable Semi-Structured Sparsity for Large Language Models
September 26, 2024
著者: Gongfan Fang, Hongxu Yin, Saurav Muralidharan, Greg Heinrich, Jeff Pool, Jan Kautz, Pavlo Molchanov, Xinchao Wang
cs.AI
要旨
大規模言語モデル(LLM)は、通常、膨大なパラメータ数によって特徴付けられ、それにより重要な冗長性が生じます。本研究では、推論時の計算オーバーヘッドを削減することを目的とした、LLMに半構造化(または「N:M」)スパースネスを確立する学習可能なプルーニング手法であるMaskLLMを紹介します。新しい重要性基準を開発する代わりに、MaskLLMはN:MパターンをGumbel Softmaxサンプリングを通じて学習可能な分布として明示的にモデル化します。このアプローチは大規模データセットでのエンドツーエンドのトレーニングを容易にし、次の2つの注目すべき利点を提供します:1)高品質のマスク - 当社の手法は効果的に大規模データセットにスケーリングし、正確なマスクを学習します;2)移転性 - マスク分布の確率モデリングにより、スパースネスの転移学習がドメインやタスク間で可能になります。私たちは、843Mから15Bのパラメータを持つLLMa-2、Nemotron-4、およびGPT-3を含むさまざまなLLMで2:4スパースネスを使用してMaskLLMを評価し、実験結果は最先端の手法に比べて実質的な改善が示されました。たとえば、主要な手法はWikitextで10以上のPerplexity(PPL)を達成しますが、密なモデルの5.12 PPLに対してMaskLLMは凍結された重みでマスクを学習するだけで著しく低い6.72 PPLを達成します。さらに、MaskLLMの学習可能性により、ダウンストリームタスクやドメインに2:4スパースネスを損失なく適用するためのカスタマイズされたマスクが可能になります。コードはhttps://github.com/NVlabs/MaskLLMで入手可能です。
English
Large Language Models (LLMs) are distinguished by their massive parameter
counts, which typically result in significant redundancy. This work introduces
MaskLLM, a learnable pruning method that establishes Semi-structured (or
``N:M'') Sparsity in LLMs, aimed at reducing computational overhead during
inference. Instead of developing a new importance criterion, MaskLLM explicitly
models N:M patterns as a learnable distribution through Gumbel Softmax
sampling. This approach facilitates end-to-end training on large-scale datasets
and offers two notable advantages: 1) High-quality Masks - our method
effectively scales to large datasets and learns accurate masks; 2)
Transferability - the probabilistic modeling of mask distribution enables the
transfer learning of sparsity across domains or tasks. We assessed MaskLLM
using 2:4 sparsity on various LLMs, including LLaMA-2, Nemotron-4, and GPT-3,
with sizes ranging from 843M to 15B parameters, and our empirical results show
substantial improvements over state-of-the-art methods. For instance, leading
approaches achieve a perplexity (PPL) of 10 or greater on Wikitext compared to
the dense model's 5.12 PPL, but MaskLLM achieves a significantly lower 6.72 PPL
solely by learning the masks with frozen weights. Furthermore, MaskLLM's
learnable nature allows customized masks for lossless application of 2:4
sparsity to downstream tasks or domains. Code is available at
https://github.com/NVlabs/MaskLLM.Summary
AI-Generated Summary