FlashFFTConv: 텐서 코어를 활용한 장거리 시퀀스의 효율적 컨볼루션
FlashFFTConv: Efficient Convolutions for Long Sequences with Tensor Cores
November 10, 2023
저자: Daniel Y. Fu, Hermann Kumbong, Eric Nguyen, Christopher Ré
cs.AI
초록
긴 필터를 가진 컨볼루션 모델은 많은 장기 시퀀스 작업에서 최첨단 추론 능력을 보여주었지만, 실제 실행 시간 측면에서는 최적화된 트랜스포머 모델에 뒤처져 있습니다. 주요 병목 현상은 고속 푸리에 변환(FFT)에 있습니다. FFT는 긴 컨볼루션을 시퀀스 길이 N에 대해 O(N logN) 시간에 실행할 수 있게 하지만, 하드웨어 활용도가 낮습니다. 본 논문에서는 FFT 컨볼루션을 최적화하는 방법을 연구합니다. 우리는 두 가지 주요 병목 현상을 발견했습니다: FFT는 전용 행렬 곱셈 유닛을 효과적으로 사용하지 못하며, 메모리 계층 간에 비용이 많이 드는 I/O를 발생시킵니다. 이를 해결하기 위해 FlashFFTConv를 제안합니다. FlashFFTConv는 행렬 분해를 사용하여 FFT를 행렬 곱셈 유닛으로 계산하고, 긴 시퀀스에 대한 커널 퓨전을 가능하게 하여 I/O를 줄입니다. 또한 두 가지 희소 컨볼루션 알고리즘을 제시합니다: 1) 부분 컨볼루션과 2) 주파수 희소 컨볼루션. 이 알고리즘들은 행렬 분해에서 블록을 건너뛰는 방식으로 간단히 구현할 수 있어, 메모리와 계산 비용을 더욱 절약할 수 있습니다. FlashFFTConv는 PyTorch 대비 정확한 FFT 컨볼루션을 최대 7.93배 빠르게 수행하며, 종단 간 최대 4.4배의 속도 향상을 달성합니다. 동일한 계산 예산 내에서 FlashFFTConv는 Hyena-GPT-s가 PILE 데이터셋에서 2.3점 더 나은 퍼플렉서티를 달성하고, M2-BERT-base가 GLUE 점수에서 3.3점 더 높은 성적을 거두도록 하여, 매개변수 수가 두 배인 모델과 동등한 성능을 보입니다. 또한 FlashFFTConv는 고해상도 비전 작업인 Path-512에서 96.1%의 정확도를 달성했는데, 이는 이전에 어떤 모델도 50%를 넘지 못했던 작업입니다. 더 나아가, 부분 컨볼루션은 더 긴 시퀀스 모델을 가능하게 하여, 가장 긴 인간 유전자(230만 염기쌍)를 처리할 수 있는 첫 번째 DNA 모델을 구현했으며, 주파수 희소 컨볼루션은 사전 훈련된 모델의 속도를 높이면서 모델 품질을 유지하거나 개선합니다.
English
Convolution models with long filters have demonstrated state-of-the-art
reasoning abilities in many long-sequence tasks but lag behind the most
optimized Transformers in wall-clock time. A major bottleneck is the Fast
Fourier Transform (FFT)--which allows long convolutions to run in O(N logN)
time in sequence length N but has poor hardware utilization. In this paper,
we study how to optimize the FFT convolution. We find two key bottlenecks: the
FFT does not effectively use specialized matrix multiply units, and it incurs
expensive I/O between layers of the memory hierarchy. In response, we propose
FlashFFTConv. FlashFFTConv uses a matrix decomposition that computes the FFT
using matrix multiply units and enables kernel fusion for long sequences,
reducing I/O. We also present two sparse convolution algorithms--1) partial
convolutions and 2) frequency-sparse convolutions--which can be implemented
simply by skipping blocks in the matrix decomposition, enabling further
opportunities for memory and compute savings. FlashFFTConv speeds up exact FFT
convolutions by up to 7.93times over PyTorch and achieves up to 4.4times
speedup end-to-end. Given the same compute budget, FlashFFTConv allows
Hyena-GPT-s to achieve 2.3 points better perplexity on the PILE and
M2-BERT-base to achieve 3.3 points higher GLUE score--matching models with
twice the parameter count. FlashFFTConv also achieves 96.1% accuracy on
Path-512, a high-resolution vision task where no model had previously achieved
better than 50%. Furthermore, partial convolutions enable longer-sequence
models--yielding the first DNA model that can process the longest human genes
(2.3M base pairs)--and frequency-sparse convolutions speed up pretrained models
while maintaining or improving model quality.