ChatPaper.aiChatPaper

FlashSampling : Échantillonnage Exact Rapide et Économe en Mémoire

FlashSampling: Fast and Memory-Efficient Exact Sampling

March 16, 2026
Auteurs: Tomas Ruiz, Zhen Qin, Yifan Zhang, Xuyang Shen, Yiran Zhong, Mengdi Wang
cs.AI

Résumé

L'échantillonnage à partir d'une distribution catégorielle est mathématiquement simple, mais dans le décodage à grand vocabulaire, il déclenche souvent un trafic mémoire supplémentaire et des noyaux de calcul supplémentaires après la tête de modèle de langage. Nous présentons FlashSampling, une primitive d'échantillonnage exacte qui fusionne l'échantillonnage dans le produit matriciel de la tête de LM et ne matérialise jamais le tenseur des logits en mémoire haute bande (HBM). La méthode est simple : calculer les logits par tuiles sur la puce, ajouter un bruit de Gumbel, ne conserver qu'un seul maximum par ligne et par tuile du vocabulaire, et terminer par une petite réduction sur les tuiles. Le noyau en tuiles fusionné est exact car l'argmax se décompose sur une partition ; les variantes groupées pour les contextes en ligne et parallèles par tenseur sont exactes grâce à la factorisation hiérarchique de la distribution catégorielle. Sur les GPU H100, H200, B200 et B300, FlashSampling accélère les charges de travail de décodage au niveau du noyau, et dans les expériences de bout en bout avec vLLM, il réduit le temps par token de sortie jusqu'à 19% sur les modèles testés. Ces résultats montrent que l'échantillonnage exact, sans approximation, peut être intégré dans le produit matriciel lui-même, transformant une étape de post-traitement limitée par la bande passante en un épilogue léger. Page du projet : https://github.com/FlashSampling/FlashSampling.
English
Sampling from a categorical distribution is mathematically simple, but in large-vocabulary decoding, it often triggers extra memory traffic and extra kernels after the LM head. We present FlashSampling, an exact sampling primitive that fuses sampling into the LM-head matmul and never materializes the logits tensor in HBM. The method is simple: compute logits tile-by-tile on chip, add Gumbel noise, keep only one maximizer per row and per vocabulary tile, and finish with a small reduction over tiles. The fused tiled kernel is exact because argmax decomposes over a partition; grouped variants for online and tensor-parallel settings are exact by hierarchical factorization of the categorical distribution. Across H100, H200, B200, and B300 GPUs, FlashSampling speeds up kernel-level decode workloads, and in end-to-end vLLM experiments, it reduces time per output token by up to 19% on the models we test. These results show that exact sampling, with no approximation, can be integrated into the matmul itself, turning a bandwidth-bound postprocessing step into a lightweight epilogue. Project Page: https://github.com/FlashSampling/FlashSampling.
PDF52March 19, 2026