Mamba: 選択的状態空間を用いた線形時間シーケンスモデリング
Mamba: Linear-Time Sequence Modeling with Selective State Spaces
December 1, 2023
著者: Albert Gu, Tri Dao
cs.AI
要旨
深層学習のエキサイティングなアプリケーションの大部分を支える基盤モデルは、ほぼ普遍的にTransformerアーキテクチャとその中核をなすアテンションモジュールに基づいています。長いシーケンスにおけるTransformerの計算効率の低さに対処するため、線形アテンション、ゲート付き畳み込み、リカレントモデル、構造化状態空間モデル(SSM)など、サブ二次時間のアーキテクチャが数多く開発されてきました。しかし、これらのモデルは言語などの重要なモダリティにおいて、アテンションほどの性能を発揮できていません。我々は、これらのモデルの主要な弱点が、コンテンツベースの推論を実行できない点にあると特定し、いくつかの改善を加えました。まず、SSMのパラメータを入力の関数とすることで、離散モダリティにおける弱点を解消し、モデルが現在のトークンに応じてシーケンス長の次元に沿って情報を選択的に伝播または忘却できるようにしました。次に、この変更により効率的な畳み込みの使用が妨げられるものの、ハードウェアを意識した並列アルゴリズムをリカレントモードで設計しました。これらの選択的SSMを、アテンションやMLPブロックさえも持たない簡素化されたエンドツーエンドのニューラルネットワークアーキテクチャ(Mamba)に統合しました。Mambaは、高速な推論(Transformerの5倍のスループット)とシーケンス長に対する線形スケーリングを実現し、実際のデータにおいて最大百万長のシーケンスまで性能が向上します。一般的なシーケンスモデルのバックボーンとして、Mambaは言語、音声、ゲノミクスなど複数のモダリティにおいて最先端の性能を達成します。言語モデリングにおいて、我々のMamba-3Bモデルは、同じサイズのTransformerを上回り、その2倍のサイズのTransformerと同等の性能を、事前学習と下流評価の両方で示しました。
English
Foundation models, now powering most of the exciting applications in deep
learning, are almost universally based on the Transformer architecture and its
core attention module. Many subquadratic-time architectures such as linear
attention, gated convolution and recurrent models, and structured state space
models (SSMs) have been developed to address Transformers' computational
inefficiency on long sequences, but they have not performed as well as
attention on important modalities such as language. We identify that a key
weakness of such models is their inability to perform content-based reasoning,
and make several improvements. First, simply letting the SSM parameters be
functions of the input addresses their weakness with discrete modalities,
allowing the model to selectively propagate or forget information along the
sequence length dimension depending on the current token. Second, even though
this change prevents the use of efficient convolutions, we design a
hardware-aware parallel algorithm in recurrent mode. We integrate these
selective SSMs into a simplified end-to-end neural network architecture without
attention or even MLP blocks (Mamba). Mamba enjoys fast inference (5times
higher throughput than Transformers) and linear scaling in sequence length, and
its performance improves on real data up to million-length sequences. As a
general sequence model backbone, Mamba achieves state-of-the-art performance
across several modalities such as language, audio, and genomics. On language
modeling, our Mamba-3B model outperforms Transformers of the same size and
matches Transformers twice its size, both in pretraining and downstream
evaluation.