PyTorchモデルの最適化~TorchScriptの仕組みと活用法~

PyTorchモデルの最適化~TorchScriptの仕組みと活用法~
Photo by Nicolas Hoizey / Unsplash

こんにちは!

本日は PyTorch で開発したAIアプリケーションの本番化に欠かせない、「最適化」についての内容です。具体的には「 TorchScript」 を使用した各種学習モデルの最適化についてみていきたいとおもいます。

TorchScriptの基礎

1 TorchScriptとは

TorchScriptは、PyTorchモデルを最適化された中間表現(IR)に変換する技術です。

、、といってもちょっと難しく聞こえるかもしれません。

平易な言葉で言い換えますと、

要するに、PyTorchで作った機械学習モデルを高速かつ多種多様な環境で動作させることをするための技術です。

例えば、、

・Pythonがインストールされていない環境でも動かせるようにする
・スマホはじめ、各種組み込み機器でも使えるようにする
・動かすときの速度を段違いに上げる
・複数の処理を同時に効率よく実行する

などを目論むときは TorchScript がおすすめです。

つまり、TorchScriptは「本番サービス」で使うときにすごく役立ちます。

2 TorchScriptの動作の仕組み

それでは、PyTorchのモデルをTorchScriptに変換する基本的な仕組みを説明します。

2.1 モデルの定義

まず、シンプルな画像認識用のCNNモデルを作りましょう

import torch


class SimpleModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = torch.nn.Conv2d(3, 64, 3)
        self.relu = torch.nn.ReLU()
        self.pool = torch.nn.MaxPool2d(2)

        # 入力サイズから全結合層の入力次元を計算
        # 入力サイズが32x32の場合の計算過程:
        # 1. Conv2d(kernel=3, stride=1, padding=0): 32x32 → 30x30
        # 2. MaxPool2d(kernel=2, stride=2): 30x30 → 15x15
        input_features = 64 * 15 * 15  # 64はconv1の出力チャンネル数
        self.fc = torch.nn.Linear(input_features, 10)

    def forward(self, x):
        x = self.conv1(x)  # [B, 64, 30, 30]
        x = self.relu(x)
        x = self.pool(x)  # [B, 64, 15, 15]
        x = x.view(x.size(0), -1)  # [B, 64*15*15]
        return self.fc(x)

まずは上記モデルの説明をいたします。

SimpleModelの構造解説

はい、このコードは画像認識のための基本的なCNNモデルですね。

モデルは32×32ピクセルのRGB画像(3チャンネル)を入力として受け取り、3チャンネルの入力を64チャンネルに拡張し、3×3のカーネルで特徴を検出させます。

畳み込み層の後には活性化関数としてReLUを配置、続くプーリング層は、2×2の領域から最大値を取り出すことで特徴マップのサイズを半分に縮小しています。

最後に全結合層では、それまでの層で抽出された特徴を受け取り、10個のクラスに分類されるようにしています。で、この層に入力する前に、多次元の特徴マップを1次元フラット化します。

2.2. TorchScriptへの変換

さて、上述のCNNをTorchScript変換するまえに、変換方法についてみてみましょう。

TorchScriptの変換には2つの方法があります、それは「トレース方式」と「スクリプト方式」です。

2つの方法といっても、両者、実行方法はとてもシンプルで以下のコードのようにtorchscript変換コードを記述することができます

def convert_to_torchscript(model, method='trace'):
    model.eval()  # 評価モードに設定
    
    if method == 'trace':
        # トレース方式
        example_input = torch.randn(1, 3, 32, 32)
        traced_model = torch.jit.trace(model, example_input)
        return traced_model
    else:
        # スクリプト方式
        return torch.jit.script(model)

2つの方式には以下のような特徴があります

トレース方式
  • ダミーデータを使って実際の処理を追跡
  • 固定的な入力パターンに適している
  • 処理の流れが単純な場合に使用
スクリプト方式
  • コードを直接解析して変換
  • 柔軟な処理が可能
  • 条件分岐やループを含む複雑なモデルに適している

2.3. 使用例

