FlashFFTConv: Convoluções Eficientes para Sequências Longas com Tensor Cores
FlashFFTConv: Efficient Convolutions for Long Sequences with Tensor Cores
November 10, 2023
Autores: Daniel Y. Fu, Hermann Kumbong, Eric Nguyen, Christopher Ré
cs.AI
Resumo
Modelos convolucionais com filtros longos demonstraram capacidades de raciocínio de última geração em muitas tarefas de sequências longas, mas ficam atrás dos Transformers mais otimizados em termos de tempo de execução. Um grande gargalo é a Transformada Rápida de Fourier (FFT) - que permite que convoluções longas sejam executadas em tempo O(N logN) para uma sequência de comprimento N, mas tem baixa utilização de hardware. Neste artigo, estudamos como otimizar a convolução FFT. Identificamos dois gargalos principais: a FFT não utiliza efetivamente unidades especializadas de multiplicação de matrizes, e incorre em operações de I/O caras entre as camadas da hierarquia de memória. Em resposta, propomos o FlashFFTConv. O FlashFFTConv utiliza uma decomposição matricial que calcula a FFT usando unidades de multiplicação de matrizes e permite a fusão de kernels para sequências longas, reduzindo o I/O. Também apresentamos dois algoritmos de convolução esparsa - 1) convoluções parciais e 2) convoluções esparsas em frequência - que podem ser implementados simplesmente pulando blocos na decomposição matricial, permitindo mais oportunidades de economia de memória e computação. O FlashFFTConv acelera convoluções FFT exatas em até 7,93 vezes em relação ao PyTorch e alcança um ganho de velocidade de até 4,4 vezes de ponta a ponta. Com o mesmo orçamento de computação, o FlashFFTConv permite que o Hyena-GPT-s alcance 2,3 pontos a menos de perplexidade no PILE e que o M2-BERT-base alcance 3,3 pontos a mais no GLUE score - igualando modelos com o dobro do número de parâmetros. O FlashFFTConv também alcança 96,1% de precisão no Path-512, uma tarefa de visão de alta resolução onde nenhum modelo havia alcançado anteriormente mais de 50%. Além disso, as convoluções parciais permitem modelos de sequências mais longas - produzindo o primeiro modelo de DNA que pode processar os genes humanos mais longos (2,3 milhões de pares de bases) - e as convoluções esparsas em frequência aceleram modelos pré-treinados enquanto mantêm ou melhoram a qualidade do modelo.
English
Convolution models with long filters have demonstrated state-of-the-art
reasoning abilities in many long-sequence tasks but lag behind the most
optimized Transformers in wall-clock time. A major bottleneck is the Fast
Fourier Transform (FFT)--which allows long convolutions to run in O(N logN)
time in sequence length N but has poor hardware utilization. In this paper,
we study how to optimize the FFT convolution. We find two key bottlenecks: the
FFT does not effectively use specialized matrix multiply units, and it incurs
expensive I/O between layers of the memory hierarchy. In response, we propose
FlashFFTConv. FlashFFTConv uses a matrix decomposition that computes the FFT
using matrix multiply units and enables kernel fusion for long sequences,
reducing I/O. We also present two sparse convolution algorithms--1) partial
convolutions and 2) frequency-sparse convolutions--which can be implemented
simply by skipping blocks in the matrix decomposition, enabling further
opportunities for memory and compute savings. FlashFFTConv speeds up exact FFT
convolutions by up to 7.93times over PyTorch and achieves up to 4.4times
speedup end-to-end. Given the same compute budget, FlashFFTConv allows
Hyena-GPT-s to achieve 2.3 points better perplexity on the PILE and
M2-BERT-base to achieve 3.3 points higher GLUE score--matching models with
twice the parameter count. FlashFFTConv also achieves 96.1% accuracy on
Path-512, a high-resolution vision task where no model had previously achieved
better than 50%. Furthermore, partial convolutions enable longer-sequence
models--yielding the first DNA model that can process the longest human genes
(2.3M base pairs)--and frequency-sparse convolutions speed up pretrained models
while maintaining or improving model quality.