MambaMixer: Modelos de Espaço de Estados Seletivos Eficientes com Seleção Dual de Tokens e Canais
MambaMixer: Efficient Selective State Space Models with Dual Token and Channel Selection
March 29, 2024
Autores: Ali Behrouz, Michele Santacatterina, Ramin Zabih
cs.AI
Resumo
Os avanços recentes em aprendizado profundo têm se baseado principalmente em Transformers devido à sua dependência de dados e capacidade de aprender em escala. O módulo de atenção nessas arquiteturas, no entanto, exibe complexidade quadrática em tempo e espaço em relação ao tamanho da entrada, limitando sua escalabilidade para modelagem de sequências longas. Apesar de tentativas recentes de projetar arquiteturas eficientes e eficazes para dados multidimensionais, como imagens e séries temporais multivariadas, os modelos existentes são independentes de dados ou falham em permitir comunicação inter e intra-dimensões. Recentemente, Modelos de Espaço de Estados (SSMs), e mais especificamente Modelos de Espaço de Estados Seletivos, com implementação eficiente voltada para hardware, têm mostrado potencial promissor para modelagem de sequências longas. Motivados pelo sucesso dos SSMs, apresentamos o MambaMixer, uma nova arquitetura com pesos dependentes de dados que utiliza um mecanismo de seleção dupla entre tokens e canais, chamado de Selective Token and Channel Mixer. O MambaMixer conecta misturadores seletivos usando um mecanismo de média ponderada, permitindo que as camadas tenham acesso direto a características iniciais. Como prova de conceito, projetamos as arquiteturas Vision MambaMixer (ViM2) e Time Series MambaMixer (TSM2) com base no bloco MambaMixer e exploramos seu desempenho em várias tarefas de visão e previsão de séries temporais. Nossos resultados destacam a importância da mistura seletiva tanto entre tokens quanto entre canais. Em tarefas de classificação no ImageNet, detecção de objetos e segmentação semântica, o ViM2 alcança desempenho competitivo com modelos de visão bem estabelecidos e supera modelos de visão baseados em SSMs. Em previsão de séries temporais, o TSM2 alcança desempenho excepcional em comparação com métodos state-of-the-art, demonstrando um custo computacional significativamente melhorado. Esses resultados mostram que, embora Transformers, atenção entre canais e MLPs sejam suficientes para um bom desempenho em previsão de séries temporais, nenhum deles é necessário.
English
Recent advances in deep learning have mainly relied on Transformers due to
their data dependency and ability to learn at scale. The attention module in
these architectures, however, exhibits quadratic time and space in input size,
limiting their scalability for long-sequence modeling. Despite recent attempts
to design efficient and effective architecture backbone for multi-dimensional
data, such as images and multivariate time series, existing models are either
data independent, or fail to allow inter- and intra-dimension communication.
Recently, State Space Models (SSMs), and more specifically Selective State
Space Models, with efficient hardware-aware implementation, have shown
promising potential for long sequence modeling. Motivated by the success of
SSMs, we present MambaMixer, a new architecture with data-dependent weights
that uses a dual selection mechanism across tokens and channels, called
Selective Token and Channel Mixer. MambaMixer connects selective mixers using a
weighted averaging mechanism, allowing layers to have direct access to early
features. As a proof of concept, we design Vision MambaMixer (ViM2) and Time
Series MambaMixer (TSM2) architectures based on the MambaMixer block and
explore their performance in various vision and time series forecasting tasks.
Our results underline the importance of selective mixing across both tokens and
channels. In ImageNet classification, object detection, and semantic
segmentation tasks, ViM2 achieves competitive performance with well-established
vision models and outperforms SSM-based vision models. In time series
forecasting, TSM2 achieves outstanding performance compared to state-of-the-art
methods while demonstrating significantly improved computational cost. These
results show that while Transformers, cross-channel attention, and MLPs are
sufficient for good performance in time series forecasting, neither is
necessary.