PyTorchモデルの最適化~TorchScriptの仕組みと活用法~

PyTorchモデルの最適化~TorchScriptの仕組みと活用法~
Photo by Nicolas Hoizey / Unsplash

こんにちは!

本日は PyTorch で開発したAIアプリケーションの本番化に欠かせない、「最適化」についての内容です。具体的には「 TorchScript」 を使用した各種学習モデルの最適化についてみていきたいとおもいます。

TorchScriptの基礎

1 TorchScriptとは

TorchScriptは、PyTorchモデルを最適化された中間表現(IR)に変換する技術です。

、、といってもちょっと難しく聞こえるかもしれません。

平易な言葉で言い換えますと、

要するに、PyTorchで作った機械学習モデルを高速かつ多種多様な環境で動作させることをするための技術です。

例えば、、

・Pythonがインストールされていない環境でも動かせるようにする
・スマホはじめ、各種組み込み機器でも使えるようにする
・動かすときの速度を段違いに上げる
・複数の処理を同時に効率よく実行する

などを目論むときは TorchScript がおすすめです。

つまり、TorchScriptは「本番サービス」で使うときにすごく役立ちます。

2 TorchScriptの動作の仕組み

それでは、PyTorchのモデルをTorchScriptに変換する基本的な仕組みを説明します。

2.1 モデルの定義

まず、シンプルな画像認識用のCNNモデルを作りましょう

import torch


class SimpleModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = torch.nn.Conv2d(3, 64, 3)
        self.relu = torch.nn.ReLU()
        self.pool = torch.nn.MaxPool2d(2)

        # 入力サイズから全結合層の入力次元を計算
        # 入力サイズが32x32の場合の計算過程:
        # 1. Conv2d(kernel=3, stride=1, padding=0): 32x32 → 30x30
        # 2. MaxPool2d(kernel=2, stride=2): 30x30 → 15x15
        input_features = 64 * 15 * 15  # 64はconv1の出力チャンネル数
        self.fc = torch.nn.Linear(input_features, 10)

    def forward(self, x):
        x = self.conv1(x)  # [B, 64, 30, 30]
        x = self.relu(x)
        x = self.pool(x)  # [B, 64, 15, 15]
        x = x.view(x.size(0), -1)  # [B, 64*15*15]
        return self.fc(x)

まずは上記モデルの説明をいたします。

SimpleModelの構造解説

はい、このコードは画像認識のための基本的なCNNモデルですね。

モデルは32×32ピクセルのRGB画像(3チャンネル)を入力として受け取り、3チャンネルの入力を64チャンネルに拡張し、3×3のカーネルで特徴を検出させます。

畳み込み層の後には活性化関数としてReLUを配置、続くプーリング層は、2×2の領域から最大値を取り出すことで特徴マップのサイズを半分に縮小しています。

最後に全結合層では、それまでの層で抽出された特徴を受け取り、10個のクラスに分類されるようにしています。で、この層に入力する前に、多次元の特徴マップを1次元フラット化します。

2.2. TorchScriptへの変換

さて、上述のCNNをTorchScript変換するまえに、変換方法についてみてみましょう。

TorchScriptの変換には2つの方法があります、それは「トレース方式」と「スクリプト方式」です。

2つの方法といっても、両者、実行方法はとてもシンプルで以下のコードのようにtorchscript変換コードを記述することができます

def convert_to_torchscript(model, method='trace'):
    model.eval()  # 評価モードに設定
    
    if method == 'trace':
        # トレース方式
        example_input = torch.randn(1, 3, 32, 32)
        traced_model = torch.jit.trace(model, example_input)
        return traced_model
    else:
        # スクリプト方式
        return torch.jit.script(model)

2つの方式には以下のような特徴があります

トレース方式
  • ダミーデータを使って実際の処理を追跡
  • 固定的な入力パターンに適している
  • 処理の流れが単純な場合に使用
スクリプト方式
  • コードを直接解析して変換
  • 柔軟な処理が可能
  • 条件分岐やループを含む複雑なモデルに適している

2.3. 使用例

さて、さきほどつくったCNNモデルをTorchScriptとして「本番」向けに変換してみましょう。

