コンパイラファースト状態空間双対性と推論のためのポータブルO(1)自己回帰キャッシング
Compiler-First State Space Duality and Portable O(1) Autoregressive Caching for Inference
March 10, 2026
著者: Cosmo Santoni
cs.AI
要旨
状態空間モデルの実装は、通常CUDAとTritonカーネルの融合に結びついており、NVIDIAハードウェアへの強固な依存性を引き継いでいる。我々は、Mamba-2の状態空間二重性アルゴリズム——対角状態構造、チャンク化可能な再帰、静的制御フローを伴うeinsum主体の計算——が、XLAの融合とタイリングパスが実際に最適化する対象にきれいにマッピングされ、カスタムカーネルを必須ではなくオプションにすることを示す。我々は、手書きのカーネルなしで、完全な推論パス(プリフィル、キャッシュされた自己回帰復号)をXLA下での形状付き標準プリミティブとして実装し、生成中のホスト同期を必要としないコンパイル済みオンデバイスキャッシュとして、このアーキテクチャの理論的なO(1)状態管理を実現する。この実装は、単一のJAXソースから、CPU、NVIDIA GPU、Google Cloud TPU上で変更なしに動作する。5つのモデル規模(1億3000万~27億パラメータ)にわたるTPU v6eでは、XLAが生成したコードは、シングルストリームプリフィルで約140 TFLOPS(15% MFU)に達し、復号時には最大64%の帯域幅利用率を示す。貪欲復号は、64ステップにわたってPyTorch/CUDAリファレンスとトークンレベルで一致し、隠れ状態の一致はfloat32の丸め誤差範囲内である。このパターンは、同じ構造条件を満たす任意のSSM再帰に転移可能であり、成熟したXLAバックエンドを備えた任意のプラットフォームで動作する。実装はhttps://github.com/CosmoNaught/mamba2-jax で公開されており、Bonsai JAXモデルライブラリにマージされている。
English
State-space model releases are typically coupled to fused CUDA and Triton kernels, inheriting a hard dependency on NVIDIA hardware. We show that Mamba-2's state space duality algorithm -- diagonal state structure, chunkable recurrence, and einsum-dominated compute with static control flow -- maps cleanly onto what XLA's fusion and tiling passes actually optimise, making custom kernels optional rather than required. We implement the full inference path (prefill, cached autoregressive decoding) as shaped standard primitives under XLA, without hand-written kernels, and realise the architecture's theoretical O(1) state management as a compiled on-device cache requiring no host synchronisation during generation. The implementation runs unmodified on CPU, NVIDIA GPU, and Google Cloud TPU from a single JAX source. On TPU v6e across five model scales (130M--2.7B parameters), XLA-generated code reaches approximately 140 TFLOPS on single-stream prefill (15% MFU) and up to 64% bandwidth utilisation on decode. Greedy decoding matches the PyTorch/CUDA reference token-for-token across 64 steps, with hidden-state agreement within float32 rounding tolerance. The pattern transfers to any SSM recurrence satisfying the same structural conditions, on any platform with a mature XLA backend. The implementation is publicly available at https://github.com/CosmoNaught/mamba2-jax and merged into the Bonsai JAX model library.