ChatPaper.aiChatPaper

闪电采样:快速且内存高效的精确实时采样

FlashSampling: Fast and Memory-Efficient Exact Sampling

March 16, 2026
作者: Tomas Ruiz, Zhen Qin, Yifan Zhang, Xuyang Shen, Yiran Zhong, Mengdi Wang
cs.AI

摘要

从分类分布中采样在数学上很简单,但在大词汇表解码任务中,通常会引发额外的内存流量和LM头部之后的多核计算。我们提出FlashSampling——一种精确的采样原语,它将采样过程融合到LM头部的矩阵乘法运算中,且无需在HBM中实例化逻辑张量。该方法原理简单:在芯片上逐块计算逻辑值,添加Gumbel噪声,每个行和词汇块仅保留一个最大值定位器,最后通过小块归约完成操作。这种融合分块核函数的精确性源于argmax在分区上的可分解性;针对在线和并张量并行设置的分组变体,则通过分类分布的层次化分解保持精确性。在H100、H200、B200和B300 GPU上的测试表明,FlashSampling能加速核级解码工作负载。在端到端vLLM实验中,对于测试模型,其每个输出令牌的生成时间最高可减少19%。这些结果证明,无需任何近似处理的精确采样可以直接融入矩阵乘法运算,将原本受带宽限制的后处理步骤转化为轻量级的收尾操作。项目页面:https://github.com/FlashSampling/FlashSampling。
English
Sampling from a categorical distribution is mathematically simple, but in large-vocabulary decoding, it often triggers extra memory traffic and extra kernels after the LM head. We present FlashSampling, an exact sampling primitive that fuses sampling into the LM-head matmul and never materializes the logits tensor in HBM. The method is simple: compute logits tile-by-tile on chip, add Gumbel noise, keep only one maximizer per row and per vocabulary tile, and finish with a small reduction over tiles. The fused tiled kernel is exact because argmax decomposes over a partition; grouped variants for online and tensor-parallel settings are exact by hierarchical factorization of the categorical distribution. Across H100, H200, B200, and B300 GPUs, FlashSampling speeds up kernel-level decode workloads, and in end-to-end vLLM experiments, it reduces time per output token by up to 19% on the models we test. These results show that exact sampling, with no approximation, can be integrated into the matmul itself, turning a bandwidth-bound postprocessing step into a lightweight epilogue. Project Page: https://github.com/FlashSampling/FlashSampling.
PDF52March 19, 2026