Tuchuanhuhuhu commited on
Commit
b8f3115
·
1 Parent(s): 5d31dec

修改代码格式

Browse files
Files changed (3) hide show
  1. ChuanhuChatbot.py +217 -47
  2. presets.py +21 -14
  3. utils.py +239 -54
ChuanhuChatbot.py CHANGED
@@ -7,12 +7,15 @@ import argparse
7
  from utils import *
8
  from presets import *
9
 
10
- logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] [%(filename)s:%(lineno)d] %(message)s")
 
 
 
11
 
12
- my_api_key = "" # 在这里输入你的 API 密钥
13
 
14
- #if we are running in Docker
15
- if os.environ.get('dockerrun') == 'yes':
16
  dockerflag = True
17
  else:
18
  dockerflag = False
@@ -20,17 +23,21 @@ else:
20
  authflag = False
21
 
22
  if dockerflag:
23
- my_api_key = os.environ.get('my_api_key')
24
  if my_api_key == "empty":
25
  logging.error("Please give a api key!")
26
  sys.exit(1)
27
- #auth
28
- username = os.environ.get('USERNAME')
29
- password = os.environ.get('PASSWORD')
30
  if not (isinstance(username, type(None)) or isinstance(password, type(None))):
31
  authflag = True
32
  else:
33
- if not my_api_key and os.path.exists("api_key.txt") and os.path.getsize("api_key.txt"):
 
 
 
 
34
  with open("api_key.txt", "r") as f:
35
  my_api_key = f.read().strip()
36
  if os.path.exists("auth.json"):
@@ -58,44 +65,91 @@ with gr.Blocks(css=customCSS) as demo:
58
  with gr.Row(scale=1).style(equal_height=True):
59
  with gr.Column(scale=5):
60
  with gr.Row(scale=1):
61
- chatbot = gr.Chatbot().style(height=600) # .style(color_map=("#1D51EE", "#585A5B"))
 
 
62
  with gr.Row(scale=1):
63
  with gr.Column(scale=12):
64
- user_input = gr.Textbox(show_label=False, placeholder="在这里输入").style(
65
- container=False)
 
66
  with gr.Column(min_width=50, scale=1):
67
  submitBtn = gr.Button("🚀", variant="primary")
68
  with gr.Row(scale=1):
69
- emptyBtn = gr.Button("🧹 新的对话",)
 
 
70
  retryBtn = gr.Button("🔄 重新生成")
71
  delLastBtn = gr.Button("🗑️ 删除一条对话")
72
  reduceTokenBtn = gr.Button("♻️ 总结对话")
73
 
74
  with gr.Column():
75
- with gr.Column(min_width=50,scale=1):
76
  with gr.Tab(label="ChatGPT"):
77
- keyTxt = gr.Textbox(show_label=True, placeholder=f"OpenAI API-key...",value=my_api_key, type="password", visible=not HIDE_MY_KEY, label="API-Key")
78
- model_select_dropdown = gr.Dropdown(label="选择模型", choices=MODELS, multiselect=False, value=MODELS[0])
 
 
 
 
 
 
 
 
 
79
  with gr.Accordion("参数", open=False):
80
- top_p = gr.Slider(minimum=-0, maximum=1.0, value=1.0, step=0.05,
81
- interactive=True, label="Top-p (nucleus sampling)",)
82
- temperature = gr.Slider(minimum=-0, maximum=2.0, value=1.0,
83
- step=0.1, interactive=True, label="Temperature",)
84
- use_streaming_checkbox = gr.Checkbox(label="实时传输回答", value=True, visible=enable_streaming_option)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
  use_websearch_checkbox = gr.Checkbox(label="使用在线搜索", value=False)
86
 
87
  with gr.Tab(label="Prompt"):
88
- systemPromptTxt = gr.Textbox(show_label=True, placeholder=f"在这里输入System Prompt...", label="System prompt", value=initial_prompt).style(container=True)
 
 
 
 
 
89
  with gr.Accordion(label="加载Prompt模板", open=True):
90
  with gr.Column():
91
  with gr.Row():
92
  with gr.Column(scale=6):
93
- templateFileSelectDropdown = gr.Dropdown(label="选择Prompt模板集合文件", choices=get_template_names(plain=True), multiselect=False, value=get_template_names(plain=True)[0])
 
 
 
 
 
94
  with gr.Column(scale=1):
95
  templateRefreshBtn = gr.Button("🔄 刷新")
96
  with gr.Row():
97
  with gr.Column():
98
- templateSelectDropdown = gr.Dropdown(label="从Prompt模板中加载", choices=load_template(get_template_names(plain=True)[0], mode=1), multiselect=False, value=load_template(get_template_names(plain=True)[0], mode=1)[0])
 
 
 
 
 
 
 
 
 
99
 
100
  with gr.Tab(label="保存/加载"):
101
  with gr.Accordion(label="保存/加载对话历史记录", open=True):
@@ -104,13 +158,22 @@ with gr.Blocks(css=customCSS) as demo:
104
  with gr.Row():
105
  with gr.Column(scale=6):
106
  saveFileName = gr.Textbox(
107
- show_label=True, placeholder=f"设置文件名: 默认为.json,可选为.md", label="设置保存文件名", value="对话历史记录").style(container=True)
 
 
 
 
108
  with gr.Column(scale=1):
109
  saveHistoryBtn = gr.Button("💾 保存对话")
110
  exportMarkdownBtn = gr.Button("📝 导出为Markdown")
111
  with gr.Row():
112
  with gr.Column(scale=6):
113
- historyFileSelectDropdown = gr.Dropdown(label="从列表中加载对话", choices=get_history_names(plain=True), multiselect=False, value=get_history_names(plain=True)[0])
 
 
 
 
 
114
  with gr.Column(scale=1):
115
  historyRefreshBtn = gr.Button("🔄 刷新")
116
  with gr.Row():
@@ -120,52 +183,159 @@ with gr.Blocks(css=customCSS) as demo:
120
  gr.Markdown(description)
121
 
122
  # Chatbot
