ChatPaper.aiChatPaper

FlashFFTConv:使用张量核对长序列进行高效卷积

FlashFFTConv: Efficient Convolutions for Long Sequences with Tensor Cores

November 10, 2023
作者: Daniel Y. Fu, Hermann Kumbong, Eric Nguyen, Christopher Ré
cs.AI

摘要

具有长滤波器的卷积模型在许多长序列任务中展现出最先进的推理能力,但在挂钟时间上落后于大多数优化的Transformer。一个主要瓶颈是快速傅里叶变换(FFT)--它允许长卷积在长度为N的序列中以O(N logN)的时间运行,但硬件利用率较低。在本文中,我们研究了如何优化FFT卷积。我们发现两个关键瓶颈:FFT没有有效地利用专门的矩阵乘法单元,并且在内存层次结构的层之间产生昂贵的I/O。作为回应,我们提出了FlashFFTConv。FlashFFTConv使用一种计算FFT的矩阵分解,利用矩阵乘法单元,并实现了长序列的内核融合,从而减少了I/O。我们还提出了两种稀疏卷积算法--1)部分卷积和2)频率稀疏卷积--可以通过在矩阵分解中跳过块来简单实现,为内存和计算节省提供了进一步的机会。FlashFFTConv将精确FFT卷积的速度提高了高达7.93倍,超过了PyTorch,并实现了高达4.4倍的端到端加速。在相同的计算预算下,FlashFFTConv使Hyena-GPT-s在PILE上的困惑度提高了2.3个点,使M2-BERT-base在GLUE分数上提高了3.3个点--与参数数量翻倍的模型相匹配。FlashFFTConv还在Path-512上实现了96.1%的准确率,这是一个高分辨率视觉任务,在此任务中以前没有任何模型能够达到50%以上的准确率。此外,部分卷积使得可以处理最长人类基因(2.3M碱基对)的第一个DNA模型成为可能,并且频率稀疏卷积加速了预训练模型,同时保持或提高了模型质量。
English
Convolution models with long filters have demonstrated state-of-the-art reasoning abilities in many long-sequence tasks but lag behind the most optimized Transformers in wall-clock time. A major bottleneck is the Fast Fourier Transform (FFT)--which allows long convolutions to run in O(N logN) time in sequence length N but has poor hardware utilization. In this paper, we study how to optimize the FFT convolution. We find two key bottlenecks: the FFT does not effectively use specialized matrix multiply units, and it incurs expensive I/O between layers of the memory hierarchy. In response, we propose FlashFFTConv. FlashFFTConv uses a matrix decomposition that computes the FFT using matrix multiply units and enables kernel fusion for long sequences, reducing I/O. We also present two sparse convolution algorithms--1) partial convolutions and 2) frequency-sparse convolutions--which can be implemented simply by skipping blocks in the matrix decomposition, enabling further opportunities for memory and compute savings. FlashFFTConv speeds up exact FFT convolutions by up to 7.93times over PyTorch and achieves up to 4.4times speedup end-to-end. Given the same compute budget, FlashFFTConv allows Hyena-GPT-s to achieve 2.3 points better perplexity on the PILE and M2-BERT-base to achieve 3.3 points higher GLUE score--matching models with twice the parameter count. FlashFFTConv also achieves 96.1% accuracy on Path-512, a high-resolution vision task where no model had previously achieved better than 50%. Furthermore, partial convolutions enable longer-sequence models--yielding the first DNA model that can process the longest human genes (2.3M base pairs)--and frequency-sparse convolutions speed up pretrained models while maintaining or improving model quality.
PDF161December 15, 2024