大語彙言語モデルにおいて損失を削減する
Cut Your Losses in Large-Vocabulary Language Models
November 13, 2024
著者: Erik Wijmans, Brody Huval, Alexander Hertzberg, Vladlen Koltun, Philipp Krähenbühl
cs.AI
要旨
言語モデルがますます大きくなるにつれて、その語彙も増加しています。これにより、訓練中のLLMのメモリフットプリントが不均衡になり、1つの単一レイヤー、つまり損失計算のクロスエントロピーにシフトしました。クロスエントロピーは、入力トークンと語彙アイテムの各ペアごとにエントリを持つロジット行列を構築し、小規模モデルではLLM全体よりもメモリを桁違いに消費します。私たちは、すべてのトークンのロジットをグローバルメモリに具現化せずにクロスエントロピー損失を計算する方法であるCut Cross-Entropy(CCE)を提案します。代わりに、CCEは正しいトークンのロジットのみを計算し、すべてのロジットに対する対数和指数をその場で評価します。私たちは、フラッシュメモリ内で語彙全体にわたる行列乗算と対数和指数の縮小を実行するカスタムカーネルを実装し、クロスエントロピー計算におけるグローバルメモリ消費を無視できるレベルに抑えます。これには劇的な効果があります。たとえば、Gemma 2(2B)モデルを取ると、CCEにより損失計算のメモリフットプリントが24 GBから1 MBに、分類器ヘッドの合計訓練時メモリ消費が28 GBから1 GBに削減されます。CCEのスループットを向上させるために、ソフトマックスの固有の疎さを活用し、勾配計算の要素のうち、勾配への寄与が無視できる(つまり、数値精度以下)ものをスキップすることを提案します。実験では、メモリ消費の劇的な削減が、訓練速度や収束を犠牲にすることなく達成されていることが示されています。
English
As language models grow ever larger, so do their vocabularies. This has
shifted the memory footprint of LLMs during training disproportionately to one
single layer: the cross-entropy in the loss computation. Cross-entropy builds
up a logit matrix with entries for each pair of input tokens and vocabulary
items and, for small models, consumes an order of magnitude more memory than
the rest of the LLM combined. We propose Cut Cross-Entropy (CCE), a method that
computes the cross-entropy loss without materializing the logits for all tokens
into global memory. Rather, CCE only computes the logit for the correct token
and evaluates the log-sum-exp over all logits on the fly. We implement a custom
kernel that performs the matrix multiplications and the log-sum-exp reduction
over the vocabulary in flash memory, making global memory consumption for the
cross-entropy computation negligible. This has a dramatic effect. Taking the
Gemma 2 (2B) model as an example, CCE reduces the memory footprint of the loss
computation from 24 GB to 1 MB, and the total training-time memory consumption
of the classifier head from 28 GB to 1 GB. To improve the throughput of CCE, we
leverage the inherent sparsity of softmax and propose to skip elements of the
gradient computation that have a negligible (i.e., below numerical precision)
contribution to the gradient. Experiments demonstrate that the dramatic
reduction in memory consumption is accomplished without sacrificing training
speed or convergence.Summary
AI-Generated Summary