HUOXIU

特定のバッチ サイズで突然パフォーマンスが低下するのはなぜですか?

編集者注:ディープラーニングモデルを最適化する際に、バッチサイズを大きくしてもGPU使用率が期待通りに向上せず、困惑したことはありませんか? 実際のプロジェクトでは、この問題はリソースの無駄、トレーニング効率の低下、さらにはAI製品のデリバリーサイクル全体に影響を及ぼす可能性があります。

本稿では、最新のGPUバッチ処理の動作原理を詳細に分析し、メモリ帯域幅と計算能力の微妙な関係を明らかにします。理論モデルを構築し、それを実際の実験と組み合わせることで、特定のバッチサイズがパフォーマンスの低下を突然引き起こす理由を説明するだけでなく、最適なバッチサイズを見つけるための手法も提供しています。

著者 | フィンバー・ティンバーズ

編纂者:岳陽

一般的に、現代のディープラーニングシステムで最初に行う最も重要な最適化は、バッチ処理を実装することです。推論中に単一の入力を処理する代わりに、N個の入力を含むバッチデータを同時に処理します。ほとんどの場合、この操作は追加コストを発生しません。推論に必要な時間は、単一の入力を処理する場合でもN個の入力を処理する場合でもほぼ同じです。では、なぜでしょうか?一見すると、バッチ処理はより多くのリソースを消費するように見えますが、結局のところ、ワークロードはN倍に増加するからです。

しかし、ニューラルネットワークの仕組みを理解するために単純または未熟なモデルを使用する場合、バッチ処理にはコストがかかります。実際、バッチ処理にはN倍の計算能力が必要です。特定の計算タスクをCPUで実行すれば、この点が真実であることがわかります。

しかし、同じ計算タスクを最新のGPUで実行すると、結果は変わりました。T4 GPUで観察された結果は次のとおりです。

グラフに示されているように、バッチ サイズが 1 から 3 に増加しても消費時間は増加しません。ただし、バッチ サイズが 3 を超えると、消費時間は直線的に増加します。

その理由は何でしょうか?鍵となるのは同時処理能力です。最新のGPUは複数の計算を同時に実行できます(ただし、単一スレッドで処理する場合はCPUよりも遅くなります)。

「モデルを用いた単一データサンプルの推論」について話すとき、モデルを単一のブロックとして考えることがよくあります。しかし、実際にはモデルは多数の行列で構成されています。推論中、各行列はメモリにロードされます。具体的には、行列の各ブロックがデバイスの共有メモリ(A100 GPUではわずか192KB)にロードされます。このブロックは、バッチ内の各要素の結果を計算するために使用されます。これはGPU RAM(HBM)とは異なることに注意することが重要です。A100 GPUは、モデルに応じて40GBまたは80GBのHBMを搭載していますが、デバイスメモリはわずか192KBです。そのため、数学演算を実行する際には、デバイスメモリからデータを絶えず読み書きする必要があるため、メモリ帯域幅がパフォーマンスのボトルネックとなります。重みの転送に必要な時間は、モデルサイズをメモリ帯域幅で割ることで推定でき、計算時間はモデルの浮動小数点演算(FLOPS)をGPUのFLOPSで割ることで推定できます。

多層パーセプトロン(MLP)を用いた場合、浮動小数点演算回数(FLOPS)は、パラメータ数×バッチ要素数[1]の約2倍(つまり、2 * m * n * b、バッチサイズはb、行列はm x n)となる。したがって、転送時間と計算時間が等しい場合、以下のことを意味する。

ここで、両辺のパラメータの数が互いに打ち消し合う可能性があることがわかります。

さらに、バッチ サイズに基づいてバッチを並べ替えることもできます。

バッチサイズがFLOPSとメモリ帯域幅の比よりも小さい場合、メモリ帯域幅がパフォーマンスのボトルネックになります。バッチサイズがこの比を超えると、計算能力(FLOPS)が新たなボトルネックになります。この分析は多層パーセプトロン(MLP)にのみ適用され、ResNet50のような畳み込みニューラルネットワークでは状況がより複雑になることに注意してください。

T4 GPU(製品仕様書[2])では、浮動小数点演算能力は65TFLOPS(32ビット浮動小数点数)、メモリ帯域幅は300GB/sに達します。このデータによると、理想的なマジックレシオは216です。実際に深さ8、幅1024の多層パーセプトロン(MLP)モデルを実行したところ、期待通りの結果が得られました。

