Tuchuanhuhuhu commited on
Commit
a543a3d
·
1 Parent(s): b69c7d1

加强代码健壮性

Browse files
Files changed (2) hide show
  1. ChuanhuChatbot.py +9 -9
  2. utils.py +31 -14
ChuanhuChatbot.py CHANGED
@@ -69,26 +69,26 @@ with gr.Blocks(css=customCSS) as demo:
69
  with gr.Column():
70
  with gr.Row():
71
  with gr.Column(scale=6):
72
- templateFileSelectDropdown = gr.Dropdown(label="选择Prompt模板集合文件(.csv)", choices=get_template_names(plain=True), multiselect=False)
73
  with gr.Column(scale=1):
74
  templateRefreshBtn = gr.Button("🔄 刷新")
75
  templaeFileReadBtn = gr.Button("📂 读入模板")
76
  with gr.Row():
77
  with gr.Column(scale=6):
78
- templateSelectDropdown = gr.Dropdown(label="从Prompt模板中加载", choices=load_template(get_template_names(plain=True)[0], mode=1), multiselect=False)
79
  with gr.Column(scale=1):
80
  templateApplyBtn = gr.Button("⬇️ 应用")
81
- with gr.Accordion(label="保存/加载对话历史记录(在文本框中输入文件名,点击“保存对话”按钮,历史记录文件会被存储到Python文件旁边)", open=False):
82
  with gr.Column():
83
  with gr.Row():
84
  with gr.Column(scale=6):
85
  saveFileName = gr.Textbox(
86
  show_label=True, placeholder=f"在这里输入保存的文件名...", label="设置保存文件名", value="对话历史记录").style(container=True)
87
  with gr.Column(scale=1):
88
- saveBtn = gr.Button("💾 保存对话")
89
  with gr.Row():
90
  with gr.Column(scale=6):
91
- historyFileSelectDropdown = gr.Dropdown(label="从列表中加载对话", choices=get_history_names(plain=True), multiselect=False)
92
  with gr.Column(scale=1):
93
  historyRefreshBtn = gr.Button("🔄 刷新")
94
  historyReadBtn = gr.Button("📂 读入对话")
@@ -116,14 +116,14 @@ with gr.Blocks(css=customCSS) as demo:
116
  chatbot, history], show_progress=True)
117
  reduceTokenBtn.click(predict, [txt, top_p, temperature, keyTxt, chatbot, history,
118
  systemPromptTxt, FALSECONSTANT, TRUECOMSTANT], [chatbot, history, statusDisplay], show_progress=True)
119
- saveBtn.click(save_chat_history, [
120
  saveFileName, systemPromptTxt, history, chatbot], None, show_progress=True)
121
- saveBtn.click(get_history_names, None, [historyFileSelectDropdown])
122
  historyRefreshBtn.click(get_history_names, None, [historyFileSelectDropdown])
123
- historyReadBtn.click(load_chat_history, [historyFileSelectDropdown], [saveFileName, systemPromptTxt, history, chatbot], show_progress=True)
124
  templateRefreshBtn.click(get_template_names, None, [templateFileSelectDropdown])
125
  templaeFileReadBtn.click(load_template, [templateFileSelectDropdown], [promptTemplates, templateSelectDropdown], show_progress=True)
126
- templateApplyBtn.click(lambda x, y: x[y], [promptTemplates, templateSelectDropdown], [systemPromptTxt], show_progress=True)
127
 
128
  print("川虎的温馨提示:访问 http://localhost:7860 查看界面")
129
  # 默认开启本地服务器,默认可以直接从IP访问,默认不创建公开分享链接
 
69
  with gr.Column():
70
  with gr.Row():
71
  with gr.Column(scale=6):
72
+ templateFileSelectDropdown = gr.Dropdown(label="选择Prompt模板集合文件(.csv)", choices=get_template_names(plain=True), multiselect=False, value=get_template_names(plain=True)[0])
73
  with gr.Column(scale=1):
74
  templateRefreshBtn = gr.Button("🔄 刷新")
