LLMサンプリングにおける3つのペナルティ
こんにちは!(株)Qualiteg プロダクト開発部です!
今日の昼食はみんなでイタリアンレストランに行きました。3種のチーズピザが好評でした。
さて、本日はLLMにおける3種のチーズならぬ、3種のペナルティをご紹介します。
ChatStreamでは、ペナルティを含む、多彩なプリセットサンプリングアルゴリズムを搭載しています。モデルや目的にあったサンプリングを行うことで、より自然な応答生成を行うことができます。
テキスト生成におけるペナルティの役割
自然言語処理(NLP)、LLMの世界で重要な概念の一つである「ペナルティ」とは何でしょうか。
ペナルティとは?
ペナルティとは、LLMがテキストを生成する際に特定のトークン(単語や文字列)の出現を調整するために使用される仕組みのことです。
テキスト生成モデルが自然で多様な文章を作り出すためには、同じ単語やフレーズが何度も繰り返されるのを防ぐ必要があります。そこで登場するのが、ペナルティです。
LLMが同じような文章、単語を何度も生成してしまうことはわりと頻繁に起こりますので、適切なペナルティを設定します。
ペナルティの種類
テキスト生成におけるペナルティにはいくつかの種類がありますが、今回は特に重要な3つのペナルティについて説明します。
- Repetition Penalty:
- 範囲: 繰り返し出現するすべてのトークンやフレーズ。
- ペナルティ基準: すべての過去のトークンに基づく。
- 使用例: 過去に生成されたトークンすべてに対して、ペナルティを適用。
- Frequency Penalty:
- 範囲: 特定のトークンの出現頻度。
- ペナルティ基準: 各トークンの出現回数に基づく。
- 使用例: 生成中に特定のトークンが何回生成されたかを追跡し、その頻度に基づいてペナルティを適用。
- Presence Penalty:
- 範囲: 一度でも生成されたトークン。
- ペナルティ基準: トークンが一度でも生成されたかどうかに基づく。
- 使用例: すでに生成されたトークンに対して、再出現時にペナルティを適用。
ペナルティの実装
では、実際にペナルティを実装していきましょう。
ChatStream のサンプリングクラスとしてペナルティを実装する
サンプリングクラスは AbstractLogitsProcessor
をオーバーライドします。
from chatstream.token_samplers.logits_processor import AbstractLogitsProcessor
AbstractLogitsProcessor
はシンプルな抽象クラスで、以下のようになっています。
from abc import ABC, abstractmethod
class AbstractLogitsProcessor(ABC):
@abstractmethod
def process(self, logits, params):
pass
@abstractmethod
def get_name(self):
pass
Repetition Penalty
さっそく、Repetition Penalty を実装してみましょう。
計算手法として乗算型と減算型の二種類を指定できるようにしており、あるトークンがすでに生成された場合、そのトークンのログ確率をペナルティ値で割る(乗算)か、ペナルティ値を引く(減算)というオペレーションを実装していますい。すべての過去のトークンが対象となります。
from chatstream.token_samplers.logits_processor import AbstractLogitsProcessor
class RepetitionPenaltyProcessor(AbstractLogitsProcessor):
"""
繰り返しのトークンに対してペナルティを適用するプロセッサ。
このクラスは、過去に使用されたトークンのlogitsを調整することで、繰り返しの出現を抑制する。
ペナルティの適用方法は、乗算または減算のいずれかで、パラメータで指定できる。
乗算ペナルティ計算の基本は logits[token_id] /= penalty で、過去に出現した token_id のロジット値を減らしていき
出現確率を下げることで繰り返し同じトークンが出力されることを抑制する
"""
def __init__(self):
pass
def process(self, logits, params):
past_tokens = params.get("past_tokens", None)
penalty = params.get("penalty", None)
penalty_method = params.get("penalty_method", "multiplicative")
# 過去のトークンのlogitsにペナルティを適用
if penalty is not None and past_tokens is not None:
# logitsのコピーを作成(引数として渡されたlogitsの非破壊保証)
adjusted_logits = logits.clone()
# penaltyの値の型を確認する
if not isinstance(penalty, (int, float)):
raise ValueError(f"penalty should be a scalar value, but got {penalty}({type(penalty)})")
# ペナルティの適用方法に応じてlogitsを更新する
if penalty_method == "multiplicative":
if penalty != 1.0:
for token_id in set(past_tokens):
adjusted_logits[token_id] /= penalty
elif penalty_method == "subtractive":
for token_id in set(past_tokens):
adjusted_logits[token_id] -= penalty
else:
raise ValueError(f"Unknown penalty_method: {penalty_method}")
else:
adjusted_logits=logits
return {"name": "RepetitionPenaltyProcessor", "type": "logits", "logits": adjusted_logits}
def get_name(self):
return "rep_penalty"
Frequency Penalty
Frequency Penaltyは以下のようになります。
トークンが出現するたびに、各トークンの出現回数に基づきそのトークンのログ確率をペナルティ値で累積的に割る(乗算)か、ペナルティ値を累積的に引く(減算)というオペレーションを実装しています。
class FrequencyPenaltyProcessor(AbstractLogitsProcessor):
"""
生成されたトークンの出現頻度に基づいてペナルティを適用するプロセッサ。
このクラスは、生成中に各トークンが出現した回数を追跡し、頻繁に出現するトークンにペナルティを適用する。
ペナルティの適用方法は乗算または減算のいずれかで、パラメータで指定できる。
"""
def __init__(self):
self.token_counts = {}
def process(self, logits, params):
penalty = params.get("penalty", None)
penalty_method = params.get("penalty_method", "multiplicative")
# logitsのコピーを作成
adjusted_logits = logits.clone()
if penalty is not None:
if not isinstance(penalty, (int, float)):
raise ValueError(f"penalty should be a scalar value, but got {penalty}({type(penalty)})")
for token_id, count in self.token_counts.items():
if penalty_method == "multiplicative":
adjusted_logits[token_id] /= (penalty ** count)
elif penalty_method == "subtractive":
adjusted_logits[token_id] -= (penalty * count)
else:
raise ValueError(f"Unknown penalty_method: {penalty_method}")
return {"name": "FrequencyPenaltyProcessor", "type": "logits", "logits": adjusted_logits}
def update_token_counts(self, token_ids):
for token_id in token_ids:
if token_id in self.token_counts:
self.token_counts[token_id] += 1
else:
self.token_counts[token_id] = 1
def get_name(self):
return "freq_penalty"
Presence Penalty
Presence Penalty は以下のようになります。
トークンが一度でも生成されたかどうかに基づき一度生成されたトークンのログ確率をペナルティ値で割る(乗算)か、ペナルティ値を引く(減算)というオペレーションを実装しています。
class PresencePenaltyProcessor(AbstractLogitsProcessor):
"""
生成されたトークンの存在に基づいてペナルティを適用するプロセッサ。
このクラスは、特定のトークンがすでに出現しているかどうかを追跡し、存在するトークンにペナルティを適用する。
ペナルティの適用方法は乗算または減算のいずれかで、パラメータで指定できる。
"""
def __init__(self):
self.seen_tokens = set()
def process(self, logits, params):
penalty = params.get("penalty", None)
penalty_method = params.get("penalty_method", "multiplicative")
# logitsのコピーを作成
adjusted_logits = logits.clone()
if penalty is not None:
if not isinstance(penalty, (int, float)):
raise ValueError(f"penalty should be a scalar value, but got {penalty}({type(penalty)})")
for token_id in self.seen_tokens:
if penalty_method == "multiplicative":
adjusted_logits[token_id] /= penalty
elif penalty_method == "subtractive":
adjusted_logits[token_id] -= penalty
else:
raise ValueError(f"Unknown penalty_method: {penalty_method}")
return {"name": "PresencePenaltyProcessor", "type": "logits", "logits": adjusted_logits}
def update_seen_tokens(self, token_ids):
self.seen_tokens.update(token_ids)
def get_name(self):
return "presence_penalty"
他のサンプリングパラメータとペナルティ
さて、他のサンプリングパラメータとの関係についてもみておきましょう。
まず、上のように実装したペナルティのおさらいですが、
ペナルティ
- Repetition Penalty: 特定のトークンやフレーズの繰り返しを防ぎます。
- Frequency Penalty: 生成されたトークンの出現頻度に基づいてペナルティを適用し、頻繁に出現するトークンを抑制します。
- Presence Penalty: 一度生成されたトークンが再度出現するのを防ぎます。
Top-k, Top-p, Temperature
次は、この3つです。また3種ですね。
この3つは特によく登場するサンプリング手法です。
top-k
、top-p
、および temperature
は生成されるテキストの質と多様性を制御するためのパラメータです。簡単に説明すると、
Top-k
目的: 最も確率の高い k
個のトークンだけを考慮します。
- 方法: 確率の高い
k
個のトークンを選び、その中から次のトークンをランダムに選択します。 - 計算方法:
- トークンの確率を降順に並べ替えます。
- 上位
k
個のトークンを選びます。 - その中から次のトークンを選択します。
Top-p (または Nucleus Sampling)
目的: トークンの累積確率が p
(例:0.9)となるまでトークンを選択します。
- 方法: 確率の高いトークンを累積確率が
p
になるまで選び、その中から次のトークンをランダムに選択します。 - 計算方法:
- トークンの確率を降順に並べ替えます。
- 確率の累積和が
p
を超えるまでトークンを選びます。 - その中から次のトークンを選択します。
Temperature
目的: 生成されるテキストのランダム性を制御します。
- 方法:
temperature
を用いてトークンの確率分布を調整します。 - 計算方法:
- 各トークンのログ確率を
temperature
で割ります。 - これにより、確率分布がスムーズになります(高温度:分布が平坦に、低温度:分布が尖ります)。
- 各トークンのログ確率を
ペナルティと top_k,top_p,temperatureの計算シナリオ
たとえば、Qualiteg May Change the World with ChatStream
というテキスト生成を行うシナリオで考えてみましょう。
ペナルティとTop-k, Top-p, Temperatureの相互作用
Qualiteg May Change the World with ChatStream
という文章生成において、どのようにペナルティが適用され、top-k
、top-p
、temperature
とどのように連携するか以下にをみていきます。
1. ペナルティの適用
- ペナルティの適用:
- Repetition Penalty、Frequency Penalty、Presence Penalty が適用され、特定のトークンのログ確率が調整されます。
- 例えば、「Qualiteg」という単語がすでに何度か出現している場合、その単語のログ確率が低くなります。
2. Temperatureの適用
- Temperatureの適用:
- 調整されたログ確率は
temperature
によってさらにスケールされます。 - これにより、分布の形状が変わり、生成されるトークンのランダム性が増減します。
- 高温度(例:1.2)の場合、分布が平坦になり、ランダム性が増します。
- 低温度(例:0.7)の場合、分布が尖り、最も高い確率のトークンが選ばれやすくなります。
- 調整されたログ確率は
3. Top-kの適用
- Top-kの適用:
top-k
が適用され、上位k
個のトークンのみが選択肢として残ります。- これにより、最も確率の高いトークンの中から次のトークンが選ばれます。
- 低確率のトークンが除外され、生成されるテキストの質が高まります。
4. Top-pの適用
- Top-pの適用:
top-p
が適用され、累積確率がp
を超えるまでトークンが選択されます。- これにより、確率の高いトークンの中から次のトークンがランダムに選ばれます。
- こちらも低確率のトークンを除外し、生成されるテキストの質が高まります。
サンプリング計算の具体例
例えば、以下のような設定のとき、どのように計算されていくかを具体的にみていきましょう
-
設定:
- Repetition Penalty: 1.2
- Frequency Penalty: 0.8
- Presence Penalty: 1.5
- Temperature: 0.7
- Top-k: 50
- Top-p: 0.9
-
ステップ 1: 初期のログ確率の計算
まず、モデルが各トークンの初期のログ確率(logits)を計算します。例えば、以下のような初期ログ確率が得られたとします:
["Qualiteg": 2.0, "May": 1.5, "Change": 1.0, "the": 0.5, "World": 0.3, "with": 0.2, "ChatStream": 0.1]
-
ステップ 2: ペナルティの適用
次に、各ペナルティを適用します。以下の例では、すでに「Qualiteg」が1回出現しており、他のトークンは初めて出現するものとします。- Repetition Penalty: 「Qualiteg」がすでに出現しているため、
logits["Qualiteg"]
に 1.2 のペナルティを適用します。
logits["Qualiteg"] /= 1.2 2.0 / 1.2 ≈ 1.67
- Frequency Penalty: 出現頻度に基づくペナルティを適用します。「Qualiteg」は1回出現しているため、頻度ペナルティを適用します。
logits["Qualiteg"] *= 0.8 1.67 * 0.8 ≈ 1.34
- Presence Penalty: 「Qualiteg」がすでに存在しているため、
logits["Qualiteg"]
に 1.5 のペナルティを適用します。
logits["Qualiteg"] /= 1.5 1.34 / 1.5 ≈ 0.89
ペナルティを適用した後のログ確率
["Qualiteg": 0.89, "May": 1.5, "Change": 1.0, "the": 0.5, "World": 0.3, "with": 0.2, "ChatStream": 0.1]
- Repetition Penalty: 「Qualiteg」がすでに出現しているため、
-
ステップ 3: Temperatureの適用
次に、temperature
を適用して確率分布を調整します。temperature
が 0.7 に設定されている場合、各ログ確率は 0.7 で割られます。logits["Qualiteg"] /= 0.7 0.89 / 0.7 ≈ 1.27 logits["May"] /= 0.7 1.5 / 0.7 ≈ 2.14 logits["Change"] /= 0.7 1.0 / 0.7 ≈ 1.43 logits["the"] /= 0.7 0.5 / 0.7 ≈ 0.71 logits["World"] /= 0.7 0.3 / 0.7 ≈ 0.43
調整後のログ確率:
["Qualiteg": 1.27, "May": 2.14, "Change": 1.43, "the": 0.71, "World": 0.43, "with": 0.29, "ChatStream": 0.14]
-
ステップ 4: Top-kの適用
次に、top-k
を適用します。ここではtop-k=50
ですが、上位5個のトークンのみを示します。上位トークン:
["May": 2.14, "Change": 1.43, "Qualiteg": 1.27, "the": 0.71, "World": 0.43]
-
ステップ 5: Top-pの適用
最後に、top-p
を適用します。ここではtop-p=0.9
です。累積確率が 0.9 を超えるまでトークンを選択します。累積確率の計算:
-
確率の計算:
May
: exp(2.14) ≈ 8.50Change
: exp(1.43) ≈ 4.18Qualiteg
: exp(1.27) ≈ 3.56the
: exp(0.71) ≈ 2.03World
: exp(0.43) ≈ 1.54
合計:8.50 + 4.18 + 3.56 + 2.03 + 1.54 ≈ 19.81
-
累積確率の計算:
May
: 8.50 / 19.81 ≈ 0.43Change
: 4.18 / 19.81 ≈ 0.21Qualiteg
: 3.56 / 19.81 ≈ 0.18- ここまでの累積確率:0.43 + 0.21 + 0.18 ≈ 0.82
the
: 2.03 / 19.81 ≈ 0.10(累積確率:0.82 + 0.10 ≈ 0.92)
累積確率が 0.9 を超えたため、
the
までのトークンが選択肢に残ります["May": 2.14, "Change": 1.43, "Qualiteg": 1.27, "the": 0.71]
-
-
ステップ 6: トークンの選択
最終的に残ったトークンからランダムに次のトークンが選ばれます。この例では、["May", "Change", "Qualiteg", "the"]
の中から1つが選ばれます。
このシナリオでは、ペナルティ、temperature
、top-k
、top-p
がどのように組み合わさってテキスト生成に影響を与えるかをみてきました。
ペナルティは特定のトークンの出現確率を調整し、temperature
は分布の形状を変え、top-k
と top-p
は選択肢を絞り込むことで、最終的に生成されるテキストの質と多様性を制御することが実際の計算過程を追うことで理解できたとおもいます。
まとめ
今日は、3種類のペナルティとその周辺にあるサンプリング手法をみてきました。
ChatStreamにも今日作成した PenaltyProcessorを取り込むことが可能です。(ただし、ChatStreamにはすでにPenaltyProcessorのプリセット実装が存在しますが) 基本的に、logitsをどのようにサンプリングするかはサービス提供者の自由ですので、好きなProcessorを好きな順序で組み合わせることができます。
サンプリングを適用したいモデルに対して、もっとも好ましい出力となるようなサンプリングクラス(関数)の組み合わせ方、実際の値をどうするか、などは実際のモデルの入出力結果を測定して判断していくことになるとおもいます。このあたりは、ノウハウのかたまりでもあるので、もしご興味があればQualitegにぜひご相談ください。
[付録]ペナルティの比較
ペナルティタイプ | 目的 | 適用方法 | ペナルティの例 |
---|---|---|---|
Repetition Penalty | 特定のトークンやフレーズが繰り返されるのを防ぐ。 | 過去に生成されたすべてのトークンのログ確率(logits)に対してペナルティを適用する。 | 例えば、あるトークンがすでに生成された場合、そのトークンのログ確率をペナルティ値で割る(乗算)か、ペナルティ値を引く(減算)。 |
Frequency Penalty | 生成されたトークンの出現頻度に基づいてペナルティを適用し、頻繁に出現するトークンを抑制する。 | 各トークンが生成された回数に基づいてペナルティを適用する。トークンが出現するたびに、そのトークンの出現確率を低減させる。 | トークンが出現するたびに、そのトークンのログ確率をペナルティ値で累積的に割る(乗算)か、ペナルティ値を累積的に引く(減算)。 |
Presence Penalty | すでに生成されたトークンが再度出現するのを防ぐ。 | トークンが一度でも生成されたかどうかに基づいてペナルティを適用する。一度生成されたトークンには再出現の際にペナルティが適用される。 | 一度生成されたトークンのログ確率をペナルティ値で割る(乗算)か、ペナルティ値を引く(減算)。 |