FlashSampling: 高速でメモリ効率の良い正確なサンプリング
FlashSampling: Fast and Memory-Efficient Exact Sampling
March 16, 2026
著者: Tomas Ruiz, Zhen Qin, Yifan Zhang, Xuyang Shen, Yiran Zhong, Mengdi Wang
cs.AI
要旨
カテゴリカル分布からのサンプリングは数学的には単純だが、大規模語彙デコーディングでは、LMヘッドの後に余分なメモリトラフィックや余分なカーネルを引き起こすことが多い。本論文では、サンプリングをLMヘッドの行列乗算に融合し、ロジットテンソルをHBMに実体化しない正確なサンプリングプリミティブ「FlashSampling」を提案する。手法は単純である:オンチップでロジットをタイイルごとに計算し、ガンベルノイズを加え、行と語彙タイルごとに最大値のみを保持し、最後にタイル間の小規模なリダクションを行う。融合されたタイルカーネルは、argmaxが分割に対して分解可能であるため正確である。オンライン設定およびテンソル並列設定のためのグループ化変種は、カテゴリカル分布の階層的因子分解によって正確である。H100、H200、B200、B300 GPUにおける評価では、FlashSamplingはカーネルレベルのデコードワークロードを高速化し、エンドツーエンドのvLLM実験では、テストしたモデルにおいてトークン当たりの処理時間を最大19%削減した。これらの結果は、近似を伴わない正確なサンプリングが行列乗算自体に統合可能であり、帯域幅制約のある後処理ステップを軽量なエピローグに変え得ることを示す。プロジェクトページ:https://github.com/FlashSampling/FlashSampling。
English
Sampling from a categorical distribution is mathematically simple, but in large-vocabulary decoding, it often triggers extra memory traffic and extra kernels after the LM head. We present FlashSampling, an exact sampling primitive that fuses sampling into the LM-head matmul and never materializes the logits tensor in HBM. The method is simple: compute logits tile-by-tile on chip, add Gumbel noise, keep only one maximizer per row and per vocabulary tile, and finish with a small reduction over tiles. The fused tiled kernel is exact because argmax decomposes over a partition; grouped variants for online and tensor-parallel settings are exact by hierarchical factorization of the categorical distribution. Across H100, H200, B200, and B300 GPUs, FlashSampling speeds up kernel-level decode workloads, and in end-to-end vLLM experiments, it reduces time per output token by up to 19% on the models we test. These results show that exact sampling, with no approximation, can be integrated into the matmul itself, turning a bandwidth-bound postprocessing step into a lightweight epilogue. Project Page: https://github.com/FlashSampling/FlashSampling.