FlashFFTConv: Effiziente Faltungen für lange Sequenzen mit Tensor-Cores
FlashFFTConv: Efficient Convolutions for Long Sequences with Tensor Cores
November 10, 2023
Autoren: Daniel Y. Fu, Hermann Kumbong, Eric Nguyen, Christopher Ré
cs.AI
Zusammenfassung
Faltungsmodelle mit langen Filtern haben in vielen Aufgaben mit langen Sequenzen state-of-the-art Fähigkeiten im Bereich des Schlussfolgerns demonstriert, hinken jedoch den am stärksten optimierten Transformern in Bezug auf die Echtzeitleistung hinterher. Ein wesentlicher Engpass ist die Schnelle Fourier-Transformation (FFT), die es ermöglicht, lange Faltungen in O(N logN) Zeit bezogen auf die Sequenzlänge N auszuführen, jedoch eine schlechte Hardwareauslastung aufweist. In dieser Arbeit untersuchen wir, wie die FFT-Faltung optimiert werden kann. Wir identifizieren zwei zentrale Engpässe: Die FFT nutzt spezialisierte Matrix-Multiplikationseinheiten nicht effektiv und verursacht teure I/O-Operationen zwischen den Ebenen der Speicherhierarchie. Als Antwort darauf schlagen wir FlashFFTConv vor. FlashFFTConv verwendet eine Matrixzerlegung, die die FFT mithilfe von Matrix-Multiplikationseinheiten berechnet und Kernel-Fusion für lange Sequenzen ermöglicht, wodurch I/O reduziert wird. Wir stellen außerdem zwei Algorithmen für spärliche Faltungen vor – 1) partielle Faltungen und 2) frequenzspärliche Faltungen – die einfach durch das Überspringen von Blöcken in der Matrixzerlegung implementiert werden können, was weitere Möglichkeiten zur Einsparung von Speicher und Rechenleistung bietet. FlashFFTConv beschleunigt exakte FFT-Faltungen um bis zu 7,93-mal gegenüber PyTorch und erreicht eine bis zu 4,4-fache Beschleunigung end-to-end. Bei gleichem Rechenbudget ermöglicht FlashFFTConv Hyena-GPT-s, eine um 2,3 Punkte bessere Perplexität auf dem PILE zu erreichen, und M2-BERT-base, eine um 3,3 Punkte höhere GLUE-Bewertung zu erzielen – was Modellen mit doppelter Parameteranzahl entspricht. FlashFFTConv erreicht außerdem eine Genauigkeit von 96,1 % auf Path-512, einer hochauflösenden Bildverarbeitungsaufgabe, bei der bisher kein Modell eine bessere Genauigkeit als 50 % erzielt hatte. Darüber hinaus ermöglichen partielle Faltungen Modelle für längere Sequenzen – was das erste DNA-Modell hervorbringt, das die längsten menschlichen Gene (2,3 Millionen Basenpaare) verarbeiten kann – und frequenzspärliche Faltungen beschleunigen vortrainierte Modelle, während die Modellqualität erhalten bleibt oder sogar verbessert wird.
English
Convolution models with long filters have demonstrated state-of-the-art
reasoning abilities in many long-sequence tasks but lag behind the most
optimized Transformers in wall-clock time. A major bottleneck is the Fast
Fourier Transform (FFT)--which allows long convolutions to run in O(N logN)
time in sequence length N but has poor hardware utilization. In this paper,
we study how to optimize the FFT convolution. We find two key bottlenecks: the
FFT does not effectively use specialized matrix multiply units, and it incurs
expensive I/O between layers of the memory hierarchy. In response, we propose
FlashFFTConv. FlashFFTConv uses a matrix decomposition that computes the FFT
using matrix multiply units and enables kernel fusion for long sequences,
reducing I/O. We also present two sparse convolution algorithms--1) partial
convolutions and 2) frequency-sparse convolutions--which can be implemented
simply by skipping blocks in the matrix decomposition, enabling further
opportunities for memory and compute savings. FlashFFTConv speeds up exact FFT
convolutions by up to 7.93times over PyTorch and achieves up to 4.4times
speedup end-to-end. Given the same compute budget, FlashFFTConv allows
Hyena-GPT-s to achieve 2.3 points better perplexity on the PILE and
M2-BERT-base to achieve 3.3 points higher GLUE score--matching models with
twice the parameter count. FlashFFTConv also achieves 96.1% accuracy on
Path-512, a high-resolution vision task where no model had previously achieved
better than 50%. Furthermore, partial convolutions enable longer-sequence
models--yielding the first DNA model that can process the longest human genes
(2.3M base pairs)--and frequency-sparse convolutions speed up pretrained models
while maintaining or improving model quality.