File size: 2,619 Bytes
63f633e
f910ab5
bb987ab
bbd68e1
2ac2435
63f633e
4258c2d
 
bc9b115
63f633e
9659059
63f633e
bb987ab
63f633e
4258c2d
 
 
5ff79ad
63f633e
bb987ab
 
 
 
 
 
 
 
63f633e
 
bb987ab
 
63f633e
2abe721
bb987ab
63f633e
 
bb987ab
63f633e
 
2abe721
bb987ab
 
 
63f633e
2abe721
bb987ab
 
 
63f633e
 
4258c2d
 
 
 
 
5ff79ad
4258c2d
bbd68e1
4258c2d
 
 
 
bb987ab
 
 
4258c2d
63f633e
5ff79ad
 
63f633e
bbd68e1
5ff79ad
 
63f633e
5ff79ad
 
4258c2d
 
5ff79ad
f910ab5
5ff79ad
4258c2d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import os
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
import torch

# グローバル変数の初期化
model = None
tokenizer = None

# Hugging Face トークンの取得
HUGGING_FACE_TOKEN = os.getenv('HUGGINGFACE_TOKEN')
if not HUGGING_FACE_TOKEN:
    raise ValueError("環境変数 HUGGING_FACE_TOKEN が設定されていません")

def load_model():
    global model, tokenizer
    if model is None:
        model_name = "Guchyos/gemma-2b-elyza-task"
        try:
            # まずモデルの設定を読み込む
            config = AutoConfig.from_pretrained(
                model_name,
                token=HUGGING_FACE_TOKEN,
                trust_remote_code=True
            )
            
            # トークナイザーの読み込み
            tokenizer = AutoTokenizer.from_pretrained(
                model_name,
                token=HUGGING_FACE_TOKEN,
                trust_remote_code=True
            )
            
            # モデルの読み込み
            model = AutoModelForCausalLM.from_pretrained(
                model_name,
                config=config,
                torch_dtype=torch.float32,
                device_map="cpu",
                token=HUGGING_FACE_TOKEN,
                load_in_8bit=False,
                load_in_4bit=False,
                trust_remote_code=True
            )
            
            # モデルを評価モードに設定
            model.eval()
            
        except Exception as e:
            raise Exception(f"モデルの読み込みに失敗しました: {str(e)}")
    return model, tokenizer

def predict(message, history):
    try:
        model, tokenizer = load_model()
        prompt = f"質問: {message}\n\n回答:"
        inputs = tokenizer(prompt, return_tensors="pt")
        
        with torch.no_grad():
            outputs = model.generate(
                **inputs,
                max_new_tokens=128,
                do_sample=False,
                pad_token_id=tokenizer.pad_token_id,
                eos_token_id=tokenizer.eos_token_id
            )
            
        response = tokenizer.decode(outputs[0], skip_special_tokens=True)
        return response.replace(prompt, "").strip()
    
    except Exception as e:
        return f"エラーが発生しました: {str(e)}"

# Gradioインターフェースの設定
demo = gr.ChatInterface(
    fn=predict,
    title="💬 Gemma 2 for ELYZA-tasks",
    description="ELYZA-tasks-100-TV用に最適化された日本語LLMです"
)

if __name__ == "__main__":
    demo.launch(share=True)