さて、さきほどつくったCNNモデルをTorchScriptとして「本番」向けに変換してみましょう。

method='trace' とすることで、「トレース方式」で変換します。

# モデルの作成
model = SimpleModel()

# 本番環境用に変換
production_model = convert_to_torchscript(model, method='trace')

# 変換したモデルで推論
input_data = torch.randn(1, 3, 32, 32)
result = production_model(input_data)

convert_to_torchscript では example_input = torch.randn(1, 3, 32, 32) のようにダミーデータを流し込んでいますね。
ダミーデータ(ダミーじゃなくて本番データでも別にかまいません)を流し込むことでネットワークを「トレース」します。

いかがでしょうか。
とても簡単に「本番用」に変換できました。

基本的には実にシンプルな手順で研究開発で作成したPyTorchモデルを、本番環境で効率的に動作する形式に変換することができまます。

これだけで速度の向上や、Pythonに依存しない実行環境での利用が可能な形式を実現できます。

3 TorchScriptの内部表現

TorchScriptで変換したモデルの中身を覗いてみましょう。

モデルの内部で何が起きているのか、どんな最適化が行われているのか、実際に確認する方法をお伝えします。

3.1 内部表現の確認方法

まずは、内部表現を確認するための関数を作ってみましょう:

def inspect_torchscript(model):
    # モデルのグラフ構造を表示
    print("Graph Structure:")
    print(model.graph)
    
    # 最適化されたコードを表示
    print("\nOptimized Code:")
    print(model.code)
    
    # 使用されている演算子を表示
    print("\nOperators:")
    for node in model.graph.nodes():
        print(f"- {node.kind()}")

3.2 内部表現の理解

この関数を使うと、まず最初にモデルの計算フローがグラフ構造として表示されます。これは私たちの書いたPyTorchのコードが、実際にどのような計算の流れに変換されているのかを教えてくれます。データがモデルの中でどのように変換され、どの層がどんな順序で処理を行うのか、その全体像を把握することができます。

次に表示される最適化済みコードは、実際に実行される形に変換された後のコードです。TorchScriptによる最適化で、私たちが書いた元のコードがどのように変更されたのか、その違いを見ることができます。この情報は特にデバッグを行う際にとても役立ちます。

最後に表示される演算子のリストは、モデル内で使用されているすべての演算の種類を教えてくれます。これを見ることで、処理の重い演算子が含まれていないかチェックしたり、最適化の余地がある部分を見つけたりすることができます。

3.3 実践的な使い方

では、実際に使ってみましょうか。

まずモデルを作成してTorchScriptに変換し、その内部を覗いてみます

    model = SimpleModel()
    production_model = convert_to_torchscript(model, method='script')
    input_data = torch.randn(1, 3, 32, 32)
    result = production_model(input_data)
    print(f"Input shape: {input_data.shape}")
    print(f"Output shape: {result.shape}")

    inspect_torchscript(production_model)

3.4 TorchScriptと計算グラフ

さて、さきほどの実行結果は以下のようになりました。


Input shape: torch.Size([1, 3, 32, 32])
Output shape: torch.Size([1, 10])

Graph Structure:
graph(%self : __torch__.SimpleModel,
      %x.1 : Tensor):
  %17 : int = prim::Constant[value=-1]() # ex/ts1.py:21:30
  %13 : int = prim::Constant[value=0]() # ex/ts1.py:21:26
  %conv1 : __torch__.torch.nn.modules.conv.Conv2d = prim::GetAttr[name="conv1"](%self)
  %x.5 : Tensor = prim::CallMethod[name="forward"](%conv1, %x.1) # ex/ts1.py:18:12
  %relu : __torch__.torch.nn.modules.activation.ReLU = prim::GetAttr[name="relu"](%self)
  %x.9 : Tensor = prim::CallMethod[name="forward"](%relu, %x.5) # ex/ts1.py:19:12
  %pool : __torch__.torch.nn.modules.pooling.MaxPool2d = prim::GetAttr[name="pool"](%self)
  %x.13 : Tensor = prim::CallMethod[name="forward"](%pool, %x.9) # ex/ts1.py:20:12
  %14 : int = aten::size(%x.13, %13) # ex/ts1.py:21:19
  %18 : int[] = prim::ListConstruct(%14, %17)
  %x.19 : Tensor = aten::view(%x.13, %18) # ex/ts1.py:21:12
  %fc : __torch__.torch.nn.modules.linear.Linear = prim::GetAttr[name="fc"](%self)
  %22 : Tensor = prim::CallMethod[name="forward"](%fc, %x.19) # ex/ts1.py:22:15
  return (%22)


