DPO(直接選好最適化)の基礎から画像・動画AIへの応用まで

DPO(直接選好最適化)の基礎から画像・動画AIへの応用まで
Photo by Diggity Marketing / Unsplash

こんにちは Qualiteg研究部です!

本日は、2023年、AnthropicのRafael Rafailov、Archit Sharmaらの研究チームによって提案された「直接選好最適化(Direct Preference Optimization: DPO)」について、基礎から応用までを解説します。

この手法は、論文「Direct Preference Optimization: Your Language Model is Secretly a Reward Model」で発表され、AIの学習手法に大きな影響を与えています。この論文では、言語モデル(LM)の動作を人間の好みに調整する新しい手法「Direct Preference Optimization(DPO)」を提案していますが、最近では、VLMなど言語モデルに限らず応用が広がっています。

しかも、理論は比較的シンプルなので、じわりと人気があがっていますね!

DPOが生まれた背景

言語モデルは大規模データで事前学習されるため、幅広い知識と能力を持つが、その動作を制御するのは困難でした。

そのため、従来の言語モデルの学習では、人間の選好(好み)を反映させるために強化学習(RL)が使用されていましたが、実装が複雑で計算コストが高いという課題がありました。

「人間のフィードバックによる強化学習(RLHF)」は有用ですが、複雑で計算コストが高く大変ですよね。

Anthropicのチームはこれらの課題を解決するため、より直接的かつシンプルなアプローチとしてDPOを開発しました。この手法は発表後すぐに注目を集め、OpenAI、Google、DeepMindなど、多くの主要AI研究機関でも採用されています。

DPOの基本原理

従来のRLHFでは、報酬モデルを学習し、それを最大化するようにポリシーを調整していましたが、
DPOは、報酬モデルを明示的に学習せず、人間の好みデータを直接用いて言語モデルを調整する新しい手法となります。

数学的には以下のように表現されます

L(θ) = E[log(1 + exp(β(r_w(x_w) - r_w(x_l))))]

ここで:

  • θ:モデルのパラメータ
  • β:温度パラメータ
  • r_w(x):選好スコア
  • x_w, x_l:それぞれ「望ましい」出力と「望ましくない」出力

この単純な分類損失をつかい、損失関数を最小化することで、モデルは人間の選好に沿った出力を生成できるように学習しますので、安定性が高く計算効率も良いというわけです。

これによりDPOは、感情制御、要約、対話といったタスクで、従来のRLHF(例えばPPO)と同等またはそれ以上の性能を発揮することがわかりました。高度なハイパーパラメータ調整も不要で簡単に実装できます。

言語以外のAIへの応用はどうでしょうか。

画像生成AIでの活用例

たとえば画像生成の文脈で考えてみましょう!

テキストプロンプトに基づいて画像を生成するモデル(例: DALL·EやStable Diffusionなど)は、ユーザーの明確な好みを反映するのが難しい場合があります。というか、詳細に指定しようとすればするほど、かなり難しいですよね。

DPOを人間のフィードバックを用いて「好ましい画像」と「好ましくない画像」を比較し、その好みを直接モデルに反映します。

たとえば「リアルな風景」や「アニメ風キャラクター」など、生成画像のスタイルやク品質を特定の好みに合わせて調整することができます。

スタイル最適化の例

preferred_style = get_human_preference(image_A, image_B)
loss = dpo_loss(model_params, preferred_style)
model_params = optimize(loss)

具体的な例として、アニメ風のキャラクター生成において、「目の大きさ」や「線の太さ」といった特徴を人間の選好に基づいて調整できます。

品質向上の例
画質の改善においても、DPOは効果的です:

品質スコア = Σ(wi * quality_metric_i)

ここでwiは各品質指標の重みを表し、人間の選好データから学習されます。

画像生成における DPO 実装例

それでは、画像生成におけるDPOの実装についてみていきましょう。

DPO実装の全体像

Direct Preference Optimization(DPO)の実装において最も重要なのは、人間の選好をどのようにしてモデルに学習させるかという点です。今回紹介するPyTorchによる実装では、画像生成タスクを例に、選好学習の基本的な構造を示していきます。

