ChatPaper.aiChatPaper

SynJax:用于JAX的结构化概率分布

SynJax: Structured Probability Distributions for JAX

August 7, 2023
作者: Miloš Stanojević, Laurent Sartran
cs.AI

摘要

深度学习软件库的发展使该领域取得了重大进展,让用户能够专注于建模,同时让库负责优化执行以适配现代硬件加速器的繁琐且耗时的任务。然而,这仅使某些类型的深度学习模型受益,比如变换器,其基本元素易于映射到向量化计算。那些明确考虑结构化对象(如树和分割)的模型并没有同等受益,因为它们需要定制算法,难以以向量化形式实现。 SynJax 直接解决了这一问题,提供了针对结构化分布的推断算法的高效向量化实现,涵盖了对齐、标记、分割、组成树和跨度树。通过 SynJax,我们可以构建明确对数据结构进行建模的大规模可微分模型。代码可在 https://github.com/deepmind/synjax 获取。
English
The development of deep learning software libraries enabled significant progress in the field by allowing users to focus on modeling, while letting the library to take care of the tedious and time-consuming task of optimizing execution for modern hardware accelerators. However, this has benefited only particular types of deep learning models, such as Transformers, whose primitives map easily to the vectorized computation. The models that explicitly account for structured objects, such as trees and segmentations, did not benefit equally because they require custom algorithms that are difficult to implement in a vectorized form. SynJax directly addresses this problem by providing an efficient vectorized implementation of inference algorithms for structured distributions covering alignment, tagging, segmentation, constituency trees and spanning trees. With SynJax we can build large-scale differentiable models that explicitly model structure in the data. The code is available at https://github.com/deepmind/synjax.
PDF60December 15, 2024