Spaces:
Sleeping
Sleeping
File size: 2,619 Bytes
63f633e f910ab5 bb987ab bbd68e1 2ac2435 63f633e 4258c2d bc9b115 63f633e 9659059 63f633e bb987ab 63f633e 4258c2d 5ff79ad 63f633e bb987ab 63f633e bb987ab 63f633e 2abe721 bb987ab 63f633e bb987ab 63f633e 2abe721 bb987ab 63f633e 2abe721 bb987ab 63f633e 4258c2d 5ff79ad 4258c2d bbd68e1 4258c2d bb987ab 4258c2d 63f633e 5ff79ad 63f633e bbd68e1 5ff79ad 63f633e 5ff79ad 4258c2d 5ff79ad f910ab5 5ff79ad 4258c2d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 |
import os
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
import torch
# グローバル変数の初期化
model = None
tokenizer = None
# Hugging Face トークンの取得
HUGGING_FACE_TOKEN = os.getenv('HUGGINGFACE_TOKEN')
if not HUGGING_FACE_TOKEN:
raise ValueError("環境変数 HUGGING_FACE_TOKEN が設定されていません")
def load_model():
global model, tokenizer
if model is None:
model_name = "Guchyos/gemma-2b-elyza-task"
try:
# まずモデルの設定を読み込む
config = AutoConfig.from_pretrained(
model_name,
token=HUGGING_FACE_TOKEN,
trust_remote_code=True
)
# トークナイザーの読み込み
tokenizer = AutoTokenizer.from_pretrained(
model_name,
token=HUGGING_FACE_TOKEN,
trust_remote_code=True
)
# モデルの読み込み
model = AutoModelForCausalLM.from_pretrained(
model_name,
config=config,
torch_dtype=torch.float32,
device_map="cpu",
token=HUGGING_FACE_TOKEN,
load_in_8bit=False,
load_in_4bit=False,
trust_remote_code=True
)
# モデルを評価モードに設定
model.eval()
except Exception as e:
raise Exception(f"モデルの読み込みに失敗しました: {str(e)}")
return model, tokenizer
def predict(message, history):
try:
model, tokenizer = load_model()
prompt = f"質問: {message}\n\n回答:"
inputs = tokenizer(prompt, return_tensors="pt")
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=128,
do_sample=False,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id
)
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
return response.replace(prompt, "").strip()
except Exception as e:
return f"エラーが発生しました: {str(e)}"
# Gradioインターフェースの設定
demo = gr.ChatInterface(
fn=predict,
title="💬 Gemma 2 for ELYZA-tasks",
description="ELYZA-tasks-100-TV用に最適化された日本語LLMです"
)
if __name__ == "__main__":
demo.launch(share=True) |