编译器优先状态空间对偶性与面向推理的便携式O(1)自回归缓存
Compiler-First State Space Duality and Portable O(1) Autoregressive Caching for Inference
March 10, 2026
作者: Cosmo Santoni
cs.AI
摘要
状态空间模型的发布通常与融合的CUDA及Triton内核耦合,这导致其对NVIDIA硬件存在强依赖。我们证明Mamba-2的状态空间对偶算法——包括对角化状态结构、可分块的递归计算、以einsum为主导的静态控制流计算——能够完美映射至XLA融合与分块优化通道的实际优化逻辑,使得定制内核成为可选而非必选项。我们在XLA框架下将完整推理流程(预填充、缓存式自回归解码)实现为具有确定形状的标准原语,无需手写内核,并将该架构理论上的O(1)状态管理实现为无需生成过程中主机同步的编译端缓存。该实现可通过单一JAX代码库在CPU、NVIDIA GPU和谷歌云TPU上直接运行。在TPU v6e上针对五种模型规模(1.3亿至27亿参数)的测试表明,XLA生成代码在单流预填充中达到约140 TFLOPS(15%模型浮点利用率),解码时带宽利用率最高达64%。贪婪解码在64步生成中与PyTorch/CUDA参考实现逐令牌一致,隐藏状态差异保持在float32舍入容限内。该模式可迁移至满足相同结构条件的任意状态空间模型递归计算,适用于所有具备成熟XLA后端的平台。实现代码已公开于https://github.com/CosmoNaught/mamba2-jax 并并入Bonsai JAX模型库。
English
State-space model releases are typically coupled to fused CUDA and Triton kernels, inheriting a hard dependency on NVIDIA hardware. We show that Mamba-2's state space duality algorithm -- diagonal state structure, chunkable recurrence, and einsum-dominated compute with static control flow -- maps cleanly onto what XLA's fusion and tiling passes actually optimise, making custom kernels optional rather than required. We implement the full inference path (prefill, cached autoregressive decoding) as shaped standard primitives under XLA, without hand-written kernels, and realise the architecture's theoretical O(1) state management as a compiled on-device cache requiring no host synchronisation during generation. The implementation runs unmodified on CPU, NVIDIA GPU, and Google Cloud TPU from a single JAX source. On TPU v6e across five model scales (130M--2.7B parameters), XLA-generated code reaches approximately 140 TFLOPS on single-stream prefill (15% MFU) and up to 64% bandwidth utilisation on decode. Greedy decoding matches the PyTorch/CUDA reference token-for-token across 64 steps, with hidden-state agreement within float32 rounding tolerance. The pattern transfers to any SSM recurrence satisfying the same structural conditions, on any platform with a mature XLA backend. The implementation is publicly available at https://github.com/CosmoNaught/mamba2-jax and merged into the Bonsai JAX model library.