入出力の保存、応答の評価、データセットのアップロード機能を追加

#1
Files changed (1) hide show
  1. app.py +85 -21
app.py CHANGED
@@ -1,7 +1,13 @@
1
  import gradio as gr
2
- # from huggingface_hub import InferenceClient
3
  from openai import OpenAI
4
  import os
 
 
 
 
 
 
 
5
  openai_api_key = os.getenv('api_key')
6
  openai_api_base = os.getenv('url')
7
  model_name = "weblab-GENIAC/Tanuki-8x8B-dpo-v1.0"
@@ -14,6 +20,54 @@ client = OpenAI(
14
  base_url=openai_api_base,
15
  )
16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
  def respond(
19
  message,
@@ -35,8 +89,7 @@ def respond(
35
  messages.append({"role": "user", "content": message})
36
 
37
  response = ""
38
-
39
- for message in client.chat.completions.create(
40
  model=model_name,
41
  messages=messages,
42
  max_tokens=max_tokens,
@@ -44,24 +97,33 @@ def respond(
44
  temperature=temperature,
45
  top_p=top_p,
46
  ):
47
- token = message.choices[0].delta.content
48
-
49
- # response += token
50
- if token is not None:
51
- response += (token)
52
- if response.find("### 指示:")>0:
53
- response=response.replace("### 指示:","")
54
- break
55
  yield response
56
-
57
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
  """
59
  For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
60
  """
61
 
62
  description = """
63
  ### [Tanuki-8x8B-dpo-v1.0](https://huggingface.co/weblab-GENIAC/Tanuki-8x8B-dpo-v1.0)との会話(期間限定での公開)
64
- - 人工知能開発のため、原則として**このChatBotの入出力データは全て著作権フリー(CC0)で公開予定です**ので、ご注意ください。著作物、個人情報、機密情報、誹謗中傷などのデータを入力しないでください。
 
65
  - **上記の条件に同意する場合のみ**、以下のChatbotを利用してください。
66
  """
67
 
@@ -71,8 +133,8 @@ FOOTER = """### 注意
71
  - コンテクスト長が4096までなので、あまり会話が長くなると、エラーで停止します。ページを再読み込みしてください。
72
  - GPUサーバー���不安定なので、応答しないことがあるかもしれません。"""
73
 
74
-
75
  def run():
 
76
  chatbot = gr.Chatbot(
77
  elem_id="chatbot",
78
  scale=1,
@@ -82,7 +144,7 @@ def run():
82
  )
83
  with gr.Blocks(fill_height=True) as demo:
84
  gr.Markdown(HEADER)
85
- gr.ChatInterface(
86
  fn=respond,
87
  stop_btn="Stop Generation",
88
  cache_examples=False,
@@ -92,9 +154,11 @@ def run():
92
  label="Parameters", open=False, render=False
93
  ),
94
  additional_inputs=[
95
- gr.Textbox(value="以下は、タスクを説明する指示です。要求を適切に満たす応答を書きなさい。",
96
- label="System message(試験用: 変えると出力が壊れる可能性)",
 
97
  render=False,),
 
98
  gr.Slider(
99
  minimum=1,
100
  maximum=4096,
@@ -125,10 +189,10 @@ def run():
125
  ],
126
  analytics_enabled=False,
127
  )
 
128
  gr.Markdown(FOOTER)
129
- demo.queue(max_size=256, api_open=False)
130
- demo.launch(share=False, quiet=True)
131
-
132
 
133
  if __name__ == "__main__":
134
  run()
 
1
  import gradio as gr
 
2
  from openai import OpenAI
3
  import os
4
+ import json
5
+ from datetime import datetime
6
+ from zoneinfo import ZoneInfo
7
+ import uuid
8
+ from pathlib import Path
9
+ from huggingface_hub import CommitScheduler
10
+
11
  openai_api_key = os.getenv('api_key')
12
  openai_api_base = os.getenv('url')
13
  model_name = "weblab-GENIAC/Tanuki-8x8B-dpo-v1.0"
 
20
  base_url=openai_api_base,
21
  )
22
 
23
+ # Define the file where to save the data. Use UUID to make sure not to overwrite existing data from a previous run.
24
+ feedback_file = Path("user_feedback/") / f"data_{uuid.uuid4()}.json"
25
+ feedback_folder = feedback_file.parent
26
+
27
+ # Schedule regular uploads. Remote repo and local folder are created if they don't already exist.
28
+ scheduler = CommitScheduler(
29
+ repo_id="kanhatakeyama/TanukiChat", # Replace with your actual repo ID
30
+ repo_type="dataset",
31
+ folder_path=feedback_folder,
32
+ path_in_repo="data",
33
+ every=1, # Upload every 1 minutes
34
+ )
35
+
36
+ def save_or_update_conversation(conversation_id, message, response, message_index, liked=None):
37
+ """
38
+ Save or update conversation data in a JSON Lines file.
39
+ If the entry already exists (same id and message_index), update the 'label' field.
40
+ Otherwise, append a new entry.
41
+ """
42
+ with scheduler.lock:
43
+ # Read existing data
44
+ data = []
45
+ if feedback_file.exists():
46
+ with feedback_file.open("r") as f:
47
+ data = [json.loads(line) for line in f if line.strip()]
48
+
49
+ # Find if an entry with the same id and message_index exists
50
+ entry_index = next((i for i, entry in enumerate(data) if entry['id'] == conversation_id and entry['message_index'] == message_index), None)
51
+
52
+ if entry_index is not None:
53
+ # Update existing entry
54
+ data[entry_index]['label'] = liked
55
+ else:
56
+ # Append new entry
57
+ data.append({
58
+ "id": conversation_id,
59
+ "timestamp": datetime.now(ZoneInfo("Asia/Tokyo")).isoformat(),
60
+ "prompt": message,
61
+ "completion": response,
62
+ "message_index": message_index,
63
+ "label": liked
64
+ })
65
+
66
+ # Write updated data back to file
67
+ with feedback_file.open("w") as f:
68
+ for entry in data:
69
+ f.write(json.dumps(entry) + "\n")
70
+
71
 
72
  def respond(
73
  message,
 
89
  messages.append({"role": "user", "content": message})
90
 
91
  response = ""
92
+ for chunk in client.chat.completions.create(
 
93
  model=model_name,
94
  messages=messages,
95
  max_tokens=max_tokens,
 
97
  temperature=temperature,
98
  top_p=top_p,
99
  ):
100
+ if chunk.choices[0].delta.content is not None:
101
+ response += chunk.choices[0].delta.content
 
 
 
 
 
 
102
  yield response
103
+
104
+ # Save conversation after the full response is generated
105
+ message_index = len(history)
106
+ save_or_update_conversation(conversation_id, message, response, message_index)
107
+
108
+ def vote(data: gr.LikeData, history, conversation_id):
109
+ """
110
+ Update user feedback (like/dislike) in the local file.
111
+ """
112
+ message_index = data.index[0]
113
+ liked = data.liked
114
+ save_or_update_conversation(conversation_id, None, None, message_index, liked)
115
+
116
+ def create_conversation_id():
117
+ return str(uuid.uuid4())
118
+
119
  """
120
  For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
121
  """
122
 
123
  description = """
124
  ### [Tanuki-8x8B-dpo-v1.0](https://huggingface.co/weblab-GENIAC/Tanuki-8x8B-dpo-v1.0)との会話(期間限定での公開)
125
+ - 人工知能開発のため、原則として**このChatBotの入出力データは全て著作権フリー(CC0)で公開する**ため、ご注意ください。著作物、個人情報、機密情報、誹謗中傷などのデータを入力しないでください。
126
+ - データセットはこちらで公開しています。  https://huggingface.co/datasets/kanhatakeyama/TanukiChat
127
  - **上記の条件に同意する場合のみ**、以下のChatbotを利用してください。
128
  """
129
 
 
133
  - コンテクスト長が4096までなので、あまり会話が長くなると、エラーで停止します。ページを再読み込みしてください。
134
  - GPUサーバー���不安定なので、応答しないことがあるかもしれません。"""
135
 
 
136
  def run():
137
+ conversation_id = gr.State(create_conversation_id)
138
  chatbot = gr.Chatbot(
139
  elem_id="chatbot",
140
  scale=1,
 
144
  )
145
  with gr.Blocks(fill_height=True) as demo:
146
  gr.Markdown(HEADER)
147
+ chat_interface = gr.ChatInterface(
148
  fn=respond,
149
  stop_btn="Stop Generation",
150
  cache_examples=False,
 
154
  label="Parameters", open=False, render=False
155
  ),
156
  additional_inputs=[
157
+ additional_inputs=[
158
+ gr.Textbox(value="以下は、タスクを説明する指示です。要求を適切に満たす応答を書きなさい。",
159
+ label="System message(試験用: 変えると性能が低下する可能性があります。)",
160
  render=False,),
161
+ conversation_id,
162
  gr.Slider(
163
  minimum=1,
164
  maximum=4096,
 
189
  ],
190
  analytics_enabled=False,
191
  )
192
+ chatbot.like(vote, [chatbot, conversation_id], None)
193
  gr.Markdown(FOOTER)
194
+ demo.queue(max_size=256, api_open=True)
195
+ demo.launch(share=True, quiet=True)
 
196
 
197
  if __name__ == "__main__":
198
  run()