DPOTrainerクラスの仕組み

DPOの中核となるのがDPOTrainerクラスです。

このクラスは、モデルの学習プロセス全体を管理します。初期化時には温度パラメータβと学習率を設定します。

βは選好の強さを制御するパラメータで、値が大きいほど望ましい出力と望ましくない出力の差が強調されます。最適化手法はadamです。

特に重要かつ特徴的なのが損失関数の実装部分です。compute_dpo_loss メソッドでは、望ましい出力と望ましくない出力のペアを受け取り、それぞれにスコアを付与します。その後、log(1 + exp(β(rejected - preferred)))という数式で表される損失を計算します。この損失関数により、望ましい出力のスコアが高く、望ましくない出力のスコアが低くなるように学習が進むというわけです。

import torch
import torch.nn as nn
import torch.optim as optim

class DPOTrainer:
    def __init__(self, model, beta=0.1, learning_rate=1e-5):
        self.model = model
        self.beta = beta
        self.optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    
    def compute_dpo_loss(self, preferred_outputs, rejected_outputs):
        # 選好スコアの計算
        preferred_scores = self.model(preferred_outputs)
        rejected_scores = self.model(rejected_outputs)
        
        # DPO損失関数の実装
        loss = torch.log(1 + torch.exp(self.beta * (rejected_scores - preferred_scores)))
        return loss.mean()
    
    def train_step(self, preferred_batch, rejected_batch):
        self.optimizer.zero_grad()
        loss = self.compute_dpo_loss(preferred_batch, rejected_batch)
        loss.backward()
        self.optimizer.step()
        return loss.item()

画像評価のためのニューラルネットワーク

画像の評価には、ImageDPOクラスとして実装された畳み込みニューラルネットワークを使用します。フツーのCNNです。一応説明をしておきますと、このネットワークはRGB画像を入力として受け取り、一連の畳み込み層で特徴を抽出し最初の畳み込み層では3チャネルの入力を64チャネルに拡張し抽出された特徴は最終的に1次元に変換され、全結合層によって1つの選好スコアへと変換されます。最後が選考スコアになるところがDPOっぽいところですね。

この構造で画像の視覚的特徴を効果的に捉え、その画像が望ましいものかどうかを数値化することができます。ここは一般的な話ですがあとは、ReLUで非線形にし、複雑な特徴の学習ができます。


# 画像生成への応用例
class ImageDPO(nn.Module):
    def __init__(self):
        super().__init__()
        self.backbone = nn.Sequential(
            nn.Conv2d(3, 64, 3, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 64, 3, padding=1),
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(64 * 32 * 32, 1)  # 画像サイズに応じて調整
        )
    
    def forward(self, x):
        return self.backbone(x)

# 使用例
def train_image_dpo():
    model = ImageDPO()
    trainer = DPOTrainer(model)
    
    # トレーニングループ
    for epoch in range(num_epochs):
        for preferred_batch, rejected_batch in data_loader:
            loss = trainer.train_step(preferred_batch, rejected_batch)
            print(f"Epoch {epoch}, Loss: {loss:.4f}")

学習プロセスの実装

実際の学習は、train_image_dpo関数によって制御されImageDPOモデルとDPOTrainerのインスタンスを作成しエポック単位で、データセットから望ましい画像と望ましくない画像のペアを順次取り出し、学習を進めています。各バッチの処理では、まず望ましい画像と望ましくない画像のペアをモデルに入力し、それぞれのスコアを計算します。その後、これらのスコアを基に損失を計算し、その損失を最小化するようにモデルのパラメータを更新します。

ところどころDPOっぽい部分がありますが、基本はおなじみのCNN系の画像用ネットワークですね。

結局重要なのは「選考データ」

実装解説をしましたが、DPOは非常にシンプルです。この手法がもたらす果実を得るには、結局はインプットデータとなります。DPO自体は何もしてくれません。入力次第です。つまり画像ペアのデータセットを適切に準備する必要がありこれには、人間の評価者による画像の選好データの収集が含まれます。さらに、モデルの学習過程をモニタリングし、適切に評価を行うことで、モデルの性能を確認し、必要に応じて調整を行うことができます。これらの要素を適切に組み合わせることで、効果的なDPO学習システムを構築することが可能となります。

