ChatPaper.aiChatPaper

Compiler-First-Zustandsraum-Dualität und portables O(1)-autoregressives Caching für Inferenz

Compiler-First State Space Duality and Portable O(1) Autoregressive Caching for Inference

March 10, 2026
Autoren: Cosmo Santoni
cs.AI

Zusammenfassung

State-Space-Model-Releases sind typischerweise mit fusionierten CUDA- und Triton-Kerneln gekoppelt, was eine feste Abhängigkeit von NVIDIA-Hardware zur Folge hat. Wir zeigen, dass sich Mamba-2s State-Space-Dualitätsalgorithmus – diagonale Zustandsstruktur, chunk-fähige Rekurrenz und einsum-dominierte Berechnungen mit statischem Kontrollfluss – sauber auf das abbildet, was XLAs Fusions- und Tiling-Pässe tatsächlich optimieren, wodurch benutzerdefinierte Kernel optional statt erforderlich werden. Wir implementieren den vollständigen Inferenzpfad (Prefill, gecachte autoregressive Decodierung) als geformte Standardprimitive unter XLA, ohne handgeschriebene Kernel, und realisieren das theoretische O(1)-Zustandsmanagement der Architektur als einen kompilierten On-Device-Cache, der während der Generierung keine Host-Synchronisation erfordert. Die Implementierung läuft unverändert auf CPU, NVIDIA-GPU und Google Cloud TPU aus einer einzigen JAX-Quelle. Auf TPU v6e über fünf Modellgrößen (130M–2,7B Parameter) erreicht der von XLA generierte Code etwa 140 TFLOPS beim Single-Stream-Prefill (15% MFU) und bis zu 64% Bandbreitenauslastung beim Decode. Greedy-Decoding stimmt token-für-token mit der PyTorch/CUDA-Referenz über 64 Schritte überein, mit einer Übereinstimmung der Hidden-States innerhalb der float32-Rundungstoleranz. Das Muster überträgt sich auf jede SSM-Rekurrenz, die dieselben strukturellen Bedingungen erfüllt, auf jeder Plattform mit einem ausgereiften XLA-Backend. Die Implementierung ist öffentlich verfügbar unter https://github.com/CosmoNaught/mamba2-jax und in die Bonsai JAX Model Library integriert.
English
State-space model releases are typically coupled to fused CUDA and Triton kernels, inheriting a hard dependency on NVIDIA hardware. We show that Mamba-2's state space duality algorithm -- diagonal state structure, chunkable recurrence, and einsum-dominated compute with static control flow -- maps cleanly onto what XLA's fusion and tiling passes actually optimise, making custom kernels optional rather than required. We implement the full inference path (prefill, cached autoregressive decoding) as shaped standard primitives under XLA, without hand-written kernels, and realise the architecture's theoretical O(1) state management as a compiled on-device cache requiring no host synchronisation during generation. The implementation runs unmodified on CPU, NVIDIA GPU, and Google Cloud TPU from a single JAX source. On TPU v6e across five model scales (130M--2.7B parameters), XLA-generated code reaches approximately 140 TFLOPS on single-stream prefill (15% MFU) and up to 64% bandwidth utilisation on decode. Greedy decoding matches the PyTorch/CUDA reference token-for-token across 64 steps, with hidden-state agreement within float32 rounding tolerance. The pattern transfers to any SSM recurrence satisfying the same structural conditions, on any platform with a mature XLA backend. The implementation is publicly available at https://github.com/CosmoNaught/mamba2-jax and merged into the Bonsai JAX model library.
PDF11March 12, 2026