ChatPaper.aiChatPaper

Dualità Spazio degli Stati Compiler-First e Caching Autoregressivo Portatile O(1) per l'Inferenza

Compiler-First State Space Duality and Portable O(1) Autoregressive Caching for Inference

March 10, 2026
Autori: Cosmo Santoni
cs.AI

Abstract

I rilasci di modelli state-space sono tipicamente accoppiati a kernel CUDA e Triton fusi, ereditando una forte dipendenza dall'hardware NVIDIA. Dimostriamo che l'algoritmo di dualità state-space di Mamba-2 – struttura di stato diagonale, ricorrenza suddivisibile in blocchi e calcolo dominato da einsum con flusso di controllo statico – si adatta perfettamente a ciò che le passate di fusione e tiling di XLA ottimizzano effettivamente, rendendo i kernel personalizzati opzionali piuttosto che obbligatori. Implementiamo l'intero percorso di inferenza (prefill, decodifica autoregressiva in cache) come primitive standard conformate sotto XLA, senza kernel scritti a mano, e realizziamo la gestione dello stato teorica O(1) dell'architettura come una cache compilata sul dispositivo che non richiede sincronizzazione con l'host durante la generazione. L'implementazione viene eseguita senza modifiche su CPU, GPU NVIDIA e Google Cloud TPU da un'unica sorgente JAX. Su TPU v6e attraverso cinque scale del modello (130M–2.7B parametri), il codice generato da XLA raggiunge circa 140 TFLOPS su prefill a flusso singolo (15% MFU) e fino al 64% di utilizzo della banda su decode. La decodifica greedy corrisponde al riferimento PyTorch/CUDA token-per-token attraverso 64 passi, con accordo dello stato nascosto entro la tolleranza di arrotondamento float32. Lo schema si trasferisce a qualsiasi ricorrenza SSM che soddisfi le stesse condizioni strutturali, su qualsiasi piattaforma con un backend XLA maturo. L'implementazione è pubblicamente disponibile all'indirizzo https://github.com/CosmoNaught/mamba2-jax e integrata nella libreria di modelli Bonsai JAX.
English
State-space model releases are typically coupled to fused CUDA and Triton kernels, inheriting a hard dependency on NVIDIA hardware. We show that Mamba-2's state space duality algorithm -- diagonal state structure, chunkable recurrence, and einsum-dominated compute with static control flow -- maps cleanly onto what XLA's fusion and tiling passes actually optimise, making custom kernels optional rather than required. We implement the full inference path (prefill, cached autoregressive decoding) as shaped standard primitives under XLA, without hand-written kernels, and realise the architecture's theoretical O(1) state management as a compiled on-device cache requiring no host synchronisation during generation. The implementation runs unmodified on CPU, NVIDIA GPU, and Google Cloud TPU from a single JAX source. On TPU v6e across five model scales (130M--2.7B parameters), XLA-generated code reaches approximately 140 TFLOPS on single-stream prefill (15% MFU) and up to 64% bandwidth utilisation on decode. Greedy decoding matches the PyTorch/CUDA reference token-for-token across 64 steps, with hidden-state agreement within float32 rounding tolerance. The pattern transfers to any SSM recurrence satisfying the same structural conditions, on any platform with a mature XLA backend. The implementation is publicly available at https://github.com/CosmoNaught/mamba2-jax and merged into the Bonsai JAX model library.
PDF12March 26, 2026