PyTorchモデルの最適化~TorchScriptの仕組みと活用法~
こんにちは!
本日は PyTorch で開発したAIアプリケーションの本番化に欠かせない、「最適化」についての内容です。具体的には「 TorchScript」 を使用した各種学習モデルの最適化についてみていきたいとおもいます。
TorchScriptの基礎
1 TorchScriptとは
TorchScriptは、PyTorchモデルを最適化された中間表現(IR)に変換する技術です。
、、といってもちょっと難しく聞こえるかもしれません。
平易な言葉で言い換えますと、
要するに、PyTorchで作った機械学習モデルを高速かつ多種多様な環境で動作させることをするための技術です。
例えば、、
・Pythonがインストールされていない環境でも動かせるようにする
・スマホはじめ、各種組み込み機器でも使えるようにする
・動かすときの速度を段違いに上げる
・複数の処理を同時に効率よく実行する
などを目論むときは TorchScript がおすすめです。
つまり、TorchScriptは「本番サービス」で使うときにすごく役立ちます。
2 TorchScriptの動作の仕組み
それでは、PyTorchのモデルをTorchScriptに変換する基本的な仕組みを説明します。
2.1 モデルの定義
まず、シンプルな画像認識用のCNNモデルを作りましょう
import torch
class SimpleModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv1 = torch.nn.Conv2d(3, 64, 3)
self.relu = torch.nn.ReLU()
self.pool = torch.nn.MaxPool2d(2)
# 入力サイズから全結合層の入力次元を計算
# 入力サイズが32x32の場合の計算過程:
# 1. Conv2d(kernel=3, stride=1, padding=0): 32x32 → 30x30
# 2. MaxPool2d(kernel=2, stride=2): 30x30 → 15x15
input_features = 64 * 15 * 15 # 64はconv1の出力チャンネル数
self.fc = torch.nn.Linear(input_features, 10)
def forward(self, x):
x = self.conv1(x) # [B, 64, 30, 30]
x = self.relu(x)
x = self.pool(x) # [B, 64, 15, 15]
x = x.view(x.size(0), -1) # [B, 64*15*15]
return self.fc(x)
まずは上記モデルの説明をいたします。
SimpleModelの構造解説
はい、このコードは画像認識のための基本的なCNNモデルですね。
モデルは32×32ピクセルのRGB画像(3チャンネル)を入力として受け取り、3チャンネルの入力を64チャンネルに拡張し、3×3のカーネルで特徴を検出させます。
畳み込み層の後には活性化関数としてReLUを配置、続くプーリング層は、2×2の領域から最大値を取り出すことで特徴マップのサイズを半分に縮小しています。
最後に全結合層では、それまでの層で抽出された特徴を受け取り、10個のクラスに分類されるようにしています。で、この層に入力する前に、多次元の特徴マップを1次元フラット化します。
2.2. TorchScriptへの変換
さて、上述のCNNをTorchScript変換するまえに、変換方法についてみてみましょう。
TorchScriptの変換には2つの方法があります、それは「トレース方式」と「スクリプト方式」です。
2つの方法といっても、両者、実行方法はとてもシンプルで以下のコードのようにtorchscript変換コードを記述することができます
def convert_to_torchscript(model, method='trace'):
model.eval() # 評価モードに設定
if method == 'trace':
# トレース方式
example_input = torch.randn(1, 3, 32, 32)
traced_model = torch.jit.trace(model, example_input)
return traced_model
else:
# スクリプト方式
return torch.jit.script(model)
2つの方式には以下のような特徴があります
トレース方式
- ダミーデータを使って実際の処理を追跡
- 固定的な入力パターンに適している
- 処理の流れが単純な場合に使用
スクリプト方式
- コードを直接解析して変換
- 柔軟な処理が可能
- 条件分岐やループを含む複雑なモデルに適している
2.3. 使用例
さて、さきほどつくったCNNモデルをTorchScriptとして「本番」向けに変換してみましょう。
method='trace' とすることで、「トレース方式」で変換します。
# モデルの作成
model = SimpleModel()
# 本番環境用に変換
production_model = convert_to_torchscript(model, method='trace')
# 変換したモデルで推論
input_data = torch.randn(1, 3, 32, 32)
result = production_model(input_data)
convert_to_torchscript
では example_input = torch.randn(1, 3, 32, 32)
のようにダミーデータを流し込んでいますね。
ダミーデータ(ダミーじゃなくて本番データでも別にかまいません)を流し込むことでネットワークを「トレース」します。
いかがでしょうか。
とても簡単に「本番用」に変換できました。
基本的には実にシンプルな手順で研究開発で作成したPyTorchモデルを、本番環境で効率的に動作する形式に変換することができまます。
これだけで速度の向上や、Pythonに依存しない実行環境での利用が可能な形式を実現できます。
3 TorchScriptの内部表現
TorchScriptで変換したモデルの中身を覗いてみましょう。
モデルの内部で何が起きているのか、どんな最適化が行われているのか、実際に確認する方法をお伝えします。
3.1 内部表現の確認方法
まずは、内部表現を確認するための関数を作ってみましょう:
def inspect_torchscript(model):
# モデルのグラフ構造を表示
print("Graph Structure:")
print(model.graph)
# 最適化されたコードを表示
print("\nOptimized Code:")
print(model.code)
# 使用されている演算子を表示
print("\nOperators:")
for node in model.graph.nodes():
print(f"- {node.kind()}")
3.2 内部表現の理解
この関数を使うと、まず最初にモデルの計算フローがグラフ構造として表示されます。これは私たちの書いたPyTorchのコードが、実際にどのような計算の流れに変換されているのかを教えてくれます。データがモデルの中でどのように変換され、どの層がどんな順序で処理を行うのか、その全体像を把握することができます。
次に表示される最適化済みコードは、実際に実行される形に変換された後のコードです。TorchScriptによる最適化で、私たちが書いた元のコードがどのように変更されたのか、その違いを見ることができます。この情報は特にデバッグを行う際にとても役立ちます。
最後に表示される演算子のリストは、モデル内で使用されているすべての演算の種類を教えてくれます。これを見ることで、処理の重い演算子が含まれていないかチェックしたり、最適化の余地がある部分を見つけたりすることができます。
3.3 実践的な使い方
では、実際に使ってみましょうか。
まずモデルを作成してTorchScriptに変換し、その内部を覗いてみます
model = SimpleModel()
production_model = convert_to_torchscript(model, method='script')
input_data = torch.randn(1, 3, 32, 32)
result = production_model(input_data)
print(f"Input shape: {input_data.shape}")
print(f"Output shape: {result.shape}")
inspect_torchscript(production_model)
3.4 TorchScriptと計算グラフ
さて、さきほどの実行結果は以下のようになりました。
Input shape: torch.Size([1, 3, 32, 32])
Output shape: torch.Size([1, 10])
Graph Structure:
graph(%self : __torch__.SimpleModel,
%x.1 : Tensor):
%17 : int = prim::Constant[value=-1]() # ex/ts1.py:21:30
%13 : int = prim::Constant[value=0]() # ex/ts1.py:21:26
%conv1 : __torch__.torch.nn.modules.conv.Conv2d = prim::GetAttr[name="conv1"](%self)
%x.5 : Tensor = prim::CallMethod[name="forward"](%conv1, %x.1) # ex/ts1.py:18:12
%relu : __torch__.torch.nn.modules.activation.ReLU = prim::GetAttr[name="relu"](%self)
%x.9 : Tensor = prim::CallMethod[name="forward"](%relu, %x.5) # ex/ts1.py:19:12
%pool : __torch__.torch.nn.modules.pooling.MaxPool2d = prim::GetAttr[name="pool"](%self)
%x.13 : Tensor = prim::CallMethod[name="forward"](%pool, %x.9) # ex/ts1.py:20:12
%14 : int = aten::size(%x.13, %13) # ex/ts1.py:21:19
%18 : int[] = prim::ListConstruct(%14, %17)
%x.19 : Tensor = aten::view(%x.13, %18) # ex/ts1.py:21:12
%fc : __torch__.torch.nn.modules.linear.Linear = prim::GetAttr[name="fc"](%self)
%22 : Tensor = prim::CallMethod[name="forward"](%fc, %x.19) # ex/ts1.py:22:15
return (%22)
Optimized Code:
def forward(self,
x: Tensor) -> Tensor:
conv1 = self.conv1
x0 = (conv1).forward(x, )
relu = self.relu
x1 = (relu).forward(x0, )
pool = self.pool
x2 = (pool).forward(x1, )
x3 = torch.view(x2, [torch.size(x2, 0), -1])
fc = self.fc
return (fc).forward(x3, )
Operators:
- prim::Constant
- prim::Constant
- prim::GetAttr
- prim::CallMethod
- prim::GetAttr
- prim::CallMethod
- prim::GetAttr
- prim::CallMethod
- aten::size
- prim::ListConstruct
- aten::view
- prim::GetAttr
- prim::CallMethod
上記のようにSimpleModelをTorchScriptに変換し、その内部構造を表示させてみました。
まずは、CNNのコードが正しく実装されている証拠として、入力として32×32ピクセルのRGB画像(shape: [1, 3, 32, 32])を与えると、10クラスの予測値(shape: [1, 10])が出力されることが確認できました。
計算グラフを思い出して、TorchScriptのありがたみをかみしめる
TorchScriptは私たちが書いたPythonコードを「計算グラフ」という形式に変換します。
計算グラフ覚えてますか?
PyTorchのAutoGradなどに慣れてしまうと、計算グラフをイチイチ意識しなくなるのですが、
少し思い出しましょう。
計算グラフはモデル内での演算の流れを有向グラフとして表現します。各ノードは演算(畳み込みやReLUなどなど)を表し、エッジはデータの流れを表します。この構造により、順伝播(forward)での計算過程を正確にトレースします。
モデルの学習時は、まず順伝播で損失値を計算しますよね。
次に、この損失値を基に各パラメータの勾配を計算する必要があります。
各パラメータが最終的な損失にどれだけ影響を与えているのかを知る必要があるからでしたね。
損失から見て、直前の層のパラメータの影響は計算しやすいのですが、より手前の層になればなるほど、間に複数の計算が入るため直接的な影響を計算するのは難しくなります。
そこで、損失から順番に逆向きに計算していくことで、「連鎖律」を使って効率的に各パラメータの勾配を求めることができます。
これが逆伝播(backward)でしたね。
うっすら思い出してきたところで、さらに理解を深めるために、具体例で「計算グラフ」と「連鎖率」をみてみましょう。
計算グラフも連鎖率も余裕で覚えてるっていう方、TorchScriptの結果だけみたい場合は以下読み飛ばしてもらって構いません。
計算グラフの具体例
入力$x$に対して、まず$a$をかけて、次に$b$を足し、最後に2乗するという計算を考えてみましょう。これを数式で書くと:
$$
\begin{aligned}
y &= (ax + b)^2
\end{aligned}
$$
この計算は、以下のような小さなステップに分解できます:
$$
\begin{aligned}
p &= ax \
q &= p + b \
y &= q^2
\end{aligned}
$$
連鎖律と勾配の計算
この勾配を以下のように分解できるのが「連鎖律」ですね
$$
\begin{aligned}
\frac{dy}{dx} &= \frac{dy}{dq} \times \frac{dq}{dp} \times \frac{dp}{dx} \
&= 2q \times 1 \times a \
&= 2a(ax + b)
\end{aligned}
$$
はい、この各ステップでの勾配を見てみましょう
$$
\begin{aligned}
\frac{dy}{dq} &= 2q & \text{(2乗の微分)} \
\frac{dq}{dp} &= 1 & \text{(足し算の微分)} \
\frac{dp}{dx} &= a & \text{(掛け算の微分)}
\end{aligned}
$$
「このパラメータを少し変えたとき、誤差はどう変化するか」を微分で知ることができますよね。逆伝播は「誤差から逆向きに微分を計算していく」方法ですので、この仕組みのおかげで、どんなに複雑なニューラルネットワークでも、一つ一つの演算の勾配さえ計算できれば(=微分できれば)、全体の勾配を効率的に求めることができます。
ちょっとくどくなりましたが、計算グラフは、この逆伝播の過程をシンプルな計算の組み合わせで表したものといえますね。
さて、ではTorchScriptの計算グラフに戻りましょう。この計算グラフでは、データの流れが明確に可視化されています。
まずは各層の取得方法です。
prim::GetAttr
という操作によって、モデルから畳み込み層(conv1)、ReLU層(relu)、プーリング層(pool)、全結合層(fc)が順番に取得されています。そして各層の実行はprim::CallMethod
という操作で行われます。
テンソルのリサイズ操作です。
aten::size
でバッチサイズを取得し、prim::ListConstruct
で新しい形状を作り、最後にaten::view
でテンソルの形状を変更しています。これはPythonコードのx.view(x.size(0), -1)
に対応する操作です。
最適化されたコード
今まで見てきたように、TorchScriptは元のコードを「最適化」し、より効率的な形式に変換します。
その証拠に最適化されたコードを見ると、元のPythonコードがより直接的な形式に変換されていることがわかります。
各層はローカル変数として取得され、それぞれforward関数が呼び出されています。
「最適化」されたコードですので最短距離を目指してますから、Pythonのオーバーヘッドが排除され、より効率的な実行が可能になっています。
例えば、テンソルの形状変更操作もtorch.view
として直接実行されるように変換されていますね。
使用される演算子
最後に、このモデルで使用されている演算子を見てみましょう。
基本的な定数生成(prim::Constant)、属性取得(prim::GetAttr)、メソッド呼び出し(prim::CallMethod)に加えて、テンソル操作(aten::size、aten::view)が使用されていますね。これらはいわゆる低レベルな操作を表しており、モデルの実行がより低レベルな操作ではどのように行われるかを確認することができました
4 TorchScriptでよくあるエラーと対処方法
TorchScriptよさげなのですが、どんなコードでも適用できるというわけではありません。
無邪気につくるとTorchScript化できません。
4.1 可変長引数とキーワード引数に関するエラー
無邪気につくっていると、以下のようなエラーメッセージがでます
Compiled functions can't take variable number of arguments or use keyword-only arguments with defaults
たとえば、以下みたいなコードをかくともれなく発生します。
# エラーが発生するコード例
class ProblematicModel(torch.nn.Module):
def forward(self, x, return_latents=False, **kwargs): # ←ここがエラー
# 処理内容
pass
原因
TorchScriptには以下のような制約があります
- 可変長引数(
*args
,**kwargs
)は使用できない - デフォルト値を持つキーワード専用引数は使用できない
いくつかのモデルを切り替えるときに、他のモデルでは不要だけど、このモデルではこのパラメータ入れたいよね、というときに forward の **kwargs
にひとまずぶっこんでおくということをやっていまいがちですが、これはTorchScriptとは相性がよくないので、最適化をしたいなら、複数モデルの若干パラメータが異なるモデルを安易に切り替えるようなコードは書かない方が無難です
対処方法
対象法ですが、例えば以下のような対策があります
- すべての引数を明示的に定義する
# 修正例1:必要な引数を全て明示的に指定
class FixedModel(torch.nn.Module):
def forward(self, x, return_latents=False, return_rgb=True, randomize_noise=True):
# 処理内容
pass
- 設定用のクラスやデータクラスを使用する
# 修正例2:設定をまとめたクラスを使用
@dataclass
class ModelConfig:
return_latents: bool = False
return_rgb: bool = True
randomize_noise: bool = True
class BetterModel(torch.nn.Module):
def forward(self, x, config: ModelConfig):
# configから設定を取得して処理
pass
4.2 実装のベストプラクティス
前節、前々節をふまえ、TorchScript用のモデルを書くときは、以下の点に注意しましょう
-
引数は明示的に定義する
- 可変長引数(
*args
,**kwargs
)は避ける - デフォルト値が必要な場合は通常の引数として定義
- 可変長引数(
-
型ヒントを活用する
from typing import Tuple, Optional
class TypedModel(torch.nn.Module):
def forward(self, x: torch.Tensor, return_latents: bool = False) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
# 処理内容
pass
- 複雑な設定が必要な場合は設定クラスを使用
from dataclasses import dataclass
@dataclass
class GeneratorConfig:
return_latents: bool = False
return_rgb: bool = True
randomize_noise: bool = True
noise_scale: float = 1.0
class Generator(torch.nn.Module):
def forward(self, x: torch.Tensor, config: GeneratorConfig) -> torch.Tensor:
# configを使用して処理を行う
output = self.process(x, noise_scale=config.noise_scale)
if config.return_latents:
# 潜在変数を返す処理
pass
return output
これらの対処方法を適用することで、TorchScriptへの変換がスムーズに行えるようになり、本番直前になって、「あれ、TorchScriptにできねー」という悲劇を避けることができます。
今回は、PyTorchの最適化、とくにPythonコードから「TorchScript」への変換方法や、その背後で行われている処理、計算グラフのおさらい、TorchScriptアンチパターンについて解説いたしました。
最後までお読みいただきありがとうございました。
次回もお楽しみに!
付録:今回使用したコード全体
import torch
class SimpleModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv1 = torch.nn.Conv2d(3, 64, 3)
self.relu = torch.nn.ReLU()
self.pool = torch.nn.MaxPool2d(2)
# 入力サイズから全結合層の入力次元を計算
# 入力サイズが32x32の場合の計算過程:
# 1. Conv2d(kernel=3, stride=1, padding=0): 32x32 → 30x30
# 2. MaxPool2d(kernel=2, stride=2): 30x30 → 15x15
input_features = 64 * 15 * 15 # 64はconv1の出力チャンネル数
self.fc = torch.nn.Linear(input_features, 10)
def forward(self, x):
x = self.conv1(x) # [B, 64, 30, 30]
x = self.relu(x)
x = self.pool(x) # [B, 64, 15, 15]
x = x.view(x.size(0), -1) # [B, 64*15*15]
return self.fc(x)
def convert_to_torchscript(model, method='trace'):
model.eval()
if method == 'trace':
example_input = torch.randn(1, 3, 32, 32)
traced_model = torch.jit.trace(model, example_input)
return traced_model
else:
return torch.jit.script(model)
def inspect_torchscript(model):
# モデルのグラフ構造を表示
print("Graph Structure:")
print(model.graph)
# 最適化されたコードを表示
print("\nOptimized Code:")
print(model.code)
# 使用されている演算子を表示
print("\nOperators:")
for node in model.graph.nodes():
print(f"- {node.kind()}")
# 実行
if __name__ == "__main__":
model = SimpleModel()
production_model = convert_to_torchscript(model, method='script')
input_data = torch.randn(1, 3, 32, 32)
result = production_model(input_data)
print(f"Input shape: {input_data.shape}")
print(f"Output shape: {result.shape}")
inspect_torchscript(production_model)