Spaces:
Runtime error
Runtime error
import gradio as gr | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
import torch | |
# モデルとトークナイザーの読み込み | |
model_name = "EleutherAI/Pythia-1b" | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
model = AutoModelForCausalLM.from_pretrained(model_name, ignore_mismatched_sizes=True) | |
# 応答を生成する関数 | |
def respond(message, history, max_tokens, temperature, top_p): | |
# 入力履歴と新しいメッセージを連結 | |
if history is None: | |
history = [] | |
input_text = "" | |
for user_message, bot_response in history: | |
input_text += f"User: {user_message}\nAssistant: {bot_response}\n" | |
input_text += f"User: {message}\nAssistant:" | |
# トークナイズ | |
inputs = tokenizer(input_text, return_tensors="pt") | |
# モデルによる応答生成 | |
with torch.no_grad(): | |
outputs = model.generate( | |
inputs.input_ids, | |
max_length=inputs.input_ids.shape[1] + max_tokens, | |
do_sample=True, | |
top_p=top_p, | |
temperature=temperature, | |
) | |
# 応答をデコード | |
response = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
# 最後のユーザー入力以降の応答部分を抽出 | |
response = response.split("Assistant:")[-1].strip() | |
# 応答と履歴を更新 | |
history.append((message, response)) | |
return response, history | |
# Gradioインターフェースの設定 | |
with gr.Blocks() as demo: | |
gr.Markdown("## AIチャット") | |
chatbot = gr.Chatbot() | |
msg = gr.Textbox(label="あなたのメッセージ", placeholder="ここにメッセージを入力...") | |
max_tokens = gr.Slider(1, 2048, value=512, step=1, label="新規トークン最大") | |
temperature = gr.Slider(0.1, 4.0, value=0.7, step=0.1, label="温度") | |
top_p = gr.Slider(0.1, 1.0, value=0.95, step=0.05, label="Top-p (核サンプリング)") | |
send_button = gr.Button("送信") | |
clear = gr.Button("クリア") | |
def clear_history(): | |
return [], [] | |
send_button.click(respond, inputs=[msg, chatbot, max_tokens, temperature, top_p], outputs=[chatbot, chatbot]) | |
clear.click(clear_history, outputs=[chatbot]) | |
demo.launch() |