SynJax:JAX 的結構化機率分佈
SynJax: Structured Probability Distributions for JAX
August 7, 2023
作者: Miloš Stanojević, Laurent Sartran
cs.AI
摘要
深度學習軟體庫的發展使該領域取得了顯著進展,讓使用者能專注於建模,同時讓庫負責優化執行以配合現代硬體加速器的繁瑣且耗時的任務。然而,這僅對特定類型的深度學習模型帶來好處,例如 Transformers,其基本元素易於映射到向量化計算。那些明確考慮結構化對象(如樹狀結構和分割)的模型並未同等受益,因為它們需要難以以向量化形式實現的自定義算法。
SynJax 直接解決了這個問題,提供了對齊、標記、分割、組成樹和跨度樹等結構分佈的有效向量化推理算法實現。使用 SynJax,我們可以構建明確對數據結構進行建模的大規模可微分模型。程式碼可在以下網址獲得:https://github.com/deepmind/synjax。
English
The development of deep learning software libraries enabled significant
progress in the field by allowing users to focus on modeling, while letting the
library to take care of the tedious and time-consuming task of optimizing
execution for modern hardware accelerators. However, this has benefited only
particular types of deep learning models, such as Transformers, whose
primitives map easily to the vectorized computation. The models that explicitly
account for structured objects, such as trees and segmentations, did not
benefit equally because they require custom algorithms that are difficult to
implement in a vectorized form.
SynJax directly addresses this problem by providing an efficient vectorized
implementation of inference algorithms for structured distributions covering
alignment, tagging, segmentation, constituency trees and spanning trees. With
SynJax we can build large-scale differentiable models that explicitly model
structure in the data. The code is available at
https://github.com/deepmind/synjax.