ChatPaper.aiChatPaper

Dualité Espace d'État Compilateur-First et Cache Autoregressif Portable O(1) pour l'Inférence

Compiler-First State Space Duality and Portable O(1) Autoregressive Caching for Inference

March 10, 2026
Auteurs: Cosmo Santoni
cs.AI

Résumé

Les implémentations de modèles à espace d'états sont généralement couplées à des noyaux CUDA et Triton fusionnés, héritant d'une dépendance matérielle contraignante envers NVIDIA. Nous démontrons que l'algorithme de dualité des espaces d'états de Mamba-2 - structure d'état diagonale, récurrence segmentable, et calcul dominé par einsum avec flux de contrôle statique - s'applique parfaitement aux optimisations réellement effectuées par les passes de fusion et de pavage de XLA, rendant les noyaux personnalisés optionnels plutôt que nécessaires. Nous implémentons le chemin d'inférence complet (préremplissage, décodage autorégressif avec cache) sous forme de primitives standard structurées dans XLA, sans noyaux écrits manuellement, et matérialisons la gestion théorique O(1) de l'état de l'architecture sous forme de cache compilé sur périphérique ne nécessitant aucune synchronisation hôte pendant la génération. L'implémentation s'exécute sans modification sur CPU, GPU NVIDIA et TPU Google Cloud à partir d'une unique source JAX. Sur TPU v6e à travers cinq échelles de modèles (130M à 2,7B paramètres), le code généré par XLA atteint environ 140 TFLOPS en préremplissage mono-flux (15% MFU) et jusqu'à 64% d'utilisation de bande passante en décodage. Le décodage glouton correspond parfaitement token-à-token à la référence PyTorch/CUDA sur 64 étapes, avec un accord des états cachés dans la tolérance d'arrondi float32. Ce schéma est transférable à toute récurrence SSM satisfaisant les mêmes conditions structurelles, sur toute plateforme disposant d'un backend XLA mature. L'implémentation est publiquement disponible à l'adresse https://github.com/CosmoNaught/mamba2-jax et fusionnée dans la bibliothèque de modèles Bonsai JAX.
English
State-space model releases are typically coupled to fused CUDA and Triton kernels, inheriting a hard dependency on NVIDIA hardware. We show that Mamba-2's state space duality algorithm -- diagonal state structure, chunkable recurrence, and einsum-dominated compute with static control flow -- maps cleanly onto what XLA's fusion and tiling passes actually optimise, making custom kernels optional rather than required. We implement the full inference path (prefill, cached autoregressive decoding) as shaped standard primitives under XLA, without hand-written kernels, and realise the architecture's theoretical O(1) state management as a compiled on-device cache requiring no host synchronisation during generation. The implementation runs unmodified on CPU, NVIDIA GPU, and Google Cloud TPU from a single JAX source. On TPU v6e across five model scales (130M--2.7B parameters), XLA-generated code reaches approximately 140 TFLOPS on single-stream prefill (15% MFU) and up to 64% bandwidth utilisation on decode. Greedy decoding matches the PyTorch/CUDA reference token-for-token across 64 steps, with hidden-state agreement within float32 rounding tolerance. The pattern transfers to any SSM recurrence satisfying the same structural conditions, on any platform with a mature XLA backend. The implementation is publicly available at https://github.com/CosmoNaught/mamba2-jax and merged into the Bonsai JAX model library.
PDF11March 12, 2026