ChatPaper.aiChatPaper

FlashSampling: Amostragem Exata Rápida e com Eficiência de Memória

FlashSampling: Fast and Memory-Efficient Exact Sampling

March 16, 2026
Autores: Tomas Ruiz, Zhen Qin, Yifan Zhang, Xuyang Shen, Yiran Zhong, Mengdi Wang
cs.AI

Resumo

A amostragem de uma distribuição categórica é matematicamente simples, mas na decodificação de vocabulário extenso, frequentemente desencadeia tráfego adicional de memória e kernels extras após o cabeçalho do modelo de linguagem. Apresentamos o FlashSampling, um primitivo de amostragem exata que funde a amostragem na multiplicação de matrizes do cabeçalho do modelo e nunca materializa o tensor de *logits* na memória de alto desempenho. O método é simples: calcula os *logits* bloco a bloco no *chip*, adiciona ruído de Gumbel, mantém apenas um maximizador por linha e por bloco de vocabulário, e finaliza com uma pequena redução sobre os blocos. O *kernel* fusionado em blocos é exato porque o *argmax* se decompõe sobre uma partição; variantes agrupadas para configurações *online* e de paralelismo de tensores são exatas pela fatoração hierárquica da distribuição categórica. Através dos GPUs H100, H200, B200 e B300, o FlashSampling acelera cargas de trabalho de decodificação a nível de *kernel*, e em experiências *end-to-end* com vLLM, reduz o tempo por *token* de saída em até 19% nos modelos que testamos. Estes resultados mostram que a amostragem exata, sem aproximação, pode ser integrada na própria multiplicação de matrizes, transformando uma etapa de pós-processamento limitada por largura de banda num epílogo leve. Página do Projeto: https://github.com/FlashSampling/FlashSampling.
English
Sampling from a categorical distribution is mathematically simple, but in large-vocabulary decoding, it often triggers extra memory traffic and extra kernels after the LM head. We present FlashSampling, an exact sampling primitive that fuses sampling into the LM-head matmul and never materializes the logits tensor in HBM. The method is simple: compute logits tile-by-tile on chip, add Gumbel noise, keep only one maximizer per row and per vocabulary tile, and finish with a small reduction over tiles. The fused tiled kernel is exact because argmax decomposes over a partition; grouped variants for online and tensor-parallel settings are exact by hierarchical factorization of the categorical distribution. Across H100, H200, B200, and B300 GPUs, FlashSampling speeds up kernel-level decode workloads, and in end-to-end vLLM experiments, it reduces time per output token by up to 19% on the models we test. These results show that exact sampling, with no approximation, can be integrated into the matmul itself, turning a bandwidth-bound postprocessing step into a lightweight epilogue. Project Page: https://github.com/FlashSampling/FlashSampling.
PDF52March 19, 2026