Optimized Code:
def forward(self,
    x: Tensor) -> Tensor:
  conv1 = self.conv1
  x0 = (conv1).forward(x, )
  relu = self.relu
  x1 = (relu).forward(x0, )
  pool = self.pool
  x2 = (pool).forward(x1, )
  x3 = torch.view(x2, [torch.size(x2, 0), -1])
  fc = self.fc
  return (fc).forward(x3, )


Operators:
- prim::Constant
- prim::Constant
- prim::GetAttr
- prim::CallMethod
- prim::GetAttr
- prim::CallMethod
- prim::GetAttr
- prim::CallMethod
- aten::size
- prim::ListConstruct
- aten::view
- prim::GetAttr
- prim::CallMethod

上記のようにSimpleModelをTorchScriptに変換し、その内部構造を表示させてみました。

まずは、CNNのコードが正しく実装されている証拠として、入力として32×32ピクセルのRGB画像(shape: [1, 3, 32, 32])を与えると、10クラスの予測値(shape: [1, 10])が出力されることが確認できました。

計算グラフを思い出して、TorchScriptのありがたみをかみしめる

TorchScriptは私たちが書いたPythonコードを「計算グラフ」という形式に変換します。

計算グラフ覚えてますか?

PyTorchのAutoGradなどに慣れてしまうと、計算グラフをイチイチ意識しなくなるのですが、
少し思い出しましょう。

計算グラフはモデル内での演算の流れを有向グラフとして表現します。各ノードは演算(畳み込みやReLUなどなど)を表し、エッジはデータの流れを表します。この構造により、順伝播(forward)での計算過程を正確にトレースします。

モデルの学習時は、まず順伝播で損失値を計算しますよね。

次に、この損失値を基に各パラメータの勾配を計算する必要があります。

各パラメータが最終的な損失にどれだけ影響を与えているのかを知る必要があるからでしたね。

損失から見て、直前の層のパラメータの影響は計算しやすいのですが、より手前の層になればなるほど、間に複数の計算が入るため直接的な影響を計算するのは難しくなります。

そこで、損失から順番に逆向きに計算していくことで、「連鎖律」を使って効率的に各パラメータの勾配を求めることができます。

これが逆伝播(backward)でしたね。

うっすら思い出してきたところで、さらに理解を深めるために、具体例で「計算グラフ」と「連鎖率」をみてみましょう。

計算グラフも連鎖率も余裕で覚えてるっていう方、TorchScriptの結果だけみたい場合は以下読み飛ばしてもらって構いません。

計算グラフの具体例

入力$x$に対して、まず$a$をかけて、次に$b$を足し、最後に2乗するという計算を考えてみましょう。これを数式で書くと:

$$
\begin{aligned}
y &= (ax + b)^2
\end{aligned}
$$

この計算は、以下のような小さなステップに分解できます:

$$
\begin{aligned}
p &= ax \
q &= p + b \
y &= q^2
\end{aligned}
$$

連鎖律と勾配の計算

この勾配を以下のように分解できるのが「連鎖律」ですね

$$
\begin{aligned}
\frac{dy}{dx} &= \frac{dy}{dq} \times \frac{dq}{dp} \times \frac{dp}{dx} \
&= 2q \times 1 \times a \
&= 2a(ax + b)
\end{aligned}
$$

はい、この各ステップでの勾配を見てみましょう

$$
\begin{aligned}
\frac{dy}{dq} &= 2q & \text{(2乗の微分)} \
\frac{dq}{dp} &= 1 & \text{(足し算の微分)} \
\frac{dp}{dx} &= a & \text{(掛け算の微分)}
\end{aligned}
$$

