在JAX中实现可扩展且可微分的联邦原语
FAX: Scalable and Differentiable Federated Primitives in JAX
March 11, 2024
作者: Keith Rush, Zachary Charles, Zachary Garrett
cs.AI
摘要
我们介绍了 FAX,这是一个基于 JAX 的库,旨在支持数据中心和跨设备应用中的大规模分布式和联邦计算。FAX 利用 JAX 的分片机制,实现对 TPU 和最先进的 JAX 运行时(包括 Pathways)的本地定位。FAX 将联邦计算的构建模块作为 JAX 中的原语进行嵌入。这带来了三个关键好处。首先,FAX 计算可以转换为 XLA HLO。其次,FAX 提供了联邦自动微分的完整实现,极大简化了联邦计算的表达。最后,FAX 计算可以映射到现有的生产跨设备联邦计算系统。我们展示了 FAX 在数据中心中提供了一个易于编程、高性能和可扩展的联邦计算框架。FAX 可在 https://github.com/google-research/google-research/tree/master/fax 获取。
English
We present FAX, a JAX-based library designed to support large-scale
distributed and federated computations in both data center and cross-device
applications. FAX leverages JAX's sharding mechanisms to enable native
targeting of TPUs and state-of-the-art JAX runtimes, including Pathways. FAX
embeds building blocks for federated computations as primitives in JAX. This
enables three key benefits. First, FAX computations can be translated to XLA
HLO. Second, FAX provides a full implementation of federated automatic
differentiation, greatly simplifying the expression of federated computations.
Last, FAX computations can be interpreted out to existing production
cross-device federated compute systems. We show that FAX provides an easily
programmable, performant, and scalable framework for federated computations in
the data center. FAX is available at
https://github.com/google-research/google-research/tree/master/fax .Summary
AI-Generated Summary