method='trace' とすることで、「トレース方式」で変換します。

# モデルの作成
model = SimpleModel()

# 本番環境用に変換
production_model = convert_to_torchscript(model, method='trace')

# 変換したモデルで推論
input_data = torch.randn(1, 3, 32, 32)
result = production_model(input_data)

convert_to_torchscript では example_input = torch.randn(1, 3, 32, 32) のようにダミーデータを流し込んでいますね。
ダミーデータ(ダミーじゃなくて本番データでも別にかまいません)を流し込むことでネットワークを「トレース」します。

いかがでしょうか。
とても簡単に「本番用」に変換できました。

基本的には実にシンプルな手順で研究開発で作成したPyTorchモデルを、本番環境で効率的に動作する形式に変換することができまます。

これだけで速度の向上や、Pythonに依存しない実行環境での利用が可能な形式を実現できます。

3 TorchScriptの内部表現

TorchScriptで変換したモデルの中身を覗いてみましょう。

モデルの内部で何が起きているのか、どんな最適化が行われているのか、実際に確認する方法をお伝えします。

3.1 内部表現の確認方法

まずは、内部表現を確認するための関数を作ってみましょう:

def inspect_torchscript(model):
    # モデルのグラフ構造を表示
    print("Graph Structure:")
    print(model.graph)
    
    # 最適化されたコードを表示
    print("\nOptimized Code:")
    print(model.code)
    
    # 使用されている演算子を表示
    print("\nOperators:")
    for node in model.graph.nodes():
        print(f"- {node.kind()}")

3.2 内部表現の理解

この関数を使うと、まず最初にモデルの計算フローがグラフ構造として表示されます。これは私たちの書いたPyTorchのコードが、実際にどのような計算の流れに変換されているのかを教えてくれます。データがモデルの中でどのように変換され、どの層がどんな順序で処理を行うのか、その全体像を把握することができます。

次に表示される最適化済みコードは、実際に実行される形に変換された後のコードです。TorchScriptによる最適化で、私たちが書いた元のコードがどのように変更されたのか、その違いを見ることができます。この情報は特にデバッグを行う際にとても役立ちます。

最後に表示される演算子のリストは、モデル内で使用されているすべての演算の種類を教えてくれます。これを見ることで、処理の重い演算子が含まれていないかチェックしたり、最適化の余地がある部分を見つけたりすることができます。

3.3 実践的な使い方

では、実際に使ってみましょうか。

まずモデルを作成してTorchScriptに変換し、その内部を覗いてみます

    model = SimpleModel()
    production_model = convert_to_torchscript(model, method='script')
    input_data = torch.randn(1, 3, 32, 32)
    result = production_model(input_data)
    print(f"Input shape: {input_data.shape}")
    print(f"Output shape: {result.shape}")

    inspect_torchscript(production_model)

3.4 TorchScriptと計算グラフ

さて、さきほどの実行結果は以下のようになりました。


Input shape: torch.Size([1, 3, 32, 32])
Output shape: torch.Size([1, 10])

Graph Structure:
graph(%self : __torch__.SimpleModel,
      %x.1 : Tensor):
  %17 : int = prim::Constant[value=-1]() # ex/ts1.py:21:30
  %13 : int = prim::Constant[value=0]() # ex/ts1.py:21:26
  %conv1 : __torch__.torch.nn.modules.conv.Conv2d = prim::GetAttr[name="conv1"](%self)
  %x.5 : Tensor = prim::CallMethod[name="forward"](%conv1, %x.1) # ex/ts1.py:18:12
  %relu : __torch__.torch.nn.modules.activation.ReLU = prim::GetAttr[name="relu"](%self)
  %x.9 : Tensor = prim::CallMethod[name="forward"](%relu, %x.5) # ex/ts1.py:19:12
  %pool : __torch__.torch.nn.modules.pooling.MaxPool2d = prim::GetAttr[name="pool"](%self)
  %x.13 : Tensor = prim::CallMethod[name="forward"](%pool, %x.9) # ex/ts1.py:20:12
  %14 : int = aten::size(%x.13, %13) # ex/ts1.py:21:19
  %18 : int[] = prim::ListConstruct(%14, %17)
  %x.19 : Tensor = aten::view(%x.13, %18) # ex/ts1.py:21:12
  %fc : __torch__.torch.nn.modules.linear.Linear = prim::GetAttr[name="fc"](%self)
  %22 : Tensor = prim::CallMethod[name="forward"](%fc, %x.19) # ex/ts1.py:22:15
  return (%22)