「このパラメータを少し変えたとき、誤差はどう変化するか」を微分で知ることができますよね。逆伝播は「誤差から逆向きに微分を計算していく」方法ですので、この仕組みのおかげで、どんなに複雑なニューラルネットワークでも、一つ一つの演算の勾配さえ計算できれば(=微分できれば)、全体の勾配を効率的に求めることができます。

ちょっとくどくなりましたが、計算グラフは、この逆伝播の過程をシンプルな計算の組み合わせで表したものといえますね。

さて、ではTorchScriptの計算グラフに戻りましょう。この計算グラフでは、データの流れが明確に可視化されています。

まずは各層の取得方法です。

prim::GetAttrという操作によって、モデルから畳み込み層(conv1)、ReLU層(relu)、プーリング層(pool)、全結合層(fc)が順番に取得されています。そして各層の実行はprim::CallMethodという操作で行われます。

テンソルのリサイズ操作です。

aten::sizeでバッチサイズを取得し、prim::ListConstructで新しい形状を作り、最後にaten::viewでテンソルの形状を変更しています。これはPythonコードのx.view(x.size(0), -1)に対応する操作です。

最適化されたコード

今まで見てきたように、TorchScriptは元のコードを「最適化」し、より効率的な形式に変換します。

その証拠に最適化されたコードを見ると、元のPythonコードがより直接的な形式に変換されていることがわかります。

各層はローカル変数として取得され、それぞれforward関数が呼び出されています。

「最適化」されたコードですので最短距離を目指してますから、Pythonのオーバーヘッドが排除され、より効率的な実行が可能になっています。

例えば、テンソルの形状変更操作もtorch.viewとして直接実行されるように変換されていますね。

使用される演算子

最後に、このモデルで使用されている演算子を見てみましょう。

基本的な定数生成(prim::Constant)、属性取得(prim::GetAttr)、メソッド呼び出し(prim::CallMethod)に加えて、テンソル操作(aten::size、aten::view)が使用されていますね。これらはいわゆる低レベルな操作を表しており、モデルの実行がより低レベルな操作ではどのように行われるかを確認することができました

4 TorchScriptでよくあるエラーと対処方法

TorchScriptよさげなのですが、どんなコードでも適用できるというわけではありません。

無邪気につくるとTorchScript化できません。

4.1 可変長引数とキーワード引数に関するエラー

無邪気につくっていると、以下のようなエラーメッセージがでます

Compiled functions can't take variable number of arguments or use keyword-only arguments with defaults

たとえば、以下みたいなコードをかくともれなく発生します。

# エラーが発生するコード例
class ProblematicModel(torch.nn.Module):
    def forward(self, x, return_latents=False, **kwargs):  # ←ここがエラー
        # 処理内容
        pass
原因

TorchScriptには以下のような制約があります

  1. 可変長引数(*args, **kwargs)は使用できない
  2. デフォルト値を持つキーワード専用引数は使用できない

いくつかのモデルを切り替えるときに、他のモデルでは不要だけど、このモデルではこのパラメータ入れたいよね、というときに forward の **kwargs にひとまずぶっこんでおくということをやっていまいがちですが、これはTorchScriptとは相性がよくないので、最適化をしたいなら、複数モデルの若干パラメータが異なるモデルを安易に切り替えるようなコードは書かない方が無難です

対処方法

対象法ですが、例えば以下のような対策があります

  1. すべての引数を明示的に定義する
# 修正例1:必要な引数を全て明示的に指定
class FixedModel(torch.nn.Module):
    def forward(self, x, return_latents=False, return_rgb=True, randomize_noise=True):
        # 処理内容
        pass
  1. 設定用のクラスやデータクラスを使用する
# 修正例2:設定をまとめたクラスを使用
@dataclass
class ModelConfig:
    return_latents: bool = False
    return_rgb: bool = True
    randomize_noise: bool = True

