ディープラーニングにおけるEMA(Exponential Moving Average)
こんにちは!
本日は、画像生成、動画生成モデルなどで重要な役割を果たしている EMA ※について解説してみたいとおもいます!
当社のAIアバター動画生成サービス「MotionVox™」でも役立っています!
といっても、画像生成のための専用技術というわけではなく、学習と推論(生成系も含む)というディープラーニングの運用の中で昨今かなり重宝されるテクニックとなっておりますので、基礎から実装までみていきたいとおもいます。
※EMAの読み方は私はエマと呼んでますが、イーエムエーって言ってる人もいます。どっちでもいいでしょう。
EMA の基礎知識
EMA(Exponential Moving Average=指数移動平均)は、ざっくりいえばモデルの重みを平均化する手法です。
実は株価分析などでも使われている古くからある概念なのですが、ディープラーニングでは比較的最近になって「あ、これ結構使えるんじゃね?」と重要性が認識されるようになりました。
(”EMA”に限らず、理論の積み上げではなく「やってみたら、使えんじゃん」っていうのがかなり多いのがディープラーニング界隈のもはや常識でしょう^^)
なぜEMAが重要か?
ご存じの通りディープラーニングでは「学習」と「推論(モデルを実際に利用することですね)」という2つの段階がありますね。
このとき、学習と推論では以下のようなことが関心事となります
- 学習時「新しいデータに素早く適応する必要がある」
- 推論時「安定した結果を出すことが重要である」
人間に例えると
- 学習時「新しい情報をどんどん吸収して変化していく状態」
- 推論時「学んだ知識を落ち着いて安定的に使う状態」
といった感じでしょうか。
EMAは、この「安定した推論」を実現するための技術です。
つまり、 いい感じの推論のためのテクニックとなります
EMAの原理
基本的な考え方
EMAは過去のデータを指数関数的な重みで平均化します。
$$
\text{EMA}t = \beta \cdot \text{EMA}{t-1} + (1-\beta) \cdot \theta_t
$$
この式をΣをつかって展開すると以下のようになりますね。
$$
\text{EMA}t = (1-\beta) \sum_{i=0}^{t} \beta^i \theta_{t-i} + \beta^t \text{EMA}_0
$$
$$ \begin{array}{l}
・\text{EMA}_t は時刻 t での指数移動平均\\
・\beta は平滑化係数 (0 \leq \beta < 1)\\
・\theta_t \text{は時刻 } t \text{での観測値}\\
・\text{EMA}_0 \text{は初期値}
\end{array} $$
上式の1項めをみてみましょう
$$ \begin{array}{l}
(1-\beta) \sum_{i=0}^{t} \beta^i \theta_{t-i} は過去のデータの重み付き和を表します\\
\text{ここで、}\beta^i \text{の部分で、古いデータほど指数関数的に重みが小さくなります} \\
(1-\beta) \text{は正規化項として機能します。}
\end{array} $$
つづいて2項めは、
$$ \begin{array}{l}
\beta^t \text{EMA}_0 は初期値の影響を表します。\\
このことから時間tが大きくなるほど、この項は0に近づきます\\
つまり、十分時間が経過すると初期値の影響は無視できるようになります\\
\end{array} $$
式をみても、よーわからん、という事もあるとおもいますので、今回はEMAの効果を体感できるEMAシミュレーターをご用意しましたので遊んでみましょう!
EMAシミュレーターで遊ぶ
まずはスライダーを左端(β = 0)まで動かしてみましょう。この状態では緑の線(EMA)が青い線(生データ)とほぼ同じ動きをしているのが分かります。これは新しいデータをほぼそのまま採用している状態です。
次に、スライダーを少しずつ右に動かしていきましょう。β = 0.3あたりまで動かすと、緑の線が少しずつ滑らかになっていきます。この辺りでは、まだデータの変化に素早く反応しながらも、わずかにノイズが削減されている様子が観察できます。
さらにβ = 0.7あたりまで動かしてみましょう。ここでは緑の線がかなり滑らかになり、短期的なノイズがより効果的に除去されています。一方で、青い線の急な変化に対する追従が少し遅れ始めているのが分かります。
最後にスライダーを右端(β = 1)近くまで動かしてみましょう。この状態では緑の線が非常に滑らかになり、ほとんど動かなくなります。これは過去のデータを非常に重視し、新しいデータの影響をほとんど受け付けない状態です。
いかがでしょうか?数式の意味が体感できましたでしょうか。
体感できたところで、EMAの性質を整理
シミュレーターをつかうとピンとくるとおもいますが、EMAの重みの性質は、
- 最新のデータほど大きな影響を持つ
- 古いデータの影響は指数関数的に減衰
- しかし、完全には消えない(履歴を保持)
さらに、もう1点メリットがあるのは、メモリ効率です。つまり、
- 単純移動平均と違い、全履歴を保持する必要がない
- 前回のEMA値だけを覚えておけばよい
ということです。
減衰率βの役割
こちらもシミュレーターでβの挙動が体感できたとおもいますが、EMAはβの値によって挙動が大きく変わります
- β = 0.9:新しい値の影響が強い(急激な変化に敏感)※
- β = 0.999:古い値の影響が強い(より安定的)
(※さきほどのEMAシミュレーターの体感でみると、β=0.9でもじゅうぶん、新しい値の影響は弱いのですが、ディープラーニングにおける影響という目線でみると、 0.999... の 9が何個並ぶか、みたいな部分の勝負になりますので、そからの目線ですと相対的にみて、「急激な変化に敏感」と表現しております)
例えば、β = 0.99の場合
新しい値の影響 = 1%、新しい値の影響 = 99%
となりますので、
- 一時的なノイズは抑制される
- 継続的な変化は徐々に反映される
- 急激な変動が緩和される
という特徴があります。
ディープラーニングにおけるEMA
さて、EMAというテクニックの特徴が理解できたところで、じゃあ、ディープラーニングではどのように役立ってくれるのか、みていきましょう!
パラメータの安定化
まず、ディープラーニングにおけるEMAの主な役割は、「推論用のモデルを作るための重み平均化テクニック」です。ここで混同しがちなのが、「学習」そのものの安定課では無いことです。
通常の学習プロセスでは、SGDやAdamといった最適化手法を使って重みを更新していきます。この学習過程で、モデルの重みは目的関数に向けて更新されますが、ミニバッチの確率的な性質により、重みは振動しながら収束していきます。
つまり、振動がだんだんちいさくなるようなあのグラフですね。
そこでEMAが活用されます。
学習中にバックグラウンドで、モデルの重みの移動平均を計算・保持しておきます。具体的には、各イテレーションでの重みに対して指数関数的な減衰を適用しながら平均を取っていきます。これにより、学習後半の重みの振動を滑らかにした値を得ることができます。
結果として、学習の最終段階での重みの振動に影響されない、より安定した推論用モデルを得ることがうれしみです。つまり、学習プロセス自体は通常通り行いながら、より汎化性能の高い推論用の重みを獲得できるということです。
メモリ効率
メモリ効率の観点からみると、EMAはみてのとおりとっても実装が簡単で効率的な手法です。現在のパラメータ値と過去の移動平均値のみを保持すれば良いため、追加のメモリ消費は最小限に抑えられます。計算面でも、単純な加重平均の計算のみで済むため、計算コストは極めて小さくなっています。これは特に大規模なニューラルネットワークの学習においてのうれしみですね。
非同期性でバッチ処理とも相性がよい
EMAはバッチ処理やオンライン学習など、様々な学習設定に柔軟に対応できます。
バッチ学習においては、各バッチ処理後のパラメータ更新を滑らかにすることで、バッチ間の変動を適切に調整します。また、オンライン学習のような逐次的なデータ処理においても、データストリームの特性に応じて過去の情報を適切に反映させることができます。特に、β値を適切に選択することで、新しい情報と過去の情報のバランスを調整し、様々な学習シナリオに対応することが可能です。
さらに、EMAは学習率のスケジューリングとも相性が良く、学習の後半でより安定したパラメータ更新を実現することができます。これは、モデルの収束性を改善し、最終的な性能向上につながります。
EMAの実装
基本実装
PyTorchでのEMAの基本的な実装を見てみましょう!
import copy
import torch
from typing import Optional
class EMAModel:
"""指数移動平均(EMA)モデルのラッパークラス
モデルのパラメータの指数移動平均を計算・保持します。
これにより、モデルの安定性とパフォーマンスを向上させることができます。
Args:
model (torch.nn.Module): EMAを適用する元のモデル
decay (float, optional): EMAの減衰率。EMA計算式の「β」のこと。デフォルトは0.999
device (Optional[torch.device], optional): EMAモデルを配置するデバイス。デフォルトはNone
"""
def __init__(
self,
model: torch.nn.Module,
decay: float = 0.999, # EMA計算式でいう「β」のこと
device: Optional[torch.device] = None
):
self.decay = decay
# デバイスの指定があれば使用、なければモデルと同じデバイスを使用
self.device = device if device is not None else next(model.parameters()).device
# モデルのディープコピーを作成し、指定デバイスに移動
self.ema_model = copy.deepcopy(model).to(self.device)
# 評価モードに設定(学習用の設定をオフ)
self.ema_model.eval()
# パラメータを勾配計算が不要な設定に
for param in self.ema_model.parameters():
param.requires_grad_(False)
def update(self, model: torch.nn.Module) -> None:
"""現在のモデルのパラメータを使用してEMAモデルを更新
Args:
model (torch.nn.Module): EMAの更新に使用する現在のモデル
"""
with torch.no_grad():
for ema_param, model_param in zip(
self.ema_model.parameters(),
model.parameters()
):
# モデルのパラメータをEMAモデルと同じデバイスに移動
model_param_on_device = model_param.to(self.device)
# EMAの更新式: 新しい値 = decay * 古い値 + (1 - decay) * 現在の値
ema_param.data.mul_(self.decay)
ema_param.data.add_(
model_param_on_device.data * (1.0 - self.decay)
)
def get_model(self) -> torch.nn.Module:
"""現在のEMAモデルを取得
Returns:
torch.nn.Module: 更新された状態のEMAモデル
"""
return self.ema_model
上記コード(EMAModel
クラス)の核となる部分を説明みてみましょう
class EMAModel:
def __init__(self, model, decay=0.999):
# ここでEMAモデルを作成
self.ema_model = copy.deepcopy(model) # ← 通常モデルのコピーを作成
self.decay = decay
# EMAモデルは学習しないので評価モードにする
self.ema_model.eval()
# 勾配計算も不要なのでオフに
for param in self.ema_model.parameters():
param.requires_grad_(False)
def update(self, model):
# ここでEMAモデルのパラメータを更新
with torch.no_grad():
for ema_param, model_param in zip(
self.ema_model.parameters(),
model.parameters()
):
# EMAの更新式
ema_param.data.mul_(self.decay)
ema_param.data.add_(
model_param.data * (1.0 - self.decay)
)
コードの主な機能は2つの部分から構成されています。
まず__init__
メソッドでEMAモデルを作成します。この時、copy.deepcopy(model)
を使って通常の学習モデルの完全なコピーを作ります。このコピーしたモデルがEMAモデルとなります。EMAモデルは学習を行わないため、eval()
モードに設定し、パラメータの勾配計算も不要なためオフにします。
次にupdate
メソッドでEMAモデルを更新します。このメソッドでは、通常の学習モデルの現在のパラメータを使ってEMAモデルのパラメータを更新します。更新の際はdecay
※(例:0.999)という値を使って重み付け平均を計算します。これにより、EMAモデルのパラメータは通常モデルのパラメータの移動平均となります。
このように、このコードは通常の学習モデルのコピーを作成し、そのコピーのパラメータを通常モデルのパラメータを使って徐々に更新していく、という一連の処理を実装しています。
(※decayは、元の式の β にあたります)
学習ループ
では、EMAModelを使用する側のコードをみていきましょう!
import torch
from torch import nn
import torch.optim as optim
from your_model import YourModel # あなたの元のモデル
from ema_model import EMAModel # 先ほど実装したEMAModel
# 1. モデルの準備
model = YourModel() # 通常の学習用モデル
ema = EMAModel(model, decay=0.999) # EMAモデルを作成
# 2. 学習に必要な設定
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
# 3. 学習ループ
num_epochs = 100
for epoch in range(num_epochs):
model.train() # 学習モードに設定
for batch_idx, (data, target) in enumerate(train_loader):
# 通常の学習ステップ
optimizer.zero_grad() # 勾配をリセット
output = model(data) # モデルで予測
loss = criterion(output, target) # 損失を計算
loss.backward() # 勾配を計算
optimizer.step() # モデルを更新
# EMAモデルの更新(重要:optimizerでモデルを更新した後に行う)
ema.update(model)
# エポック終了時の評価
model.eval() # 評価モードに設定
ema_model = ema.get_model() # EMAモデルを取得
# 通常モデルでの評価
with torch.no_grad():
# 評価用データセットで精度を計算
normal_acc = evaluate(model, val_loader)
# EMAモデルでの評価
with torch.no_grad():
# 評価用データセットで精度を計算
ema_acc = evaluate(ema_model, val_loader)
print(f'Epoch {epoch}:')
print(f' Normal Model Accuracy: {normal_acc:.4f}')
print(f' EMA Model Accuracy: {ema_acc:.4f}')
このコードでは、まず通常の学習用モデルを作成し、そのモデルを使ってEMAモデルを初期化します。学習に必要な損失関数や最適化アルゴリズムなどの設定は、通常の学習と同じように行います。EMAに関して特別な設定はとくにありません。
学習ループ内では、通常のモデルを普通に学習させます。
具体的には損失の計算、勾配の計算、そしてパラメータの更新という一連の流れを行います。各バッチの学習後、ema.update(model)
を呼び出してEMAモデルを更新します。この更新は必ずoptimizer.step()
でモデルのパラメータを更新した後に行う必要があります。
エポック終了時には通常モデルとEMAモデルの両方で評価を行い、性能を比較します。EMAモデルはema.get_model()
で取得できます。評価時にはEMAモデルは既にeval()
モードになっているため、ここも改めて設定する必要はありません。
decayパラメータ(元の数式のβにあたる)は一般的に0.999や0.9999が使われます。値が大きいほどモデルは安定しますが、更新が遅くなります。学習の初期は0.99など小さめの値を使い、後で大きくするという方法もあります。
このように実装することで、通常のモデルよりも安定した予測が可能になります。特にテスト時や本番環境での推論時にEMAモデルを使用することで、より安定した結果を得ることができます。モデルの保存時は、通常モデルとEMAモデルの両方を保存しましょう。
重みの更新過程を追跡する
import torch
from torch import nn
import torch.optim as optim
from your_model import YourModel
from ema_model import EMAModel
def track_parameter_changes(model, ema_model, param_index=0):
"""モデルのパラメータ変化を追跡して表示する関数"""
for name, param in model.named_parameters():
for ema_name, ema_param in ema_model.ema_model.named_parameters():
if ema_name == name:
track_param = param.data.flatten()[param_index]
ema_val = ema_param.data.flatten()[param_index]
print(f"{name}:")
print(f" Current: {track_param.item():.6f}")
print(f" EMA: {ema_val.item():.6f}")
break
# モデルの準備
model = YourModel()
ema = EMAModel(model, decay=0.999)
# 学習設定
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
# 学習ループ
num_epochs = 100
for epoch in range(num_epochs):
model.train()
for batch_idx, (data, target) in enumerate(train_loader):
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()
# EMAモデルの更新
ema.update(model)
# 100バッチごとに重みの変化を表示
if batch_idx % 100 == 0:
print(f"\nEpoch {epoch}, Batch {batch_idx}")
print(f"Loss: {loss.item():.4f}")
track_parameter_changes(model, ema)
# エポック終了時の評価
model.eval()
ema_model = ema.get_model()
with torch.no_grad():
normal_acc = evaluate(model, val_loader)
ema_acc = evaluate(ema_model, val_loader)
print(f'\nEpoch {epoch} Summary:')
print(f' Normal Model Accuracy: {normal_acc:.4f}')
print(f' EMA Model Accuracy: {ema_acc:.4f}')
# エポック終了時の重みの状態を表示
print("\nParameter state at epoch end:")
track_parameter_changes(model, ema)
さっきの学習コードに上記をいれると
- 100バッチごとに現在の損失値と重みの状態を表示します
- エポック終了時にも重みの状態を確認できます
track_parameter_changes
関数ではモデルの同じ位置のパラメータを比較して、通常モデルとEMAモデルの値の違いを表示します
出力例
Epoch 0, Batch 0
Loss: 2.3045
conv1.weight:
Current: 0.023456
EMA: 0.023433
Epoch 0, Batch 100
Loss: 1.9876
conv1.weight:
Current: 0.025678
EMA: 0.024123
...
この出力例ではまず学習の進行状況について、Loss(損失値)が2.3045から1.9876に減少していることから、モデルが学習できていることがわかります。
重みの変化については、conv1.weight
というパラメータを例にとると、通常モデル(Current)の値が0.023456から0.025678と比較的大きく変化しているのに対し、EMAモデルの値は0.023433から0.024123と、より緩やかに変化していることがわかります。
これはEMAの特徴である「パラメータの急激な変化を抑制する」効果を示しています。通常モデルのパラメータが大きく変動する一方で、EMAモデルのパラメータはより滑らかに、穏やかに更新されていることが数値から確認できます。このような安定した更新により、EMAモデルはノイズの影響を受けにくく、より安定した予測が可能になります。
実際のユースケース
class StableDiffusionTrainer:
def __init__(self):
self.model = DiffusionModel()
self.ema_model = EMAModel(self.model)
def train_step(self, prompt, image):
# 通常の学習
self.optimizer.zero_grad()
loss = self.model.train_step(prompt, image)
loss.backward()
self.optimizer.step()
# EMAモデルの更新
self.ema_model.update(self.model)
def generate(self, prompt):
# 推論時はEMAモデルを使用
return self.ema_model.ema_model.generate(prompt)
画像生成・動画生成モデルにおけるEMA
画像生成モデル、特にStable Diffusionのような拡散モデルでは、EMAが重要な役割を果たしています。以下に実装例と共に解説します。
class StableDiffusionTrainer:
def __init__(self):
self.model = DiffusionModel()
self.ema_model = EMAModel(self.model)
def train_step(self, prompt, image):
# 通常の学習
self.optimizer.zero_grad()
loss = self.model.train_step(prompt, image)
loss.backward()
self.optimizer.step()
# EMAモデルの更新
self.ema_model.update(self.model)
def generate(self, prompt):
# 推論時はEMAモデルを使用
return self.ema_model.ema_model.generate(prompt)
画像生成モデルにおけるEMAの活用
画像生成モデル、特にStable Diffusionのような拡散モデルでは、EMAが重要な役割を果たしています。以下に実装例と共に解説します。
class StableDiffusionTrainer:
def __init__(self):
self.model = DiffusionModel()
self.ema_model = EMAModel(self.model)
def train_step(self, prompt, image):
# 通常の学習
self.optimizer.zero_grad()
loss = self.model.train_step(prompt, image)
loss.backward()
self.optimizer.step()
# EMAモデルの更新
self.ema_model.update(self.model)
def generate(self, prompt):
# 推論時はEMAモデルを使用
return self.ema_model.ema_model.generate(prompt)
このコードはStable Diffusionの学習プロセスを簡易化実装したものです。まず、クラスの初期化部分では2つのモデルを作成しています。1つは実際に学習を行うメインのDiffusionModelで、もう1つはEMAを適用するためのEMAModelです。
学習ステップの処理では、まず通常の勾配降下による学習を実行します。optimizer.zero_gradで勾配をリセットし、モデルによる学習ステップを実行して損失を計算し、その後backwardとoptimizerのステップを実行します。この通常の学習の後、EMAモデルの重みを更新します。この更新は毎学習ステップ後に行われ、徐々に安定した重みが形成されます
生成(推論)時には、これまでの説明のとおり、「通常のモデル」ではなく「EMAモデル」を使用します。EMAモデルが学習時のノイズが軽減されており、より安定した生成が可能となるからですね。
このような実装は、特に画像生成モデルの学習では、生成される画像の品質が学習中に大きく変動することがままあります。
ある時点ではいい感じの画像を生成していても、次の更新で品質が落ちることがほんとによくあります。これは学習の過程で重みが急激に変化することが原因です。
EMAを使用することで、重みの更新を滑らかにし、生成される画像の品質を安定させることができます。生成画像の安定性が向上し、学習データのノイズやバッチごとの変動の影響を軽減できます。
実際の応用例
実際のStable Diffusionの学習では、以下のように実装されることが多いです
class DiffusionTrainer:
def __init__(self):
# モデルの初期化
self.model = UNet()
# EMAモデルの設定(decay率は高めに設定)
self.ema_model = EMAModel(self.model, decay=0.9999)
def training_loop(self, dataloader):
for epoch in range(num_epochs):
for batch in dataloader:
# 通常の訓練ステップ
loss = self.train_step(batch)
# EMAの更新
self.ema_model.update(self.model)
if self.steps % 1000 == 0:
# 定期的に生成結果を確認
with torch.no_grad():
# EMAモデルを使用して画像を生成
samples = self.sample_images(
self.ema_model.get_model()
)
def sample_images(self, model):
"""画像生成の処理
通常のモデルではなくEMAモデルを使用"""
return model.generate(...)
画像生成モデルでEMAを使用する際は、いくつかの重要な点に注意が必要です。decay値は画像生成では0.9999など、(もはやさっきのシミュレーターではほとんど差分がわからないような)かなり高い値が使われることが多く、これにより重みの更新がより穏やかになります。また、定期的にEMAモデルで画像を生成し品質を確認することが重要です。
このように、EMAは画像生成モデルの品質と安定性を向上させる重要な技術となっており、特に長時間の学習が必要な大規模モデルでは、EMAの使用が標準的な手法として確立されてます。
まとめ
さて、すこし長くなりましたが、ここらでまとめに入りたいと思います。
まずはQ&A形式でふりかえってみましょう
Q: EMAって何?ディープラーニングでなぜ使うの?
EMA(Exponential Moving Average=指数移動平均)は、モデルの重みを平均化する手法です。ディープラーニングでは、特に推論時の安定性を高めるために使用されます。
Q: SGDとの関係は?最近の手法なの?
実は学習時の安定化というと、よく聞かれるのが、「EMAとSGDはどういう関係なのか?」です。
結論からいえば、別コンテクストの話なんです
SGD(確率的勾配降下法)は学習時の重み更新方法(オプティマイザー)です。AdamやSGD with Momentumなど、様々な派生がありますね。
で、EMAは学習済みの重みを平均化する手法です。←ここがわかりにくいんですよね
ここまでのブログ内容でたぶんそこはおさえてあるとおもうのですが、
- SGDは学習時の重み更新に使用
- EMAは学習後の重みの平滑化に使用
となります。
「SGDで学習を進める」→「その結果にEMAを適用する」→(繰り返す)という形で使用されますね。
Q: 推論でしか使わないの?
はい、
EMAは主に推論時に使用されます。
なぜなら
・学習時は新しい情報への素早い適応が必要
・推論時は安定した出力が重要 というトレードオフがあるため
です。
総括~EMAの重要性と応用~
ディープラーニングにおけるEMAは、「学習の柔軟性」と「推論の安定性」という相反する要求を両立させるための重要でナイスなテクニックです!
通常モデルで素早い学習を行いながら、EMAモデルで安定した推論を実現できるという1粒で2度おいしいを実現してくれるのが特徴的ですね。
この技術は特に画像生成・動画生成といった生成モデルの分野で大きな成果を上げています。Stable DiffusionやGANなどの生成モデルでは、出力の安定性が極めて重要であり、EMAの採用により一貫性のある高品質な生成結果が得られるようになりました。
また自己教師あり学習の分野でもEMAが活用され、より安定した特徴表現の学習を可能にしています。
EMAの大きな利点は、実装が比較的シンプルで計算コストが小さく、メモリ効率も良い点にあります。また、様々なディープラーニング手法に適用できる汎用性の高さも特筆すべき点です。
今後の展望としては、より大規模なモデルでの活用や、新しい学習アルゴリズムとの組み合わせ、そして適応的なdecay率の調整手法の研究などが期待されています。単純な仕組みながら、その効果は絶大で、現代のディープラーニングにおいて標準的な技術として確立されています。また、当社が開発した動画生成AI「MotionVox™」でも、学習時にEMAモデルを作成しており、パラメータ更新を滑らかにし、高い安定性を実現しています!
それでは、最後までお読みいただきありがとうございました!
また次回、お会いしましょう!