FlashFFTConv: テンソルコアを用いた長系列のための効率的な畳み込み
FlashFFTConv: Efficient Convolutions for Long Sequences with Tensor Cores
November 10, 2023
著者: Daniel Y. Fu, Hermann Kumbong, Eric Nguyen, Christopher Ré
cs.AI
要旨
長いフィルタを持つ畳み込みモデルは、多くの長系列タスクにおいて最先端の推論能力を実証してきましたが、最適化されたTransformerモデルと比較して実時間性能では遅れをとっています。主なボトルネックは高速フーリエ変換(FFT)です。FFTは長い畳み込みを系列長Nに対してO(N logN)時間で実行可能にしますが、ハードウェアの利用率が低いという問題があります。本論文では、FFT畳み込みの最適化方法を研究します。2つの主要なボトルネックを発見しました:FFTは専用の行列乗算ユニットを効果的に使用せず、メモリ階層間で高コストなI/Oを引き起こします。これに対応して、FlashFFTConvを提案します。FlashFFTConvは、行列乗算ユニットを使用してFFTを計算する行列分解を採用し、長系列におけるカーネル融合を可能にすることでI/Oを削減します。また、2つのスパース畳み込みアルゴリズムを提示します:1)部分畳み込みと2)周波数スパース畳み込みです。これらは行列分解のブロックをスキップするだけで実装可能であり、メモリと計算のさらなる節約の機会を提供します。FlashFFTConvは、PyTorchと比較して正確なFFT畳み込みを最大7.93倍高速化し、エンドツーエンドで最大4.4倍の高速化を達成します。同じ計算予算のもとで、FlashFFTConvはHyena-GPT-sがPILEデータセットで2.3ポイント良いパープレキシティを達成し、M2-BERT-baseがGLUEスコアで3.3ポイント向上させることができます。これはパラメータ数が2倍のモデルに匹敵する性能です。また、FlashFFTConvは高解像度視覚タスクであるPath-512で96.1%の精度を達成しました。これは、これまで50%以上の精度を達成したモデルがなかったタスクです。さらに、部分畳み込みはより長い系列のモデルを可能にし、最長のヒト遺伝子(230万塩基対)を処理できる初のDNAモデルを実現しました。周波数スパース畳み込みは、事前学習済みモデルの高速化を図りながら、モデルの品質を維持または向上させます。
English
Convolution models with long filters have demonstrated state-of-the-art
reasoning abilities in many long-sequence tasks but lag behind the most
optimized Transformers in wall-clock time. A major bottleneck is the Fast
Fourier Transform (FFT)--which allows long convolutions to run in O(N logN)
time in sequence length N but has poor hardware utilization. In this paper,
we study how to optimize the FFT convolution. We find two key bottlenecks: the
FFT does not effectively use specialized matrix multiply units, and it incurs
expensive I/O between layers of the memory hierarchy. In response, we propose
FlashFFTConv. FlashFFTConv uses a matrix decomposition that computes the FFT
using matrix multiply units and enables kernel fusion for long sequences,
reducing I/O. We also present two sparse convolution algorithms--1) partial
convolutions and 2) frequency-sparse convolutions--which can be implemented
simply by skipping blocks in the matrix decomposition, enabling further
opportunities for memory and compute savings. FlashFFTConv speeds up exact FFT
convolutions by up to 7.93times over PyTorch and achieves up to 4.4times
speedup end-to-end. Given the same compute budget, FlashFFTConv allows
Hyena-GPT-s to achieve 2.3 points better perplexity on the PILE and
M2-BERT-base to achieve 3.3 points higher GLUE score--matching models with
twice the parameter count. FlashFFTConv also achieves 96.1% accuracy on
Path-512, a high-resolution vision task where no model had previously achieved
better than 50%. Furthermore, partial convolutions enable longer-sequence
models--yielding the first DNA model that can process the longest human genes
(2.3M base pairs)--and frequency-sparse convolutions speed up pretrained models
while maintaining or improving model quality.