LASP-2:重新思考線性注意力及其混合的序列平行化
LASP-2: Rethinking Sequence Parallelism for Linear Attention and Its Hybrid
February 11, 2025
作者: Weigao Sun, Disen Lan, Yiran Zhong, Xiaoye Qu, Yu Cheng
cs.AI
摘要
線性序列建模方法,如線性注意力,提供了線性時間訓練和常數記憶體推斷的優勢,適用於不同序列長度。然而,現有的序列並行(SP)方法要麼未經過最佳化以適應線性注意力的“先右乘積”特性,要麼使用環狀通訊策略,導致計算並行性較低,限制了它們在分佈式系統中對更長序列的可擴展性。本文介紹了LASP-2,一種新的SP方法,用於增強訓練具有非常長輸入序列的線性注意力變壓器模型的通訊和計算並行性。與先前的工作LASP相比,LASP-2重新思考了線性注意力層上SP的最小通訊需求,重新組織了LASP的整個通訊-計算工作流程。通過這種方式,在中間記憶狀態上只需要一個單一的AllGather集體通訊,其大小與序列長度無關,從而顯著改善了通訊和計算並行性,以及它們的重疊。此外,我們將LASP-2擴展為LASP-2H,通過對標準注意力模塊應用類似的通訊重新設計,為混合模型提供了一種高效的SP解決方案,這些模型混合了線性和標準注意力層。我們對Linear-Llama3模型進行了評估,這是一種將線性注意力替換標準注意力的Llama3變體,展示了LASP-2和LASP-2H的有效性。具體而言,LASP-2在64個GPU上,2048K序列長度下,相對於LASP提高了15.2%的訓練速度,相對於Ring Attention提高了36.6%的訓練速度。代碼已作為一部分釋出:https://github.com/OpenSparseLLMs/Linear-MoE。
English
Linear sequence modeling approaches, such as linear attention, provide
advantages like linear-time training and constant-memory inference over
sequence lengths. However, existing sequence parallelism (SP) methods are
either not optimized for the right-product-first feature of linear attention or
use a ring-style communication strategy, which results in lower computation
parallelism, limits their scalability for longer sequences in distributed
systems. In this paper, we introduce LASP-2, a new SP method to enhance both
communication and computation parallelism when training linear attention
transformer models with very-long input sequences. Compared to previous work
LASP, LASP-2 rethinks the minimal communication requirement for SP on linear
attention layers, reorganizes the whole communication-computation workflow of
LASP. In this way, only one single AllGather collective communication is needed
on intermediate memory states, whose sizes are independent of the sequence
length, leading to significant improvements of both communication and
computation parallelism, as well as their overlap. Additionally, we extend
LASP-2 to LASP-2H by applying similar communication redesign to standard
attention modules, offering an efficient SP solution for hybrid models that
blend linear and standard attention layers. Our evaluation on a Linear-Llama3
model, a variant of Llama3 with linear attention replacing standard attention,
demonstrates the effectiveness of LASP-2 and LASP-2H. Specifically, LASP-2
achieves training speed improvements of 15.2% over LASP and 36.6% over Ring
Attention, with a sequence length of 2048K across 64 GPUs. The Code is released
as a part of: https://github.com/OpenSparseLLMs/Linear-MoE.Summary
AI-Generated Summary