DPO(直接選好最適化)の基礎から画像・動画AIへの応用まで

DPO(直接選好最適化)の基礎から画像・動画AIへの応用まで
Photo by Diggity Marketing / Unsplash

こんにちは Qualiteg研究部です!

本日は、2023年、AnthropicのRafael Rafailov、Archit Sharmaらの研究チームによって提案された「直接選好最適化(Direct Preference Optimization: DPO)」について、基礎から応用までを解説します。

この手法は、論文「Direct Preference Optimization: Your Language Model is Secretly a Reward Model」で発表され、AIの学習手法に大きな影響を与えています。この論文では、言語モデル(LM)の動作を人間の好みに調整する新しい手法「Direct Preference Optimization(DPO)」を提案していますが、最近では、VLMなど言語モデルに限らず応用が広がっています。

しかも、理論は比較的シンプルなので、じわりと人気があがっていますね!

DPOが生まれた背景

言語モデルは大規模データで事前学習されるため、幅広い知識と能力を持つが、その動作を制御するのは困難でした。

そのため、従来の言語モデルの学習では、人間の選好(好み)を反映させるために強化学習(RL)が使用されていましたが、実装が複雑で計算コストが高いという課題がありました。

「人間のフィードバックによる強化学習(RLHF)」は有用ですが、複雑で計算コストが高く大変ですよね。

Anthropicのチームはこれらの課題を解決するため、より直接的かつシンプルなアプローチとしてDPOを開発しました。この手法は発表後すぐに注目を集め、OpenAI、Google、DeepMindなど、多くの主要AI研究機関でも採用されています。

DPOの基本原理

従来のRLHFでは、報酬モデルを学習し、それを最大化するようにポリシーを調整していましたが、
DPOは、報酬モデルを明示的に学習せず、人間の好みデータを直接用いて言語モデルを調整する新しい手法となります。

数学的には以下のように表現されます

L(θ) = E[log(1 + exp(β(r_w(x_w) - r_w(x_l))))]

ここで:

  • θ:モデルのパラメータ
  • β:温度パラメータ
  • r_w(x):選好スコア
  • x_w, x_l:それぞれ「望ましい」出力と「望ましくない」出力

この単純な分類損失をつかい、損失関数を最小化することで、モデルは人間の選好に沿った出力を生成できるように学習しますので、安定性が高く計算効率も良いというわけです。

これによりDPOは、感情制御、要約、対話といったタスクで、従来のRLHF(例えばPPO)と同等またはそれ以上の性能を発揮することがわかりました。高度なハイパーパラメータ調整も不要で簡単に実装できます。

言語以外のAIへの応用はどうでしょうか。

画像生成AIでの活用例

たとえば画像生成の文脈で考えてみましょう!

テキストプロンプトに基づいて画像を生成するモデル(例: DALL·EやStable Diffusionなど)は、ユーザーの明確な好みを反映するのが難しい場合があります。というか、詳細に指定しようとすればするほど、かなり難しいですよね。

DPOを人間のフィードバックを用いて「好ましい画像」と「好ましくない画像」を比較し、その好みを直接モデルに反映します。

たとえば「リアルな風景」や「アニメ風キャラクター」など、生成画像のスタイルやク品質を特定の好みに合わせて調整することができます。

スタイル最適化の例

preferred_style = get_human_preference(image_A, image_B)
loss = dpo_loss(model_params, preferred_style)
model_params = optimize(loss)

具体的な例として、アニメ風のキャラクター生成において、「目の大きさ」や「線の太さ」といった特徴を人間の選好に基づいて調整できます。

品質向上の例
画質の改善においても、DPOは効果的です:

品質スコア = Σ(wi * quality_metric_i)

ここでwiは各品質指標の重みを表し、人間の選好データから学習されます。

画像生成における DPO 実装例

それでは、画像生成におけるDPOの実装についてみていきましょう。

DPO実装の全体像

Direct Preference Optimization(DPO)の実装において最も重要なのは、人間の選好をどのようにしてモデルに学習させるかという点です。今回紹介するPyTorchによる実装では、画像生成タスクを例に、選好学習の基本的な構造を示していきます。

