PockEngine: ポケットサイズでのスパースかつ効率的なファインチューニング
PockEngine: Sparse and Efficient Fine-tuning in a Pocket
October 26, 2023
著者: Ligeng Zhu, Lanxiang Hu, Ji Lin, Wei-Chen Wang, Wei-Ming Chen, Chuang Gan, Song Han
cs.AI
要旨
オンデバイス学習と効率的なファインチューニングにより、継続的でプライバシー保護されたカスタマイズ(例:個人データを用いた大規模言語モデルのローカルファインチューニング)が可能となります。しかし、既存のトレーニングフレームワークは、強力なアクセラレータ(例:GPU、TPU)を備えたクラウドサーバー向けに設計されており、リソース制約やエッジハードウェアの多様性といった課題に直面するエッジでの学習に最適化されていません。本論文では、PockEngineを紹介します。これは、様々なエッジデバイスでのファインチューニングを可能にする、小さく、スパースで効率的なエンジンです。PockEngineはスパースなバックプロパゲーションをサポートします。これにより、後方グラフをプルーニングし、モデル品質を維持しながら、メモリ節約とレイテンシ削減を実現します。次に、PockEngineはコンパイルファーストです。トレーニンググラフ全体(前方、後方、最適化ステップを含む)がコンパイル時に導出されるため、ランタイムオーバーヘッドが削減され、グラフ変換の機会がもたらされます。PockEngineはまた、豊富なトレーニンググラフ最適化を統合しており、オペレータの並べ替えやバックエンドの切り替えを含むトレーニングコストをさらに加速します。PockEngineは多様なアプリケーション、フロントエンド、ハードウェアバックエンドをサポートします。PyTorch/TensorFlow/Jaxで定義されたモデルを柔軟にコンパイルし、モバイルCPU/GPU/DSPにバイナリを展開します。我々は、PockEngineをビジョンモデルと大規模言語モデルの両方で評価しました。PockEngineは、既存のTensorFlow(Raspberry Pi)に対して最大15倍の高速化、Jetson AGX Orinでのバックプロパゲーションにおいて5.6倍のメモリ節約を達成しました。特に、PockEngineはNVIDIA Jetson AGX Orin上でLLaMav2-7Bのファインチューニングを550トークン/秒で可能にし、PyTorchよりも7.9倍高速でした。
English
On-device learning and efficient fine-tuning enable continuous and
privacy-preserving customization (e.g., locally fine-tuning large language
models on personalized data). However, existing training frameworks are
designed for cloud servers with powerful accelerators (e.g., GPUs, TPUs) and
lack the optimizations for learning on the edge, which faces challenges of
resource limitations and edge hardware diversity. We introduce PockEngine: a
tiny, sparse and efficient engine to enable fine-tuning on various edge
devices. PockEngine supports sparse backpropagation: it prunes the backward
graph and sparsely updates the model with measured memory saving and latency
reduction while maintaining the model quality. Secondly, PockEngine is
compilation first: the entire training graph (including forward, backward and
optimization steps) is derived at compile-time, which reduces the runtime
overhead and brings opportunities for graph transformations. PockEngine also
integrates a rich set of training graph optimizations, thus can further
accelerate the training cost, including operator reordering and backend
switching. PockEngine supports diverse applications, frontends and hardware
backends: it flexibly compiles and tunes models defined in
PyTorch/TensorFlow/Jax and deploys binaries to mobile CPU/GPU/DSPs. We
evaluated PockEngine on both vision models and large language models.
PockEngine achieves up to 15 times speedup over off-the-shelf TensorFlow
(Raspberry Pi), 5.6 times memory saving back-propagation (Jetson AGX Orin).
Remarkably, PockEngine enables fine-tuning LLaMav2-7B on NVIDIA Jetson AGX Orin
at 550 tokens/s, 7.9times faster than the PyTorch.