Spaces:
Sleeping
Sleeping
Tuchuanhuhuhu
commited on
Commit
·
a543a3d
1
Parent(s):
b69c7d1
加强代码健壮性
Browse files- ChuanhuChatbot.py +9 -9
- utils.py +31 -14
ChuanhuChatbot.py
CHANGED
@@ -69,26 +69,26 @@ with gr.Blocks(css=customCSS) as demo:
|
|
69 |
with gr.Column():
|
70 |
with gr.Row():
|
71 |
with gr.Column(scale=6):
|
72 |
-
templateFileSelectDropdown = gr.Dropdown(label="选择Prompt模板集合文件(.csv)", choices=get_template_names(plain=True), multiselect=False)
|
73 |
with gr.Column(scale=1):
|
74 |
templateRefreshBtn = gr.Button("🔄 刷新")
|
75 |
templaeFileReadBtn = gr.Button("📂 读入模板")
|
76 |
with gr.Row():
|
77 |
with gr.Column(scale=6):
|
78 |
-
templateSelectDropdown = gr.Dropdown(label="从Prompt模板中加载", choices=load_template(get_template_names(plain=True)[0], mode=1), multiselect=False)
|
79 |
with gr.Column(scale=1):
|
80 |
templateApplyBtn = gr.Button("⬇️ 应用")
|
81 |
-
with gr.Accordion(label="保存/加载对话历史记录
|
82 |
with gr.Column():
|
83 |
with gr.Row():
|
84 |
with gr.Column(scale=6):
|
85 |
saveFileName = gr.Textbox(
|
86 |
show_label=True, placeholder=f"在这里输入保存的文件名...", label="设置保存文件名", value="对话历史记录").style(container=True)
|
87 |
with gr.Column(scale=1):
|
88 |
-
|
89 |
with gr.Row():
|
90 |
with gr.Column(scale=6):
|
91 |
-
historyFileSelectDropdown = gr.Dropdown(label="从列表中加载对话", choices=get_history_names(plain=True), multiselect=False)
|
92 |
with gr.Column(scale=1):
|
93 |
historyRefreshBtn = gr.Button("🔄 刷新")
|
94 |
historyReadBtn = gr.Button("📂 读入对话")
|
@@ -116,14 +116,14 @@ with gr.Blocks(css=customCSS) as demo:
|
|
116 |
chatbot, history], show_progress=True)
|
117 |
reduceTokenBtn.click(predict, [txt, top_p, temperature, keyTxt, chatbot, history,
|
118 |
systemPromptTxt, FALSECONSTANT, TRUECOMSTANT], [chatbot, history, statusDisplay], show_progress=True)
|
119 |
-
|
120 |
saveFileName, systemPromptTxt, history, chatbot], None, show_progress=True)
|
121 |
-
|
122 |
historyRefreshBtn.click(get_history_names, None, [historyFileSelectDropdown])
|
123 |
-
historyReadBtn.click(load_chat_history, [historyFileSelectDropdown], [saveFileName, systemPromptTxt, history, chatbot], show_progress=True)
|
124 |
templateRefreshBtn.click(get_template_names, None, [templateFileSelectDropdown])
|
125 |
templaeFileReadBtn.click(load_template, [templateFileSelectDropdown], [promptTemplates, templateSelectDropdown], show_progress=True)
|
126 |
-
templateApplyBtn.click(
|
127 |
|
128 |
print("川虎的温馨提示:访问 http://localhost:7860 查看界面")
|
129 |
# 默认开启本地服务器,默认可以直接从IP访问,默认不创建公开分享链接
|
|
|
69 |
with gr.Column():
|
70 |
with gr.Row():
|
71 |
with gr.Column(scale=6):
|
72 |
+
templateFileSelectDropdown = gr.Dropdown(label="选择Prompt模板集合文件(.csv)", choices=get_template_names(plain=True), multiselect=False, value=get_template_names(plain=True)[0])
|
73 |
with gr.Column(scale=1):
|
74 |
templateRefreshBtn = gr.Button("🔄 刷新")
|
75 |
templaeFileReadBtn = gr.Button("📂 读入模板")
|
76 |
with gr.Row():
|
77 |
with gr.Column(scale=6):
|
78 |
+
templateSelectDropdown = gr.Dropdown(label="从Prompt模板中加载", choices=load_template(get_template_names(plain=True)[0], mode=1), multiselect=False, value=load_template(get_template_names(plain=True)[0], mode=1)[0])
|
79 |
with gr.Column(scale=1):
|
80 |
templateApplyBtn = gr.Button("⬇️ 应用")
|
81 |
+
with gr.Accordion(label="保存/加载对话历史记录", open=False):
|
82 |
with gr.Column():
|
83 |
with gr.Row():
|
84 |
with gr.Column(scale=6):
|
85 |
saveFileName = gr.Textbox(
|
86 |
show_label=True, placeholder=f"在这里输入保存的文件名...", label="设置保存文件名", value="对话历史记录").style(container=True)
|
87 |
with gr.Column(scale=1):
|
88 |
+
saveHistoryBtn = gr.Button("💾 保存对话")
|
89 |
with gr.Row():
|
90 |
with gr.Column(scale=6):
|
91 |
+
historyFileSelectDropdown = gr.Dropdown(label="从列表中加载对话", choices=get_history_names(plain=True), multiselect=False, value=get_history_names(plain=True)[0])
|
92 |
with gr.Column(scale=1):
|
93 |
historyRefreshBtn = gr.Button("🔄 刷新")
|
94 |
historyReadBtn = gr.Button("📂 读入对话")
|
|
|
116 |
chatbot, history], show_progress=True)
|
117 |
reduceTokenBtn.click(predict, [txt, top_p, temperature, keyTxt, chatbot, history,
|
118 |
systemPromptTxt, FALSECONSTANT, TRUECOMSTANT], [chatbot, history, statusDisplay], show_progress=True)
|
119 |
+
saveHistoryBtn.click(save_chat_history, [
|
120 |
saveFileName, systemPromptTxt, history, chatbot], None, show_progress=True)
|
121 |
+
saveHistoryBtn.click(get_history_names, None, [historyFileSelectDropdown])
|
122 |
historyRefreshBtn.click(get_history_names, None, [historyFileSelectDropdown])
|
123 |
+
historyReadBtn.click(load_chat_history, [historyFileSelectDropdown, systemPromptTxt, history, chatbot], [saveFileName, systemPromptTxt, history, chatbot], show_progress=True)
|
124 |
templateRefreshBtn.click(get_template_names, None, [templateFileSelectDropdown])
|
125 |
templaeFileReadBtn.click(load_template, [templateFileSelectDropdown], [promptTemplates, templateSelectDropdown], show_progress=True)
|
126 |
+
templateApplyBtn.click(get_template_content, [promptTemplates, templateSelectDropdown, systemPromptTxt], [systemPromptTxt], show_progress=True)
|
127 |
|
128 |
print("川虎的温馨提示:访问 http://localhost:7860 查看界面")
|
129 |
# 默认开启本地服务器,默认可以直接从IP访问,默认不创建公开分享链接
|
utils.py
CHANGED
@@ -210,15 +210,18 @@ def predict(inputs, top_p, temperature, openai_api_key, chatbot=[], history=[],
|
|
210 |
|
211 |
|
212 |
def delete_last_conversation(chatbot, history):
|
213 |
-
|
214 |
-
chatbot
|
215 |
-
|
216 |
-
|
217 |
-
|
218 |
history.pop()
|
219 |
history.pop()
|
|
|
220 |
print(history)
|
221 |
-
|
|
|
|
|
222 |
|
223 |
def save_chat_history(filename, system, history, chatbot):
|
224 |
if filename == "":
|
@@ -232,11 +235,16 @@ def save_chat_history(filename, system, history, chatbot):
|
|
232 |
json.dump(json_s, f)
|
233 |
|
234 |
|
235 |
-
def load_chat_history(filename):
|
236 |
-
|
237 |
-
|
238 |
-
|
239 |
-
|
|
|
|
|
|
|
|
|
|
|
240 |
|
241 |
def sorted_by_pinyin(list):
|
242 |
return sorted(list, key=lambda char: lazy_pinyin(char)[0][0])
|
@@ -250,6 +258,8 @@ def get_file_names(dir, plain=False, filetypes=[".json"]):
|
|
250 |
except FileNotFoundError:
|
251 |
files = []
|
252 |
files = sorted_by_pinyin(files)
|
|
|
|
|
253 |
if plain:
|
254 |
return files
|
255 |
else:
|
@@ -260,6 +270,7 @@ def get_history_names(plain=False):
|
|
260 |
|
261 |
def load_template(filename, mode=0):
|
262 |
lines = []
|
|
|
263 |
if filename.endswith(".json"):
|
264 |
with open(os.path.join(TEMPLATES_DIR, filename), "r", encoding="utf8") as f:
|
265 |
lines = json.load(f)
|
@@ -270,19 +281,25 @@ def load_template(filename, mode=0):
|
|
270 |
lines = list(reader)
|
271 |
lines = lines[1:]
|
272 |
if mode == 1:
|
273 |
-
return
|
274 |
elif mode == 2:
|
275 |
return {row[0]:row[1] for row in lines}
|
276 |
else:
|
277 |
-
|
|
|
278 |
|
279 |
def get_template_names(plain=False):
|
280 |
return get_file_names(TEMPLATES_DIR, plain, filetypes=[".csv", "json"])
|
281 |
|
|
|
|
|
|
|
|
|
|
|
|
|
282 |
def reset_state():
|
283 |
return [], []
|
284 |
|
285 |
-
|
286 |
def compose_system(system_prompt):
|
287 |
return {"role": "system", "content": system_prompt}
|
288 |
|
|
|
210 |
|
211 |
|
212 |
def delete_last_conversation(chatbot, history):
|
213 |
+
try:
|
214 |
+
if "☹️发生了错误" in chatbot[-1][1]:
|
215 |
+
chatbot.pop()
|
216 |
+
print(history)
|
217 |
+
return chatbot, history
|
218 |
history.pop()
|
219 |
history.pop()
|
220 |
+
chatbot.pop()
|
221 |
print(history)
|
222 |
+
return chatbot, history
|
223 |
+
except:
|
224 |
+
return chatbot, history
|
225 |
|
226 |
def save_chat_history(filename, system, history, chatbot):
|
227 |
if filename == "":
|
|
|
235 |
json.dump(json_s, f)
|
236 |
|
237 |
|
238 |
+
def load_chat_history(filename, system, history, chatbot):
|
239 |
+
try:
|
240 |
+
print("Loading from history...")
|
241 |
+
with open(os.path.join(HISTORY_DIR, filename), "r") as f:
|
242 |
+
json_s = json.load(f)
|
243 |
+
print(json_s)
|
244 |
+
return filename, json_s["system"], json_s["history"], json_s["chatbot"]
|
245 |
+
except FileNotFoundError:
|
246 |
+
print("File not found.")
|
247 |
+
return filename, system, history, chatbot
|
248 |
|
249 |
def sorted_by_pinyin(list):
|
250 |
return sorted(list, key=lambda char: lazy_pinyin(char)[0][0])
|
|
|
258 |
except FileNotFoundError:
|
259 |
files = []
|
260 |
files = sorted_by_pinyin(files)
|
261 |
+
if files == []:
|
262 |
+
files = [""]
|
263 |
if plain:
|
264 |
return files
|
265 |
else:
|
|
|
270 |
|
271 |
def load_template(filename, mode=0):
|
272 |
lines = []
|
273 |
+
print("Loading template...")
|
274 |
if filename.endswith(".json"):
|
275 |
with open(os.path.join(TEMPLATES_DIR, filename), "r", encoding="utf8") as f:
|
276 |
lines = json.load(f)
|
|
|
281 |
lines = list(reader)
|
282 |
lines = lines[1:]
|
283 |
if mode == 1:
|
284 |
+
return sorted_by_pinyin([row[0] for row in lines])
|
285 |
elif mode == 2:
|
286 |
return {row[0]:row[1] for row in lines}
|
287 |
else:
|
288 |
+
choices = sorted_by_pinyin([row[0] for row in lines])
|
289 |
+
return {row[0]:row[1] for row in lines}, gr.Dropdown.update(choices=choices, value=choices[0])
|
290 |
|
291 |
def get_template_names(plain=False):
|
292 |
return get_file_names(TEMPLATES_DIR, plain, filetypes=[".csv", "json"])
|
293 |
|
294 |
+
def get_template_content(templates, selection, original_system_prompt):
|
295 |
+
try:
|
296 |
+
return templates[selection]
|
297 |
+
except:
|
298 |
+
return original_system_prompt
|
299 |
+
|
300 |
def reset_state():
|
301 |
return [], []
|
302 |
|
|
|
303 |
def compose_system(system_prompt):
|
304 |
return {"role": "system", "content": system_prompt}
|
305 |
|