DPOTrainerクラスの仕組み

DPOの中核となるのがDPOTrainerクラスです。

このクラスは、モデルの学習プロセス全体を管理します。初期化時には温度パラメータβと学習率を設定します。

βは選好の強さを制御するパラメータで、値が大きいほど望ましい出力と望ましくない出力の差が強調されます。最適化手法はadamです。

特に重要かつ特徴的なのが損失関数の実装部分です。compute_dpo_loss メソッドでは、望ましい出力と望ましくない出力のペアを受け取り、それぞれにスコアを付与します。その後、log(1 + exp(β(rejected - preferred)))という数式で表される損失を計算します。この損失関数により、望ましい出力のスコアが高く、望ましくない出力のスコアが低くなるように学習が進むというわけです。

import torch
import torch.nn as nn
import torch.optim as optim

class DPOTrainer:
    def __init__(self, model, beta=0.1, learning_rate=1e-5):
        self.model = model
        self.beta = beta
        self.optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    
    def compute_dpo_loss(self, preferred_outputs, rejected_outputs):
        # 選好スコアの計算
        preferred_scores = self.model(preferred_outputs)
        rejected_scores = self.model(rejected_outputs)
        
        # DPO損失関数の実装
        loss = torch.log(1 + torch.exp(self.beta * (rejected_scores - preferred_scores)))
        return loss.mean()
    
    def train_step(self, preferred_batch, rejected_batch):
        self.optimizer.zero_grad()
        loss = self.compute_dpo_loss(preferred_batch, rejected_batch)
        loss.backward()
        self.optimizer.step()
        return loss.item()

画像評価のためのニューラルネットワーク

画像の評価には、ImageDPOクラスとして実装された畳み込みニューラルネットワークを使用します。フツーのCNNです。一応説明をしておきますと、このネットワークはRGB画像を入力として受け取り、一連の畳み込み層で特徴を抽出し最初の畳み込み層では3チャネルの入力を64チャネルに拡張し抽出された特徴は最終的に1次元に変換され、全結合層によって1つの選好スコアへと変換されます。最後が選考スコアになるところがDPOっぽいところですね。

この構造で画像の視覚的特徴を効果的に捉え、その画像が望ましいものかどうかを数値化することができます。ここは一般的な話ですがあとは、ReLUで非線形にし、複雑な特徴の学習ができます。


# 画像生成への応用例
class ImageDPO(nn.Module):
    def __init__(self):
        super().__init__()
        self.backbone = nn.Sequential(
            nn.Conv2d(3, 64, 3, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 64, 3, padding=1),
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(64 * 32 * 32, 1)  # 画像サイズに応じて調整
        )
    
    def forward(self, x):
        return self.backbone(x)

# 使用例
def train_image_dpo():
    model = ImageDPO()
    trainer = DPOTrainer(model)
    
    # トレーニングループ
    for epoch in range(num_epochs):
        for preferred_batch, rejected_batch in data_loader:
            loss = trainer.train_step(preferred_batch, rejected_batch)
            print(f"Epoch {epoch}, Loss: {loss:.4f}")

学習プロセスの実装

実際の学習は、train_image_dpo関数によって制御されImageDPOモデルとDPOTrainerのインスタンスを作成しエポック単位で、データセットから望ましい画像と望ましくない画像のペアを順次取り出し、学習を進めています。各バッチの処理では、まず望ましい画像と望ましくない画像のペアをモデルに入力し、それぞれのスコアを計算します。その後、これらのスコアを基に損失を計算し、その損失を最小化するようにモデルのパラメータを更新します。

ところどころDPOっぽい部分がありますが、基本はおなじみのCNN系の画像用ネットワークですね。

結局重要なのは「選考データ」

実装解説をしましたが、DPOは非常にシンプルです。この手法がもたらす果実を得るには、結局はインプットデータとなります。DPO自体は何もしてくれません。入力次第です。つまり画像ペアのデータセットを適切に準備する必要がありこれには、人間の評価者による画像の選好データの収集が含まれます。さらに、モデルの学習過程をモニタリングし、適切に評価を行うことで、モデルの性能を確認し、必要に応じて調整を行うことができます。これらの要素を適切に組み合わせることで、効果的なDPO学習システムを構築することが可能となります。

