File size: 6,240 Bytes
2592b96
9d917ac
205a913
 
 
 
9d917ac
 
205a913
 
9d917ac
 
205a913
7b617f0
65a5a95
6fd7b13
4f00b4e
b4b8293
2bf74f7
dbd3992
0e5d2cd
0744d77
7b617f0
2824c72
0e5d2cd
61e913e
205a913
72bb22a
205a913
 
 
 
 
 
 
72bb22a
 
 
 
448a0ea
72bb22a
448a0ea
72bb22a
9d917ac
2cd9ba8
 
9d917ac
 
 
 
 
 
 
 
205a913
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a4c7cfa
 
2004aa9
 
 
a4c7cfa
205a913
 
 
 
 
 
 
 
540e7a8
 
 
328a888
8a55503
 
a8b07c3
 
205a913
 
9b1b384
 
9d917ac
 
205a913
 
9d917ac
 
 
205a913
 
9d917ac
205a913
 
 
 
 
 
 
84abc70
 
205a913
 
 
 
 
 
 
 
 
 
9d917ac
 
 
 
 
205a913
9d917ac
 
205a913
 
 
 
 
 
 
 
02609ac
9d917ac
 
 
205a913
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
import spaces
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import warnings
warnings.filterwarnings("ignore")

"""
Sarashinaモデルを使用したGradioチャットボット
Hugging Face Transformersライブラリを使用してローカルでモデルを実行
"""

# モデルとトークナイザーの初期化
MODEL_NAME = "sbintuitions/sarashina2.2-3b-instruct-v0.1" # Sarashina2 3B instruceted
# MODEL_NAME = "sbintuitions/sarashina2-7b" # Sarashina2 7B
# MODEL_NAME = "sbintuitions/sarashina2-13b" # Sarashina2 13B
# MODEL_NAME = "sbintuitions/sarashina2-70b" # Sarashina2 70B
# MODEL_NAME = "sbintuitions/sarashina1-65b" # Sarashina1 65B
# MODEL_NAME = "elyza/Llama-3-ELYZA-JP-8B" # ELYZA-JP-8B
# MODEL_NAME = "lightblue/ao-karasu-72B" # ao-karasu-72B
# MODEL_NAME = "llm-jp/llm-jp-3-13b-instruct" # llm-jp-3-13b-instruct
# MODEL_NAME = "llm-jp/llm-jp-3-172b-instruct3" # llm-jp-3-172b-instruct3
# MODEL_NAME = "shisa-ai/shisa-v2-unphi4-14b" # shisa-v2-unphi4-14b
# MODEL_NAME = "shisa-ai/shisa-v2-qwen2.5-32b" # shisa-v2-qwen2.5-32b



print("モデルを読み込み中〜...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
    device_map="auto" if torch.cuda.is_available() else None,
    trust_remote_code=True
)
print("モデルの読み込みが完了しました〜。")

print(f"Is CUDA available: {torch.cuda.is_available()}")
# True
print("あ")
print(f"CUDA device: {torch.cuda.get_device_name(torch.cuda.current_device())}")
print("い")
# Tesla T4

@spaces.GPU

def respond(
    message,
    history: list[tuple[str, str]],
    system_message,
    max_tokens,
    temperature,
    top_p,
):
    """
    チャットボットの応答を生成する関数
    Gradio ChatInterfaceの標準形式に対応
    """
    try:
        # システムメッセージと会話履歴を含むプロンプトを構築
        conversation = ""
        if system_message.strip():
            conversation += f"システム: {system_message}\n"
        
        # 会話履歴を追加
        for user_msg, bot_msg in history:
            if user_msg:
                conversation += f"ユーザー: {user_msg}\n"
            if bot_msg:
                conversation += f"アシスタント: {bot_msg}\n"
        
        # 現在のメッセージを追加
        conversation += f"ユーザー: {message}\nアシスタント: "
        
        # トークン化
        inputs = tokenizer.encode(conversation, return_tensors="pt")
        
        # GPU使用時はCUDAに移動
        if torch.cuda.is_available():
            inputs = inputs.cuda()
        
        # 応答生成(ストリーミング対応)
        response = ""
        with torch.no_grad():
            # 一度に生成してからストリーミング風に出力
            outputs = model.generate(
                inputs,
                max_new_tokens=max_tokens,
                temperature=temperature,
                top_p=top_p,
                do_sample=True,
                pad_token_id=tokenizer.eos_token_id,
                eos_token_id=tokenizer.eos_token_id,
                repetition_penalty=1.1
            )
        
        # 生成されたテキストをデコード
        generated = tokenizer.decode(outputs[0], skip_special_tokens=True)

        # 変換できるかテスト用!!
        #    import json
        #    # レスポンス用の辞書を作るときに
        #    return json.dumps({"result": generated}, ensure_ascii=False)

        # 応答部分のみを抽出
        full_response = generated[len(conversation):].strip()
        
        # 不要な部分を除去
        if "ユーザー:" in full_response:
            full_response = full_response.split("ユーザー:")[0].strip()
        
        # ストリーミング風の出力
        #for i in range(len(full_response)):
        #    response = full_response[:i+1]
        #    yield response

        #response = full_response[:len(full_response)] #追加
        #yield response #追加
        #yield full_response #追加
        return full_response #追加
            
    except Exception as e:
        #yield f"エラーが発生しました: {str(e)}"
        return f"エラーが発生しました: {str(e)}" #追加

"""
Gradio ChatInterfaceを使用したシンプルなチャットボット
カスタマイズ可能なパラメータを含む
"""
demo = gr.ChatInterface(
    respond,
    title="🤖 Sarashina Chatbot",
    description="Sarashina2.2-3b-instruct モデルを使用した日本語チャットボットです。",
    additional_inputs=[
        gr.Textbox(
            value="あなたは親切で知識豊富な日本語アシスタントです。ユーザーの質問に丁寧に答えてください。", 
            label="システムメッセージ",
            lines=3
        ),
        gr.Slider(
            minimum=1, 
            maximum=8192, 
            value=4096, 
            step=1, 
            label="最大新規トークン数"
        ),
        gr.Slider(
            minimum=0.1, 
            maximum=2.0, 
            value=0.7, 
            step=0.1, 
            label="Temperature (創造性)"
        ),
        gr.Slider(
            minimum=0.1,
            maximum=1.0,
            value=0.95,
            step=0.05,
            label="Top-p (多様性制御)",
        ),
    ],
    theme=gr.themes.Soft(),
    examples=[
        ["こんにちは!今日はどんなことを話しましょうか?"],
        ["日本の文化について教えてください。"],
        ["簡単なレシピを教えてもらえますか?"],
        ["プログラミングについて質問があります。"],
    ],
    cache_examples=False,
    #streaming=False  # 追加 ← これで return のみ受け付ける同期モードに
)

if __name__ == "__main__":
    demo.launch(
        server_name="0.0.0.0",
        server_port=7860,
        share=False,
        show_api=True,  # API documentation を表示
        debug=True
    )