|
講演者:Lin Wei、Alibaba Cloud 研究員、Alibaba Cloud AI プラットフォーム PAI テクニカル リード
この記事では、PyTorch/XLA に基づく大規模モデル向けの分散トレーニング フレームワークである Alibaba Cloud TorchAcc について説明します。
過去10年間のAIの飛躍的な進歩は、主に学習技術の革新とモデル規模の急速な拡大によってもたらされました。大規模モデルは人間に匹敵する理解力を示しますが、その学習には非常に高い計算能力が求められます。十分な計算リソースがあればこそ、大規模モデルを膨大なデータセットで効果的に学習し、限られた時間内で高品質な収束を実現できるのです。
画像ソース: GTC 2024 China AI Day オンライン セッションのプレゼンテーション「TorchAcc: TorchXLA に基づく分散トレーニング フレームワーク」。 左のグラフが示すように、特に大規模モデルの成長傾向は過去5年間で顕著で、2年ごとに平均15倍のペースで規模が拡大しています。Transformerに代表される言語モデルやマルチモーダルモデルでは、その拡大率はさらに驚異的で、2年ごとに750倍という驚異的な増加を見せています。一方、右のグラフは明確な矛盾を浮き彫りにしています。それは、単一のGPUの演算能力も、GPUメモリ容量の発展速度も、このようなモデル規模の急速な拡大に追いつけないということです。この現実が、分散学習の切実な必要性を直接的に促しています。分散学習は、もはや従来の単純なデータ並列モデルに限定されず、モデル規模の拡大に対する個々の演算ユニットの演算能力とストレージの改善速度の遅れを補うために、モデル並列戦略をより重視して採用しています。
分散トレーニングの実践において、モデル並列型の分散トレーニングシステムの構築はデータ並列型よりも複雑であるという点では、開発者の間で一般的に意見が一致しています。分散の観点から見ると、データ並列化は比較的単純でシンプルです。これは、各コンピューティングノードで実行されるタスクが本質的に同等で一貫性があるためです。この場合、トレーニングプロセスの最後にAllReduceステップを挿入するだけで、各ワーカーノードが個別に計算した勾配の差を累積・統合し、平均を計算します。そして、最終的な勾配結果を参加ノード全体にブロードキャストすることで、グローバルモデルパラメータを同期的に更新します。
このタイプのシンプルな分散学習パラダイムは、確かに単一マシンコンピューティングに類似した特性を示しており、主にAllReduceによるグローバル勾配同期が用いられています。しかし、大規模モデルの時代においては、モデルのサイズが過剰になるため、単一のGPUに収容することはもはや不可能です。そのため、モデル並列戦略を採用する必要があり、開発の難易度は大幅に高まります。
その理由は、モデルの並列化には、モデルのサイズと構造に基づいてモデルを適切に「分割」する方法、つまり計算負荷を分散できる複数のモジュールに分割する必要があるためです。異なる分割戦略では、各ノードにおける演算子のアルゴリズム実装が異なります。同時に、異なる分割方法はノード間の通信プリミティブにも違いをもたらすため、最適な分割スキームとそれに対応する通信プリミティブを慎重に選択する必要があります。
モデル分割後の次のタスクは、適切な通信プリミティブを選択し、各演算子とそれに関連する通信操作を細かくスケジュールすることで、計算とネットワーク通信のオーバーラップを最大化し、基盤となるコンピューティングリソースの効率を最大限に活用することです。複数の分割オプションとスケジュール決定が可能なため、最適なモデル並列化戦略を見つける複雑さはデータ並列化よりもはるかに高く、開発者のスキルと経験に対する要求は高くなります。
画像ソース: GTC 2024 China AI Day オンライン セッションのプレゼンテーション「TorchAcc: TorchXLA に基づく分散トレーニング フレームワーク」。
この記事では、4つの主要な側面に焦点を当てます。最初のトピックは、TorchAccで多様な並列戦略を実装する方法です。従来のデータ並列処理に加え、現在普及しているFSDP(Fully Sharded Data Parallel、別名ZeRO(Zero Redundancy Optimizer))を網羅しています。さらに、演算子並列処理(Tensor Parallelism)やパイプライン並列処理など、様々な形式のモデル並列処理も網羅しています。
TorchAccのハイライトの一つは、様々な並列戦略を自動的に探索し、有機的に統合することで、ユーザーに高度に自動化された分散戦略構成ソリューションを提供することです。同時に、開発者のカスタマイズニーズを満たすために、TorchAccは半自動制御インターフェースも提供しており、ユーザーは並列戦略の自動探索プロセスに介入して調整することができ、柔軟性を維持しながらトレーニング効率とリソース利用率を最大化できます。
このように、TorchAccは、アルゴリズム開発者が分散トレーニングの具体的な実装の詳細に時間を費やすのではなく、モデルの構造設計、トレーニング手法の最適化、そしてモデルの収束性能の向上に集中できるよう効果的に支援します。TorchAccは、開発者が最適な分散トレーニングソリューションを探索し実装できるようインテリジェントに支援し、コンピューティングリソースの利用とアルゴリズムの反復処理の効率を大幅に向上させます。
第二に、モデルの並列化の必要性は、大規模なモデルが単一のGPUのメモリ容量の限界を超えるという事実に起因します。メモリ容量はモデルの学習にとって非常に重要であり、メモリボトルネックの克服は分散学習全体の効率向上に不可欠です。そのため、TorchAccは、きめ細かなスケジューリングとメモリリソースのアドレス割り当て戦略を通じて並列モデル学習の効率を最大化し、モデルが既存のメモリアドレス空間を最大限に活用できるようにするスマートなメモリアロケータを提供します。
さらに、モデル構造がますます複雑化し、規模が拡大するにつれて、ユーザーの計算リソースに対する需要も絶えず高まっています。そのため、学習中のモデルの計算強度をさらに最適化し、メモリアクセスのオーバーヘッドを削減することも重要です。
最後に、データセンターインフラの現在の発展動向を踏まえると、大規模モデルのトレーニングにおけるネットワーク環境への要件はますます厳しくなっています。現代のデータセンターサーバーは、大規模並列モデルのトレーニングにおける高速データ交換の要求を満たすため、テラバイト(TB)レベルに達する相互接続帯域幅を備えています。しかし、モデルの並列化によってもたらされる複雑な通信パターンと高頻度のデータインタラクションは、全体的なトレーニング効率にも課題をもたらします。そのため、ネットワーク帯域幅を有効活用し、反復計算における通信プロセスに費やされる時間の割合を削減することが、トレーニング効率を向上させるための重要な要素となっています。
TorchAcc の実装では、PyTorch ベースでも TensorFlow ベースでも、フロントエンドにおけるユーザーのモデルトレーニングプロセスを、一連の技術的手段を用いて統合された中間表現層(Model IR)グラフに変換します。TensorFlow の場合、それ自体が計算グラフモデルであるため、変換プロセスは比較的単純です。PyTorch の場合、シンボリックトレースや LazyTensor などの技術を用いて計算グラフをキャプチャし、それを IR グラフに変換します。
画像ソース: GTC 2024 China AI Day オンライン セッションのプレゼンテーション「TorchAcc: TorchXLA に基づく分散トレーニング フレームワーク」。 TorchAccは、中間表現層(IRグラフ)の構築に基づいて、計算最適化、ストレージ最適化、通信最適化、分散戦略最適化など、多様な最適化戦略を実装します。IRグラフは、これらの最適化されたパスを様々な組み合わせで反復実行し、最終的に最適な実行計画を取得します。この計画は、基盤となるバックエンドに渡され、モデル学習のパフォーマンス向上を最大化するために実行されます。
この包括的なソリューションにより、TorchAccは複数モデルの分散トレーニングシナリオにおいて、大幅なパフォーマンス向上を実証しました。一部のモデルのトレーニングプロセスでは、最大3倍のパフォーマンス向上を達成し、分散トレーニングの課題解決におけるTorchAccの効率性と実用性を完全に実証しました。
画像ソース: GTC 2024 China AI Day オンライン セッションのプレゼンテーション「TorchAcc: TorchXLA に基づく分散トレーニング フレームワーク」。 この画像は主にTorchAccの全体的なアーキテクチャを示しています。PyTorch/XLAをベースとし、OpenXLAを活用するTorchAccは、大規模モデルの学習を高速化するフレームワークを提供します。異なるフロントエンドを用いて構築されたモデルを処理する際、TorchAccはSymbolic TraceやLazyTensorといった適切なグラフキャプチャ技術を柔軟に活用し、FXグラフとHLOグラフという2つの異なるレベルのグラフ表現を生成します。FXグラフはより高い抽象度レベルにあり、HLOグラフはより低い抽象度レベルにあります。
キャプチャされたモデル計算グラフに基づいて、TorchAcc はさらに、前述の計算最適化、ストレージ最適化、通信最適化、分散戦略最適化という 4 種類の最適化作業を実行できます。
画像ソース: GTC 2024 China AI Day オンライン セッションのプレゼンテーション「TorchAcc: TorchXLA に基づく分散トレーニング フレームワーク」。 分散戦略の最適化レベルでは、TorchAcc は業界で広く使用されている様々な並列戦略をサポートし、これらの戦略を柔軟に組み合わせることで、特定のモデルを効果的に並列化できます。具体的には、データ並列化(DP)、パイプライン並列化(PP)、完全シャーディングデータ並列化(FSDP、ZeRO とも呼ばれる)という 3 つの分散戦略の実装と最適化はすべて、FX Graph のより高い抽象度レベルで完了します。
並列戦略をFXグラフレベルで操作することを選択した理由は、このレベルには計算グラフの構造と操作に関する十分な情報が含まれており、開発者が様々な並列戦略に適した最適化スキームを設計する上で役立つためです。下位レベルのHLOグラフで直接最適化する場合と比較して、FXグラフは抽象度と汎用性が高く、このレベルでの最適化は一般的にコストが低く、効率的でターゲットを絞った分散戦略の調整を容易に実装できます。
パイプライン並列処理を例に挙げると、システムはFX Graphレベルで異なるステージを自動的に検出し、適切な分割ポイントを決定することで、モデルを複数の連続実行ステージに効果的に分割し、パイプライン並列処理を実現します。このプロセスにおいて、FX Graphが提供する詳細な計算構造情報を活用して、インテリジェントなセグメンテーションを行うことができます。
より複雑な並列化戦略であるテンソル並列化とシーケンス並列化では、意思決定のためにより詳細かつ正確な情報が必要となります。これを実現するために、システムは順方向および逆方向の伝播中に計算グラフ全体の実行計画を分析する必要があります。この作業は主にHLO(ハードウェアローカル割り当て)の低レベル表現レベルで実行されます。
PyTorch/XLA が提供するマーク シャーディング インターフェイスを利用することで、システムはモデル パラメータに対応する分割マーカーを追加し、この分割情報を OpenXLA の SPMD 最適化パスに渡すことで、計算グラフの分割、最適化、導出、書き換えのプロセスをトリガーし、最終的に自動テンソル並列処理とシーケンス並列処理機能を実現します。
画像ソース: GTC 2024 China AI Day オンライン セッションのプレゼンテーション「TorchAcc: TorchXLA に基づく分散トレーニング フレームワーク」。 演算子最適化レベルでは、TorchAccはFlashAttentionテクノロジーを導入し、Attentionモジュールの実行効率を向上させます。まず、XLAのカスタム呼び出し機能を通じて、FlashAttentionの実装がOpenXLAコンパイラとランタイムフレームワークにシームレスに統合されます。これにより、FlashAttentionをXLAカーネルレベルで直接実行できるようになり、ハードウェアアクセラレーション機能を最大限に活用できるようになります。
統合プロセスにおいては、PyTorchとXLA間のTensorデータ転送を適切に処理することが極めて重要であり、これにより、両システム間の変換におけるデータの一貫性とパフォーマンスの最適化が確保されます。同時に、FlashAttentionの内部パラメータの転送といった細部も適切に処理する必要があります。これにより、並列計算と最適化において、これらの重要なパラメータが正しく効率的に計算に適用され、アテンションメカニズム実行時のモデルの計算速度とリソース利用率がさらに向上します。
FlashAttentionの最適化機能へのユーザーアクセスを容易にするために、2つのインターフェースを提供しています。1つはPythonインターフェースを介して、事前に記述されたFlashAttention演算子を直接呼び出す方法です。もう1つは、OpenXLA上で事前に記述されたパターンマッチパスを利用する方法です。このパスは、計算グラフ内のアテンションブロックを自動的に識別し、これらの計算構造を抽出して、カスタムFlashAttention呼び出しに置き換えます。この設計の利点は、XLAの優れたカーネルフュージョンやその他の演算子最適化機能を最大限に活用しながら、FlashAttentionが提供する高度な計算最適化手法と組み合わせることができる点です。
画像ソース: GTC 2024 China AI Day オンライン セッションのプレゼンテーション「TorchAcc: TorchXLA に基づく分散トレーニング フレームワーク」。 Llama 2-7Bモデルの性能テストでは、上記の計算最適化の効果が明確に確認できました。XLA独自の最適化技術、特にカーネルフュージョンを活用することで、メモリを大量に消費する多数の演算子を効果的に統合し、その数を大幅に削減しました。FlashAttentionを追加することで、最適化性能はさらに向上しました。
画像ソース: GTC 2024 China AI Day オンライン セッションのプレゼンテーション「TorchAcc: TorchXLA に基づく分散トレーニング フレームワーク」。 通信最適化レベルでは、分散トレーニングの効率を向上させるために、主に3つのコアタスクを完了しました。まず、散在していた集団通信演算子を統合し、演算子の数を減らすことで通信オーバーヘッドとスケジューリングの複雑さを軽減しました。次に、統合した集団通信演算子を独立したCUDAストリームに移動して実行することで、計算と通信の非同期オーバーラップを可能にしました。最後に、OpenXLAのLatency Hiding Schedulerを最大限に活用し、通信演算子のスケジューリングを細かく最適化し、それらが可能な限り早期に開始および実行されるようにすることで、通信と計算のオーバーラップを強化しました。
画像ソース: GTC 2024 China AI Day オンライン セッションのプレゼンテーション「TorchAcc: TorchXLA に基づく分散トレーニング フレームワーク」。 Llama2-7Bモデルを用いたエンドツーエンドのマルチマシンパフォーマンステストの結果、通信最適化戦略を適用した後、128枚のGPUカードで分散学習を実行した際の速度向上率が88から116に向上することがわかりました。タイムライングラフからも、最適化された通信演算子がより秩序立ち、計算とのオーバーラップが改善されていることが分かります。
画像ソース: GTC 2024 China AI Day オンライン セッションのプレゼンテーション「TorchAcc: TorchXLA に基づく分散トレーニング フレームワーク」。 この記事の最後のセクションでは、計算グラフ内の演算子の実行順序とメモリ内のテンソルのアドレス割り当てを最適化することでメモリのオーバーヘッドを削減する TorchAcc のメモリ最適化機能について説明します。
図に示すように、4つの演算子V0、V1、V2、V3を含む計算グラフがあるとします。演算子の実行順序が制御されておらず、左図に示すようにV0-V1-V2-V3の順序で実行され、各テンソルがデフォルトの方法でメモリアドレスを要求すると、図Bの左半分に示すような状況、つまりすべてのテンソルを収容するのに十分なメモリ容量がなく、メモリ不足エラーが発生する可能性があります。
しかし、メモリ割り当てを予測し、細かく管理できれば、つまりアドレス割り当て時に実行される演算子の順序を予測できれば、図Bの右半分に示すように、より適切なメモリレイアウトを実現し、限られたメモリ内で計算全体をスムーズに完了させることができます。さらに、V0-V2-V1-V3の順序で実行するなど、実行順序を厳密に制御することで、メモリ使用量を当初の約70%にまで削減できます。
このコンセプトは、XLA中間プレゼンテーション層の既存のスケジューラとバッファ管理メカニズムに基づいており、これに基づいてより高度なメモリ最適化手法を提案します。現在、業界ではメモリ割り当てを最適化するために、ヒューリスティックアルゴリズムや制約解決など様々な手法が存在します。しかし、これらの手法は適時性と効率性のバランスを取ることが難しい場合が多く、実際の本番環境のクラスタに適用する際には限界が生じる可能性があります。
画像ソース: GTC 2024 China AI Day オンライン セッションのプレゼンテーション「TorchAcc: TorchXLA に基づく分散トレーニング フレームワーク」。 トレーニング シナリオで効果的かつ効率的なメモリ最適化を実現することは、主に次の理由により、非常に困難なタスクです。
NP 困難問題の本質: モデルの規模、演算子の多様性、演算子間のメモリ割り当ての複雑さにより、メモリの最適化は典型的な NP 困難問題となり、通常、グローバルな最適解を見つけることは計算上不可能であることを意味します。 演算子実行の柔軟性:学習中、順方向伝播、逆方向伝播、重み更新といった演算は非常に柔軟です。特に重み更新においては、勾配は生成後いつでも重み更新に利用できます。しかし、実行時間の違いはGPUメモリの割り当てと解放に影響を与え、最適化の難易度を高めます。 メモリ再利用の複雑さ: トレーニング中に、順方向および逆方向の伝播によってメモリを再利用することで再計算を削減できますが、Tensor のライフサイクルとサイズのバリエーションの多様性により、メモリの再利用が非常に複雑になり、ヒューリスティック アルゴリズムなどの従来の最適化方法に深刻な課題が生じます。
上記の問題を解決するために、私たちは分割統治戦略を採用しました。
メモリを考慮した重み更新スケジューラ:勾配生成のタイミング、使用する最適化手法の種類、および現在のメモリリソースの状態に基づいて適切な重み更新タイミングを選択する、メモリを考慮した重み更新スケジューラを導入します。これにより、特にAdamのような複雑な最適化手法では、モメンタムやその他の変数の保存を考慮する必要があるため、即時更新によるメモリ負荷の増大を回避できます。 グラフ分割と局所最適化:大規模な計算グラフは、キーノード(メモリ非依存演算子)に基づいて、メモリ非依存の複数のサブグラフに分割されます。サブグラフ間の実行順序は固定ですが、サブグラフ内の実行順序は変化します。このようにして、複雑なグローバル線形計画問題を複数のローカル問題に分解し、最適な実行順序を見つける線形計画法などの効率的な最適化手法をサブグラフ内に適用できます。
上述の分割統治戦略により、最終的にこれらのサブグラフの解の結果を集約することが可能になります。これが、私たちが提案するメモリ最適化探索手法、ROAM(Reorder Operators and Arrange Tensors Address to Reduce Memory Usage)です。 上記の手法は、メモリ最適化問題を効率的に処理することに成功しました。実験結果では、ネイティブPyTorch、ヒューリスティックアルゴリズム、そしてFacebookの最新の整数線形計画法に基づく最適化手法などのベースラインと比較して、ROAMはそれぞれ約16%、13%、27%のメモリオーバーヘッドを削減し、最適化時間とスケーラビリティの点で優れた性能を示し、この手法の有効性を裏付けています。
画像ソース: GTC 2024 China AI Day オンライン セッションのプレゼンテーション「TorchAcc: TorchXLA に基づく分散トレーニング フレームワーク」。
画像ソース: GTC 2024 China AI Day オンライン セッションのプレゼンテーション「TorchAcc: TorchXLA に基づく分散トレーニング フレームワーク」。 別の観点からパフォーマンスを測定するため、アルゴリズムの時間計算量を調べました。実験の結果、一般的な深層学習シナリオにおいて、当社の最適化アルゴリズムはわずか数分で結果を生成できることが実証されました。右の比較に示すように、Facebookが最近提案したMODeL(線形計画法に基づく最適化手法)と比較して、当社の手法は解を求める時間を大幅に短縮しています。これは、MODeLが大規模グラフを効果的に分割しないのに対し、当社の手法はメモリを考慮した重み更新スケジューラとサブグラフ分割戦略を導入することで最適化問題の空間計算量を効果的に削減し、解を求める効率を向上させるためです。
要約すると、TorchAcc は、メモリ最適化、計算最適化、通信最適化、並列戦略最適化において大きな成果を達成し、分散トレーニングの効率とパフォーマンスを総合的に向上させました。
上記のコンテンツは、GTC 2024 China AI Dayのオンライン中国語プレゼンテーションセッションからの抜粋です。画像のQRコードをスキャンするか、カンファレンスのウェブサイトにアクセスして、プレゼンテーションビデオを視聴し、配布資料をダウンロードしてください。
|