なーんだ、結局データづくりか、と思ったかもしれませんが、RLHFをやったことのある人ならわかるとおもいますが、それに比べるとだいぶハッピーです。

動画生成への応用

さて、次は動画生成への応用についてもみてみましょう。

当社の動画生成AIでも実はDPOを取り入れています!

特に動画生成において重要なのは、時間軸の考慮です。

frame_consistency_score = Σt(similarity(frame_t, frame_t+1))
motion_naturalness = evaluate_motion_smoothness(frames)
total_score = α * frame_consistency_score + β * motion_naturalness

人物の動きの自然さを改善する場合、動画の元となる複数枚の静止画において、人間の評価者が「より自然な動き」を選択します。人間が「自然だな」とおもったデータが 好み=選好データとなりますので、これを用いてDPOで学習します。これにより、従来の補完手法よりも1段上の「自然さ」を実現することができます。

動画生成時のDPO適用実装例

動画生成のためのDPO実装は、画像生成と比べてやや複雑です。

特に重要なのは「時間軸に沿った一貫性の維持」です。

動画生成AIでは、いきなり動画ができるのではなく、複数枚の静止画(=シーケンスといいます)の集合体として動画ができます。

1枚ずつの静止画をパラパラアニメのようにつなぎ合わせることで動画にみせます。そこで各独立した静止画を自然にみせるためには、静止画間の時系列が重要になります。つまり「時間軸での一貫性の維持」が重要です。

(余談ですが、当社が目指しているAIヒューマンの実現の最も重要な要素も「一貫性の維持」です。画像でも性格でも、一貫性の維持が人間を人間らしくしていますので、非常に面白いテーマです。)

ここでは、VideoDPOクラスを中心に、動画生成のための実装について詳しく解説していきます。

3次元畳み込みを用いた特徴抽出

VideoDPOクラスの中核となるのは、3次元畳み込み層(Conv3D)を用いたframe_encoderです。この層では、空間的な特徴だけでなく、時間的な特徴も同時に抽出します。入力データは(バッチサイズ、チャネル数、フレーム数、高さ、幅)という5次元のテンソルとして扱います。

ネットワークは一般的な内容ですが、説明しますと、最初の畳み込み層では3チャネルの入力を64チャネルに拡張、その後MaxPooling層で特徴マップのサイズを削減しつぎの畳み込み層では64チャネルから128チャネルへと拡張して豊かな特徴表現を得ます。各畳み込み層の後にはおなじみReLUです。

時間的特徴のスコアリング

抽出された特徴は、temporal_scorerによって評価されます。このモジュールでは、まず多次元の特徴マップを一列に並べ替え、512ユニットを持つ中間層を経て、最終的に1つのスコアへと変換します。このスコアは、動画の質を総合的に評価する指標となります。

class VideoDPO(nn.Module):
    def __init__(self, num_frames=16):
        super().__init__()
        self.frame_encoder = nn.Sequential(
            nn.Conv3d(3, 64, kernel_size=(3, 3, 3), padding=(1, 1, 1)),
            nn.ReLU(),
            nn.MaxPool3d(kernel_size=(2, 2, 2)),
            nn.Conv3d(64, 128, kernel_size=(3, 3, 3), padding=(1, 1, 1)),
            nn.ReLU(),
            nn.MaxPool3d(kernel_size=(2, 2, 2))
        )
        self.temporal_scorer = nn.Sequential(
            nn.Linear(128 * 4 * 4 * 4, 512),
            nn.ReLU(),
            nn.Linear(512, 1)
        )
    
    def forward(self, x):
        # x shape: (batch_size, channels, frames, height, width)
        features = self.frame_encoder(x)
        features = features.view(features.size(0), -1)
        score = self.temporal_scorer(features)
        return score

フレーム間の一貫性評価

ここでは先ほども述べたフレーム間の一貫性についてみてみましょう。

