RakutenAI-7B-chat を使用したチャットアプリケーションを5分で作る
こんにちは、株式会社 Qualiteg プロダクト開発部です。
今日は、 RakutenAI-7B-chat と ChatStream 0.7.0 を使用して本格的なチャットアプリケーションを作っていきましょう。
RakutenAI-7B-chat は Mistral 7B を日本語継続学習させたモデルで、チャットチューニングが行われており、 日本語LLM リーダーボード https://wandb.ai/wandb-japan/llm-leaderboard/reports/Nejumi-LLM-Neo--Vmlldzo2MTkyMTU0でも上位にランクされている期待大のモデルです。
ソースコード
早速ですが、以下がソースコードとなります。
4bit 量子化をしているため、使用する GPU は A4000 (16GB) 程度で快適に動作します。
import logging
import torch
import uvicorn
from fastapi import FastAPI
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from chatstream import ChatStream, ChatPromptRakutenMistral as ChatPrompt, LoadTime, TokenSamplerIsok
from chatstream.fastersession.faster_session_rdb_store import FasterSessionRdbStore
from chatstream.util.session.chat_stream_session_on_set_value_listener import chat_stream_session_on_set_value_listener
"""
'Rakuten/RakutenAI-7B-chat' の ChatStream Server のサンプルプログラム
- 開発用構成です(本番では、ChatStreamPool によるスケールアウトや、Qualiteg SunsetServer など、セキュアで堅牢なロードバランサー、リバースプロキシ導入を推奨しています)
- 単体起動用(スケールアウトモードのノードではなく、シングルインスタンスでWebアプリケーション、WebAPIサーバーとして振る舞います)
- モデルは 4bit 量子化 として扱います
"""
num_gpus = 1 # このノードで使用する GPU数
device = torch.device("cuda")
model_path = 'Rakuten/RakutenAI-7B-chat'
use_fast = True
# 4bit 量子化で使用するときの config
quantization_config = BitsAndBytesConfig(load_in_4bit=True)
# モデル読み込み(LoadTime を使用することで、進捗表示をしながら読み込みする)
model = LoadTime(name=model_path, hf=True,
fn=lambda: AutoModelForCausalLM.from_pretrained(model_path,
quantization_config=quantization_config,
device_map="auto"))()
model.eval()
tokenizer = AutoTokenizer.from_pretrained(model_path)
# ChatStream で使用するデータベース情報を設定する
database_def = {
"type": "rdbms", # "memory","rdbms","mongo","redis"
"rdbms": {
"db_url": "[your_db_url]" # DBを指定してください
},
"db": None # DBオブジェクトを直接する場合はこちらを記述してください
}
server_host_info = {'protocol': 'http', 'host': 'localhost', 'port': 9999}
num_of_concurrent_executions = 10 # 最大同時文章生成数
max_new_tokens = 512 # 最大トークン生成数
tokens_per_sec = 6 # num_of_concurrent_executions 件の同時接続があったときトークン生成速度(tokens/sec)
text_generation_timeout_sec = (max_new_tokens / tokens_per_sec) + 10 # 1回あたりの文章生成タイムアウト時間。+10 は余裕時間。
# ChatStream インスタンスを生成する
chat_stream = ChatStream(
num_of_concurrent_executions=num_of_concurrent_executions, # 最大同時文章生成数
max_queue_size=5, # 最大文章生成数に達したときの待ち行列の大きさ。この大きさを超えるリクエストがあるとき Too many request エラーとなる。
model_id='rakuten__rakuten_ai_7b_chat', # モデルID
model_support_languages=['ja', 'en'], # モデルがサポートしている言語。
model_desc={
'disp_name': {'en': 'RakutenAI-7B-chat', 'ja': 'RakutenAI-7B-chat', }, # UI表示用モデル名
'about': {
'ja': '2024/3/21 にリリースされた Mistral-7B-v0.1 ベースの日本語LLM', # UI 表示用説明文(日本語)
'en': 'Japanese LLM based on Mistral-7B-v0.1 released on 3/21/2024', # UI 表示用説明文(English)
},
'default_utterance_hints': [
{'utterance': {'ja': "Who starred in the movie 'Titanic' released in 1997?", # UI 表示サンプル発話(日本語)
'en': "Who starred in the movie 'Titanic' released in 1997?", # UI 表示サンプル発話(English)
},
'desc': {'ja': '映画タイタニックの主演は?', # UI 表示サンプル発話の説明文(日本語)
'en': '', # UI 表示サンプル発話の説明文(English)
}
},
]
},
server_host_info=server_host_info,
model=model,
tokenizer=tokenizer,
num_gpus=num_gpus,
device=device,
chat_prompt_clazz=ChatPrompt, # このモデル用の ChatPrompt をセットする
add_special_tokens=False, # 特殊トークン追加の有無
text_generation_timeout_sec=text_generation_timeout_sec, # タイムアウトは同時ユーザー数が最大のときのトークン生成速度xmax_new_tokens から計算する
max_new_tokens=max_new_tokens, # 1回あたりの最大トークン生成数
context_len=1024, # コンテクスト長をセットする
temperature=0.7, # サンプリングパラメータ temperature をセットする
top_k=10, # サンプリングパラメータ top K
# top_p=0.9, # サンプリングパラメータ top P をセットする
repetition_penalty=1.05, # サンプリングパラメータ 繰り返しペナルティ をセットする
database=database_def, # データベース情報をセットする
client_roles={
"user": {
"apis": {
"allow": "all", # [DefaultApiNames.CHAT_STREAM, ],
"auth_method": "nothing", # 本 ChatStream は単体起動するため 認証無し とする(スケールアウトモードの場合は適切なサーバー認証をセットします)
"use_session": True, # 本 ChatStream は単体起動するため use_session:True とする。(スケールアウトモードで起動するときは use_session:False とします)
}
},
}, # ロールをセットする。
locale='ja',
token_sampler=TokenSamplerIsok(), # TokenSamplerHft() # TokenSamplerIsok() #
seed=42,
)
chat_stream.logger.setLevel(logging.DEBUG)
# セッションデータを RDBMS に保存する ストア
rdb_store = FasterSessionRdbStore(database_def=database_def,
on_set_value_listener=chat_stream_session_on_set_value_listener, # シリアライズできないオブジェクトをセッションの永続化時にスキップするヘルパー
)
# memory_store = get_chat_stream_session_memory_store() # セッションデータをメモリに保存する
# file_store = get_chat_stream_session_file_store() # セッションデータをファイルに保存する
# ChatStreamに追加するミドルウェアの設定用dict
mw_opts = {
"faster_session": {
"secret_key": "chatstream-default-session-secret-key",
"store": rdb_store, # session をファイルに保存するファイルストアを取得する
# "same_site":"Strict", # set-cookie の same_site 属性、デフォルトは Strict
# "is_http_only":True,# set-cookie の http_only 属性、デフォルトは True
# "is_secure":True,# set-cookie の secure 属性、デフォルトは True
# "max_age":0 # set-cookie の max_age 属性、デフォルトは True
},
}
# FastAPI インスタンスを作る
app = FastAPI()
# 必要なミドルウェアを自動的に追加する(手動で設定することも可能です)
chat_stream.append_middlewares(app, opts=mw_opts)
# 必要なAPIを自動的に追加する(生やす)
# ここではすべての API を追加していますが、用途に応じてAPIを選択することも可能です
# 各 URLパスの具体的な内容は default_api_paths.py を参照してください
chat_stream.append_apis(app, {"all": True})
@app.on_event("startup")
async def startup():
# Web サーバーの起動後
# ChatStreamのキューイングシステムを開始する
await chat_stream.start_queue_worker()
def start_server():
# Web サーバーを起動する
uvicorn.run(app, host=server_host_info.get('host'), port=server_host_info.get('port'))
def main():
start_server()
if __name__ == "__main__":
main()
モデルへの入出力に使用する ChatPrompt クラスは、最新版の ChatStream に同梱されている ChatPromptRakutenMistral を使用します。
または、以下の記事を参考に自ら作成することも可能です。
https://blog.qualiteg.com/chatprompt_rakuten_ai_7b_chat/
さて、さっそくこのコードを実行して、チャットを試してみましょう
無事起動し、Web ブラウザからチャットを試すことができました!
Qualitegプロダクト開発部では、HuggingFaceに最新のモデルが発表された都度、迅速にChatStream へのポーティングを行っています。
そのため、最新のモデルでもほぼコードを書かずに、すぐにお試しいただけます。今回も、ほぼボイラープレートのみで本格 LLM チャットを実装することができました。
それでは、また次回のLLMでお会いしましょう!