Dualidad Primero el Compilador del Espacio de Estados y Caché Autoregresivo Portátil O(1) para Inferencia
Compiler-First State Space Duality and Portable O(1) Autoregressive Caching for Inference
March 10, 2026
Autores: Cosmo Santoni
cs.AI
Resumen
Las versiones de modelos de espacio de estados suelen estar acopladas a kernels fusionados de CUDA y Triton, heredando una dependencia estricta del hardware de NVIDIA. Demostramos que el algoritmo de dualidad de espacio de estados de Mamba-2 —estructura de estados diagonal, recurrencia fragmentable y cómputo dominado por einsum con flujo de control estático— se adapta perfectamente a lo que las pasadas de fusión y mosaico de XLA optimizan realmente, haciendo que los kernels personalizados sean opcionales en lugar de obligatorios. Implementamos la ruta de inferencia completa (prellenado, decodificación autoregresiva en caché) como primitivas estándar moldeadas bajo XLA, sin kernels escritos a mano, y materializamos la gestión teórica O(1) de estados de la arquitectura como una caché compilada en el dispositivo que no requiere sincronización con el host durante la generación. La implementación se ejecuta sin modificaciones en CPU, GPU NVIDIA y TPU de Google Cloud a partir de una única fuente JAX. En TPU v5e a través de cinco escalas de modelo (130M–2.7B parámetros), el código generado por XLA alcanza aproximadamente 140 TFLOPS en prellenado de flujo único (15% MFU) y hasta un 64% de utilización de ancho de banda en decodificación. La decodificación voraz coincide con la referencia PyTorch/CUDA token por token a lo largo de 64 pasos, con concordancia del estado oculto dentro de la tolerancia de redondeo float32. El patrón se transfiere a cualquier recurrencia de SSM que satisfaga las mismas condiciones estructurales, en cualquier plataforma con un backend XLA maduro. La implementación está disponible públicamente en https://github.com/CosmoNaught/mamba2-jax e integrada en la biblioteca de modelos Bonsai JAX.
English
State-space model releases are typically coupled to fused CUDA and Triton kernels, inheriting a hard dependency on NVIDIA hardware. We show that Mamba-2's state space duality algorithm -- diagonal state structure, chunkable recurrence, and einsum-dominated compute with static control flow -- maps cleanly onto what XLA's fusion and tiling passes actually optimise, making custom kernels optional rather than required. We implement the full inference path (prefill, cached autoregressive decoding) as shaped standard primitives under XLA, without hand-written kernels, and realise the architecture's theoretical O(1) state management as a compiled on-device cache requiring no host synchronisation during generation. The implementation runs unmodified on CPU, NVIDIA GPU, and Google Cloud TPU from a single JAX source. On TPU v6e across five model scales (130M--2.7B parameters), XLA-generated code reaches approximately 140 TFLOPS on single-stream prefill (15% MFU) and up to 64% bandwidth utilisation on decode. Greedy decoding matches the PyTorch/CUDA reference token-for-token across 64 steps, with hidden-state agreement within float32 rounding tolerance. The pattern transfers to any SSM recurrence satisfying the same structural conditions, on any platform with a mature XLA backend. The implementation is publicly available at https://github.com/CosmoNaught/mamba2-jax and merged into the Bonsai JAX model library.