compute_temporal_consistency関数では、連続するフレーム間の差分を平均二乗誤差(MSE)を用いて計算します。

(MSE忘れちゃった場合は、当社のQiitaブログをお読みください^-^)

この値が小さいほど、動画の動きが滑らかで自然であることを意味し最終的な返り値は負の値となっており、これは一貫性が高いほど損失が小さくなるように設計されています。

def compute_temporal_consistency(frames):
    """フレーム間の一貫性スコアを計算"""
    consistency = 0
    for i in range(len(frames)-1):
        consistency += torch.nn.functional.mse_loss(frames[i], frames[i+1])
    return -consistency  # 差が小さいほど一貫性が高い

# 動画生成の学習ループ例
def train_video_dpo(model, train_loader, num_epochs=10):
    trainer = DPOTrainer(model)
    
    for epoch in range(num_epochs):
        for preferred_videos, rejected_videos in train_loader:
            # 基本的なDPO損失
            dpo_loss = trainer.train_step(preferred_videos, rejected_videos)
            
            # 時間的一貫性を考慮した追加の損失
            temporal_loss = compute_temporal_consistency(preferred_videos)
            
            # 総合的な損失
            total_loss = dpo_loss + 0.5 * temporal_loss
            
            print(f"Epoch {epoch}, DPO Loss: {dpo_loss:.4f}, "
                  f"Temporal Loss: {temporal_loss:.4f}")
学習プロセスの実装

train_video_dpo関数では、実際の学習プロセスを制御します。各エポックでは、望ましい動画と望ましくない動画のペアを使用してDPO損失を計算します。さらに、望ましい動画については時間的一貫性も評価し、これらを組み合わせた総合的な損失関数を最小化します。DPO損失と時間的一貫性損失を0.5という係数で重み付けして組み合わせることで、選好学習と動画の品質維持のバランスを取っています。各エポックごとに両方の損失値が出力されるため、学習の進行状況を詳細にモニタリングすることができます。

実践上の考慮点

さて、シンプルな実装を示しましたが、やはり入力データが重要です。つまり、入力データに「良い動画」と「悪い動画」のペアが必要となりますが、これらの収集は容易ではなく、また明確な評価基準が求められます。ここでも適切なデータセットの構築が、モデルの性能を左右します。

DPOの影響と今後の展望

DPOの登場以降、多くのAI研究機関がこの手法を採用し、様々な改良版や応用研究が発表されています。特に2024年に入ってからは、当社も含め画像生成や動画生成の分野でこのブームに乗っかっており応用研究が活発化しています。

さらに動画処理では、モバイルカメラとの相性から、より効率的な学習方法や、リアルタイム処理への適用などさらなる研究が進められています。

まとめ

2023年に登場したDPOは、その数学的シンプルさと実装の容易さから、言語モデルをはじめ、画像・動画生成を含む様々な分野で急速に採用が進んでいます。特に、主観的な品質評価が重要な視覚メディアの生成において、人間の好みを効率的に学習できる点が大きな利点となっています。

当社サービスもこのDPOを活用してさらに品質と効率の高いものをリリースしていく予定ですのでご期待くださいませ!

当社では、このように最新AIサイエンス、AIエンジニアリングの知見を総動員し人々のクリエイティビティを深化させるサービス、ソリューションをご提供しております。先端AIサイエンス、エンジニアリングの知見を使って世界を変革するチャレンジを少数精鋭で進めております。「腕に覚えアリ」という方はぜひ当社でのキャリアをご検討いただければ幸いです!

Read more

PyTorchの重いCUDA処理を非同期化したらメモリリークした話と、その解決策

PyTorchの重いCUDA処理を非同期化したらメモリリークした話と、その解決策