75
  templaeFileReadBtn = gr.Button("📂 读入模板")
76
  with gr.Row():
77
  with gr.Column(scale=6):
78
+ templateSelectDropdown = gr.Dropdown(label="从Prompt模板中加载", choices=load_template(get_template_names(plain=True)[0], mode=1), multiselect=False, value=load_template(get_template_names(plain=True)[0], mode=1)[0])
79
  with gr.Column(scale=1):
80
  templateApplyBtn = gr.Button("⬇️ 应用")
81
+ with gr.Accordion(label="保存/加载对话历史记录", open=False):
82
  with gr.Column():
83
  with gr.Row():
84
  with gr.Column(scale=6):
85
  saveFileName = gr.Textbox(
86
  show_label=True, placeholder=f"在这里输入保存的文件名...", label="设置保存文件名", value="对话历史记录").style(container=True)
87
  with gr.Column(scale=1):
88
+ saveHistoryBtn = gr.Button("💾 保存对话")
89
  with gr.Row():
90
  with gr.Column(scale=6):
91
+ historyFileSelectDropdown = gr.Dropdown(label="从列表中加载对话", choices=get_history_names(plain=True), multiselect=False, value=get_history_names(plain=True)[0])
92
  with gr.Column(scale=1):
93
  historyRefreshBtn = gr.Button("🔄 刷新")
94
  historyReadBtn = gr.Button("📂 读入对话")
 
116
  chatbot, history], show_progress=True)
117
  reduceTokenBtn.click(predict, [txt, top_p, temperature, keyTxt, chatbot, history,
118
  systemPromptTxt, FALSECONSTANT, TRUECOMSTANT], [chatbot, history, statusDisplay], show_progress=True)
119
+ saveHistoryBtn.click(save_chat_history, [
120
  saveFileName, systemPromptTxt, history, chatbot], None, show_progress=True)
121
+ saveHistoryBtn.click(get_history_names, None, [historyFileSelectDropdown])
122
  historyRefreshBtn.click(get_history_names, None, [historyFileSelectDropdown])
123
+ historyReadBtn.click(load_chat_history, [historyFileSelectDropdown, systemPromptTxt, history, chatbot], [saveFileName, systemPromptTxt, history, chatbot], show_progress=True)
124
  templateRefreshBtn.click(get_template_names, None, [templateFileSelectDropdown])
125
  templaeFileReadBtn.click(load_template, [templateFileSelectDropdown], [promptTemplates, templateSelectDropdown], show_progress=True)
126
+ templateApplyBtn.click(get_template_content, [promptTemplates, templateSelectDropdown, systemPromptTxt], [systemPromptTxt], show_progress=True)
127
 
128
  print("川虎的温馨提示:访问 http://localhost:7860 查看界面")
129
  # 默认开启本地服务器,默认可以直接从IP访问,默认不创建公开分享链接
