ChatPaper.aiChatPaper

Flash-KMeans:快速且内存高效的精确K均值算法

Flash-KMeans: Fast and Memory-Efficient Exact K-Means

March 10, 2026
作者: Shuo Yang, Haocheng Xi, Yilong Zhao, Muyang Li, Xiaoze Fan, Jintao Zhang, Han Cai, Yujun Lin, Xiuyu Li, Kurt Keutzer, Song Han, Chenfeng Xu, Ion Stoica
cs.AI

摘要

k-means算法历来主要被定位为离线处理原语,通常用于数据集组织或嵌入预处理,而非作为在线系统中的核心组件。本研究通过现代AI系统设计的视角重新审视这一经典算法,将其实现为在线处理原语。我们指出,现有GPU实现的k-means算法根本瓶颈在于底层系统约束而非理论算法复杂度:分配阶段因高带宽内存中N×K距离矩阵的大规模显式物化而遭遇严重IO瓶颈;质心更新阶段则因不规则散射式令牌聚合引发的硬件级原子写竞争而严重受限。为弥补这一性能差距,我们提出flash-kmeans——面向现代GPU工作负载的IO感知无竞争k-means实现。该方案引入两项核心内核级创新:(1) FlashAssign通过融合距离计算与在线argmin操作,彻底规避中间存储物化;(2) 排序逆映射更新通过显式构建逆映射,将高竞争原子散射转换为高带宽分段局部归约。此外,我们整合算法-系统协同设计,包括分块流重叠和缓存感知编译启发式策略,确保实际可部署性。在NVIDIA H200 GPU上的大量实验表明,flash-kmeans相较最佳基线实现端到端加速达17.9倍,同时以33倍和200倍以上的优势超越cuML、FAISS等工业标准库。
English
k-means has historically been positioned primarily as an offline processing primitive, typically used for dataset organization or embedding preprocessing rather than as a first-class component in online systems. In this work, we revisit this classical algorithm under the lens of modern AI system design and enable k-means as an online primitive. We point out that existing GPU implementations of k-means remain fundamentally bottlenecked by low-level system constraints rather than theoretical algorithmic complexity. Specifically, the assignment stage suffers from a severe IO bottleneck due to the massive explicit materialization of the N times K distance matrix in High Bandwidth Memory (HBM). Simultaneously, the centroid update stage is heavily penalized by hardware-level atomic write contention caused by irregular, scatter-style token aggregations. To bridge this performance gap, we propose flash-kmeans, an IO-aware and contention-free k-means implementation for modern GPU workloads. Flash-kmeans introduces two core kernel-level innovations: (1) FlashAssign, which fuses distance computation with an online argmin to completely bypass intermediate memory materialization; (2) sort-inverse update, which explicitly constructs an inverse mapping to transform high-contention atomic scatters into high-bandwidth, segment-level localized reductions. Furthermore, we integrate algorithm-system co-designs, including chunked-stream overlap and cache-aware compile heuristics, to ensure practical deployability. Extensive evaluations on NVIDIA H200 GPUs demonstrate that flash-kmeans achieves up to 17.9times end-to-end speedup over best baselines, while outperforming industry-standard libraries like cuML and FAISS by 33times and over 200times, respectively.
PDF451March 13, 2026