Optimized Code:
def forward(self,
    x: Tensor) -> Tensor:
  conv1 = self.conv1
  x0 = (conv1).forward(x, )
  relu = self.relu
  x1 = (relu).forward(x0, )
  pool = self.pool
  x2 = (pool).forward(x1, )
  x3 = torch.view(x2, [torch.size(x2, 0), -1])
  fc = self.fc
  return (fc).forward(x3, )


Operators:
- prim::Constant
- prim::Constant
- prim::GetAttr
- prim::CallMethod
- prim::GetAttr
- prim::CallMethod
- prim::GetAttr
- prim::CallMethod
- aten::size
- prim::ListConstruct
- aten::view
- prim::GetAttr
- prim::CallMethod

上記のようにSimpleModelをTorchScriptに変換し、その内部構造を表示させてみました。

まずは、CNNのコードが正しく実装されている証拠として、入力として32×32ピクセルのRGB画像(shape: [1, 3, 32, 32])を与えると、10クラスの予測値(shape: [1, 10])が出力されることが確認できました。

計算グラフを思い出して、TorchScriptのありがたみをかみしめる

TorchScriptは私たちが書いたPythonコードを「計算グラフ」という形式に変換します。

計算グラフ覚えてますか?

PyTorchのAutoGradなどに慣れてしまうと、計算グラフをイチイチ意識しなくなるのですが、
少し思い出しましょう。

計算グラフはモデル内での演算の流れを有向グラフとして表現します。各ノードは演算(畳み込みやReLUなどなど)を表し、エッジはデータの流れを表します。この構造により、順伝播(forward)での計算過程を正確にトレースします。

モデルの学習時は、まず順伝播で損失値を計算しますよね。

次に、この損失値を基に各パラメータの勾配を計算する必要があります。

各パラメータが最終的な損失にどれだけ影響を与えているのかを知る必要があるからでしたね。

損失から見て、直前の層のパラメータの影響は計算しやすいのですが、より手前の層になればなるほど、間に複数の計算が入るため直接的な影響を計算するのは難しくなります。

そこで、損失から順番に逆向きに計算していくことで、「連鎖律」を使って効率的に各パラメータの勾配を求めることができます。

これが逆伝播(backward)でしたね。

うっすら思い出してきたところで、さらに理解を深めるために、具体例で「計算グラフ」と「連鎖率」をみてみましょう。

計算グラフも連鎖率も余裕で覚えてるっていう方、TorchScriptの結果だけみたい場合は以下読み飛ばしてもらって構いません。

計算グラフの具体例

入力$x$に対して、まず$a$をかけて、次に$b$を足し、最後に2乗するという計算を考えてみましょう。これを数式で書くと:

$$
\begin{aligned}
y &= (ax + b)^2
\end{aligned}
$$

この計算は、以下のような小さなステップに分解できます:

$$
\begin{aligned}
p &= ax \
q &= p + b \
y &= q^2
\end{aligned}
$$

連鎖律と勾配の計算

この勾配を以下のように分解できるのが「連鎖律」ですね

$$
\begin{aligned}
\frac{dy}{dx} &= \frac{dy}{dq} \times \frac{dq}{dp} \times \frac{dp}{dx} \
&= 2q \times 1 \times a \
&= 2a(ax + b)
\end{aligned}
$$

はい、この各ステップでの勾配を見てみましょう

$$
\begin{aligned}
\frac{dy}{dq} &= 2q & \text{(2乗の微分)} \
\frac{dq}{dp} &= 1 & \text{(足し算の微分)} \
\frac{dp}{dx} &= a & \text{(掛け算の微分)}
\end{aligned}
$$