こんにちは!Qualitegプロダクト開発部です! 今回は同期メソッドを非同期メソッド(async)化しただけなのに、思わぬメモリリーク※に見舞われたお話です。 深層学習モデルを使った動画処理システムを開発していた時のことです。 「処理の進捗をリアルタイムでWebSocketで通知したい」という要件があり、「単にasync/awaitを使えばいいだけでしょ?」と軽く考えていたら、思わぬ落とし穴にはまりました。 プロ仕様のGPUを使っていたにも関わらず、メモリ不足でクラッシュしてしまいました。 この記事では、その原因と解決策、そして学んだ教訓を詳しく共有したいと思います。同じような問題に直面している方の参考になれば幸いです。 ※ 厳密には「メモリリーク」ではなく「メモリの解放遅延」ですが、 実用上の影響は同じなので、この記事では便宜上「メモリリーク」と表現します。 背景:なぜ進捗通知は非同期である必要があるのか モダンなWebアプリケーションの要求 最近のWebアプリケーション開発では、ユーザー体験を向上させるため、長時間かかる処理の進捗をリアルタイムで表示することが

By Qualiteg プロダクト開発部
ゼロトラスト時代のLLMセキュリティ完全ガイド:ガーディアンエージェントへの進化を見据えて

ゼロトラスト時代のLLMセキュリティ完全ガイド:ガーディアンエージェントへの進化を見据えて

こんにちは! 今日はセキュリティの新たな考え方「ゼロトラスト」とLLMを中心としたAIセキュリティについて解説いたします! はじめに 3つのパラダイムシフトが同時に起きている いま、企業のIT環境では3つの大きな変革が起ころうとしています。 1つ目は「境界防御からゼロトラストへ」というセキュリティモデルの転換。 2つ目は「LLMの爆発的普及」による新たなリスクの出現。 そして3つ目は「AIエージェント時代の到来」とそれに伴う「ガーディアンエージェント」という新概念の登場です。 これらは別々の出来事のように見えて、実は密接に関連しています。本記事では、この3つの変革がどのように結びつき、企業がどのような対策を取るべきかを解説いたします 目次 1. はじめに:3つのパラダイムシフトが同時に起きている 2. 第1の変革:ゼロトラストという新しいセキュリティ思想 3. 第2の変革:LLM時代の到来とその影響 4. 第3の変革:AIエージェントとガーディアンエージェント 5. 3つの変革を統合する:実践的なアプローチ 6. 実装のベストプラクティス 7. 日本

By Qualiteg コンサルティング
発話音声からリアルなリップシンクを生成する技術 第4回:LSTMの学習と限界、そしてTransformerへ

発話音声からリアルなリップシンクを生成する技術 第4回:LSTMの学習と限界、そしてTransformerへ

1. 位置損失 (L_position) - 口の形の正確さ 時間 口の開き 正解 予測 L_position = Σᵢ wᵢ × ||y_pred - y_true||² 各時点での予測値と正解値の差を計算。重要なパラメータ(顎の開き、口の開き)には大きな重みを付けます。 jaw_open: ×2.0 mouth_open: ×2.0 その他: ×1.0 2. 速度損失 (L_velocity) - 動きの速さ 時間 速度 t→t+1 v = y[t] -

By Qualiteg 研究部, Qualiteg コンサルティング
大企業のAIセキュリティを支える基盤技術 - 今こそ理解するActive Directory 第1回 基本概念の理解

大企業のAIセキュリティを支える基盤技術 - 今こそ理解するActive Directory 第1回 基本概念の理解

こんにちは! 今回から数回にわたり Active Directory について解説してまいります。 Active Directory(AD:アクティブディレクトリー)は、Microsoft が開発したディレクトリサービスであり、今日の大企業における IT インフラストラクチャーにおいて、もはやデファクトスタンダードと言っても過言ではない存在となっており、組織内のユーザー、コンピューター、その他のリソースを一元的に管理するための基盤として広く採用されています。 AIセキュリティの現実:単独では機能しない ChatGPTやClaudeなどの生成AIが企業に急速に普及する中、「AIセキュリティ」という言葉が注目を集めています。情報漏洩の防止、不適切な利用の検知、コンプライアンスの確保など、企業が取り組むべき課題は山積みです。 しかし、ここで注意しなければいけない事実があります。それは、 AIセキュリティソリューションは、それ単体では企業環境で限定的な効果しか期待できない ということです。 企業が直面する本質的な課題 AIセキュリティツールを導入する際、企業のIT部門

By Qualiteg コンサルティング