ChatPaper.aiChatPaper

FlashSampling: Быстрое и ресурсоэффективное точное сэмплирование

FlashSampling: Fast and Memory-Efficient Exact Sampling

March 16, 2026
Авторы: Tomas Ruiz, Zhen Qin, Yifan Zhang, Xuyang Shen, Yiran Zhong, Mengdi Wang
cs.AI

Аннотация

Выборка из категориального распределения математически проста, однако при декодировании с большим словарем она часто приводит к дополнительной нагрузке на память и запуску дополнительных ядер после выходного слоя языковой модели (LM head). Мы представляем FlashSampling — точный примитив выборки, который интегрирует процедуру выборки в матричное умножение выходного слоя и полностью избегает материализации тензора логитов в высокоскоростной памяти (HBM). Метод прост: вычислять логиты поблочно на кристалле, добавлять шум Гумбеля, сохранять только один максимум на строку и на блок словаря, завершая процесс компактной редукцией по блокам. Объединенное поблочное ядро является точным, поскольку операция argmax декомпозируется на разделы; групповые варианты для онлайн-режима и тензорного параллелизма остаются точными благодаря иерархической факторизации категориального распределения. На GPU H100, H200, B200 и B300 FlashSampling ускоряет рабочие нагрузки декодирования на уровне ядра, а в сквозных экспериментах с vLLM сокращает время на генерацию одного токена до 19% для протестированных моделей. Эти результаты демонстрируют, что точную выборку без аппроксимаций можно интегрировать в само матричное умножение, превращая ограниченный пропускной способностью этап постобработки в легковесный эпилог. Страница проекта: https://github.com/FlashSampling/FlashSampling.
English
Sampling from a categorical distribution is mathematically simple, but in large-vocabulary decoding, it often triggers extra memory traffic and extra kernels after the LM head. We present FlashSampling, an exact sampling primitive that fuses sampling into the LM-head matmul and never materializes the logits tensor in HBM. The method is simple: compute logits tile-by-tile on chip, add Gumbel noise, keep only one maximizer per row and per vocabulary tile, and finish with a small reduction over tiles. The fused tiled kernel is exact because argmax decomposes over a partition; grouped variants for online and tensor-parallel settings are exact by hierarchical factorization of the categorical distribution. Across H100, H200, B200, and B300 GPUs, FlashSampling speeds up kernel-level decode workloads, and in end-to-end vLLM experiments, it reduces time per output token by up to 19% on the models we test. These results show that exact sampling, with no approximation, can be integrated into the matmul itself, turning a bandwidth-bound postprocessing step into a lightweight epilogue. Project Page: https://github.com/FlashSampling/FlashSampling.
PDF52March 19, 2026