なーんだ、結局データづくりか、と思ったかもしれませんが、RLHFをやったことのある人ならわかるとおもいますが、それに比べるとだいぶハッピーです。

動画生成への応用

さて、次は動画生成への応用についてもみてみましょう。

当社の動画生成AIでも実はDPOを取り入れています!

特に動画生成において重要なのは、時間軸の考慮です。

frame_consistency_score = Σt(similarity(frame_t, frame_t+1))
motion_naturalness = evaluate_motion_smoothness(frames)
total_score = α * frame_consistency_score + β * motion_naturalness

人物の動きの自然さを改善する場合、動画の元となる複数枚の静止画において、人間の評価者が「より自然な動き」を選択します。人間が「自然だな」とおもったデータが 好み=選好データとなりますので、これを用いてDPOで学習します。これにより、従来の補完手法よりも1段上の「自然さ」を実現することができます。

動画生成時のDPO適用実装例

動画生成のためのDPO実装は、画像生成と比べてやや複雑です。

特に重要なのは「時間軸に沿った一貫性の維持」です。

動画生成AIでは、いきなり動画ができるのではなく、複数枚の静止画(=シーケンスといいます)の集合体として動画ができます。

1枚ずつの静止画をパラパラアニメのようにつなぎ合わせることで動画にみせます。そこで各独立した静止画を自然にみせるためには、静止画間の時系列が重要になります。つまり「時間軸での一貫性の維持」が重要です。

(余談ですが、当社が目指しているAIヒューマンの実現の最も重要な要素も「一貫性の維持」です。画像でも性格でも、一貫性の維持が人間を人間らしくしていますので、非常に面白いテーマです。)

ここでは、VideoDPOクラスを中心に、動画生成のための実装について詳しく解説していきます。

3次元畳み込みを用いた特徴抽出

VideoDPOクラスの中核となるのは、3次元畳み込み層(Conv3D)を用いたframe_encoderです。この層では、空間的な特徴だけでなく、時間的な特徴も同時に抽出します。入力データは(バッチサイズ、チャネル数、フレーム数、高さ、幅)という5次元のテンソルとして扱います。

ネットワークは一般的な内容ですが、説明しますと、最初の畳み込み層では3チャネルの入力を64チャネルに拡張、その後MaxPooling層で特徴マップのサイズを削減しつぎの畳み込み層では64チャネルから128チャネルへと拡張して豊かな特徴表現を得ます。各畳み込み層の後にはおなじみReLUです。

時間的特徴のスコアリング

抽出された特徴は、temporal_scorerによって評価されます。このモジュールでは、まず多次元の特徴マップを一列に並べ替え、512ユニットを持つ中間層を経て、最終的に1つのスコアへと変換します。このスコアは、動画の質を総合的に評価する指標となります。

class VideoDPO(nn.Module):
    def __init__(self, num_frames=16):
        super().__init__()
        self.frame_encoder = nn.Sequential(
            nn.Conv3d(3, 64, kernel_size=(3, 3, 3), padding=(1, 1, 1)),
            nn.ReLU(),
            nn.MaxPool3d(kernel_size=(2, 2, 2)),
            nn.Conv3d(64, 128, kernel_size=(3, 3, 3), padding=(1, 1, 1)),
            nn.ReLU(),
            nn.MaxPool3d(kernel_size=(2, 2, 2))
        )
        self.temporal_scorer = nn.Sequential(
            nn.Linear(128 * 4 * 4 * 4, 512),
            nn.ReLU(),
            nn.Linear(512, 1)
        )
    
    def forward(self, x):
        # x shape: (batch_size, channels, frames, height, width)
        features = self.frame_encoder(x)
        features = features.view(features.size(0), -1)
        score = self.temporal_scorer(features)
        return score

フレーム間の一貫性評価

ここでは先ほども述べたフレーム間の一貫性についてみてみましょう。

compute_temporal_consistency関数では、連続するフレーム間の差分を平均二乗誤差(MSE)を用いて計算します。

