Dualidade Estado-Espaço com Abordagem Compiler-First e Cache Autoregressivo Portátil O(1) para Inferência
Compiler-First State Space Duality and Portable O(1) Autoregressive Caching for Inference
March 10, 2026
Autores: Cosmo Santoni
cs.AI
Resumo
As versões de modelos de espaço de estados são tipicamente acopladas a kernels fundidos de CUDA e Triton, herdando uma dependência rígida de hardware NVIDIA. Demonstramos que o algoritmo de dualidade de espaço de estados do Mamba-2 — estrutura de estados diagonal, recorrência segmentável e computação dominada por einsum com fluxo de controle estático — mapeia-se perfeitamente no que as passagens de fusão e blocagem (tiling) do XLA realmente otimizam, tornando os kernels personalizados opcionais em vez de obrigatórios. Implementamos o caminho completo de inferência (pré-preenchimento, decodagem autorregressiva em cache) como primitivas padrão formatadas no XLA, sem kernels escritos manualmente, e realizamos o gerenciamento de estados teórico O(1) da arquitetura como uma cache compilada no dispositivo que não requer sincronização com o host durante a geração. A implementação é executada sem modificações em CPU, GPU NVIDIA e Google Cloud TPU a partir de uma única fonte em JAX. No TPU v6e em cinco escalas de modelo (130M–2.7B de parâmetros), o código gerado pelo XLA atinge aproximadamente 140 TFLOPS no pré-preenchimento de fluxo único (15% MFU) e até 64% de utilização de largura de banda na decodagem. A decodagem gulosa (greedy) corresponde à referência PyTorch/CUDA token por token ao longo de 64 passos, com concordância do estado oculto dentro da tolerância de arredondamento float32. O padrão transfere-se para qualquer recorrência de SSM que satisfaça as mesmas condições estruturais, em qualquer plataforma com um backend XLA maduro. A implementação está publicamente disponível em https://github.com/CosmoNaught/mamba2-jax e foi incorporada à biblioteca de modelos Bonsai JAX.
English
State-space model releases are typically coupled to fused CUDA and Triton kernels, inheriting a hard dependency on NVIDIA hardware. We show that Mamba-2's state space duality algorithm -- diagonal state structure, chunkable recurrence, and einsum-dominated compute with static control flow -- maps cleanly onto what XLA's fusion and tiling passes actually optimise, making custom kernels optional rather than required. We implement the full inference path (prefill, cached autoregressive decoding) as shaped standard primitives under XLA, without hand-written kernels, and realise the architecture's theoretical O(1) state management as a compiled on-device cache requiring no host synchronisation during generation. The implementation runs unmodified on CPU, NVIDIA GPU, and Google Cloud TPU from a single JAX source. On TPU v6e across five model scales (130M--2.7B parameters), XLA-generated code reaches approximately 140 TFLOPS on single-stream prefill (15% MFU) and up to 64% bandwidth utilisation on decode. Greedy decoding matches the PyTorch/CUDA reference token-for-token across 64 steps, with hidden-state agreement within float32 rounding tolerance. The pattern transfers to any SSM recurrence satisfying the same structural conditions, on any platform with a mature XLA backend. The implementation is publicly available at https://github.com/CosmoNaught/mamba2-jax and merged into the Bonsai JAX model library.