class BetterModel(torch.nn.Module):
    def forward(self, x, config: ModelConfig):
        # configから設定を取得して処理
        pass

4.2 実装のベストプラクティス

前節、前々節をふまえ、TorchScript用のモデルを書くときは、以下の点に注意しましょう

  1. 引数は明示的に定義する

    • 可変長引数(*args, **kwargs)は避ける
    • デフォルト値が必要な場合は通常の引数として定義
  2. 型ヒントを活用する

from typing import Tuple, Optional

class TypedModel(torch.nn.Module):
    def forward(self, x: torch.Tensor, return_latents: bool = False) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
        # 処理内容
        pass
  1. 複雑な設定が必要な場合は設定クラスを使用
from dataclasses import dataclass

@dataclass
class GeneratorConfig:
    return_latents: bool = False
    return_rgb: bool = True
    randomize_noise: bool = True
    noise_scale: float = 1.0

class Generator(torch.nn.Module):
    def forward(self, x: torch.Tensor, config: GeneratorConfig) -> torch.Tensor:
        # configを使用して処理を行う
        output = self.process(x, noise_scale=config.noise_scale)
        if config.return_latents:
            # 潜在変数を返す処理
            pass
        return output

これらの対処方法を適用することで、TorchScriptへの変換がスムーズに行えるようになり、本番直前になって、「あれ、TorchScriptにできねー」という悲劇を避けることができます。

今回は、PyTorchの最適化、とくにPythonコードから「TorchScript」への変換方法や、その背後で行われている処理、計算グラフのおさらい、TorchScriptアンチパターンについて解説いたしました。

最後までお読みいただきありがとうございました。

次回もお楽しみに!

付録:今回使用したコード全体

import torch

class SimpleModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = torch.nn.Conv2d(3, 64, 3)
        self.relu = torch.nn.ReLU()
        self.pool = torch.nn.MaxPool2d(2)

        # 入力サイズから全結合層の入力次元を計算
        # 入力サイズが32x32の場合の計算過程:
        # 1. Conv2d(kernel=3, stride=1, padding=0): 32x32 → 30x30
        # 2. MaxPool2d(kernel=2, stride=2): 30x30 → 15x15
        input_features = 64 * 15 * 15  # 64はconv1の出力チャンネル数
        self.fc = torch.nn.Linear(input_features, 10)

    def forward(self, x):
        x = self.conv1(x)  # [B, 64, 30, 30]
        x = self.relu(x)
        x = self.pool(x)  # [B, 64, 15, 15]
        x = x.view(x.size(0), -1)  # [B, 64*15*15]
        return self.fc(x)

def convert_to_torchscript(model, method='trace'):
    model.eval()

    if method == 'trace':
        example_input = torch.randn(1, 3, 32, 32)
        traced_model = torch.jit.trace(model, example_input)
        return traced_model
    else:
        return torch.jit.script(model)


def inspect_torchscript(model):
    # モデルのグラフ構造を表示
    print("Graph Structure:")
    print(model.graph)

    # 最適化されたコードを表示
    print("\nOptimized Code:")
    print(model.code)

    # 使用されている演算子を表示
    print("\nOperators:")
    for node in model.graph.nodes():
        print(f"- {node.kind()}")

# 実行
if __name__ == "__main__":
    model = SimpleModel()
    production_model = convert_to_torchscript(model, method='script')
    input_data = torch.randn(1, 3, 32, 32)
    result = production_model(input_data)
    print(f"Input shape: {input_data.shape}")
    print(f"Output shape: {result.shape}")

    inspect_torchscript(production_model)

Read more

産業交流展2024 に出展いたしました

産業交流展2024 に出展いたしました