「このパラメータを少し変えたとき、誤差はどう変化するか」を微分で知ることができますよね。逆伝播は「誤差から逆向きに微分を計算していく」方法ですので、この仕組みのおかげで、どんなに複雑なニューラルネットワークでも、一つ一つの演算の勾配さえ計算できれば(=微分できれば)、全体の勾配を効率的に求めることができます。

ちょっとくどくなりましたが、計算グラフは、この逆伝播の過程をシンプルな計算の組み合わせで表したものといえますね。

さて、ではTorchScriptの計算グラフに戻りましょう。この計算グラフでは、データの流れが明確に可視化されています。

まずは各層の取得方法です。

prim::GetAttrという操作によって、モデルから畳み込み層(conv1)、ReLU層(relu)、プーリング層(pool)、全結合層(fc)が順番に取得されています。そして各層の実行はprim::CallMethodという操作で行われます。

テンソルのリサイズ操作です。

aten::sizeでバッチサイズを取得し、prim::ListConstructで新しい形状を作り、最後にaten::viewでテンソルの形状を変更しています。これはPythonコードのx.view(x.size(0), -1)に対応する操作です。

最適化されたコード

今まで見てきたように、TorchScriptは元のコードを「最適化」し、より効率的な形式に変換します。

その証拠に最適化されたコードを見ると、元のPythonコードがより直接的な形式に変換されていることがわかります。

各層はローカル変数として取得され、それぞれforward関数が呼び出されています。

「最適化」されたコードですので最短距離を目指してますから、Pythonのオーバーヘッドが排除され、より効率的な実行が可能になっています。

例えば、テンソルの形状変更操作もtorch.viewとして直接実行されるように変換されていますね。

使用される演算子

最後に、このモデルで使用されている演算子を見てみましょう。

基本的な定数生成(prim::Constant)、属性取得(prim::GetAttr)、メソッド呼び出し(prim::CallMethod)に加えて、テンソル操作(aten::size、aten::view)が使用されていますね。これらはいわゆる低レベルな操作を表しており、モデルの実行がより低レベルな操作ではどのように行われるかを確認することができました

4 TorchScriptでよくあるエラーと対処方法

TorchScriptよさげなのですが、どんなコードでも適用できるというわけではありません。

無邪気につくるとTorchScript化できません。

4.1 可変長引数とキーワード引数に関するエラー

無邪気につくっていると、以下のようなエラーメッセージがでます

Compiled functions can't take variable number of arguments or use keyword-only arguments with defaults

たとえば、以下みたいなコードをかくともれなく発生します。

# エラーが発生するコード例
class ProblematicModel(torch.nn.Module):
    def forward(self, x, return_latents=False, **kwargs):  # ←ここがエラー
        # 処理内容
        pass
原因

TorchScriptには以下のような制約があります

  1. 可変長引数(*args, **kwargs)は使用できない
  2. デフォルト値を持つキーワード専用引数は使用できない

いくつかのモデルを切り替えるときに、他のモデルでは不要だけど、このモデルではこのパラメータ入れたいよね、というときに forward の **kwargs にひとまずぶっこんでおくということをやっていまいがちですが、これはTorchScriptとは相性がよくないので、最適化をしたいなら、複数モデルの若干パラメータが異なるモデルを安易に切り替えるようなコードは書かない方が無難です

対処方法

対象法ですが、例えば以下のような対策があります

  1. すべての引数を明示的に定義する
# 修正例1:必要な引数を全て明示的に指定
class FixedModel(torch.nn.Module):
    def forward(self, x, return_latents=False, return_rgb=True, randomize_noise=True):
        # 処理内容
        pass
  1. 設定用のクラスやデータクラスを使用する
# 修正例2:設定をまとめたクラスを使用
@dataclass
class ModelConfig:
    return_latents: bool = False
    return_rgb: bool = True
    randomize_noise: bool = True

class BetterModel(torch.nn.Module):
    def forward(self, x, config: ModelConfig):
        # configから設定を取得して処理
        pass

4.2 実装のベストプラクティス

前節、前々節をふまえ、TorchScript用のモデルを書くときは、以下の点に注意しましょう

  1. 引数は明示的に定義する

    • 可変長引数(*args, **kwargs)は避ける
    • デフォルト値が必要な場合は通常の引数として定義
  2. 型ヒントを活用する