(MSE忘れちゃった場合は、当社のQiitaブログをお読みください^-^)

この値が小さいほど、動画の動きが滑らかで自然であることを意味し最終的な返り値は負の値となっており、これは一貫性が高いほど損失が小さくなるように設計されています。

def compute_temporal_consistency(frames):
    """フレーム間の一貫性スコアを計算"""
    consistency = 0
    for i in range(len(frames)-1):
        consistency += torch.nn.functional.mse_loss(frames[i], frames[i+1])
    return -consistency  # 差が小さいほど一貫性が高い

# 動画生成の学習ループ例
def train_video_dpo(model, train_loader, num_epochs=10):
    trainer = DPOTrainer(model)
    
    for epoch in range(num_epochs):
        for preferred_videos, rejected_videos in train_loader:
            # 基本的なDPO損失
            dpo_loss = trainer.train_step(preferred_videos, rejected_videos)
            
            # 時間的一貫性を考慮した追加の損失
            temporal_loss = compute_temporal_consistency(preferred_videos)
            
            # 総合的な損失
            total_loss = dpo_loss + 0.5 * temporal_loss
            
            print(f"Epoch {epoch}, DPO Loss: {dpo_loss:.4f}, "
                  f"Temporal Loss: {temporal_loss:.4f}")
学習プロセスの実装

train_video_dpo関数では、実際の学習プロセスを制御します。各エポックでは、望ましい動画と望ましくない動画のペアを使用してDPO損失を計算します。さらに、望ましい動画については時間的一貫性も評価し、これらを組み合わせた総合的な損失関数を最小化します。DPO損失と時間的一貫性損失を0.5という係数で重み付けして組み合わせることで、選好学習と動画の品質維持のバランスを取っています。各エポックごとに両方の損失値が出力されるため、学習の進行状況を詳細にモニタリングすることができます。

実践上の考慮点

さて、シンプルな実装を示しましたが、やはり入力データが重要です。つまり、入力データに「良い動画」と「悪い動画」のペアが必要となりますが、これらの収集は容易ではなく、また明確な評価基準が求められます。ここでも適切なデータセットの構築が、モデルの性能を左右します。

DPOの影響と今後の展望

DPOの登場以降、多くのAI研究機関がこの手法を採用し、様々な改良版や応用研究が発表されています。特に2024年に入ってからは、当社も含め画像生成や動画生成の分野でこのブームに乗っかっており応用研究が活発化しています。

さらに動画処理では、モバイルカメラとの相性から、より効率的な学習方法や、リアルタイム処理への適用などさらなる研究が進められています。

まとめ

2023年に登場したDPOは、その数学的シンプルさと実装の容易さから、言語モデルをはじめ、画像・動画生成を含む様々な分野で急速に採用が進んでいます。特に、主観的な品質評価が重要な視覚メディアの生成において、人間の好みを効率的に学習できる点が大きな利点となっています。

当社サービスもこのDPOを活用してさらに品質と効率の高いものをリリースしていく予定ですのでご期待くださいませ!

当社では、このように最新AIサイエンス、AIエンジニアリングの知見を総動員し人々のクリエイティビティを深化させるサービス、ソリューションをご提供しております。先端AIサイエンス、エンジニアリングの知見を使って世界を変革するチャレンジを少数精鋭で進めております。「腕に覚えアリ」という方はぜひ当社でのキャリアをご検討いただければ幸いです!

Read more

LLM推論基盤プロビジョニング講座 第3回 使用モデルの推論時消費メモリ見積もり

LLM推論基盤プロビジョニング講座 第3回 使用モデルの推論時消費メモリ見積もり