データには多少のノイズが含まれていますが、全体的な傾向は予測と一致しています。推論時間は閾値128付近で急激に増加し始めます(ここでは、段階的に2倍にするアプローチを採用し、異なるバッチサイズが推論時間に与える影響を観察し記録しています)。MLP層の幅を変えると、この現象が様々なアーキテクチャで発生することがわかります(以下は、すべてのデータポイントがグラフに明確に表示されるように両対数プロットしたものです)。

これはすごいですね!🆒 様々なモデルアーキテクチャにおいて、重要な閾値が見られます。興味深いことに、小規模なネットワークではバッチサイズ(1から512)による速度向上は見られず、処理速度はほぼ一定のままです。当初の私の説明では、GPUは数学演算を非常に高速に実行できるのに対し、他のハードウェア(CPUなど)は比較的遅いためだと考えられます。実験の初期段階では、かなりのノイズ干渉が観測されましたが、これは今のところ「システムオーバーヘッド」によるものとしか考えられません。

多くの機械学習エンジニアは、機械学習そのものよりも、機械学習に関連しないコードで発生するシステムオーバーヘッドの削減に時間を費やすことがよくあります。強化学習(RL)研究、特に継続学習問題に焦点を当てる研究者にとって、実験でGPUを使用することは、1) 非常に大規模なニューラルネットワークを持っている場合、または2) テクノロジースタック全体を極限まで最適化している場合を除いて、費用対効果が低いことがよくあります。DeepMindで働いたことがあるエンジニアを困惑させたいなら、「グラフ内環境」について聞いてみてください。かつて私たちは、TensorFlowの計算グラフ内にRL環境を実装したことさえありました。

では、畳み込みニューラル ネットワークについてはどうでしょうか?

畳み込みニューラルネットワークでは、重みの総数はフィルター数とフィルターサイズの積です。`torch.nn.Conv2d` を例にとると、重みは `kernel_size^2` と `out_channels` の積として計算されます。解像度 (224, 224)、ストライド 1、カーネルサイズ 3 の画像を処理すると仮定すると、各フィルターは 224 回使用されます。つまり、畳み込み層では同じ重みを繰り返し使用するため、バッチ処理による大きな利点はありません。プーリング層に関しては、計算コストは​​ピクセル数に比例しており、これはご想像のとおりです。

トランスフォーマーはどうですか?

Transformerは本質的に多層パーセプトロン(MLP)であり、同じものと考えることができます。Transformerには注意機構がありますが、KVキャッシュ(計算データをメモリ上に保持できる)のおかげで、注意機構にかかる時間が大幅に短縮されます。これについては以前、詳しく書きました[3]。

この考え方は、Mixture of Expertsモデルにも当てはまります。多くのTransformer実装では、KVキャッシュはAttentionクラスに組み込まれています(例えば、MaxText[4]は典型的なケースです[5])。MoEモデルと通常のデコーダーの唯一の違いは、フィードフォワードネットワーク層の一部がMoE層に置き換えられているだけなので、KVキャッシュのパフォーマンスは推論プロセスと同様に一定に保たれますが、1つの違いがあります。

MoE層のゲーティング機構は、データバッチを複数のエキスパートに分配します。ゲーティングがデータバッチを均等に分配しない場合、問題が発生する可能性があります。これを回避するルーティング機構(「エキスパートの選択」など)は存在しますが、自己回帰デコーダーでは通常「トークンの選択」しかなく、ゲーティングバイアスにつながる可能性があります。ゲーティングによってトークンを均等に分配することは、1) 現在の研究の焦点であり、2) 学習中に最適化する必要がある重要な目標です。

読んでくれてありがとう!

このブログを楽しんで、新しいことを学んでいただければ幸いです。

著者について

フィンバー・ティンバーズ

経験主義者。機械学習研究者。以前はDeepMindでエンジニアリングに従事。🧠

終わり

今週のインタラクティブコンテンツ🍻

実際のプロジェクトではバッチサイズをどのように選択していますか?予期せぬパフォーマンスのボトルネックに遭遇したことはありますか?

🔗記事内のリンク🔗

[1]https://www.stat.cmu.edu/~ryantibs/convexopt-F18/scribes/Lecture_19.pdf

[2]https://www.nvidia.com/content/dam/en-zz/Solutions/Data-Center/tesla-t4/t4-tensor-core-datasheet-951643.pdf

[3]https://www.artfintel.com/p/where-do-llms-spend-their-flops

[4]https://github.com/google/maxtext

[5]https://github.com/google/maxtext/blob/main/MaxText/layers/attentions.py#L91

オリジナルリンク:

https://www.artfintel.com/p/how-does-batching-work-on-modern