ChatPaper.aiChatPaper

FlashSampling: Campionamento Esatto Veloce ed Efficiente in Termini di Memoria

FlashSampling: Fast and Memory-Efficient Exact Sampling

March 16, 2026
Autori: Tomas Ruiz, Zhen Qin, Yifan Zhang, Xuyang Shen, Yiran Zhong, Mengdi Wang
cs.AI

Abstract

Il campionamento da una distribuzione categorica è matematicamente semplice, ma, nel decoding con vocabolari di grandi dimensioni, spesso innesca traffico di memoria aggiuntivo e kernel supplementari dopo l'head del LM. Presentiamo FlashSampling, una primitiva di campionamento esatta che fonde il campionamento nel matmul dell'LM-head e non materializza mai il tensore dei logit nell'HBM. Il metodo è semplice: calcola i logit tile per tile sull'chip, aggiunge rumore di Gumbel, mantiene solo un massimizzatore per riga e per tile del vocabolario, e conclude con una piccola riduzione sui tile. Il kernel a tile fuso è esatto perché l'argmax si scompone su una partizione; le varianti raggruppate per contesti online e tensor-parallel sono esatte grazie alla fattorizzazione gerarchica della distribuzione categorica. Su GPU H100, H200, B200 e B300, FlashSampling accelera i carichi di lavoro di decoding a livello di kernel e, in esperimenti end-to-end con vLLM, riduce il tempo per token di output fino al 19% sui modelli testati. Questi risultati dimostrano che il campionamento esatto, senza approssimazioni, può essere integrato nel matmul stesso, trasformando un passo di post-elaborazione vincolato dalla larghezza di banda in un epilogo leggero. Pagina del progetto: https://github.com/FlashSampling/FlashSampling.
English
Sampling from a categorical distribution is mathematically simple, but in large-vocabulary decoding, it often triggers extra memory traffic and extra kernels after the LM head. We present FlashSampling, an exact sampling primitive that fuses sampling into the LM-head matmul and never materializes the logits tensor in HBM. The method is simple: compute logits tile-by-tile on chip, add Gumbel noise, keep only one maximizer per row and per vocabulary tile, and finish with a small reduction over tiles. The fused tiled kernel is exact because argmax decomposes over a partition; grouped variants for online and tensor-parallel settings are exact by hierarchical factorization of the categorical distribution. Across H100, H200, B200, and B300 GPUs, FlashSampling speeds up kernel-level decode workloads, and in end-to-end vLLM experiments, it reduces time per output token by up to 19% on the models we test. These results show that exact sampling, with no approximation, can be integrated into the matmul itself, turning a bandwidth-bound postprocessing step into a lightweight epilogue. Project Page: https://github.com/FlashSampling/FlashSampling.
PDF82March 31, 2026