from typing import Tuple, Optional

class TypedModel(torch.nn.Module):
    def forward(self, x: torch.Tensor, return_latents: bool = False) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
        # 処理内容
        pass
  1. 複雑な設定が必要な場合は設定クラスを使用
from dataclasses import dataclass

@dataclass
class GeneratorConfig:
    return_latents: bool = False
    return_rgb: bool = True
    randomize_noise: bool = True
    noise_scale: float = 1.0

class Generator(torch.nn.Module):
    def forward(self, x: torch.Tensor, config: GeneratorConfig) -> torch.Tensor:
        # configを使用して処理を行う
        output = self.process(x, noise_scale=config.noise_scale)
        if config.return_latents:
            # 潜在変数を返す処理
            pass
        return output

これらの対処方法を適用することで、TorchScriptへの変換がスムーズに行えるようになり、本番直前になって、「あれ、TorchScriptにできねー」という悲劇を避けることができます。

今回は、PyTorchの最適化、とくにPythonコードから「TorchScript」への変換方法や、その背後で行われている処理、計算グラフのおさらい、TorchScriptアンチパターンについて解説いたしました。

最後までお読みいただきありがとうございました。

次回もお楽しみに!

付録:今回使用したコード全体

import torch

class SimpleModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = torch.nn.Conv2d(3, 64, 3)
        self.relu = torch.nn.ReLU()
        self.pool = torch.nn.MaxPool2d(2)

        # 入力サイズから全結合層の入力次元を計算
        # 入力サイズが32x32の場合の計算過程:
        # 1. Conv2d(kernel=3, stride=1, padding=0): 32x32 → 30x30
        # 2. MaxPool2d(kernel=2, stride=2): 30x30 → 15x15
        input_features = 64 * 15 * 15  # 64はconv1の出力チャンネル数
        self.fc = torch.nn.Linear(input_features, 10)

    def forward(self, x):
        x = self.conv1(x)  # [B, 64, 30, 30]
        x = self.relu(x)
        x = self.pool(x)  # [B, 64, 15, 15]
        x = x.view(x.size(0), -1)  # [B, 64*15*15]
        return self.fc(x)

def convert_to_torchscript(model, method='trace'):
    model.eval()

    if method == 'trace':
        example_input = torch.randn(1, 3, 32, 32)
        traced_model = torch.jit.trace(model, example_input)
        return traced_model
    else:
        return torch.jit.script(model)


def inspect_torchscript(model):
    # モデルのグラフ構造を表示
    print("Graph Structure:")
    print(model.graph)

    # 最適化されたコードを表示
    print("\nOptimized Code:")
    print(model.code)

    # 使用されている演算子を表示
    print("\nOperators:")
    for node in model.graph.nodes():
        print(f"- {node.kind()}")

# 実行
if __name__ == "__main__":
    model = SimpleModel()
    production_model = convert_to_torchscript(model, method='script')
    input_data = torch.randn(1, 3, 32, 32)
    result = production_model(input_data)
    print(f"Input shape: {input_data.shape}")
    print(f"Output shape: {result.shape}")

    inspect_torchscript(production_model)

Read more

ディープラーニングにおけるEMA(Exponential Moving Average)

ディープラーニングにおけるEMA(Exponential Moving Average)

