Compiler-First State Space Dualiteit en Draagbare O(1) Autoregressieve Caching voor Inferentie
Compiler-First State Space Duality and Portable O(1) Autoregressive Caching for Inference
March 10, 2026
Auteurs: Cosmo Santoni
cs.AI
Samenvatting
State-space model-implementaties zijn doorgaans gekoppeld aan gefuseerde CUDA- en Triton-kernels, wat een harde afhankelijkheid van NVIDIA-hardware met zich meebrengt. Wij tonen aan dat Mamba-2's state-space dualiteitsalgoritme – diagonale staatstructuur, chunkbare recurrentie en einsum-gedomineerd rekenwerk met statische control flow – naadloos aansluit bij wat XLA's fusie- en tiling-passes daadwerkelijk optimaliseren, waardoor aangepaste kernels optioneel worden in plaats van vereist. Wij implementeren het volledige inferentiepad (prefill, gecachte autoregressieve decodering) als gevormde standaardprimitieven onder XLA, zonder handgeschreven kernels, en realiseren de architectuur's theoretische O(1) staatbeheer als een gecompileerde on-device cache die geen hostsynchronisatie vereist tijdens generatie. De implementatie draait ongewijzigd op CPU, NVIDIA GPU en Google Cloud TPU vanuit een enkele JAX-bron. Op TPU v5e over vijf modelschalen (130M–2,7B parameters) bereikt XLA-gegenereerde code ongeveer 140 TFLOPS op single-stream prefill (15% MFU) en tot 64% bandbreedtebenutting bij decodering. Greedy decodering komt token-voor-token overeen met de PyTorch/CUDA-referentie over 64 stappen, met overeenstemming van de verborgen toestanden binnen de float32-afrondingstolerantie. Het patroon is overdraagbaar naar elke SSM-recurrentie die aan dezelfde structurele voorwaarden voldoet, op elk platform met een volwassen XLA-backend. De implementatie is publiekelijk beschikbaar op https://github.com/CosmoNaught/mamba2-jax en opgenomen in de Bonsai JAX-modelbibliotheek.
English
State-space model releases are typically coupled to fused CUDA and Triton kernels, inheriting a hard dependency on NVIDIA hardware. We show that Mamba-2's state space duality algorithm -- diagonal state structure, chunkable recurrence, and einsum-dominated compute with static control flow -- maps cleanly onto what XLA's fusion and tiling passes actually optimise, making custom kernels optional rather than required. We implement the full inference path (prefill, cached autoregressive decoding) as shaped standard primitives under XLA, without hand-written kernels, and realise the architecture's theoretical O(1) state management as a compiled on-device cache requiring no host synchronisation during generation. The implementation runs unmodified on CPU, NVIDIA GPU, and Google Cloud TPU from a single JAX source. On TPU v6e across five model scales (130M--2.7B parameters), XLA-generated code reaches approximately 140 TFLOPS on single-stream prefill (15% MFU) and up to 64% bandwidth utilisation on decode. Greedy decoding matches the PyTorch/CUDA reference token-for-token across 64 steps, with hidden-state agreement within float32 rounding tolerance. The pattern transfers to any SSM recurrence satisfying the same structural conditions, on any platform with a mature XLA backend. The implementation is publicly available at https://github.com/CosmoNaught/mamba2-jax and merged into the Bonsai JAX model library.