123
- user_input.submit(predict, [keyTxt, systemPromptTxt, history, user_input, chatbot, token_count, top_p, temperature, use_streaming_checkbox, model_select_dropdown, use_websearch_checkbox], [chatbot, history, status_display, token_count], show_progress=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
124
  user_input.submit(reset_textbox, [], [user_input])
125
 
126
- submitBtn.click(predict, [keyTxt, systemPromptTxt, history, user_input, chatbot, token_count, top_p, temperature, use_streaming_checkbox, model_select_dropdown, use_websearch_checkbox], [chatbot, history, status_display, token_count], show_progress=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
127
  submitBtn.click(reset_textbox, [], [user_input])
128
 
129
- emptyBtn.click(reset_state, outputs=[chatbot, history, token_count, status_display], show_progress=True)
 
 
 
 
130
 
131
- retryBtn.click(retry, [keyTxt, systemPromptTxt, history, chatbot, token_count, top_p, temperature, use_streaming_checkbox, model_select_dropdown], [chatbot, history, status_display, token_count], show_progress=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
132
 
133
- delLastBtn.click(delete_last_conversation, [chatbot, history, token_count], [
134
- chatbot, history, token_count, status_display], show_progress=True)
 
 
 
 
135
 
136
- reduceTokenBtn.click(reduce_token_size, [keyTxt, systemPromptTxt, history, chatbot, token_count, top_p, temperature, use_streaming_checkbox, model_select_dropdown], [chatbot, history, status_display, token_count], show_progress=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
137
 
138
  # Template
139
  templateRefreshBtn.click(get_template_names, None, [templateFileSelectDropdown])
140
- templateFileSelectDropdown.change(load_template, [templateFileSelectDropdown], [promptTemplates, templateSelectDropdown], show_progress=True)
141
- templateSelectDropdown.change(get_template_content, [promptTemplates, templateSelectDropdown, systemPromptTxt], [systemPromptTxt], show_progress=True)
 
 
 
 
 
 
 
 
 
 
142
 
143
  # S&L
144
- saveHistoryBtn.click(save_chat_history, [saveFileName, systemPromptTxt, history, chatbot], downloadFile, show_progress=True)
 
 
 
 
 
145
  saveHistoryBtn.click(get_history_names, None, [historyFileSelectDropdown])
146
- exportMarkdownBtn.click(export_markdown, [saveFileName, systemPromptTxt, history, chatbot], downloadFile, show_progress=True)
 
 
 
 
 
147
  historyRefreshBtn.click(get_history_names, None, [historyFileSelectDropdown])
148
- historyFileSelectDropdown.change(load_chat_history, [historyFileSelectDropdown, systemPromptTxt, history, chatbot], [saveFileName, systemPromptTxt, history, chatbot], show_progress=True)
149
- downloadFile.change(load_chat_history, [downloadFile, systemPromptTxt, history, chatbot], [saveFileName, systemPromptTxt, history, chatbot])
 
 
 
 
 
 
 
 
 
150
 
151
 
152
- logging.info(colorama.Back.GREEN + "\n川虎的温馨提示:访问 http://localhost:7860 查看界面" + colorama.Style.RESET_ALL)
 
 
 
 
153
  # 默认开启本地服务器,默认可以直接从IP访问,默认不创建公开分享链接
154
  demo.title = "川虎ChatGPT 🚀"
155
 
156
  if __name__ == "__main__":
157
- #if running in Docker
158
  if dockerflag:
159
  if authflag:
160
- demo.queue().launch(server_name="0.0.0.0", server_port=7860,auth=(username, password))
 
 
161
  else:
162
  demo.queue().launch(server_name="0.0.0.0", server_port=7860, share=False)
163
- #if not running in Docker
164
  else:
165
  if authflag:
166
  demo.queue().launch(share=False, auth=(username, password))
167
  else:
168
- demo.queue().launch(share=False) # 改为 share=True 可以创建公开分享链接
169
- #demo.queue().launch(server_name="0.0.0.0", server_port=7860, share=False) # 可自定义端口
170
- #demo.queue().launch(server_name="0.0.0.0", server_port=7860,auth=("在这里填写用户名", "在这里填写密码")) # 可设置用户名与密码
171
- #demo.queue().launch(auth=("在这里填写用户名", "在这里填写密码")) # 适合Nginx反向代理
 
7
  from utils import *
8
  from presets import *
9
 
10
+ logging.basicConfig(
11
+ level=logging.INFO,
12
+ format="%(asctime)s [%(levelname)s] [%(filename)s:%(lineno)d] %(message)s",
13
+ )
14
 
15
+ my_api_key = "" # 在这里输入你的 API 密钥
16
 
17
+ # if we are running in Docker
18
+ if os.environ.get("dockerrun") == "yes":
19
  dockerflag = True
20
  else:
21
  dockerflag = False
 
23
  authflag = False
24
 
25
  if dockerflag:
26
+ my_api_key = os.environ.get("my_api_key")
27
  if my_api_key == "empty":
28
  logging.error("Please give a api key!")
29
  sys.exit(1)
30
+ # auth
31
+ username = os.environ.get("USERNAME")
32
+ password = os.environ.get("PASSWORD")
33
  if not (isinstance(username, type(None)) or isinstance(password, type(None))):
34
  authflag = True
35
  else:
36
+ if (
37
+ not my_api_key
38
+ and os.path.exists("api_key.txt")
39
+ and os.path.getsize("api_key.txt")
40
+ ):
41
  with open("api_key.txt", "r") as f:
42
  my_api_key = f.read().strip()
43
  if os.path.exists("auth.json"):
 
65
  with gr.Row(scale=1).style(equal_height=True):
66
  with gr.Column(scale=5):
67
  with gr.Row(scale=1):
68
+ chatbot = gr.Chatbot().style(
69
+ height=600
70
+ ) # .style(color_map=("#1D51EE", "#585A5B"))
71
  with gr.Row(scale=1):
72
  with gr.Column(scale=12):
73
+ user_input = gr.Textbox(
74
+ show_label=False, placeholder="在这里输入"
75
+ ).style(container=False)
76
  with gr.Column(min_width=50, scale=1):
77
  submitBtn = gr.Button("🚀", variant="primary")
78
  with gr.Row(scale=1):
79
+ emptyBtn = gr.Button(
80
+ "🧹 新的对话",
81
+ )
82
  retryBtn = gr.Button("🔄 重新生成")
83
  delLastBtn = gr.Button("🗑️ 删除一条对话")
84
  reduceTokenBtn = gr.Button("♻️ 总结对话")
85
 
86
  with gr.Column():
87
+ with gr.Column(min_width=50, scale=1):
88
  with gr.Tab(label="ChatGPT"):
89
+ keyTxt = gr.Textbox(
90
+ show_label=True,
91
+ placeholder=f"OpenAI API-key...",
92
+ value=my_api_key,
93
+ type="password",
94
+ visible=not HIDE_MY_KEY,
95
+ label="API-Key",
96
+ )
97
+ model_select_dropdown = gr.Dropdown(
98
+ label="选择模型", choices=MODELS, multiselect=False, value=MODELS[0]
99
+ )
100
  with gr.Accordion("参数", open=False):
101
+ top_p = gr.Slider(
102
+ minimum=-0,
103
+ maximum=1.0,
104
+ value=1.0,
105
+ step=0.05,
106
+ interactive=True,
107
+ label="Top-p (nucleus sampling)",
108
+ )
109
+ temperature = gr.Slider(
110
+ minimum=-0,
111
+ maximum=2.0,
112
+ value=1.0,
113
+ step=0.1,
114
+ interactive=True,
115
+ label="Temperature",
116
+ )
117
+ use_streaming_checkbox = gr.Checkbox(
118
+ label="实时传输回答", value=True, visible=enable_streaming_option
119
+ )
120
  use_websearch_checkbox = gr.Checkbox(label="使用在线搜索", value=False)
121
 
122
  with gr.Tab(label="Prompt"):
123
+ systemPromptTxt = gr.Textbox(
124
+ show_label=True,
125
+ placeholder=f"在这里输入System Prompt...",
126
+ label="System prompt",
127
+ value=initial_prompt,
128
+ ).style(container=True)
129
  with gr.Accordion(label="加载Prompt模板", open=True):
130
  with gr.Column():
131
  with gr.Row():
132
  with gr.Column(scale=6):
133
+ templateFileSelectDropdown = gr.Dropdown(
134
+ label="选择Prompt模板集合文件",
135
+ choices=get_template_names(plain=True),
136
+ multiselect=False,
137
+ value=get_template_names(plain=True)[0],
138
+ )
139
  with gr.Column(scale=1):
140
  templateRefreshBtn = gr.Button("🔄 刷新")
141
  with gr.Row():
142
  with gr.Column():
143
+ templateSelectDropdown = gr.Dropdown(
144
+ label="从Prompt模板中加载",
145
+ choices=load_template(
146
+ get_template_names(plain=True)[0], mode=1
147
+ ),
148
+ multiselect=False,
149
+ value=load_template(
150
+ get_template_names(plain=True)[0], mode=1
151
+ )[0],
152
+ )
153
 
154
  with gr.Tab(label="保存/加载"):
155
  with gr.Accordion(label="保存/加载对话历史记录", open=True):
 
158
  with gr.Row():
159
  with gr.Column(scale=6):
160
  saveFileName = gr.Textbox(
161
+ show_label=True,
162
+ placeholder=f"设置文件名: 默认为.json,可选为.md",
163
+ label="设置保存文件名",
164
+ value="对话历史记录",
165
+ ).style(container=True)
166
  with gr.Column(scale=1):
167
  saveHistoryBtn = gr.Button("💾 保存对话")
168
  exportMarkdownBtn = gr.Button("📝 导出为Markdown")
169
  with gr.Row():
170
  with gr.Column(scale=6):
171
+ historyFileSelectDropdown = gr.Dropdown(
172
+ label="从列表中加载对话",
173
+ choices=get_history_names(plain=True),
174
+ multiselect=False,
175
+ value=get_history_names(plain=True)[0],
176
+ )
177
  with gr.Column(scale=1):
178
  historyRefreshBtn = gr.Button("🔄 刷新")
179
  with gr.Row():
 
183
  gr.Markdown(description)
184
 
185
  # Chatbot
186
+ user_input.submit(
187
+ predict,
188
+ [
189
+ keyTxt,
190
+ systemPromptTxt,
191
+ history,
192
+ user_input,
193
+ chatbot,
194
+ token_count,
195
+ top_p,
196
+ temperature,
197
+ use_streaming_checkbox,
198
+ model_select_dropdown,
199
+ use_websearch_checkbox,
200
+ ],
201
+ [chatbot, history, status_display, token_count],
202
+ show_progress=True,
203
+ )
204
  user_input.submit(reset_textbox, [], [user_input])
205
 
206
+ submitBtn.click(
207
+ predict,
208
+ [
209
+ keyTxt,
210
+ systemPromptTxt,
211
+ history,
212
+ user_input,
213
+ chatbot,
214
+ token_count,
215
+ top_p,
216
+ temperature,
217
+ use_streaming_checkbox,
218
+ model_select_dropdown,
219
+ use_websearch_checkbox,
220
+ ],
221
+ [chatbot, history, status_display, token_count],
222
+ show_progress=True,
223
+ )
224
  submitBtn.click(reset_textbox, [], [user_input])
225
 
226
+ emptyBtn.click(
227
+ reset_state,
228
+ outputs=[chatbot, history, token_count, status_display],
229
+ show_progress=True,
230
+ )
231
 
232
+ retryBtn.click(
233
+ retry,
234
+ [
235
+ keyTxt,
236
+ systemPromptTxt,
237
+ history,
238
+ chatbot,
239
+ token_count,
240
+ top_p,
241
+ temperature,
242
+ use_streaming_checkbox,
243
+ model_select_dropdown,
244
+ ],
245
+ [chatbot, history, status_display, token_count],
246
+ show_progress=True,
247
+ )
248
 
249
+ delLastBtn.click(
250
+ delete_last_conversation,
251
+ [chatbot, history, token_count],
252
+ [chatbot, history, token_count, status_display],
253
+ show_progress=True,
254
+ )
255
 
256
+ reduceTokenBtn.click(
257
+ reduce_token_size,
258
+ [
259
+ keyTxt,
260
+ systemPromptTxt,
261
+ history,
262
+ chatbot,
263
+ token_count,
264
+ top_p,
265
+ temperature,
266
+ use_streaming_checkbox,
267
+ model_select_dropdown,
268
+ ],
269
+ [chatbot, history, status_display, token_count],
270
+ show_progress=True,
271
+ )
272
 
273
  # Template
274
  templateRefreshBtn.click(get_template_names, None, [templateFileSelectDropdown])
275
+ templateFileSelectDropdown.change(
276
+ load_template,
277
+ [templateFileSelectDropdown],
278
+ [promptTemplates, templateSelectDropdown],
279
+ show_progress=True,
280
+ )
281
+ templateSelectDropdown.change(
282
+ get_template_content,
283
+ [promptTemplates, templateSelectDropdown, systemPromptTxt],
284
+ [systemPromptTxt],
285
+ show_progress=True,
286
+ )
287
 
288
  # S&L
289
+ saveHistoryBtn.click(
290
+ save_chat_history,
291
+ [saveFileName, systemPromptTxt, history, chatbot],
292
+ downloadFile,
293
+ show_progress=True,
294
+ )
295
  saveHistoryBtn.click(get_history_names, None, [historyFileSelectDropdown])
296
+ exportMarkdownBtn.click(
297
+ export_markdown,
298
+ [saveFileName, systemPromptTxt, history, chatbot],
299
+ downloadFile,
300
+ show_progress=True,
301
+ )
302
  historyRefreshBtn.click(get_history_names, None, [historyFileSelectDropdown])
303
+ historyFileSelectDropdown.change(
304
+ load_chat_history,
305
+ [historyFileSelectDropdown, systemPromptTxt, history, chatbot],
306
+ [saveFileName, systemPromptTxt, history, chatbot],
307
+ show_progress=True,
308
+ )
309
+ downloadFile.change(
310
+ load_chat_history,
311
+ [downloadFile, systemPromptTxt, history, chatbot],
312
+ [saveFileName, systemPromptTxt, history, chatbot],
313
+ )
314
 
315
 
316
+ logging.info(
317
+ colorama.Back.GREEN
318
+ + "\n川虎的温馨提示:访问 http://localhost:7860 查看界面"
319
+ + colorama.Style.RESET_ALL
320
+ )
321
  # 默认开启本地服务器,默认可以直接从IP访问,默认不创建公开分享链接
322
  demo.title = "川虎ChatGPT 🚀"
323
 
324
  if __name__ == "__main__":
325
+ # if running in Docker
326
  if dockerflag:
327
  if authflag:
328
+ demo.queue().launch(
329
+ server_name="0.0.0.0", server_port=7860, auth=(username, password)
330
+ )
331
  else:
332
  demo.queue().launch(server_name="0.0.0.0", server_port=7860, share=False)
333
+ # if not running in Docker
334
  else:
335
  if authflag:
336
  demo.queue().launch(share=False, auth=(username, password))
337
  else:
338
+ demo.queue().launch(share=False) # 改为 share=True 可以创建公开分享链接
339
+ # demo.queue().launch(server_name="0.0.0.0", server_port=7860, share=False) # 可自定义端口
340
+ # demo.queue().launch(server_name="0.0.0.0", server_port=7860,auth=("在这里填写用户名", "在这里填写密码")) # 可设置用户名与密码
341
+ # demo.queue().launch(auth=("在这里填写用户名", "在这里填写密码")) # 适合Nginx反向代理
presets.py CHANGED
@@ -62,8 +62,15 @@ pre code {
62
  }
63
  """
64
 
65
- summarize_prompt = "你是谁?我们刚才聊了什么?" # 总结对话时的 prompt
66
- MODELS = ["gpt-3.5-turbo", "gpt-3.5-turbo-0301", "gpt-4","gpt-4-0314", "gpt-4-32k", "gpt-4-32k-0314"] # 可选的模型
 
 
 
 
 
 
 
67
  websearch_prompt = """Web search results:
68
 
69
  {web_results}
@@ -74,17 +81,17 @@ Query: {query}
74
  Reply in 中文"""
75
 
76
  # 错误信息
77
- standard_error_msg = "☹️发生了错误:" # 错误信息的标准前缀
78
- error_retrieve_prompt = "请检查网络连接,或者API-Key是否有效。" # 获取对话时发生错误
79
- connection_timeout_prompt = "连接超时,无法获取对话。" # 连接超时
80
- read_timeout_prompt = "读取超时,无法获取对话。" # 读取超时
81
- proxy_error_prompt = "代理错误,无法获取对话。" # 代理错误
82
- ssl_error_prompt = "SSL错误,无法获取对话。" # SSL 错误
83
- no_apikey_msg = "API key长度不是51位,请检查是否输入正确。" # API key 长度不足 51 位
84
 
85
- max_token_streaming = 3500 # 流式对话时的最大 token 数
86
- timeout_streaming = 30 # 流式对话时的超时时间
87
- max_token_all = 3500 # 非流式对话时的最大 token 数
88
- timeout_all = 200 # 非流式对话时的超时时间
89
  enable_streaming_option = True # 是否启用选择选择是否实时显示回答的勾选框
90
- HIDE_MY_KEY = False # 如果你想在UI中隐藏你的 API 密钥,将此值设置为 True
 
62
  }
63
  """
64
 
65
+ summarize_prompt = "你是谁?我们刚才聊了什么?" # 总结对话时的 prompt
66
+ MODELS = [
67
+ "gpt-3.5-turbo",
68
+ "gpt-3.5-turbo-0301",
69
+ "gpt-4",
70
+ "gpt-4-0314",
71
+ "gpt-4-32k",
72
+ "gpt-4-32k-0314",
73
+ ] # 可选的模型
74
  websearch_prompt = """Web search results:
75
 
76
  {web_results}
 
81
  Reply in 中文"""
82
 
83
  # 错误信息
84
+ standard_error_msg = "☹️发生了错误:" # 错误信息的标准前缀
85
+ error_retrieve_prompt = "请检查网络连接,或者API-Key是否有效。" # 获取对话时发生错误
86
+ connection_timeout_prompt = "连接超时,无法获取对话。" # 连接超时
87
+ read_timeout_prompt = "读取超时,无法获取对话。" # 读取超时
88
+ proxy_error_prompt = "代理错误,无法获取对话。" # 代理错误
89
+ ssl_error_prompt = "SSL错误,无法获取对话。" # SSL 错误
90
+ no_apikey_msg = "API key长度不是51位,请检查是否输入正确。" # API key 长度不足 51 位
91
 
92
+ max_token_streaming = 3500 # 流式对话时的最大 token 数
93
+ timeout_streaming = 30 # 流式对话时的超时时间
94
+ max_token_all = 3500 # 非流式对话时的最大 token 数
95
+ timeout_all = 200 # 非流式对话时的超时时间
96
  enable_streaming_option = True # 是否启用选择选择是否实时显示回答的勾选框
97
+ HIDE_MY_KEY = False # 如果你想在UI中隐藏你的 API 密钥,将此值设置为 True
utils.py CHANGED
@@ -4,10 +4,12 @@ from typing import TYPE_CHECKING, Any, Callable, Dict, List, Tuple, Type
4
  import logging
5
  import json
6
  import gradio as gr
 
7
  # import openai
8
  import os
9
  import traceback
10
  import requests
 
11
  # import markdown
12
  import csv
13
  import mdtex2html
@@ -28,30 +30,33 @@ if TYPE_CHECKING:
28
  headers: List[str]
29
  data: List[List[str | int | bool]]
30
 
 
31
  initial_prompt = "You are a helpful assistant."
32
  API_URL = "https://api.openai.com/v1/chat/completions"
33
  HISTORY_DIR = "history"
34
  TEMPLATES_DIR = "templates"
35
 
 
36
  def postprocess(
37
- self, y: List[Tuple[str | None, str | None]]
38
- ) -> List[Tuple[str | None, str | None]]:
39
- """
40
- Parameters:
41
- y: List of tuples representing the message and response pairs. Each message and response should be a string, which may be in Markdown format.
42
- Returns:
43
- List of tuples representing the message and response. Each message and response will be a string of HTML.
44
- """
45
- if y is None:
46
- return []
47
- for i, (message, response) in enumerate(y):
48
- y[i] = (
49
- # None if message is None else markdown.markdown(message),
50
- # None if response is None else markdown.markdown(response),
51
- None if message is None else message,
52
- None if response is None else mdtex2html.convert(response),
53
- )
54
- return y
 
55
 
56
  def count_token(message):
57
  encoding = tiktoken.get_encoding("cl100k_base")
@@ -59,6 +64,7 @@ def count_token(message):
59
  length = len(encoding.encode(input_str))
60
  return length
61
 
 
62
  def parse_text(text):
63
  lines = text.split("\n")
64
  lines = [line for line in lines if line != ""]
@@ -66,11 +72,11 @@ def parse_text(text):
66
  for i, line in enumerate(lines):
67
  if "```" in line:
68
  count += 1
69
- items = line.split('`')
70
  if count % 2 == 1:
71
  lines[i] = f'<pre><code class="language-{items[-1]}">'
72
  else:
73
- lines[i] = f'<br></code></pre>'
74
  else:
75
  if i > 0:
76
  if count % 2 == 1:
@@ -86,29 +92,37 @@ def parse_text(text):
86
  line = line.replace("(", "&#40;")
87
  line = line.replace(")", "&#41;")
88
  line = line.replace("$", "&#36;")
89
- lines[i] = "<br>"+line
90
  text = "".join(lines)
91
  return text
92
 
 
93
  def construct_text(role, text):
94
  return {"role": role, "content": text}
95
 
 
96
  def construct_user(text):
97
  return construct_text("user", text)
98
 
 
99
  def construct_system(text):
100
  return construct_text("system", text)
101
 
 
102
  def construct_assistant(text):
103
  return construct_text("assistant", text)
104
 
 
105
  def construct_token_message(token, stream=False):
106
  return f"Token 计数: {token}"
107
 
108
- def get_response(openai_api_key, system_prompt, history, temperature, top_p, stream, selected_model):
 
 
 
109
  headers = {
110
  "Content-Type": "application/json",
111
- "Authorization": f"Bearer {openai_api_key}"
112
  }
113
 
114
  history = [construct_system(system_prompt), *history]
@@ -127,10 +141,23 @@ def get_response(openai_api_key, system_prompt, history, temperature, top_p, str
127
  timeout = timeout_streaming
128
  else:
129
  timeout = timeout_all
130
- response = requests.post(API_URL, headers=headers, json=payload, stream=True, timeout=timeout)
 
 
131
  return response
132
 
133
- def stream_predict(openai_api_key, system_prompt, history, inputs, chatbot, all_token_counts, top_p, temperature, selected_model):
 
 
 
 
 
 
 
 
 
 
 
134
  def get_return_value():
135
  return chatbot, history, status_text, all_token_counts
136
 
@@ -144,16 +171,28 @@ def stream_predict(openai_api_key, system_prompt, history, inputs, chatbot, all_
144
  user_token_count = 0
145
  if len(all_token_counts) == 0:
146
  system_prompt_token_count = count_token(construct_system(system_prompt))
147
- user_token_count = count_token(construct_user(inputs)) + system_prompt_token_count
 
 
148
  else:
149
  user_token_count = count_token(construct_user(inputs))
150
  all_token_counts.append(user_token_count)
151
  logging.info(f"输入token计数: {user_token_count}")
152
  yield get_return_value()
153
  try:
154
- response = get_response(openai_api_key, system_prompt, history, temperature, top_p, True, selected_model)
 
 
 
 
 
 
 
 
155
  except requests.exceptions.ConnectTimeout:
156
- status_text = standard_error_msg + connection_timeout_prompt + error_retrieve_prompt
 
 
157
  yield get_return_value()
158
  return
159
  except requests.exceptions.ReadTimeout:
@@ -182,16 +221,24 @@ def stream_predict(openai_api_key, system_prompt, history, inputs, chatbot, all_
182
  yield get_return_value()
183
  continue
184
  # decode each line as response data is in bytes
185
- if chunklength > 6 and "delta" in chunk['choices'][0]:
186
- finish_reason = chunk['choices'][0]['finish_reason']
187
- status_text = construct_token_message(sum(all_token_counts), stream=True)
 
 
188
  if finish_reason == "stop":
189
  yield get_return_value()
190
  break
191
  try:
192
- partial_words = partial_words + chunk['choices'][0]["delta"]["content"]
 
 
193
  except KeyError:
194
- status_text = standard_error_msg + "API回复中找不到内容。很可能是Token计数达到上限了。请重置对话。当前Token计数: " + str(sum(all_token_counts))
 
 
 
 
195
  yield get_return_value()
196
  break
197
  history[-1] = construct_assistant(partial_words)
@@ -200,16 +247,36 @@ def stream_predict(openai_api_key, system_prompt, history, inputs, chatbot, all_
200
  yield get_return_value()
201
 
202
 
203
- def predict_all(openai_api_key, system_prompt, history, inputs, chatbot, all_token_counts, top_p, temperature, selected_model):
 
 
 
 
 
 
 
 
 
 
204
  logging.info("一次性回答模式")
205
  history.append(construct_user(inputs))
206
  history.append(construct_assistant(""))
207
  chatbot.append((parse_text(inputs), ""))
208
  all_token_counts.append(count_token(construct_user(inputs)))
209
  try:
210
- response = get_response(openai_api_key, system_prompt, history, temperature, top_p, False, selected_model)
 
 
 
 
 
 
 
 
211
  except requests.exceptions.ConnectTimeout:
212
- status_text = standard_error_msg + connection_timeout_prompt + error_retrieve_prompt
 
 
213
  return chatbot, history, status_text, all_token_counts
214
  except requests.exceptions.ProxyError:
215
  status_text = standard_error_msg + proxy_error_prompt + error_retrieve_prompt
@@ -227,8 +294,21 @@ def predict_all(openai_api_key, system_prompt, history, inputs, chatbot, all_tok
227
  return chatbot, history, status_text, all_token_counts
228
 
229
 
230
- def predict(openai_api_key, system_prompt, history, inputs, chatbot, all_token_counts, top_p, temperature, stream=False, selected_model = MODELS[0], use_websearch_checkbox = False, should_check_token_count = True): # repetition_penalty, top_k
231
- logging.info("输入为:" +colorama.Fore.BLUE + f"{inputs}" + colorama.Style.RESET_ALL)
 
 
 
 
 
 
 
 
 
 
 
 
 
232
  if use_websearch_checkbox:
233
  results = ddg(inputs, max_results=3)
234
  web_results = []
@@ -237,7 +317,11 @@ def predict(openai_api_key, system_prompt, history, inputs, chatbot, all_token_c
237
  web_results.append(f'[{idx+1}]"{result["body"]}"\nURL: {result["href"]}')
238
  web_results = "\n\n".join(web_results)
239
  today = datetime.datetime.today().strftime("%Y-%m-%d")
240
- inputs = websearch_prompt.replace("{current_date}", today).replace("{query}", inputs).replace("{web_results}", web_results)
 
 
 
 
241
  if len(openai_api_key) != 51:
242
  status_text = standard_error_msg + no_apikey_msg
243
  logging.info(status_text)
@@ -254,16 +338,41 @@ def predict(openai_api_key, system_prompt, history, inputs, chatbot, all_token_c
254
  yield chatbot, history, "开始生成回答……", all_token_counts
255
  if stream:
256
  logging.info("使用流式传输")
257
- iter = stream_predict(openai_api_key, system_prompt, history, inputs, chatbot, all_token_counts, top_p, temperature, selected_model)
 
 
 
 
 
 
 
 
 
 
258
  for chatbot, history, status_text, all_token_counts in iter:
259
  yield chatbot, history, status_text, all_token_counts
260
  else:
261
  logging.info("不使用流式传输")
262
- chatbot, history, status_text, all_token_counts = predict_all(openai_api_key, system_prompt, history, inputs, chatbot, all_token_counts, top_p, temperature, selected_model)
 
 
 
 
 
 
 
 
 
 
263
  yield chatbot, history, status_text, all_token_counts
264
  logging.info(f"传输完毕。当前token计数为{all_token_counts}")
265
- if len(history) > 1 and history[-1]['content'] != inputs:
266
- logging.info("回答为:" +colorama.Fore.BLUE + f"{history[-1]['content']}" + colorama.Style.RESET_ALL)
 
 
 
 
 
267
  if stream:
268
  max_token = max_token_streaming
269
  else:
@@ -272,13 +381,34 @@ def predict(openai_api_key, system_prompt, history, inputs, chatbot, all_token_c
272
  status_text = f"精简token中{all_token_counts}/{max_token}"
273
  logging.info(status_text)
274
  yield chatbot, history, status_text, all_token_counts
275
- iter = reduce_token_size(openai_api_key, system_prompt, history, chatbot, all_token_counts, top_p, temperature, stream=False, selected_model=selected_model, hidden=True)
 
 
 
 
 
 
 
 
 
 
 
276
  for chatbot, history, status_text, all_token_counts in iter:
277
  status_text = f"Token 达到上限,已自动降低Token计数至 {status_text}"
278
  yield chatbot, history, status_text, all_token_counts
279
 
280
 
281
- def retry(openai_api_key, system_prompt, history, chatbot, token_count, top_p, temperature, stream=False, selected_model = MODELS[0]):
 
 
 
 
 
 
 
 
 
 
282
  logging.info("重试中……")
283
  if len(history) == 0:
284
  yield chatbot, history, f"{standard_error_msg}上下文是空的", token_count
@@ -286,22 +416,58 @@ def retry(openai_api_key, system_prompt, history, chatbot, token_count, top_p, t
286
  history.pop()
287
  inputs = history.pop()["content"]
288
  token_count.pop()
289
- iter = predict(openai_api_key, system_prompt, history, inputs, chatbot, token_count, top_p, temperature, stream=stream, selected_model=selected_model)
 
 
 
 
 
 
 
 
 
 
 
290
  logging.info("重试完毕")
291
  for x in iter:
292
  yield x
293
 
294
 
295
- def reduce_token_size(openai_api_key, system_prompt, history, chatbot, token_count, top_p, temperature, stream=False, selected_model = MODELS[0], hidden=False):
 
 
 
 
 
 
 
 
 
 
 
296
  logging.info("开始减少token数量……")
297
- iter = predict(openai_api_key, system_prompt, history, summarize_prompt, chatbot, token_count, top_p, temperature, stream=stream, selected_model = selected_model, should_check_token_count=False)
 
 
 
 
 
 
 
 
 
 
 
 
298
  logging.info(f"chatbot: {chatbot}")
299
  for chatbot, history, status_text, previous_token_count in iter:
300
  history = history[-2:]
301
  token_count = previous_token_count[-1:]
302
  if hidden:
303
  chatbot.pop()
304
- yield chatbot, history, construct_token_message(sum(token_count), stream=stream), token_count
 
 
305
  logging.info("减少token数量完毕")
306
 
307
 
@@ -320,7 +486,12 @@ def delete_last_conversation(chatbot, history, previous_token_count):
320
  if len(previous_token_count) > 0:
321
  logging.info("删除了一组对话的token计数记录")
322
  previous_token_count.pop()
323
- return chatbot, history, previous_token_count, construct_token_message(sum(previous_token_count))
 
 
 
 
 
324
 
325
 
326
  def save_file(filename, system, history, chatbot):
@@ -340,6 +511,7 @@ def save_file(filename, system, history, chatbot):
340
  logging.info("保存对话历史完毕")
341
  return os.path.join(HISTORY_DIR, filename)
342
 
 
343
  def save_chat_history(filename, system, history, chatbot):
344
  if filename == "":
345
  return
@@ -347,6 +519,7 @@ def save_chat_history(filename, system, history, chatbot):
347
  filename += ".json"
348
  return save_file(filename, system, history, chatbot)
349
 
 
350
  def export_markdown(filename, system, history, chatbot):
351
  if filename == "":
352
  return
@@ -382,9 +555,11 @@ def load_chat_history(filename, system, history, chatbot):
382
  logging.info("没有找到对话历史文件,不执行任何操作")
383
  return filename, system, history, chatbot
384
 
 
385
  def sorted_by_pinyin(list):
386
  return sorted(list, key=lambda char: lazy_pinyin(char)[0][0])
387
 
 
388
  def get_file_names(dir, plain=False, filetypes=[".json"]):
389
  logging.info(f"获取文件名列表,目录为{dir},文件类型为{filetypes},是否为纯文本列表{plain}")
390
  files = []
@@ -401,10 +576,12 @@ def get_file_names(dir, plain=False, filetypes=[".json"]):
401
  else:
402
  return gr.Dropdown.update(choices=files)
403
 
 
404
  def get_history_names(plain=False):
405
  logging.info("获取历史记录文件名列表")
406
  return get_file_names(HISTORY_DIR, plain)
407
 
 
408
  def load_template(filename, mode=0):
409
  logging.info(f"加载模板文件{filename},模式为{mode}(0为返回字典和下拉菜单,1为返回下拉菜单,2为返回字典)")
410
  lines = []
@@ -414,22 +591,28 @@ def load_template(filename, mode=0):
414
  lines = json.load(f)
415
  lines = [[i["act"], i["prompt"]] for i in lines]
416
  else:
417
- with open(os.path.join(TEMPLATES_DIR, filename), "r", encoding="utf8") as csvfile:
 
 
418
  reader = csv.reader(csvfile)
419
  lines = list(reader)
420
  lines = lines[1:]
421
  if mode == 1:
422
  return sorted_by_pinyin([row[0] for row in lines])
423
  elif mode == 2:
424
- return {row[0]:row[1] for row in lines}
425
  else:
426
  choices = sorted_by_pinyin([row[0] for row in lines])
427
- return {row[0]:row[1] for row in lines}, gr.Dropdown.update(choices=choices, value=choices[0])
 
 
 
428
 
429
  def get_template_names(plain=False):
430
  logging.info("获取模板文件名列表")
431
  return get_file_names(TEMPLATES_DIR, plain, filetypes=[".csv", "json"])
432
 
 
433
  def get_template_content(templates, selection, original_system_prompt):
434
  logging.info(f"应用模板中,选择为{selection},原始系统提示为{original_system_prompt}")
435
  try:
@@ -437,9 +620,11 @@ def get_template_content(templates, selection, original_system_prompt):
437
  except:
438
  return original_system_prompt
439
 
 
440
  def reset_state():
441
  logging.info("重置状态")
442
  return [], [], [], construct_token_message(0)
443
 
 
444
  def reset_textbox():
445
- return gr.update(value='')
 
4
  import logging
5
  import json
6
  import gradio as gr
7
+
8
  # import openai
9
  import os
10
  import traceback
11
  import requests
12
+
13
  # import markdown
14
  import csv
15
  import mdtex2html
 
30
  headers: List[str]
31
  data: List[List[str | int | bool]]
32
 
33
+
34
  initial_prompt = "You are a helpful assistant."
35
  API_URL = "https://api.openai.com/v1/chat/completions"
36
  HISTORY_DIR = "history"
37
  TEMPLATES_DIR = "templates"
38
 
39
+
40
  def postprocess(
41
+ self, y: List[Tuple[str | None, str | None]]
42
+ ) -> List[Tuple[str | None, str | None]]:
43
+ """
44
+ Parameters:
45
+ y: List of tuples representing the message and response pairs. Each message and response should be a string, which may be in Markdown format.
46
+ Returns:
47
+ List of tuples representing the message and response. Each message and response will be a string of HTML.
48
+ """
49
+ if y is None:
50
+ return []
51
+ for i, (message, response) in enumerate(y):
52
+ y[i] = (
53
+ # None if message is None else markdown.markdown(message),
54
+ # None if response is None else markdown.markdown(response),
55
+ None if message is None else message,
56
+ None if response is None else mdtex2html.convert(response),
57
+ )
58
+ return y
59
+
60
 
61
  def count_token(message):
62
  encoding = tiktoken.get_encoding("cl100k_base")
 
64
  length = len(encoding.encode(input_str))
65
  return length
66
 
67
+
68
  def parse_text(text):
69
  lines = text.split("\n")
70
  lines = [line for line in lines if line != ""]
 
72
  for i, line in enumerate(lines):
73
  if "```" in line:
74
  count += 1
75
+ items = line.split("`")
76
  if count % 2 == 1:
77
  lines[i] = f'<pre><code class="language-{items[-1]}">'
78
  else:
79
+ lines[i] = f"<br></code></pre>"
80
  else:
81
  if i > 0:
82
  if count % 2 == 1:
 
92
  line = line.replace("(", "&#40;")
93
  line = line.replace(")", "&#41;")
94
  line = line.replace("$", "&#36;")
95
+ lines[i] = "<br>" + line
96
  text = "".join(lines)
97
  return text
98
 
99
+
100
  def construct_text(role, text):
101
  return {"role": role, "content": text}
102
 
103
+
104
  def construct_user(text):
105
  return construct_text("user", text)
106
 
107
+
108
  def construct_system(text):
109
  return construct_text("system", text)
110
 
111
+
112
  def construct_assistant(text):
113
  return construct_text("assistant", text)
114
 
115
+
116
  def construct_token_message(token, stream=False):
117
  return f"Token 计数: {token}"
118
 
119
+
120
+ def get_response(
121
+ openai_api_key, system_prompt, history, temperature, top_p, stream, selected_model
122
+ ):
123
  headers = {
124
  "Content-Type": "application/json",
125
+ "Authorization": f"Bearer {openai_api_key}",
126
  }
127
 
128
  history = [construct_system(system_prompt), *history]
 
141
  timeout = timeout_streaming
142
  else:
143
  timeout = timeout_all
144
+ response = requests.post(
145
+ API_URL, headers=headers, json=payload, stream=True, timeout=timeout
146
+ )
147
  return response
148
 
149
+
150
+ def stream_predict(
151
+ openai_api_key,
152
+ system_prompt,
153
+ history,
154
+ inputs,
155
+ chatbot,
156
+ all_token_counts,
157
+ top_p,
158
+ temperature,
159
+ selected_model,
160
+ ):
161
  def get_return_value():
162
  return chatbot, history, status_text, all_token_counts
163
 
 
171
  user_token_count = 0
172
  if len(all_token_counts) == 0:
173
  system_prompt_token_count = count_token(construct_system(system_prompt))
174
+ user_token_count = (
175
+ count_token(construct_user(inputs)) + system_prompt_token_count
176
+ )
177
  else:
178
  user_token_count = count_token(construct_user(inputs))
179
  all_token_counts.append(user_token_count)
180
  logging.info(f"输入token计数: {user_token_count}")
181
  yield get_return_value()
182
  try:
183
+ response = get_response(
184
+ openai_api_key,
185
+ system_prompt,
186
+ history,
187
+ temperature,
188
+ top_p,
189
+ True,
190
+ selected_model,
191
+ )
192
  except requests.exceptions.ConnectTimeout:
193
+ status_text = (
194
+ standard_error_msg + connection_timeout_prompt + error_retrieve_prompt
195
+ )
196
  yield get_return_value()
197
  return
198
  except requests.exceptions.ReadTimeout:
 
221
  yield get_return_value()
222
  continue
223
  # decode each line as response data is in bytes
224
+ if chunklength > 6 and "delta" in chunk["choices"][0]:
225
+ finish_reason = chunk["choices"][0]["finish_reason"]
226
+ status_text = construct_token_message(
227
+ sum(all_token_counts), stream=True
228
+ )
229
  if finish_reason == "stop":
230
  yield get_return_value()
231
  break
232
  try:
233
+ partial_words = (
234
+ partial_words + chunk["choices"][0]["delta"]["content"]
235
+ )
236
  except KeyError:
237
+ status_text = (
238
+ standard_error_msg
239
+ + "API回复中找不到内容。很可能是Token计数达到上限了。请重置对话。当前Token计数: "
240
+ + str(sum(all_token_counts))
241
+ )
242
  yield get_return_value()
243
  break
244
  history[-1] = construct_assistant(partial_words)
 
247
  yield get_return_value()
248
 
249
 
250
+ def predict_all(
251
+ openai_api_key,
252
+ system_prompt,
253
+ history,
254
+ inputs,
255
+ chatbot,
256
+ all_token_counts,
257
+ top_p,
258
+ temperature,
259
+ selected_model,
260
+ ):
261
  logging.info("一次性回答模式")
262
  history.append(construct_user(inputs))
263
  history.append(construct_assistant(""))
264
  chatbot.append((parse_text(inputs), ""))
265
  all_token_counts.append(count_token(construct_user(inputs)))
266
  try:
267
+ response = get_response(
268
+ openai_api_key,
269
+ system_prompt,
270
+ history,
271
+ temperature,
272
+ top_p,
273
+ False,
274
+ selected_model,
275
+ )
276
  except requests.exceptions.ConnectTimeout:
277
+ status_text = (
278
+ standard_error_msg + connection_timeout_prompt + error_retrieve_prompt
279
+ )
280
  return chatbot, history, status_text, all_token_counts
281
  except requests.exceptions.ProxyError:
282
  status_text = standard_error_msg + proxy_error_prompt + error_retrieve_prompt
 
294
  return chatbot, history, status_text, all_token_counts
295
 
296
 
297
+ def predict(
298
+ openai_api_key,
299
+ system_prompt,
300
+ history,
301
+ inputs,
302
+ chatbot,
303
+ all_token_counts,
304
+ top_p,
305
+ temperature,
306
+ stream=False,
307
+ selected_model=MODELS[0],
308
+ use_websearch_checkbox=False,
309
+ should_check_token_count=True,
310
+ ): # repetition_penalty, top_k
311
+ logging.info("输入为:" + colorama.Fore.BLUE + f"{inputs}" + colorama.Style.RESET_ALL)
312
  if use_websearch_checkbox:
313
  results = ddg(inputs, max_results=3)
314
  web_results = []
 
317
  web_results.append(f'[{idx+1}]"{result["body"]}"\nURL: {result["href"]}')
318
  web_results = "\n\n".join(web_results)
319
  today = datetime.datetime.today().strftime("%Y-%m-%d")
320
+ inputs = (
321
+ websearch_prompt.replace("{current_date}", today)
322
+ .replace("{query}", inputs)
323
+ .replace("{web_results}", web_results)
324
+ )
325
  if len(openai_api_key) != 51:
326
  status_text = standard_error_msg + no_apikey_msg
327
  logging.info(status_text)
 
338
  yield chatbot, history, "开始生成回答……", all_token_counts
339
  if stream:
340
  logging.info("使用流式传输")
341
+ iter = stream_predict(
342
+ openai_api_key,
343
+ system_prompt,
344
+ history,
345
+ inputs,
346
+ chatbot,
347
+ all_token_counts,
348
+ top_p,
349
+ temperature,
350
+ selected_model,
351
+ )
352
  for chatbot, history, status_text, all_token_counts in iter:
353
  yield chatbot, history, status_text, all_token_counts
354
  else:
355
  logging.info("不使用流式传输")
356
+ chatbot, history, status_text, all_token_counts = predict_all(
357
+ openai_api_key,
358
+ system_prompt,
359
+ history,
360
+ inputs,
361
+ chatbot,
362
+ all_token_counts,
363
+ top_p,
364
+ temperature,
365
+ selected_model,
366
+ )
367
  yield chatbot, history, status_text, all_token_counts
368
  logging.info(f"传输完毕。当前token计数为{all_token_counts}")
369
+ if len(history) > 1 and history[-1]["content"] != inputs:
370
+ logging.info(
371
+ "回答为:"
372
+ + colorama.Fore.BLUE
373
+ + f"{history[-1]['content']}"
374
+ + colorama.Style.RESET_ALL
375
+ )
376
  if stream:
377
  max_token = max_token_streaming
378
  else:
 
381
  status_text = f"精简token中{all_token_counts}/{max_token}"
382
  logging.info(status_text)
383
  yield chatbot, history, status_text, all_token_counts
384
+ iter = reduce_token_size(
385
+ openai_api_key,
386
+ system_prompt,
387
+ history,
388
+ chatbot,
389
+ all_token_counts,
390
+ top_p,
391
+ temperature,
392
+ stream=False,
393
+ selected_model=selected_model,
394
+ hidden=True,
395
+ )
396
  for chatbot, history, status_text, all_token_counts in iter:
397
  status_text = f"Token 达到上限,已自动降低Token计数至 {status_text}"
398
  yield chatbot, history, status_text, all_token_counts
399
 
400
 
401
+ def retry(
402
+ openai_api_key,
403
+ system_prompt,
404
+ history,
405
+ chatbot,
406
+ token_count,
407
+ top_p,
408
+ temperature,
409
+ stream=False,
410
+ selected_model=MODELS[0],
411
+ ):
412
  logging.info("重试中……")
413
  if len(history) == 0:
414
  yield chatbot, history, f"{standard_error_msg}上下文是空的", token_count
 
416
  history.pop()
417
  inputs = history.pop()["content"]
418
  token_count.pop()
419
+ iter = predict(
420
+ openai_api_key,
421
+ system_prompt,
422
+ history,
423
+ inputs,
424
+ chatbot,
425
+ token_count,
426
+ top_p,
427
+ temperature,
428
+ stream=stream,
429
+ selected_model=selected_model,
430
+ )
431
  logging.info("重试完毕")
432
  for x in iter:
433
  yield x
434
 
435
 
436
+ def reduce_token_size(
437
+ openai_api_key,
438
+ system_prompt,
439
+ history,
440
+ chatbot,
441
+ token_count,
442
+ top_p,
443
+ temperature,
444
+ stream=False,
445
+ selected_model=MODELS[0],
446
+ hidden=False,
447
+ ):
448
  logging.info("开始减少token数量……")
449
+ iter = predict(
450
+ openai_api_key,
451
+ system_prompt,
452
+ history,
453
+ summarize_prompt,
454
+ chatbot,
455
+ token_count,
456
+ top_p,
457
+ temperature,
458
+ stream=stream,
459
+ selected_model=selected_model,
460
+ should_check_token_count=False,
461
+ )
462
  logging.info(f"chatbot: {chatbot}")
463
  for chatbot, history, status_text, previous_token_count in iter:
464
  history = history[-2:]
465
  token_count = previous_token_count[-1:]
466
  if hidden:
467
  chatbot.pop()
468
+ yield chatbot, history, construct_token_message(
469
+ sum(token_count), stream=stream
470
+ ), token_count
471
  logging.info("减少token数量完毕")
472
 
473
 
 
486
  if len(previous_token_count) > 0:
487
  logging.info("删除了一组对话的token计数记录")
488
  previous_token_count.pop()
489
+ return (
490
+ chatbot,
491
+ history,
492
+ previous_token_count,
493
+ construct_token_message(sum(previous_token_count)),
494
+ )
495
 
496
 
497
  def save_file(filename, system, history, chatbot):
 
511
  logging.info("保存对话历史完毕")
512
  return os.path.join(HISTORY_DIR, filename)
513
 
514
+
515
  def save_chat_history(filename, system, history, chatbot):
516
  if filename == "":
517
  return
 
519
  filename += ".json"
520
  return save_file(filename, system, history, chatbot)
521
 
522
+
523
  def export_markdown(filename, system, history, chatbot):
524
  if filename == "":
525
  return
 
555
  logging.info("没有找到对话历史文件,不执行任何操作")
556
  return filename, system, history, chatbot
557
 
558
+
559
  def sorted_by_pinyin(list):
560
  return sorted(list, key=lambda char: lazy_pinyin(char)[0][0])
561
 
562
+
563
  def get_file_names(dir, plain=False, filetypes=[".json"]):
564
  logging.info(f"获取文件名列表,目录为{dir},文件类型为{filetypes},是否为纯文本列表{plain}")
565
  files = []
 
576
  else:
577
  return gr.Dropdown.update(choices=files)
578
 
579
+
580
  def get_history_names(plain=False):
581
  logging.info("获取历史记录文件名列表")
582
  return get_file_names(HISTORY_DIR, plain)
583
 
584
+
585
  def load_template(filename, mode=0):
586
  logging.info(f"加载模板文件{filename},模式为{mode}(0为返回字典和下拉菜单,1为返回下拉菜单,2为返回字典)")
587
  lines = []
 
591
  lines = json.load(f)
592
  lines = [[i["act"], i["prompt"]] for i in lines]
593
  else:
594
+ with open(
595
+ os.path.join(TEMPLATES_DIR, filename), "r", encoding="utf8"
596
+ ) as csvfile:
597
  reader = csv.reader(csvfile)
598
  lines = list(reader)
599
  lines = lines[1:]
600
  if mode == 1:
601
  return sorted_by_pinyin([row[0] for row in lines])
602
  elif mode == 2:
603
+ return {row[0]: row[1] for row in lines}
604
  else:
605
  choices = sorted_by_pinyin([row[0] for row in lines])
606
+ return {row[0]: row[1] for row in lines}, gr.Dropdown.update(
607
+ choices=choices, value=choices[0]
608
+ )
609
+
610
 
611
  def get_template_names(plain=False):
612
  logging.info("获取模板文件名列表")
613
  return get_file_names(TEMPLATES_DIR, plain, filetypes=[".csv", "json"])
614
 
615
+
616
  def get_template_content(templates, selection, original_system_prompt):
617
  logging.info(f"应用模板中,选择为{selection},原始系统提示为{original_system_prompt}")
618
  try:
 
620
  except:
621
  return original_system_prompt
622
 
623
+
624
  def reset_state():
625
  logging.info("重置状态")
626
  return [], [], [], construct_token_message(0)
627
 
628
+
629
  def reset_textbox():
630
+ return gr.update(value="")