ChatPaper.aiChatPaper

FlashSampling: Muestreo Exacto Rápido y Eficiente en Memoria

FlashSampling: Fast and Memory-Efficient Exact Sampling

March 16, 2026
Autores: Tomas Ruiz, Zhen Qin, Yifan Zhang, Xuyang Shen, Yiran Zhong, Mengdi Wang
cs.AI

Resumen

El muestreo de una distribución categórica es matemáticamente simple, pero en la decodificación con vocabularios grandes, a menudo desencadena tráfico de memoria adicional y kernels extra después de la capa LM (cabeza del modelo de lenguaje). Presentamos FlashSampling, una primitiva de muestreo exacta que fusiona el muestreo en la multiplicación de matrices (matmul) de la cabeza LM y nunca materializa el tensor de logits en la memoria de alto ancho de banda (HBM). El método es simple: calcular los logits por bloques (tile-by-tile) en el chip, añadir ruido de Gumbel, mantener solo un maximizador por fila y por bloque de vocabulario, y finalizar con una pequeña reducción sobre los bloques. El kernel fusionado y en bloques es exacto porque la operación argmax se descompone sobre una partición; las variantes agrupadas para entornos en línea y de paralelismo de tensores son exactas gracias a la factorización jerárquica de la distribución categórica. En las GPU H100, H200, B200 y B300, FlashSampling acelera las cargas de trabajo de decodificación a nivel de kernel, y en experimentos de vLLM de extremo a extremo, reduce el tiempo por token de salida hasta en un 19% en los modelos que probamos. Estos resultados demuestran que el muestreo exacto, sin aproximación alguna, puede integrarse en la propia multiplicación de matrices, convirtiendo un paso de postprocesamiento limitado por el ancho de banda en un epílogo ligero. Página del proyecto: https://github.com/FlashSampling/FlashSampling.
English
Sampling from a categorical distribution is mathematically simple, but in large-vocabulary decoding, it often triggers extra memory traffic and extra kernels after the LM head. We present FlashSampling, an exact sampling primitive that fuses sampling into the LM-head matmul and never materializes the logits tensor in HBM. The method is simple: compute logits tile-by-tile on chip, add Gumbel noise, keep only one maximizer per row and per vocabulary tile, and finish with a small reduction over tiles. The fused tiled kernel is exact because argmax decomposes over a partition; grouped variants for online and tensor-parallel settings are exact by hierarchical factorization of the categorical distribution. Across H100, H200, B200, and B300 GPUs, FlashSampling speeds up kernel-level decode workloads, and in end-to-end vLLM experiments, it reduces time per output token by up to 19% on the models we test. These results show that exact sampling, with no approximation, can be integrated into the matmul itself, turning a bandwidth-bound postprocessing step into a lightweight epilogue. Project Page: https://github.com/FlashSampling/FlashSampling.
PDF52March 19, 2026