こんにちは!前回はLLMサービスへのリクエスト数見積もりについて解説しました。今回は7ステッププロセスの3番目、「使用モデルの推論時消費メモリ見積もり」について詳しく掘り下げていきます。 GPUメモリがリクエスト処理能力を決定する LLMサービス構築において、GPUが同時に処理できるリクエスト数はGPUメモリの消費量によって制約されます。 つまり、利用可能なGPUメモリがどれだけあるかによって、同時に何件のリクエストを処理できるかがほぼ決まります。 では、その具体例として、Llama3 8B(80億パラメータ)モデルをNVIDIA RTX A5000(24GB)にロードするケースを考えてみましょう。 このGPUには24GBのGPUメモリがありますが、すべてをリクエスト処理に使えるわけではありません。最初にモデル自体が一定量のメモリを消費し、残りの領域で実際のリクエスト処理を行います。 GPUメモリ消費の二大要素 GPUの消費メモリ量は主に以下の2つの要素によって決まります 1. モデルのフットプリント LLMをGPUに読み込んだときに最初に消費されるメモリ

By Qualiteg コンサルティング
システムとcondaのC++標準ライブラリ(libstdc++)のバージョン違い問題による事象と対処法解説

システムとcondaのC++標準ライブラリ(libstdc++)のバージョン違い問題による事象と対処法解説

こんにちは! 先日、dlibをつかったPythonアプリケーション(conda環境で動作する)作っていたところ、以下のようなエラーに遭遇しました。 ImportError: /home/mlu/anaconda3/envs/example_env/bin/../lib/libstdc++.so.6: version `GLIBCXX_3.4.32' not found (required by /home/mlu/anaconda3/envs/example_env/lib/python3.10/site-packages/_dlib_pybind11.cpython-310-x86_64-linux-gnu.so) 「dlib_pybind11モジュールがGLIBCXX_3.4.32を要求してるけど、みつからない!」という感じのエラーですね。

By Qualiteg プロダクト開発部
LLM推論基盤プロビジョニング講座 第2回 LLMサービスのリクエスト数を見積もる

LLM推論基盤プロビジョニング講座 第2回 LLMサービスのリクエスト数を見積もる

こんにちは! 今回はLLM推論基盤プロビジョニング講座 第2回です! STEP2 LLMサービスへのリクエスト数見積もり それでは、早速、LLM推論基盤プロビジョニングの第2ステップである「リクエスト数見積もり」の重要性と方法を解説いたします。 LLMサービスを構築する際に必要となるGPUノード数を適切に見積もるためには、まずサービスに対して想定されるリクエスト数を正確に予測する必要があります。 リクエスト数見積もりの基本的な考え方 LLMサービスへの想定リクエスト数から必要なGPUノード数を算出するプロセスは、サービス設計において非常に重要です。過小評価すればサービス品質が低下し、過大評価すれば無駄なコストが発生します。このバランスを適切に取るための基礎となるのがリクエスト数の見積もりです。 想定リクエスト数の諸元 リクエスト数を見積もるための5つの重要な要素(諸元)をみてみましょう。 1. DAU(Daily Active Users): 1日あたりの実際にサービスを利用するユーザー数です。これはサービスの規模を示す最も基本的な指標となります。 2. 1日

By Qualiteg コンサルティング
Zoom会議で肩が踊る?自動フレーミング映像安定化とAIによる性能向上の可能性

Zoom会議で肩が踊る?自動フレーミング映像安定化とAIによる性能向上の可能性

こんにちは! 本日は、自動フレーミング映像の安定化に関するアルゴリズム・ノウハウを解説いたします 第1章 問題の背景と目的 バストアップ映像を撮影する際、特にオンラインミーティングやYouTubeなどのトーク映像では、人物がうなずく、首を振るなどの自然な動作をした際に「首まわりや肩がフレーム内で上下に移動してしまう」という現象がしばしば起こります。これは、多くの場合カメラや撮影ソフトウェアが人物の「目や顔を画面中央に保とう」とする自動フレーミング機能の働きに起因します。 撮影対象の人物が頭を下げた際に、映像のフレーム全体が相対的に上方向へシフトし、その結果、本来動いていないはずの肩の部分が映像内で持ち上がっているように見えてしまう現象です。 本稿では、この問題を撮影後の後処理(ポストプロセッシング)のみを用いて、高速、高い精度かつロバストに解決する手法をご紹介します。 前半では、従来のCV(コンピュータービジョン)の手法を使い高速に処理する方法をご紹介します。後半では、AIを使用してより安定性の高い性能を実現する方法について考察します。 第2章 古典手法による肩の上下

By Qualiteg 研究部