ChatPaper.aiChatPaper

FlashSampling: Snel en Geheugenefficiënt Exact Samplen

FlashSampling: Fast and Memory-Efficient Exact Sampling

March 16, 2026
Auteurs: Tomas Ruiz, Zhen Qin, Yifan Zhang, Xuyang Shen, Yiran Zhong, Mengdi Wang
cs.AI

Samenvatting

Steekproeven nemen uit een categorische verdeling is wiskundig eenvoudig, maar bij decoding met een grote woordenschat leidt dit vaak tot extra geheugenverkeer en extra kernels na de LM-head. Wij presenteren FlashSampling, een exacte sampling-primitief die de sampling versmelt met de LM-head-matmul en nooit de logits-tensor materialiseert in HBM. De methode is eenvoudig: bereken logits tegel-voor-tegel on-chip, voeg Gumbel-ruis toe, behoud slechts één maximizer per rij en per vocabulaire tegel, en rond af met een kleine reductie over de tegels. De gefuseerde getegelde kernel is exact omdat argmax zich laat ontbinden over een partitie; gegroepeerde varianten voor online- en tensor-parallelle settings zijn exact door hiërarchische factorisatie van de categorische verdeling. Op H100-, H200-, B200- en B300-GPU's versnelt FlashSampling kernel-level decode-workloads, en in end-to-end vLLM-experimenten reduceert het de tijd per outputtoken met tot 19% bij de geteste modellen. Deze resultaten tonen aan dat exacte sampling, zonder benadering, kan worden geïntegreerd in de matmul zelf, waardoor een bandbreedtegebonden nabewerkingsstap verandert in een lichtgewicht epiloog. Projectpagina: https://github.com/FlashSampling/FlashSampling.
English
Sampling from a categorical distribution is mathematically simple, but in large-vocabulary decoding, it often triggers extra memory traffic and extra kernels after the LM head. We present FlashSampling, an exact sampling primitive that fuses sampling into the LM-head matmul and never materializes the logits tensor in HBM. The method is simple: compute logits tile-by-tile on chip, add Gumbel noise, keep only one maximizer per row and per vocabulary tile, and finish with a small reduction over tiles. The fused tiled kernel is exact because argmax decomposes over a partition; grouped variants for online and tensor-parallel settings are exact by hierarchical factorization of the categorical distribution. Across H100, H200, B200, and B300 GPUs, FlashSampling speeds up kernel-level decode workloads, and in end-to-end vLLM experiments, it reduces time per output token by up to 19% on the models we test. These results show that exact sampling, with no approximation, can be integrated into the matmul itself, turning a bandwidth-bound postprocessing step into a lightweight epilogue. Project Page: https://github.com/FlashSampling/FlashSampling.
PDF52March 19, 2026