ChatPaper.aiChatPaper

컴파일러 우선 상태 공간 이중성 및 추론을 위한 이식 가능한 O(1) 자기회귀 캐싱

Compiler-First State Space Duality and Portable O(1) Autoregressive Caching for Inference

March 10, 2026
저자: Cosmo Santoni
cs.AI

초록

일반적으로 상태 공간 모델 릴리스는 융합된 CUDA 및 Triton 커널과 결합되어 NVIDIA 하드웨어에 대한 강력한 의존성을 상속받습니다. 우리는 Mamba-2의 상태 공간 이중성 알고리즘(대각선 상태 구조, 청킹 가능한 순환, 정적 제어 흐름으로 지배되는 einsum 연산)이 XLA의 퓨전 및 타일링 패스가 실제로 최적화하는 대상에 깔끔하게 매핑되어 커스텀 커널을 필수가 아닌 선택 사항으로 만든다는 점을 보여줍니다. 우리는 핸드라이팅 커널 없이 XLA 하에서 형상화된 표준 프리미티브로 전체 추론 경로(프리필, 캐시된 자기회귀 디코딩)를 구현하고, 생성 과정 중 호스트 동기화가 필요 없는 컴파일된 온디바이스 캐시로 아키텍처의 이론적 O(1) 상태 관리를 실현합니다. 이 구현은 단일 JAX 소스 코드로 CPU, NVIDIA GPU 및 Google Cloud TPU에서 수정 없이 실행됩니다. 5가지 모델 규모(1.3억~27억 매개변수)에 걸친 TPU v6e에서 XLA 생성 코드는 단일 스트림 프리필에서 약 140 TFLOPS(15% MFU)에 도달하고 디코딩에서 최대 64% 대역폭 활용률을 보입니다. 그리디 디코딩은 64단계에 걸쳐 PyTorch/CUDA 기준과 토큰 단위로 일치하며, 은닉 상태 일치는 float32 반올림 허용 오차 범위 내입니다. 이 패턴은 동일한 구조적 조건을 만족하는 모든 SSM 순환으로, 성숙한 XLA 백엔드를 가진 모든 플랫폼에 이전 가능합니다. 구현은 https://github.com/CosmoNaught/mamba2-jax에서 공개적으로 이용 가능하며 Bonsai JAX 모델 라이브러리에 병합되었습니다.
English
State-space model releases are typically coupled to fused CUDA and Triton kernels, inheriting a hard dependency on NVIDIA hardware. We show that Mamba-2's state space duality algorithm -- diagonal state structure, chunkable recurrence, and einsum-dominated compute with static control flow -- maps cleanly onto what XLA's fusion and tiling passes actually optimise, making custom kernels optional rather than required. We implement the full inference path (prefill, cached autoregressive decoding) as shaped standard primitives under XLA, without hand-written kernels, and realise the architecture's theoretical O(1) state management as a compiled on-device cache requiring no host synchronisation during generation. The implementation runs unmodified on CPU, NVIDIA GPU, and Google Cloud TPU from a single JAX source. On TPU v6e across five model scales (130M--2.7B parameters), XLA-generated code reaches approximately 140 TFLOPS on single-stream prefill (15% MFU) and up to 64% bandwidth utilisation on decode. Greedy decoding matches the PyTorch/CUDA reference token-for-token across 64 steps, with hidden-state agreement within float32 rounding tolerance. The pattern transfers to any SSM recurrence satisfying the same structural conditions, on any platform with a mature XLA backend. The implementation is publicly available at https://github.com/CosmoNaught/mamba2-jax and merged into the Bonsai JAX model library.
PDF11March 12, 2026