Flash-KMeans:快速且記憶體高效的精確K-Means演算法
Flash-KMeans: Fast and Memory-Efficient Exact K-Means
March 10, 2026
作者: Shuo Yang, Haocheng Xi, Yilong Zhao, Muyang Li, Xiaoze Fan, Jintao Zhang, Han Cai, Yujun Lin, Xiuyu Li, Kurt Keutzer, Song Han, Chenfeng Xu, Ion Stoica
cs.AI
摘要
k-means 演算法歷史上主要被定位為離線處理原語,通常用於資料集組織或嵌入預處理,而非作為線上系統的一等組件。在本研究中,我們透過現代 AI 系統設計的視角重新審視這一經典演算法,將 k-means 實現為線上原語。我們指出,現有的 GPU k-means 實現仍受制於底層系統約束而非理論演算法複雜度。具體而言,分配階段因高頻寬記憶體中 N×K 距離矩陣的大規模顯式實體化而存在嚴重 IO 瓶頸;同時,質心更新階段因不規則散射式 token 聚合導致的硬體級原子寫入競爭而效能受限。為消除此效能差距,我們提出 flash-kmeans——專為現代 GPU 工作負載設計的 IO 感知無競爭 k-means 實現。該實現包含兩項核心核心級創新:(1) FlashAssign 透過距離計算與線上 argmin 的融合,徹底規避中介記憶體實體化;(2) 排序反向更新透過顯式構建反向映射,將高競爭原子散射轉化為高頻寬分段級局部歸約。此外,我們整合演算法-系統協同設計(包括分塊流重疊與快取感知編譯啟發式),確保實際部署可行性。在 NVIDIA H200 GPU 上的大量實驗表明,flash-kmeans 相較最佳基準實現實現最高 17.9 倍的端到端加速,並分別以 33 倍和 200 倍以上的效能優勢超越 cuML、FAISS 等業界標準函式庫。
English
k-means has historically been positioned primarily as an offline processing primitive, typically used for dataset organization or embedding preprocessing rather than as a first-class component in online systems. In this work, we revisit this classical algorithm under the lens of modern AI system design and enable k-means as an online primitive. We point out that existing GPU implementations of k-means remain fundamentally bottlenecked by low-level system constraints rather than theoretical algorithmic complexity. Specifically, the assignment stage suffers from a severe IO bottleneck due to the massive explicit materialization of the N times K distance matrix in High Bandwidth Memory (HBM). Simultaneously, the centroid update stage is heavily penalized by hardware-level atomic write contention caused by irregular, scatter-style token aggregations. To bridge this performance gap, we propose flash-kmeans, an IO-aware and contention-free k-means implementation for modern GPU workloads. Flash-kmeans introduces two core kernel-level innovations: (1) FlashAssign, which fuses distance computation with an online argmin to completely bypass intermediate memory materialization; (2) sort-inverse update, which explicitly constructs an inverse mapping to transform high-contention atomic scatters into high-bandwidth, segment-level localized reductions. Furthermore, we integrate algorithm-system co-designs, including chunked-stream overlap and cache-aware compile heuristics, to ensure practical deployability. Extensive evaluations on NVIDIA H200 GPUs demonstrate that flash-kmeans achieves up to 17.9times end-to-end speedup over best baselines, while outperforming industry-standard libraries like cuML and FAISS by 33times and over 200times, respectively.