File size: 6,240 Bytes
2592b96 9d917ac 205a913 9d917ac 205a913 9d917ac 205a913 7b617f0 65a5a95 6fd7b13 4f00b4e b4b8293 2bf74f7 dbd3992 0e5d2cd 0744d77 7b617f0 2824c72 0e5d2cd 61e913e 205a913 72bb22a 205a913 72bb22a 448a0ea 72bb22a 448a0ea 72bb22a 9d917ac 2cd9ba8 9d917ac 205a913 a4c7cfa 2004aa9 a4c7cfa 205a913 540e7a8 328a888 8a55503 a8b07c3 205a913 9b1b384 9d917ac 205a913 9d917ac 205a913 9d917ac 205a913 84abc70 205a913 9d917ac 205a913 9d917ac 205a913 02609ac 9d917ac 205a913 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 |
import spaces
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import warnings
warnings.filterwarnings("ignore")
"""
Sarashinaモデルを使用したGradioチャットボット
Hugging Face Transformersライブラリを使用してローカルでモデルを実行
"""
# モデルとトークナイザーの初期化
MODEL_NAME = "sbintuitions/sarashina2.2-3b-instruct-v0.1" # Sarashina2 3B instruceted
# MODEL_NAME = "sbintuitions/sarashina2-7b" # Sarashina2 7B
# MODEL_NAME = "sbintuitions/sarashina2-13b" # Sarashina2 13B
# MODEL_NAME = "sbintuitions/sarashina2-70b" # Sarashina2 70B
# MODEL_NAME = "sbintuitions/sarashina1-65b" # Sarashina1 65B
# MODEL_NAME = "elyza/Llama-3-ELYZA-JP-8B" # ELYZA-JP-8B
# MODEL_NAME = "lightblue/ao-karasu-72B" # ao-karasu-72B
# MODEL_NAME = "llm-jp/llm-jp-3-13b-instruct" # llm-jp-3-13b-instruct
# MODEL_NAME = "llm-jp/llm-jp-3-172b-instruct3" # llm-jp-3-172b-instruct3
# MODEL_NAME = "shisa-ai/shisa-v2-unphi4-14b" # shisa-v2-unphi4-14b
# MODEL_NAME = "shisa-ai/shisa-v2-qwen2.5-32b" # shisa-v2-qwen2.5-32b
print("モデルを読み込み中〜...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
device_map="auto" if torch.cuda.is_available() else None,
trust_remote_code=True
)
print("モデルの読み込みが完了しました〜。")
print(f"Is CUDA available: {torch.cuda.is_available()}")
# True
print("あ")
print(f"CUDA device: {torch.cuda.get_device_name(torch.cuda.current_device())}")
print("い")
# Tesla T4
@spaces.GPU
def respond(
message,
history: list[tuple[str, str]],
system_message,
max_tokens,
temperature,
top_p,
):
"""
チャットボットの応答を生成する関数
Gradio ChatInterfaceの標準形式に対応
"""
try:
# システムメッセージと会話履歴を含むプロンプトを構築
conversation = ""
if system_message.strip():
conversation += f"システム: {system_message}\n"
# 会話履歴を追加
for user_msg, bot_msg in history:
if user_msg:
conversation += f"ユーザー: {user_msg}\n"
if bot_msg:
conversation += f"アシスタント: {bot_msg}\n"
# 現在のメッセージを追加
conversation += f"ユーザー: {message}\nアシスタント: "
# トークン化
inputs = tokenizer.encode(conversation, return_tensors="pt")
# GPU使用時はCUDAに移動
if torch.cuda.is_available():
inputs = inputs.cuda()
# 応答生成(ストリーミング対応)
response = ""
with torch.no_grad():
# 一度に生成してからストリーミング風に出力
outputs = model.generate(
inputs,
max_new_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
do_sample=True,
pad_token_id=tokenizer.eos_token_id,
eos_token_id=tokenizer.eos_token_id,
repetition_penalty=1.1
)
# 生成されたテキストをデコード
generated = tokenizer.decode(outputs[0], skip_special_tokens=True)
# 変換できるかテスト用!!
# import json
# # レスポンス用の辞書を作るときに
# return json.dumps({"result": generated}, ensure_ascii=False)
# 応答部分のみを抽出
full_response = generated[len(conversation):].strip()
# 不要な部分を除去
if "ユーザー:" in full_response:
full_response = full_response.split("ユーザー:")[0].strip()
# ストリーミング風の出力
#for i in range(len(full_response)):
# response = full_response[:i+1]
# yield response
#response = full_response[:len(full_response)] #追加
#yield response #追加
#yield full_response #追加
return full_response #追加
except Exception as e:
#yield f"エラーが発生しました: {str(e)}"
return f"エラーが発生しました: {str(e)}" #追加
"""
Gradio ChatInterfaceを使用したシンプルなチャットボット
カスタマイズ可能なパラメータを含む
"""
demo = gr.ChatInterface(
respond,
title="🤖 Sarashina Chatbot",
description="Sarashina2.2-3b-instruct モデルを使用した日本語チャットボットです。",
additional_inputs=[
gr.Textbox(
value="あなたは親切で知識豊富な日本語アシスタントです。ユーザーの質問に丁寧に答えてください。",
label="システムメッセージ",
lines=3
),
gr.Slider(
minimum=1,
maximum=8192,
value=4096,
step=1,
label="最大新規トークン数"
),
gr.Slider(
minimum=0.1,
maximum=2.0,
value=0.7,
step=0.1,
label="Temperature (創造性)"
),
gr.Slider(
minimum=0.1,
maximum=1.0,
value=0.95,
step=0.05,
label="Top-p (多様性制御)",
),
],
theme=gr.themes.Soft(),
examples=[
["こんにちは!今日はどんなことを話しましょうか?"],
["日本の文化について教えてください。"],
["簡単なレシピを教えてもらえますか?"],
["プログラミングについて質問があります。"],
],
cache_examples=False,
#streaming=False # 追加 ← これで return のみ受け付ける同期モードに
)
if __name__ == "__main__":
demo.launch(
server_name="0.0.0.0",
server_port=7860,
share=False,
show_api=True, # API documentation を表示
debug=True
) |