Flash-GMM: Ein speichereffizienter Kernel für skalierbares weiches Clustering
Flash-GMM: A Memory-Efficient Kernel for Scalable Soft Clustering
June 9, 2026
Autoren: Gal Bloch, Ariel Gera, Matan Orbach, Ohad Eytan, Assaf Toledo
cs.AI
Zusammenfassung
Wir stellen Flash-GMM vor, einen fusionierten Triton-Kernel zur effizienten Berechnung von Gaußschen Mischmodellen (GMMs) über große Datenmengen in einem einzigen GPU-Durchlauf. Durch die Eliminierung der Notwendigkeit, die vollständige Verantwortlichkeitsmatrix im GPU-Speicher zu materialisieren, erreicht Flash-GMM eine 20-fache Beschleunigung gegenüber bestehenden Implementierungen und ermöglicht das Training auf Datensätzen, die mehr als 100-mal größer sind als zuvor auf einem Gerät möglich. Um die Auswirkungen zu demonstrieren, integrieren wir Flash-GMM in den IVF-Grobquantisierer für die approximative Nächste-Nachbar-Suche (ANN). Wir zeigen, dass weiches GMM-Clustering nun ein praktikabler Ersatz für k-means ist und dass GMM-Verantwortlichkeiten genutzt werden können, um Grenzvektoren mehreren Clustern zuzuweisen. Unser Ansatz erreicht festgelegte Recall-Ziele mit bis zu 1,7-mal weniger Distanzberechnungen oder gleichbedeutend mit +2–12 recall@10 bei gleichem Rechenaufwand. Wir veröffentlichen den Kernel als Open-Source-Projekt.
English
We present Flash-GMM, a fused Triton kernel for efficient computation of Gaussian Mixture Models (GMMs) over large-scale data in a single GPU pass. By eliminating the need to materialize the full responsibility matrix in GPU memory, Flash-GMM achieves a 20times speedup over existing implementations and enables training on datasets more than 100times larger than previously feasible on one device. To demonstrate its impact, we integrate Flash-GMM into the IVF coarse quantizer for approximate nearest-neighbor (ANN) search. We show that soft GMM clustering is now a viable drop-in replacement for k-means, and that GMM responsibilities can be leveraged to assign border vectors to multiple clusters. Our approach reaches fixed recall targets with up to 1.7times fewer distance computations, or equivalently, yields +2--12 recall@10 at matched computational cost. We release the kernel as an open-source project.