ChatPaper.aiChatPaper

FlashSampling:快速且内存高效的精确实时采样

FlashSampling: Fast and Memory-Efficient Exact Sampling

March 16, 2026
作者: Tomas Ruiz, Zhen Qin, Yifan Zhang, Xuyang Shen, Yiran Zhong, Mengdi Wang
cs.AI

摘要

從類別分佈中採樣在數學上很簡單,但在大詞彙量解碼任務中,常會引發額外的記憶體傳輸和LM頭之後的額外核心啟動。我們提出FlashSampling——一種精確採樣原語,將採樣過程融合至LM頭的矩陣乘法運算,且從不將logits張量實體化存儲於高頻寬記憶體。該方法原理簡潔:在晶片上逐塊計算logits,加入耿貝爾噪聲,僅保留每行及每個詞彙塊中的最大值索引,最後通過輕量級塊間歸約完成操作。此融合式分塊核心的精確性源於argmax運算在分區上的可分解性;針對線上學習與張量並行場景的分組變體,則通過類別分佈的層次分解保持精確性。在H100、H200、B200及B300等GPU上的測試表明,FlashSampling能加速核心級解碼工作負載;在端到端vLLM實驗中,對於測試模型可將單個輸出詞元的生成時間最高降低19%。這些結果證明,無需任何近似處理的精確採樣技術可直接整合進矩陣乘法運算,從而將頻寬受限的後處理步驟轉化為輕量級的收尾操作。項目頁面:https://github.com/FlashSampling/FlashSampling。
English
Sampling from a categorical distribution is mathematically simple, but in large-vocabulary decoding, it often triggers extra memory traffic and extra kernels after the LM head. We present FlashSampling, an exact sampling primitive that fuses sampling into the LM-head matmul and never materializes the logits tensor in HBM. The method is simple: compute logits tile-by-tile on chip, add Gumbel noise, keep only one maximizer per row and per vocabulary tile, and finish with a small reduction over tiles. The fused tiled kernel is exact because argmax decomposes over a partition; grouped variants for online and tensor-parallel settings are exact by hierarchical factorization of the categorical distribution. Across H100, H200, B200, and B300 GPUs, FlashSampling speeds up kernel-level decode workloads, and in end-to-end vLLM experiments, it reduces time per output token by up to 19% on the models we test. These results show that exact sampling, with no approximation, can be integrated into the matmul itself, turning a bandwidth-bound postprocessing step into a lightweight epilogue. Project Page: https://github.com/FlashSampling/FlashSampling.
PDF52March 19, 2026