oggata commited on
Commit
96b5085
·
verified ·
1 Parent(s): 1dbd9cd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +229 -58
app.py CHANGED
@@ -1,64 +1,235 @@
1
  import gradio as gr
2
- from huggingface_hub import InferenceClient
 
 
 
3
 
4
- """
5
- For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
6
- """
7
- client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
 
 
9
 
10
- def respond(
11
- message,
12
- history: list[tuple[str, str]],
13
- system_message,
14
- max_tokens,
15
- temperature,
16
- top_p,
17
- ):
18
- messages = [{"role": "system", "content": system_message}]
19
-
20
- for val in history:
21
- if val[0]:
22
- messages.append({"role": "user", "content": val[0]})
23
- if val[1]:
24
- messages.append({"role": "assistant", "content": val[1]})
25
-
26
- messages.append({"role": "user", "content": message})
27
-
28
- response = ""
29
-
30
- for message in client.chat_completion(
31
- messages,
32
- max_tokens=max_tokens,
33
- stream=True,
34
- temperature=temperature,
35
- top_p=top_p,
36
- ):
37
- token = message.choices[0].delta.content
38
-
39
- response += token
40
- yield response
41
-
42
-
43
- """
44
- For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
45
- """
46
- demo = gr.ChatInterface(
47
- respond,
48
- additional_inputs=[
49
- gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
50
- gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
51
- gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
52
- gr.Slider(
53
- minimum=0.1,
54
- maximum=1.0,
55
- value=0.95,
56
- step=0.05,
57
- label="Top-p (nucleus sampling)",
58
- ),
59
- ],
60
- )
61
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
 
 
63
  if __name__ == "__main__":
64
- demo.launch()
 
 
 
 
 
 
1
  import gradio as gr
2
+ import requests
3
+ import json
4
+ import os
5
+ from typing import List, Tuple
6
 
7
+ class JapaneseLLMChat:
8
+ def __init__(self):
9
+ # 利用可能な日本語LLMモデル
10
+ self.models = {
11
+ "cyberagent/open-calm-7b": "CyberAgent Open CALM 7B",
12
+ "rinna/japanese-gpt-neox-3.6b-instruction-sft": "Rinna GPT-NeoX 3.6B",
13
+ "matsuo-lab/weblab-10b-instruction-sft": "Matsuo Lab WebLab 10B",
14
+ "stabilityai/japanese-stablelm-instruct-alpha-7b": "Japanese StableLM 7B"
15
+ }
16
+
17
+ # デフォルトモデル
18
+ self.current_model = "cyberagent/open-calm-7b"
19
+
20
+ # HuggingFace API設定
21
+ self.api_url = "https://api-inference.huggingface.co/models/"
22
+ self.headers = {}
23
+
24
+ def set_api_key(self, api_key: str):
25
+ """APIキーを設定"""
26
+ if api_key.strip():
27
+ self.headers = {"Authorization": f"Bearer {api_key}"}
28
+ return "✅ APIキーが設定されました"
29
+ else:
30
+ return "❌ 有効なAPIキーを入力してください"
31
+
32
+ def set_model(self, model_name: str):
33
+ """使用するモデルを変更"""
34
+ self.current_model = model_name
35
+ return f"モデルを {self.models[model_name]} に変更しました"
36
+
37
+ def query_model(self, prompt: str, max_length: int = 200, temperature: float = 0.7) -> str:
38
+ """HuggingFace Inference APIにクエリを送信"""
39
+ if not self.headers:
40
+ return "❌ APIキーが設定されていません"
41
+
42
+ url = self.api_url + self.current_model
43
+
44
+ payload = {
45
+ "inputs": prompt,
46
+ "parameters": {
47
+ "max_length": max_length,
48
+ "temperature": temperature,
49
+ "do_sample": True,
50
+ "top_p": 0.95,
51
+ "return_full_text": False
52
+ }
53
+ }
54
+
55
+ try:
56
+ response = requests.post(url, headers=self.headers, json=payload, timeout=30)
57
+
58
+ if response.status_code == 200:
59
+ result = response.json()
60
+ if isinstance(result, list) and len(result) > 0:
61
+ generated_text = result[0].get("generated_text", "")
62
+ return generated_text.strip()
63
+ else:
64
+ return "❌ 予期しないレスポンス形式です"
65
+ elif response.status_code == 503:
66
+ return "⏳ モデルが読み込み中です。しばらく待ってから再試行してください。"
67
+ elif response.status_code == 401:
68
+ return "❌ APIキーが無効です"
69
+ else:
70
+ return f"❌ エラーが発生しました (ステータス: {response.status_code})"
71
+
72
+ except requests.exceptions.Timeout:
73
+ return "⏳ リクエストがタイムアウトしました。再試行してください。"
74
+ except requests.exceptions.RequestException as e:
75
+ return f"❌ 接続エラー: {str(e)}"
76
+
77
+ def chat_response(self, message: str, history: List[Tuple[str, str]],
78
+ max_length: int, temperature: float) -> Tuple[str, List[Tuple[str, str]]]:
79
+ """チャット応答を生成"""
80
+ if not message.strip():
81
+ return "", history
82
+
83
+ # 対話履歴を考慮したプロンプト作成
84
+ conversation_context = ""
85
+ for user_msg, bot_msg in history[-3:]: # 直近3回の会話を含める
86
+ conversation_context += f"ユーザー: {user_msg}\nアシスタント: {bot_msg}\n"
87
+
88
+ # プロンプトの構築
89
+ if self.current_model == "rinna/japanese-gpt-neox-3.6b-instruction-sft":
90
+ prompt = f"{conversation_context}ユーザー: {message}\nアシスタント:"
91
+ elif "instruct" in self.current_model.lower():
92
+ prompt = f"以下は、タスクを説明する指示と、文脈のある入力の組み合わせです。要求を適切に満たす応答を書いてください。\n\n### 指示:\n日本語で自然な会話を行ってください。\n\n### 入力:\n{conversation_context}ユーザー: {message}\n\n### 応答:\n"
93
+ else:
94
+ prompt = f"{conversation_context}ユーザー: {message}\nアシスタント:"
95
+
96
+ # モデルから応答を取得
97
+ response = self.query_model(prompt, max_length, temperature)
98
+
99
+ # 履歴に追加
100
+ history.append((message, response))
101
+
102
+ return "", history
103
 