こんにちは! 本日は、画像生成、動画生成モデルなどで重要な役割を果たしている EMA ※について解説してみたいとおもいます! 当社のAIアバター動画生成サービス「MotionVox™」でも役立っています! といっても、画像生成のための専用技術というわけではなく、学習と推論(生成系も含む)というディープラーニングの運用の中で昨今かなり重宝されるテクニックとなっておりますので、基礎から実装までみていきたいとおもいます。 ※EMAの読み方は私はエマと呼んでますが、イーエムエーって言ってる人もいます。どっちでもいいでしょう。 EMA の基礎知識 EMA(Exponential Moving Average=指数移動平均)は、ざっくりいえばモデルの重みを平均化する手法です。 実は株価分析などでも使われている古くからある概念なのですが、ディープラーニングでは比較的最近になって「あ、これ結構使えるんじゃね?」と重要性が認識されるようになりました。 (”EMA”に限らず、理論の積み上げではなく「やってみたら、使えんじゃん」っていうのがかなり多いのがディープラーニング界隈のもはや常識でし

By Qualiteg 研究部
TOKYO DIGICONX 「MotionVox™」出展レポート

TOKYO DIGICONX 「MotionVox™」出展レポート

こんにちは! 2025年1月9日~11日に東京ビッグサイトにて開催された TOKYO DIGICONX に出展してまいりました。 開催中3日間の様子を簡単にレポートいたします! TOKYO DIGICONX TOKYO DIGICONX は東京ビッグサイト南3・4ホールにて開催で、正式名称は『TOKYO XR・メタバース&コンテンツ ビジネスワールド』ということで、xR・メタバース・コンテンツ・AIと先端テクノロジーが集まる展示会です 「Motion Vox™」のお披露目を行いました 当社からは、新サービス「Motion Vox™」を中心とした展示をさせていただきました MotionVox™は動画内の顔と声を簡単にAIアバター動画に変換できるAIアバター動画生成サービスです。 自分で撮影した動画をアップロードし、変換したい顔と声を選ぶだけの3ステップで完了。特別な機材は不要で、自然な表情とリップシンクを実現。 社内研修やYouTube配信、ドキュメンタリー制作など、幅広い用途で活用できます。 当社ブースの様子 「MotionVox™」の初出展とい

By Qualiteg ビジネス開発本部 | マーケティング部
【本日開催】TOKYO DIGICONX で「MotionVox」を出展~リアルを纏う、AIアバター~

【本日開催】TOKYO DIGICONX で「MotionVox」を出展~リアルを纏う、AIアバター~

こんにちは! 本日(2025年1月9日)より東京ビックサイトにて開催されている「TOKYO DIGICONX」に、フォトリアリスティック(Photorealistic Avater)な次世代アバター生成AI「MotionVox」を出展しています! XR・メタバース・AIと先端テクノロジーが集まる本展示会で、ビジネス向け次世代AI動画生成ツールとしてMotionVox™をご紹介させていただきます。 MotionVox™とは MotionVox™は、あなたの表情や発話を魅力的なアバターが完全再現する動画生成AIです。まるで本物の人間がそこにいるかのような自然な表情と圧倒的な存在感で、新しい表現の可能性を切り開きます。 主な特徴 * フォトリアリスティックな高品質アバター * 高再現度の表情同期 * プロフェッショナルなリップシンク * カスタマイズ可能なボイスチェンジ機能 * 簡単な操作性 * プライバシーの完全保護 多様な用途に対応 MotionVoxは、以下のようなさまざまなビジネスシーンで活用いただけます! * 動画配信やVTuber活動 * S

By Qualiteg ビジネス開発本部 | マーケティング部
[AI新規事業創出]Qualitegセレクション:ビジネスモデル設計①ビジネスモデル図

[AI新規事業創出]Qualitegセレクション:ビジネスモデル設計①ビジネスモデル図

Qualiteg blogを訪問してくださった皆様、こんにちは。Micheleです。AIを活用した新規事業やマーケティングを手がけている私には、クライアントからよく寄せられる質問があります。AIを用いた事業展開を検討されている方々が共通して直面するであろう課題に対して、このブログを通じて私なりの解答をご提供したいと思います。 「新規事業のビジネスモデル図の描き方 〜実践で活かせる具体的なコツ〜」 新規事業開発のコンサルティングをさせていただいておりますとクライアント企業様の現場で、「ビジネスモデル図をどう描けばいいの?」という質問をよく頂きます。 実は私も最初は悩んだのですが、数々の失敗と成功を経て、効果的なビジネスモデル図の描き方が分かってきました。今回は、その実践的なコツをお伝えしていきます。 なぜビジネスモデル図が重要なのか ビジネスモデル図は、単なる図解ではありません。これは、自分のビジネスアイデアを「検証可能な形」に落とし込むための重要なツールです。 上申の際にステークホルダーの説明をするのに使うこともできます。また、アイディア創出後のマネタイズ検討の場合も情報

By Join us, Michele on Qualiteg's adventure to innovation