utils.py CHANGED
@@ -210,15 +210,18 @@ def predict(inputs, top_p, temperature, openai_api_key, chatbot=[], history=[],
210
 
211
 
212
  def delete_last_conversation(chatbot, history):
213
- if "☹️发生了错误" in chatbot[-1][1]:
214
- chatbot.pop()
215
- print(history)
216
- return chatbot, history
217
- if len(history) > 0:
218
  history.pop()
219
  history.pop()
 
220
  print(history)
221
- return chatbot, history
 
 
222
 
223
  def save_chat_history(filename, system, history, chatbot):
224
  if filename == "":
@@ -232,11 +235,16 @@ def save_chat_history(filename, system, history, chatbot):
232
  json.dump(json_s, f)
233
 
234
 
235
- def load_chat_history(filename):
236
- with open(os.path.join(HISTORY_DIR, filename), "r") as f:
237
- json_s = json.load(f)
238
- print(json_s)
239
- return filename, json_s["system"], json_s["history"], json_s["chatbot"]
 
 
 
 
 
240
 
241
  def sorted_by_pinyin(list):
242
  return sorted(list, key=lambda char: lazy_pinyin(char)[0][0])
@@ -250,6 +258,8 @@ def get_file_names(dir, plain=False, filetypes=[".json"]):
250
  except FileNotFoundError:
251
  files = []
252
  files = sorted_by_pinyin(files)
 
 
253
  if plain:
254
  return files
255
  else:
@@ -260,6 +270,7 @@ def get_history_names(plain=False):
260
 
261
  def load_template(filename, mode=0):
262
  lines = []
 
263
  if filename.endswith(".json"):
264
  with open(os.path.join(TEMPLATES_DIR, filename), "r", encoding="utf8") as f:
265
  lines = json.load(f)
@@ -270,19 +281,25 @@ def load_template(filename, mode=0):
270
  lines = list(reader)
271
  lines = lines[1:]
272
  if mode == 1:
273
- return sorted([row[0] for row in lines])
274
  elif mode == 2:
275
  return {row[0]:row[1] for row in lines}
276
  else:
277
- return {row[0]:row[1] for row in lines}, gr.Dropdown.update(choices=sorted_by_pinyin([row[0] for row in lines]))
 
278
 
279
  def get_template_names(plain=False):
280
  return get_file_names(TEMPLATES_DIR, plain, filetypes=[".csv", "json"])
281
 
 
 
 
 
 
 
282
  def reset_state():
283
  return [], []
284
 
285
-
286
  def compose_system(system_prompt):
287
  return {"role": "system", "content": system_prompt}
288
 
 
210
 
211
 
212
  def delete_last_conversation(chatbot, history):
213
+ try:
214
+ if "☹️发生了错误" in chatbot[-1][1]:
215
+ chatbot.pop()
216
+ print(history)
217
+ return chatbot, history
218
  history.pop()
219
  history.pop()
220
+ chatbot.pop()
221
  print(history)
222
+ return chatbot, history
223
+ except:
224
+ return chatbot, history
225
 
226
  def save_chat_history(filename, system, history, chatbot):
227
  if filename == "":
 
235
  json.dump(json_s, f)
236
 
237
 
238
+ def load_chat_history(filename, system, history, chatbot):
239
+ try:
240
+ print("Loading from history...")
241
+ with open(os.path.join(HISTORY_DIR, filename), "r") as f:
242
+ json_s = json.load(f)
243
+ print(json_s)
244
+ return filename, json_s["system"], json_s["history"], json_s["chatbot"]
245
+ except FileNotFoundError:
246
+ print("File not found.")
247
+ return filename, system, history, chatbot
248
 
249
  def sorted_by_pinyin(list):
250
  return sorted(list, key=lambda char: lazy_pinyin(char)[0][0])
 
258
  except FileNotFoundError:
259
  files = []
260
  files = sorted_by_pinyin(files)
261
+ if files == []:
262
+ files = [""]
263
  if plain:
264
  return files
265
  else:
 
270
 
271
  def load_template(filename, mode=0):
272
  lines = []
273
+ print("Loading template...")
274
  if filename.endswith(".json"):
275
  with open(os.path.join(TEMPLATES_DIR, filename), "r", encoding="utf8") as f:
276
  lines = json.load(f)
 
281
  lines = list(reader)
282
  lines = lines[1:]
283
  if mode == 1:
284
+ return sorted_by_pinyin([row[0] for row in lines])
285
  elif mode == 2:
286
  return {row[0]:row[1] for row in lines}
287
  else:
288
+ choices = sorted_by_pinyin([row[0] for row in lines])
289
+ return {row[0]:row[1] for row in lines}, gr.Dropdown.update(choices=choices, value=choices[0])
290
 
291
  def get_template_names(plain=False):
292
  return get_file_names(TEMPLATES_DIR, plain, filetypes=[".csv", "json"])
293
 
294
+ def get_template_content(templates, selection, original_system_prompt):
295
+ try:
296
+ return templates[selection]
297
+ except:
298
+ return original_system_prompt
299
+
300
  def reset_state():
301
  return [], []
302
 
 
303
  def compose_system(system_prompt):
304
  return {"role": "system", "content": system_prompt}
305