DPO(直接選好最適化)の基礎から画像・動画AIへの応用まで
こんにちは Qualiteg研究部です!
本日は、2023年、AnthropicのRafael Rafailov、Archit Sharmaらの研究チームによって提案された「直接選好最適化(Direct Preference Optimization: DPO)」について、基礎から応用までを解説します。
この手法は、論文「Direct Preference Optimization: Your Language Model is Secretly a Reward Model」で発表され、AIの学習手法に大きな影響を与えています。この論文では、言語モデル(LM)の動作を人間の好みに調整する新しい手法「Direct Preference Optimization(DPO)」を提案していますが、最近では、VLMなど言語モデルに限らず応用が広がっています。
しかも、理論は比較的シンプルなので、じわりと人気があがっていますね!
DPOが生まれた背景
言語モデルは大規模データで事前学習されるため、幅広い知識と能力を持つが、その動作を制御するのは困難でした。
そのため、従来の言語モデルの学習では、人間の選好(好み)を反映させるために強化学習(RL)が使用されていましたが、実装が複雑で計算コストが高いという課題がありました。
「人間のフィードバックによる強化学習(RLHF)」は有用ですが、複雑で計算コストが高く大変ですよね。
Anthropicのチームはこれらの課題を解決するため、より直接的かつシンプルなアプローチとしてDPOを開発しました。この手法は発表後すぐに注目を集め、OpenAI、Google、DeepMindなど、多くの主要AI研究機関でも採用されています。
DPOの基本原理
従来のRLHFでは、報酬モデルを学習し、それを最大化するようにポリシーを調整していましたが、
DPOは、報酬モデルを明示的に学習せず、人間の好みデータを直接用いて言語モデルを調整する新しい手法となります。
数学的には以下のように表現されます
L(θ) = E[log(1 + exp(β(r_w(x_w) - r_w(x_l))))]
ここで:
- θ:モデルのパラメータ
- β:温度パラメータ
- r_w(x):選好スコア
- x_w, x_l:それぞれ「望ましい」出力と「望ましくない」出力
この単純な分類損失をつかい、損失関数を最小化することで、モデルは人間の選好に沿った出力を生成できるように学習しますので、安定性が高く計算効率も良いというわけです。
これによりDPOは、感情制御、要約、対話といったタスクで、従来のRLHF(例えばPPO)と同等またはそれ以上の性能を発揮することがわかりました。高度なハイパーパラメータ調整も不要で簡単に実装できます。
言語以外のAIへの応用はどうでしょうか。
画像生成AIでの活用例
たとえば画像生成の文脈で考えてみましょう!
テキストプロンプトに基づいて画像を生成するモデル(例: DALL·EやStable Diffusionなど)は、ユーザーの明確な好みを反映するのが難しい場合があります。というか、詳細に指定しようとすればするほど、かなり難しいですよね。
DPOを人間のフィードバックを用いて「好ましい画像」と「好ましくない画像」を比較し、その好みを直接モデルに反映します。
たとえば「リアルな風景」や「アニメ風キャラクター」など、生成画像のスタイルやク品質を特定の好みに合わせて調整することができます。
スタイル最適化の例
preferred_style = get_human_preference(image_A, image_B)
loss = dpo_loss(model_params, preferred_style)
model_params = optimize(loss)
具体的な例として、アニメ風のキャラクター生成において、「目の大きさ」や「線の太さ」といった特徴を人間の選好に基づいて調整できます。
品質向上の例
画質の改善においても、DPOは効果的です:
品質スコア = Σ(wi * quality_metric_i)
ここでwiは各品質指標の重みを表し、人間の選好データから学習されます。
画像生成における DPO 実装例
それでは、画像生成におけるDPOの実装についてみていきましょう。
DPO実装の全体像
Direct Preference Optimization(DPO)の実装において最も重要なのは、人間の選好をどのようにしてモデルに学習させるかという点です。今回紹介するPyTorchによる実装では、画像生成タスクを例に、選好学習の基本的な構造を示していきます。
DPOTrainerクラスの仕組み
DPOの中核となるのがDPOTrainerクラスです。
このクラスは、モデルの学習プロセス全体を管理します。初期化時には温度パラメータβと学習率を設定します。
βは選好の強さを制御するパラメータで、値が大きいほど望ましい出力と望ましくない出力の差が強調されます。最適化手法はadamです。
特に重要かつ特徴的なのが損失関数の実装部分です。compute_dpo_loss メソッドでは、望ましい出力と望ましくない出力のペアを受け取り、それぞれにスコアを付与します。その後、log(1 + exp(β(rejected - preferred)))という数式で表される損失を計算します。この損失関数により、望ましい出力のスコアが高く、望ましくない出力のスコアが低くなるように学習が進むというわけです。
import torch
import torch.nn as nn
import torch.optim as optim
class DPOTrainer:
def __init__(self, model, beta=0.1, learning_rate=1e-5):
self.model = model
self.beta = beta
self.optimizer = optim.Adam(model.parameters(), lr=learning_rate)
def compute_dpo_loss(self, preferred_outputs, rejected_outputs):
# 選好スコアの計算
preferred_scores = self.model(preferred_outputs)
rejected_scores = self.model(rejected_outputs)
# DPO損失関数の実装
loss = torch.log(1 + torch.exp(self.beta * (rejected_scores - preferred_scores)))
return loss.mean()
def train_step(self, preferred_batch, rejected_batch):
self.optimizer.zero_grad()
loss = self.compute_dpo_loss(preferred_batch, rejected_batch)
loss.backward()
self.optimizer.step()
return loss.item()
画像評価のためのニューラルネットワーク
画像の評価には、ImageDPOクラスとして実装された畳み込みニューラルネットワークを使用します。フツーのCNNです。一応説明をしておきますと、このネットワークはRGB画像を入力として受け取り、一連の畳み込み層で特徴を抽出し最初の畳み込み層では3チャネルの入力を64チャネルに拡張し抽出された特徴は最終的に1次元に変換され、全結合層によって1つの選好スコアへと変換されます。最後が選考スコアになるところがDPOっぽいところですね。
この構造で画像の視覚的特徴を効果的に捉え、その画像が望ましいものかどうかを数値化することができます。ここは一般的な話ですがあとは、ReLUで非線形にし、複雑な特徴の学習ができます。
# 画像生成への応用例
class ImageDPO(nn.Module):
def __init__(self):
super().__init__()
self.backbone = nn.Sequential(
nn.Conv2d(3, 64, 3, padding=1),
nn.ReLU(),
nn.Conv2d(64, 64, 3, padding=1),
nn.ReLU(),
nn.Flatten(),
nn.Linear(64 * 32 * 32, 1) # 画像サイズに応じて調整
)
def forward(self, x):
return self.backbone(x)
# 使用例
def train_image_dpo():
model = ImageDPO()
trainer = DPOTrainer(model)
# トレーニングループ
for epoch in range(num_epochs):
for preferred_batch, rejected_batch in data_loader:
loss = trainer.train_step(preferred_batch, rejected_batch)
print(f"Epoch {epoch}, Loss: {loss:.4f}")
学習プロセスの実装
実際の学習は、train_image_dpo関数によって制御されImageDPOモデルとDPOTrainerのインスタンスを作成しエポック単位で、データセットから望ましい画像と望ましくない画像のペアを順次取り出し、学習を進めています。各バッチの処理では、まず望ましい画像と望ましくない画像のペアをモデルに入力し、それぞれのスコアを計算します。その後、これらのスコアを基に損失を計算し、その損失を最小化するようにモデルのパラメータを更新します。
ところどころDPOっぽい部分がありますが、基本はおなじみのCNN系の画像用ネットワークですね。
結局重要なのは「選考データ」
実装解説をしましたが、DPOは非常にシンプルです。この手法がもたらす果実を得るには、結局はインプットデータとなります。DPO自体は何もしてくれません。入力次第です。つまり画像ペアのデータセットを適切に準備する必要がありこれには、人間の評価者による画像の選好データの収集が含まれます。さらに、モデルの学習過程をモニタリングし、適切に評価を行うことで、モデルの性能を確認し、必要に応じて調整を行うことができます。これらの要素を適切に組み合わせることで、効果的なDPO学習システムを構築することが可能となります。
なーんだ、結局データづくりか、と思ったかもしれませんが、RLHFをやったことのある人ならわかるとおもいますが、それに比べるとだいぶハッピーです。
動画生成への応用
さて、次は動画生成への応用についてもみてみましょう。
当社の動画生成AIでも実はDPOを取り入れています!
特に動画生成において重要なのは、時間軸の考慮です。
frame_consistency_score = Σt(similarity(frame_t, frame_t+1))
motion_naturalness = evaluate_motion_smoothness(frames)
total_score = α * frame_consistency_score + β * motion_naturalness
人物の動きの自然さを改善する場合、動画の元となる複数枚の静止画において、人間の評価者が「より自然な動き」を選択します。人間が「自然だな」とおもったデータが 好み=選好データとなりますので、これを用いてDPOで学習します。これにより、従来の補完手法よりも1段上の「自然さ」を実現することができます。
動画生成時のDPO適用実装例
動画生成のためのDPO実装は、画像生成と比べてやや複雑です。
特に重要なのは「時間軸に沿った一貫性の維持」です。
動画生成AIでは、いきなり動画ができるのではなく、複数枚の静止画(=シーケンスといいます)の集合体として動画ができます。
1枚ずつの静止画をパラパラアニメのようにつなぎ合わせることで動画にみせます。そこで各独立した静止画を自然にみせるためには、静止画間の時系列が重要になります。つまり「時間軸での一貫性の維持」が重要です。
(余談ですが、当社が目指しているAIヒューマンの実現の最も重要な要素も「一貫性の維持」です。画像でも性格でも、一貫性の維持が人間を人間らしくしていますので、非常に面白いテーマです。)
ここでは、VideoDPOクラスを中心に、動画生成のための実装について詳しく解説していきます。
3次元畳み込みを用いた特徴抽出
VideoDPOクラスの中核となるのは、3次元畳み込み層(Conv3D)を用いたframe_encoderです。この層では、空間的な特徴だけでなく、時間的な特徴も同時に抽出します。入力データは(バッチサイズ、チャネル数、フレーム数、高さ、幅)という5次元のテンソルとして扱います。
ネットワークは一般的な内容ですが、説明しますと、最初の畳み込み層では3チャネルの入力を64チャネルに拡張、その後MaxPooling層で特徴マップのサイズを削減しつぎの畳み込み層では64チャネルから128チャネルへと拡張して豊かな特徴表現を得ます。各畳み込み層の後にはおなじみReLUです。
時間的特徴のスコアリング
抽出された特徴は、temporal_scorerによって評価されます。このモジュールでは、まず多次元の特徴マップを一列に並べ替え、512ユニットを持つ中間層を経て、最終的に1つのスコアへと変換します。このスコアは、動画の質を総合的に評価する指標となります。
class VideoDPO(nn.Module):
def __init__(self, num_frames=16):
super().__init__()
self.frame_encoder = nn.Sequential(
nn.Conv3d(3, 64, kernel_size=(3, 3, 3), padding=(1, 1, 1)),
nn.ReLU(),
nn.MaxPool3d(kernel_size=(2, 2, 2)),
nn.Conv3d(64, 128, kernel_size=(3, 3, 3), padding=(1, 1, 1)),
nn.ReLU(),
nn.MaxPool3d(kernel_size=(2, 2, 2))
)
self.temporal_scorer = nn.Sequential(
nn.Linear(128 * 4 * 4 * 4, 512),
nn.ReLU(),
nn.Linear(512, 1)
)
def forward(self, x):
# x shape: (batch_size, channels, frames, height, width)
features = self.frame_encoder(x)
features = features.view(features.size(0), -1)
score = self.temporal_scorer(features)
return score
フレーム間の一貫性評価
ここでは先ほども述べたフレーム間の一貫性についてみてみましょう。
compute_temporal_consistency関数では、連続するフレーム間の差分を平均二乗誤差(MSE)を用いて計算します。
(MSE忘れちゃった場合は、当社のQiitaブログをお読みください^-^)
この値が小さいほど、動画の動きが滑らかで自然であることを意味し最終的な返り値は負の値となっており、これは一貫性が高いほど損失が小さくなるように設計されています。
def compute_temporal_consistency(frames):
"""フレーム間の一貫性スコアを計算"""
consistency = 0
for i in range(len(frames)-1):
consistency += torch.nn.functional.mse_loss(frames[i], frames[i+1])
return -consistency # 差が小さいほど一貫性が高い
# 動画生成の学習ループ例
def train_video_dpo(model, train_loader, num_epochs=10):
trainer = DPOTrainer(model)
for epoch in range(num_epochs):
for preferred_videos, rejected_videos in train_loader:
# 基本的なDPO損失
dpo_loss = trainer.train_step(preferred_videos, rejected_videos)
# 時間的一貫性を考慮した追加の損失
temporal_loss = compute_temporal_consistency(preferred_videos)
# 総合的な損失
total_loss = dpo_loss + 0.5 * temporal_loss
print(f"Epoch {epoch}, DPO Loss: {dpo_loss:.4f}, "
f"Temporal Loss: {temporal_loss:.4f}")
学習プロセスの実装
train_video_dpo関数では、実際の学習プロセスを制御します。各エポックでは、望ましい動画と望ましくない動画のペアを使用してDPO損失を計算します。さらに、望ましい動画については時間的一貫性も評価し、これらを組み合わせた総合的な損失関数を最小化します。DPO損失と時間的一貫性損失を0.5という係数で重み付けして組み合わせることで、選好学習と動画の品質維持のバランスを取っています。各エポックごとに両方の損失値が出力されるため、学習の進行状況を詳細にモニタリングすることができます。
実践上の考慮点
さて、シンプルな実装を示しましたが、やはり入力データが重要です。つまり、入力データに「良い動画」と「悪い動画」のペアが必要となりますが、これらの収集は容易ではなく、また明確な評価基準が求められます。ここでも適切なデータセットの構築が、モデルの性能を左右します。
DPOの影響と今後の展望
DPOの登場以降、多くのAI研究機関がこの手法を採用し、様々な改良版や応用研究が発表されています。特に2024年に入ってからは、当社も含め画像生成や動画生成の分野でこのブームに乗っかっており応用研究が活発化しています。
さらに動画処理では、モバイルカメラとの相性から、より効率的な学習方法や、リアルタイム処理への適用などさらなる研究が進められています。
まとめ
2023年に登場したDPOは、その数学的シンプルさと実装の容易さから、言語モデルをはじめ、画像・動画生成を含む様々な分野で急速に採用が進んでいます。特に、主観的な品質評価が重要な視覚メディアの生成において、人間の好みを効率的に学習できる点が大きな利点となっています。
当社サービスもこのDPOを活用してさらに品質と効率の高いものをリリースしていく予定ですのでご期待くださいませ!
当社では、このように最新AIサイエンス、AIエンジニアリングの知見を総動員し人々のクリエイティビティを深化させるサービス、ソリューションをご提供しております。先端AIサイエンス、エンジニアリングの知見を使って世界を変革するチャレンジを少数精鋭で進めております。「腕に覚えアリ」という方はぜひ当社でのキャリアをご検討いただければ幸いです!