FlashSampling: Schnelles und speichereffizientes exaktes Sampling
FlashSampling: Fast and Memory-Efficient Exact Sampling
March 16, 2026
Autoren: Tomas Ruiz, Zhen Qin, Yifan Zhang, Xuyang Shen, Yiran Zhong, Mengdi Wang
cs.AI
Zusammenfassung
Das Abtasten aus einer kategorialen Verteilung ist mathematisch einfach, führt jedoch bei der Dekodierung mit großem Vokabular oft zu zusätzlichem Speicherverkehr und zusätzlichen Kernels nach dem LM-Head. Wir stellen FlashSampling vor, eine exakte Abtastprimitive, die das Abtasten in die LM-Head-Matmul fusioniert und den Logits-Tensor niemals im HBM materialisiert. Die Methode ist einfach: Berechne Logits tileweise auf dem Chip, füge Gumbel-Rauschen hinzu, behalte nur einen Maximierer pro Zeile und pro Vokabular-Tile und schließe mit einer kleinen Reduktion über die Tiles ab. Der fusionierte Tile-Kernel ist exakt, weil sich Argmax über eine Partition zerlegen lässt; gruppierte Varianten für Online- und Tensor-Parallel-Einstellungen sind durch hierarchische Faktorisierung der kategorialen Verteilung exakt. Auf H100-, H200-, B200- und B300-GPUs beschleunigt FlashSampling Kernel-level-Dekodierlasten, und in Ende-zu-Ende-vLLM-Experimenten reduziert es die Zeit pro Ausgabetoken bei den von uns getesteten Modellen um bis zu 19%. Diese Ergebnisse zeigen, dass exaktes Abtasten ohne Approximation in die Matmul selbst integriert werden kann, wodurch ein bandbreitenbeschränkter Nachverarbeitungsschritt in einen leichtgewichtigen Epilog verwandelt wird. Projektseite: https://github.com/FlashSampling/FlashSampling.
English
Sampling from a categorical distribution is mathematically simple, but in large-vocabulary decoding, it often triggers extra memory traffic and extra kernels after the LM head. We present FlashSampling, an exact sampling primitive that fuses sampling into the LM-head matmul and never materializes the logits tensor in HBM. The method is simple: compute logits tile-by-tile on chip, add Gumbel noise, keep only one maximizer per row and per vocabulary tile, and finish with a small reduction over tiles. The fused tiled kernel is exact because argmax decomposes over a partition; grouped variants for online and tensor-parallel settings are exact by hierarchical factorization of the categorical distribution. Across H100, H200, B200, and B300 GPUs, FlashSampling speeds up kernel-level decode workloads, and in end-to-end vLLM experiments, it reduces time per output token by up to 19% on the models we test. These results show that exact sampling, with no approximation, can be integrated into the matmul itself, turning a bandwidth-bound postprocessing step into a lightweight epilogue. Project Page: https://github.com/FlashSampling/FlashSampling.