こんにちは! 2024年11月21日~11月23日の3日間 東京ビックサイトにて開催された産業交流展2024(リアル展)において、当社のプロダクト・サービスの展示を行いました。 多くの方々に当社ブースへお立ち寄りいただき、誠にありがとうございました! (産業交流展2024のオンライン展示会は 2024年11月29日まで開催中です!) 本ブログでは、展示会当日の様子を簡単にレポートさせていただきます。 展示会の様子 当社ブースは「東京ビジネスフロンティア」パビリオン内に設けていただきました。 当社からは3名体制で、 エンタープライズLLMソリューション「Bestllam 」やLLMセキュリティソリューション「 LLM-Audit」 、経産省認定講座「AI・DX研修」についてデモンストレーションおよびご説明・ご案内をさせていただきました。 さらに、ステラリンク社さまのご厚意により、このかわいい移動式サイネージ「AdRobot」に、当社ブースの宣伝もしていただきました! 特典カード さて、ブースにお立ち寄りの際にお渡しした、Bestllam特典カードの招待コー

By Qualiteg ビジネス開発本部 | マーケティング部
「Windowsターミナル」を Windows Server 2022 Datacenter エディションに手軽にインストールする方法

「Windowsターミナル」を Windows Server 2022 Datacenter エディションに手軽にインストールする方法

こんにちは! 本稿はWindows Server 2022 Datacenterエディションに「Windowsターミナル」をインストールする方法のメモです。 ステップバイステップでやるのは少し手間だったので、Powershellにペタっとするだけで自動的にインストールできるよう手順をスクリプト化しました。 管理者権限で開いた Powershell に以下、スクリプトをペタっとすると、後は勝手に「Windowsターミナル」がインストールされます。 (ただしスクリプトの実行結果の保証も責任も負いかねます) なにが手間か 何が手間かというと、Windows Server 2022 では、StoreもApp Installer(winget)もデフォルトではインストールされていないため「Windowsターミナル」をマニュアルでインストールしなければなりませんでした。 そこでペタっとするだけのスクリプト化 管理者権限で開いたPowershellに以下のスクリプトをペタっとすると「Windowsターミナル」が無事インストールされます。 パッケージのダウンロード先には [ユーザ

By Qualiteg プロダクト開発部
産業交流展2024に出展いたします

産業交流展2024に出展いたします

平素は当社事業に格別のご高配を賜り、厚く御礼申し上げます。 以前にもご案内させていただきましたが、この度、株式会社Qualitegは、多くの優れた企業が一堂に会する国内最大級の総合展示会「産業交流展2024」に出展する運びとなりました。 本展示会では、当社の最新のサービス・ソリューションを展示させていただきます。ご来場の皆様に直接ご説明させていただく貴重な機会として、ぜひブースまでお立ち寄りくださいませ 展示会概要 * 名称: 産業交流展2024 * 会期: 2024年11月20日(水)~22日(金) * 会場: 東京ビッグサイト 1・2ホール、アトリウム * 西1ホール 東京ビジネスフロンティアゾーン ビ-15 * 入場料: 無料(事前登録制) 開催時間 * 11月20日(水) 10:00~17:00 * 11月21日(木) 10:00~17:00 * 11月22日(金) 10:00~16:00

By Qualiteg ニュース
Qualitegオリジナル:サービス設計のまとめ方

Qualitegオリジナル:サービス設計のまとめ方

Qualiteg blogを訪問してくださった皆様、こんにちは。Micheleです。AIを活用した新規事業やマーケティングを手がけている私には、クライアントからよく寄せられる質問があります。AIを用いた事業展開を検討されている方々が共通して直面するであろう課題に対して、このブログを通じて私なりの解答をご提供したいと思います。 はじめに スタートアップにおいて、サービス設計は成功を左右する重要な要素です。私たちは新規事業開発コンサルタントとして、長年多くの新規事業の立ち上げに関わってきました。 そして今、自社で新規事業の立ち上げを実施中です。本記事では、効果的なサービス設計のアプローチについて、実践的な観点からお伝えしたいと思います。 1. ユーザー中心の問題定義 サービス設計の第一歩は、解決すべき問題を明確に定義することです。しかし、ここでよくある失敗は、自社の技術やアイデアから出発してしまうことです。代わりに、以下のステップを踏むことをお勧めします: * ターゲットユーザーへの徹底的なインタビュー * 既存の解決策の分析と不足点の特定 * ユーザーの行動パターン

By Join us, Michele on Qualiteg's adventure to innovation