ディープラーニングモデルの安全な並列推論とパフォーマンス最適化
こんにちは!
今日は、よく聞かれる質問の1つである「単一のモデルインスタンスで安全に並列推論を行えるか?」に関する内容です!
evalモードでの並列推論の安全性
PyTorchモデルがmodel.eval()
を使用してevalモードに設定されている場合、一般的に並列推論に対して安全になります。
(ここでいう「並列」はマルチスレッドによる処理ととらえてください。バッチ推論については後述します。)
その理由は、
- パラメータの不変性
evalモードでは、順伝播(forward pass)中にモデルのパラメータが更新されません。 - 学習特有レイヤーの非活性化
BatchNormなどのレイヤーは、バッチ統計の計算ではなく、実行時統計(running statistics)を使用するモードに切り替わります。 - 入力データの独立性
各スレッドやプロセスは独自の入力データで動作し、それぞれ別のメモリ領域に存在します。
以下は、evalモードでの安全な並列推論の基本的な例です:
import torch
import threading
def safe_inference(model, data):
with torch.no_grad():
return model(data)
model = YourModel()
model.eval() # 重要: evalモードに設定
# 複数スレッドで推論を実行
threads = []
for i in range(10):
t = threading.Thread(target=safe_inference, args=(model, your_data[i]))
threads.append(t)
t.start()
for t in threads:
t.join()
注意が必要な場合
しかし、以下のような状況では注意が必要です:
- カスタムレイヤーの存在
独自に実装したレイヤーがある場合、その並列実行時の挙動を慎重に確認する必要があります。
class CustomLayer(torch.nn.Module):
def __init__(self):
super().__init__()
self.counter = 0 # 潜在的な問題源
def forward(self, x):
self.counter += 1 # スレッドセーフではない
return x + self.counter
# このようなカスタムレイヤーは並列実行時に問題を引き起こす可能性があります
- GPUメモリの制約
複数スレッドが同時に大量のデータを処理する場合、GPUメモリ不足が発生する可能性があります。 - 複雑なモデル構造
特定のタイプのAttentionメカニズムなど、一部の複雑なモデル構造では、並列実行時に予期せぬ挙動を示す可能性があります。
プールの使用
上記のような注意が必要な場合、モデルインスタンスのプールを使用することで問題を回避できる場合があります。
以下は簡単なモデルプールの実装例です
import torch
from queue import Queue
class ModelPool:
def __init__(self, model_class, num_instances):
self.pool = Queue()
for _ in range(num_instances):
model = model_class().to('cuda')
model.eval()
self.pool.put(model)
def get_model(self):
return self.pool.get()
def return_model(self, model):
self.pool.put(model)
def safe_pooled_inference(pool, data):
model = pool.get_model()
try:
with torch.no_grad():
result = model(data)
return result
finally:
pool.return_model(model)
# 使用例
pool = ModelPool(YourModel, num_instances=3)
results = [safe_pooled_inference(pool, data) for data in your_data_list]
このアプローチでは、各推論タスクが独立したモデルインスタンスを使用するため、並列実行時の問題を回避できます。
パフォーマンスの最適化の基本はバッチ
並列推論は柔軟性を提供しますが、オーバーヘッドによりパフォーマンスが低下する可能性があります。ここでは、パフォーマンスを向上させるための重要なヒントを紹介します。
バッチ処理の活用
個別の並列推論よりも、バッチ処理を活用することで大幅なパフォーマンス向上が見込めます。GPUは大量のデータを同時に処理するのに適しているため、バッチ処理はGPUの能力を最大限に活用できます。
1. 静的バッチ処理
最も単純な方法は、固定サイズのバッチを使用することです:
def batch_inference(model, data_list, batch_size=32):
results = []
for i in range(0, len(data_list), batch_size):
batch = torch.stack(data_list[i:i+batch_size])
with torch.no_grad():
batch_results = model(batch)
results.extend(batch_results)
return results
# 使用例
results = batch_inference(model, your_data_list)
ただし、都合よくバッチのタイミングでアクセスは来ない
Webサービスなのでオンデマンドな推論サービスをつくってるときには、GPUの単純な並列推論だけでは対処しきれません。
なぜなら、都合よく、同じタイミングでユーザーがアクセスしてこないからです。
むしろうまくバッチにのせられるタイミングのほうがマレです。
2. ダイナミックバッチング
リアルタイムで到着するデータを効率的に処理するために、ダイナミックバッチングを使用できます
import time
from collections import deque
class DynamicBatcher:
def __init__(self, model, max_batch_size=32, max_wait_time=0.1):
self.model = model
self.max_batch_size = max_batch_size
self.max_wait_time = max_wait_time
self.queue = deque()
self.results = {}
def add_item(self, item_id, data):
self.queue.append((item_id, data))
if len(self.queue) >= self.max_batch_size:
self.process_batch()
def process_batch(self):
batch_ids, batch_data = zip(*[self.queue.popleft() for _ in range(len(self.queue))])
batch_tensor = torch.stack(batch_data)
with torch.no_grad():
batch_results = self.model(batch_tensor)
for item_id, result in zip(batch_ids, batch_results):
self.results[item_id] = result
def get_result(self, item_id):
start_time = time.time()
while item_id not in self.results:
if time.time() - start_time > self.max_wait_time:
self.process_batch()
time.sleep(0.01)
return self.results.pop(item_id)
# 使用例
batcher = DynamicBatcher(model)
def process_item(item_id, data):
batcher.add_item(item_id, data)
return batcher.get_result(item_id)
# 複数スレッドからprocess_itemを呼び出す
このアプローチでは、データが到着次第バッチに追加され、バッチサイズが最大に達するか、最大待機時間を超えた場合に処理が実行されます。
3. 連続バッチ処理
また、連続的にデータが生成される場合、以下のような連続バッチ処理が効果的です
import torch
from torch.utils.data import DataLoader, IterableDataset
class ContinuousDataset(IterableDataset):
def __iter__(self):
while True:
yield self.get_next_item() # データ生成ロジックを実装
def get_next_item(self):
# 実際のデータ生成ロジックをここに実装
pass
def continuous_batch_inference(model, dataset, batch_size=32):
dataloader = DataLoader(dataset, batch_size=batch_size)
for batch in dataloader:
with torch.no_grad():
yield model(batch)
# 使用例
dataset = ContinuousDataset()
for batch_results in continuous_batch_inference(model, dataset):
process_results(batch_results) # 結果の処理
この方法では、データが連続的に生成される場合でも、効率的にバッチ処理を行うことができます。
まとめ
今回は、とくに1台のGPUにおける並列化とパフォーマンスについて解説しました。
evalモードでの並列推論は多くの場合安全ですが、パフォーマンスを最大化するためにはバッチ処理が必須ですね。またディープラーニング、LLM系のサービスの推論シーンは多くの場合でダイナミックバッチング、連続バッチ処理などの技術が重要となります。当社でも、ダイナミックバッチ、連続バッチを当初から研究しており、LLMや動画生成、AIキャラクター応答にも応用しています。
これらのテクニックを適切に選択し、実装することで、推論のスループットを大幅に向上させることができます。
並列化とは別観点ではありますが、モデルの量子化、TorchScript の使用、GPU 最適化など、追加の手法を組み合わせることで、さらなるパフォーマンス向上が期待できます。
GPUはとても高額な機器なので、1台のGPUを「使い切る」という視点は非常に重要で当社Qualitegでも日々技術を磨いています。
さらに大規模なアクセスには「GPUクラスター」の導入を考えましょう
一方、大量の同時アクセスが想定されるシーンでは複数台のGPUを使用した負荷分散が必須となります。そちらのテクニックについてもまた別途ブログにて投稿させていただこうとおもいますが、以下の動画に LLM におけるGPUクラスターの構成方法について解説していますので、こちらもよろしければご覧くださいませ。
それでは、また次回お会いしましょう!