104
+ # チャットインスタンスを作成
105
+ chat_bot = JapaneseLLMChat()
106
 
107
+ # Gradio インターフェースの構築
108
+ def create_interface():
109
+ with gr.Blocks(
110
+ title="日本語LLMチャット",
111
+ theme=gr.themes.Soft(),
112
+ css="""
113
+ .gradio-container {
114
+ max-width: 1000px !important;
115
+ }
116
+ """
117
+ ) as demo:
118
+
119
+ gr.Markdown(
120
+ """
121
+ # 🤖 日本語LLMチャット
122
+ HuggingFace Inference APIを使用した日本語対話システム
123
+ """
124
+ )
125
+
126
+ with gr.Row():
127
+ with gr.Column(scale=2):
128
+ # APIキー設定
129
+ with gr.Group():
130
+ gr.Markdown("### 🔑 設定")
131
+ api_key_input = gr.Textbox(
132
+ label="HuggingFace API Token",
133
+ placeholder="hf_xxxxxxxxxxxxxxxxx",
134
+ type="password"
135
+ )
136
+ api_key_btn = gr.Button("APIキーを設定", variant="primary")
137
+ api_key_status = gr.Textbox(label="ステータス", interactive=False)
138
+
139
+ # モデル選択
140
+ with gr.Group():
141
+ gr.Markdown("### 🧠 モデル選択")
142
+ model_dropdown = gr.Dropdown(
143
+ choices=[(v, k) for k, v in chat_bot.models.items()],
144
+ value="cyberagent/open-calm-7b",
145
+ label="使用するモデル"
146
+ )
147
+ model_status = gr.Textbox(label="現在のモデル", interactive=False,
148
+ value=chat_bot.models[chat_bot.current_model])
149
+
150
+ # パラメータ設定
151
+ with gr.Group():
152
+ gr.Markdown("### ⚙️ 生成パラメータ")
153
+ max_length_slider = gr.Slider(
154
+ minimum=50, maximum=500, value=200,
155
+ label="最大生成長"
156
+ )
157
+ temperature_slider = gr.Slider(
158
+ minimum=0.1, maximum=2.0, value=0.7,
159
+ label="Temperature(創造性)"
160
+ )
161
+
162
+ with gr.Column(scale=3):
163
+ # チャットインターフェース
164
+ chatbot = gr.Chatbot(
165
+ height=500,
166
+ label="会話",
167
+ show_label=True,
168
+ avatar_images=["👤", "🤖"]
169
+ )
170
+
171
+ msg = gr.Textbox(
172
+ label="メッセージ",
173
+ placeholder="メッセージを入力してください...",
174
+ lines=2
175
+ )
176
+
177
+ with gr.Row():
178
+ send_btn = gr.Button("送信", variant="primary")
179
+ clear_btn = gr.Button("会話をクリア", variant="secondary")
180
+
181
+ # 使用方法の説明
182
+ with gr.Accordion("📖 使用方法", open=False):
183
+ gr.Markdown(
184
+ """
185
+ 1. **APIキーの設定**: HuggingFace(https://huggingface.co/settings/tokens)からAccess Tokenを取得し、上記フィールドに入力してください
186
+ 2. **モデル選択**: 使用したい日本語LLMを選択してください
187
+ 3. **パラメータ調整**: 必要に応じて生成パラメータを調整してください
188
+ 4. **チャット開始**: メッセージを入力して「送信」ボタンをクリックしてください
189
+
190
+ **注意**:
191
+ - 初回使用���はモデルの読み込みに時間がかかる場合があります
192
+ - 大きなモデル(7B以上)の使用には有料アカウントが必要な場合があります
193
+ """
194
+ )
195
+
196
+ # イベントハンドラーの設定
197
+ api_key_btn.click(
198
+ chat_bot.set_api_key,
199
+ inputs=[api_key_input],
200
+ outputs=[api_key_status]
201
+ )
202
+
203
+ model_dropdown.change(
204
+ chat_bot.set_model,
205
+ inputs=[model_dropdown],
206
+ outputs=[model_status]
207
+ )
208
+
209
+ send_btn.click(
210
+ chat_bot.chat_response,
211
+ inputs=[msg, chatbot, max_length_slider, temperature_slider],
212
+ outputs=[msg, chatbot]
213
+ )
214
+
215
+ msg.submit(
216
+ chat_bot.chat_response,
217
+ inputs=[msg, chatbot, max_length_slider, temperature_slider],
218
+ outputs=[msg, chatbot]
219
+ )
220
+
221
+ clear_btn.click(
222
+ lambda: ([], ""),
223
+ outputs=[chatbot, msg]
224
+ )
225
+
226
+ return demo
227
 
228
+ # アプリケーションの起動
229
  if __name__ == "__main__":
230
+ demo = create_interface()
231
+ demo.launch(
232
+ share=True,
233
+ server_name="0.0.0.0",
234
+ server_port=7860
235
+ )