FlashFFTConv:使用張量核心對長序列進行高效卷積
FlashFFTConv: Efficient Convolutions for Long Sequences with Tensor Cores
November 10, 2023
作者: Daniel Y. Fu, Hermann Kumbong, Eric Nguyen, Christopher Ré
cs.AI
摘要
具有長過濾器的卷積模型在許多長序列任務中展示了最先進的推理能力,但在掛鐘時間上落後於最優化的Transformer。一個主要瓶頸是快速傅立葉變換(FFT)--它允許長卷積在序列長度N上以O(N logN)的時間運行,但硬件利用率較低。在本文中,我們研究如何優化FFT卷積。我們發現兩個關鍵瓶頸:FFT未有效地使用專門的矩陣乘法單元,並且在記憶層次結構之間產生昂貴的I/O。為此,我們提出了FlashFFTConv。FlashFFTConv使用一種計算FFT的矩陣分解,利用矩陣乘法單元並實現長序列的核融合,減少I/O。我們還提出了兩種稀疏卷積算法--1)部分卷積和2)頻率稀疏卷積--可以通過跳過矩陣分解中的塊簡單實現,從而為記憶體和計算節省提供進一步的機會。FlashFFTConv將精確FFT卷積的速度提高了高達7.93倍,超過了PyTorch,並實現了高達4.4倍的端到端加速。在相同的計算預算下,FlashFFTConv使Hyena-GPT-s在PILE上的困惑度提高了2.3個點,使M2-BERT-base的GLUE分數提高了3.3個點--與雙倍參數計數的模型相匹配。FlashFFTConv還在Path-512上實現了96.1%的準確率,這是一個高分辨率視覺任務,之前沒有任何模型能夠達到50%以上的準確率。此外,部分卷積使得更長序列的模型成為可能--產生了第一個可以處理最長人類基因(2.3M個鹼基對)的DNA模型--而頻率稀疏卷積則加速了預訓練模型,同時保持或提高模型質量。
English
Convolution models with long filters have demonstrated state-of-the-art
reasoning abilities in many long-sequence tasks but lag behind the most
optimized Transformers in wall-clock time. A major bottleneck is the Fast
Fourier Transform (FFT)--which allows long convolutions to run in O(N logN)
time in sequence length N but has poor hardware utilization. In this paper,
we study how to optimize the FFT convolution. We find two key bottlenecks: the
FFT does not effectively use specialized matrix multiply units, and it incurs
expensive I/O between layers of the memory hierarchy. In response, we propose
FlashFFTConv. FlashFFTConv uses a matrix decomposition that computes the FFT
using matrix multiply units and enables kernel fusion for long sequences,
reducing I/O. We also present two sparse convolution algorithms--1) partial
convolutions and 2) frequency-sparse convolutions--which can be implemented
simply by skipping blocks in the matrix decomposition, enabling further
opportunities for memory and compute savings. FlashFFTConv speeds up exact FFT
convolutions by up to 7.93times over PyTorch and achieves up to 4.4times
speedup end-to-end. Given the same compute budget, FlashFFTConv allows
Hyena-GPT-s to achieve 2.3 points better perplexity on the PILE and
M2-BERT-base to achieve 3.3 points higher GLUE score--matching models with
twice the parameter count. FlashFFTConv also achieves 96.1% accuracy on
Path-512, a high-resolution vision task where no model had previously achieved
better than 50%. Furthermore, partial convolutions enable longer-sequence
models--yielding the first DNA model that can process the longest human genes
(2.3M base pairs)--and frequency-sparse convolutions speed up pretrained models
while maintaining or improving model quality.