Tuchuanhuhuhu commited on
Commit
2c5812c
·
1 Parent(s): 3fe8fc4

加入中止回答的功能

Browse files
ChuanhuChatbot.py CHANGED
@@ -5,10 +5,10 @@ import sys
5
 
6
  import gradio as gr
7
 
8
- from utils import *
9
- from presets import *
10
- from overwrites import *
11
- from chat_func import *
12
 
13
  logging.basicConfig(
14
  level=logging.DEBUG,
@@ -54,7 +54,7 @@ else:
54
  gr.Chatbot.postprocess = postprocess
55
  PromptHelper.compact_text_chunks = compact_text_chunks
56
 
57
- with open("custom.css", "r", encoding="utf-8") as f:
58
  customCSS = f.read()
59
 
60
  with gr.Blocks(
@@ -124,8 +124,7 @@ with gr.Blocks(
124
  token_count = gr.State([])
125
  promptTemplates = gr.State(load_template(get_template_names(plain=True)[0], mode=2))
126
  user_api_key = gr.State(my_api_key)
127
- TRUECOMSTANT = gr.State(True)
128
- FALSECONSTANT = gr.State(False)
129
  topic = gr.State("未命名对话历史记录")
130
 
131
  with gr.Row():
@@ -275,12 +274,9 @@ with gr.Blocks(
275
 
276
  gr.Markdown(description)
277
 
278
- keyTxt.submit(submit_key, keyTxt, [user_api_key, status_display])
279
- keyTxt.change(submit_key, keyTxt, [user_api_key, status_display])
280
- # Chatbot
281
- user_input.submit(
282
- predict,
283
- [
284
  user_api_key,
285
  systemPromptTxt,
286
  history,
@@ -294,40 +290,45 @@ with gr.Blocks(
294
  use_websearch_checkbox,
295
  index_files,
296
  ],
297
- [chatbot, history, status_display, token_count],
298
  show_progress=True,
299
  )
300
- user_input.submit(reset_textbox, [], [user_input])
301
 
302
- # submitBtn.click(return_cancel_btn, [], [submitBtn, cancelBtn])
303
- submitBtn.click(
304
- predict,
305
- [
306
- user_api_key,
307
- systemPromptTxt,
308
- history,
309
- user_input,
310
- chatbot,
311
- token_count,
312
- top_p,
313
- temperature,
314
- use_streaming_checkbox,
315
- model_select_dropdown,
316
- use_websearch_checkbox,
317
- index_files,
318
- ],
319
- [chatbot, history, status_display, token_count],
320
- show_progress=True,
 
 
 
 
 
 
 
321
  )
322
- submitBtn.click(reset_textbox, [], [user_input])
323
 
324
  emptyBtn.click(
325
  reset_state,
326
  outputs=[chatbot, history, token_count, status_display],
327
  show_progress=True,
328
- )
329
 
330
- retryBtn.click(
331
  retry,
332
  [
333
  user_api_key,
@@ -342,7 +343,7 @@ with gr.Blocks(
342
  ],
343
  [chatbot, history, status_display, token_count],
344
  show_progress=True,
345
- )
346
 
347
  delLastBtn.click(
348
  delete_last_conversation,
@@ -441,17 +442,31 @@ if __name__ == "__main__":
441
  if dockerflag:
442
  if authflag:
443
  demo.queue(concurrency_count=CONCURRENT_COUNT).launch(
444
- server_name="0.0.0.0", server_port=7860, auth=(username, password),
445
- favicon_path="./assets/favicon.png"
 
 
446
  )
447
  else:
448
- demo.queue(concurrency_count=CONCURRENT_COUNT).launch(server_name="0.0.0.0", server_port=7860, share=False, favicon_path="./assets/favicon.png")
 
 
 
 
 
449
  # if not running in Docker
450
  else:
451
  if authflag:
452
- demo.queue(concurrency_count=CONCURRENT_COUNT).launch(share=False, auth=(username, password), favicon_path="./assets/favicon.png", inbrowser=True)
 
 
 
 
 
453
  else:
454
- demo.queue(concurrency_count=CONCURRENT_COUNT).launch(share=False, favicon_path="./assets/favicon.ico", inbrowser=True) # 改为 share=True 可以创建公开分享链接
 
 
455
  # demo.queue(concurrency_count=CONCURRENT_COUNT).launch(server_name="0.0.0.0", server_port=7860, share=False) # 可自定义端口
456
  # demo.queue(concurrency_count=CONCURRENT_COUNT).launch(server_name="0.0.0.0", server_port=7860,auth=("在这里填写用户名", "在这里填写密码")) # 可设置用户名与密码
457
  # demo.queue(concurrency_count=CONCURRENT_COUNT).launch(auth=("在这里填写用户名", "在这里填写密码")) # 适合Nginx反向代理
 
5
 
6
  import gradio as gr
7
 
8
+ from modules.utils import *
9
+ from modules.presets import *
10
+ from modules.overwrites import *
11
+ from modules.chat_func import *
12
 
13
  logging.basicConfig(
14
  level=logging.DEBUG,
 
54
  gr.Chatbot.postprocess = postprocess
55
  PromptHelper.compact_text_chunks = compact_text_chunks
56
 
57
+ with open("assets/custom.css", "r", encoding="utf-8") as f:
58
  customCSS = f.read()
59
 
60
  with gr.Blocks(
 
124
  token_count = gr.State([])
125
  promptTemplates = gr.State(load_template(get_template_names(plain=True)[0], mode=2))
126
  user_api_key = gr.State(my_api_key)
127
+ outputing = gr.State(False)
 
128
  topic = gr.State("未命名对话历史记录")
129
 
130
  with gr.Row():
 
274
 
275
  gr.Markdown(description)
276
 
277
+ chatgpt_predict_args = dict(
278
+ fn=predict,
279
+ inputs=[
 
 
 
280
  user_api_key,
281
  systemPromptTxt,
282
  history,
 
290
  use_websearch_checkbox,
291
  index_files,
292
  ],
293
+ outputs=[chatbot, history, status_display, token_count],
294
  show_progress=True,
295
  )
 
296
 
297
+ start_outputing_args = dict(
298
+ fn=start_outputing, inputs=[], outputs=[submitBtn, cancelBtn], show_progress=True
299
+ )
300
+
301
+ end_outputing_args = dict(
302
+ fn=end_outputing, inputs=[], outputs=[submitBtn, cancelBtn]
303
+ )
304
+
305
+ reset_textbox_args = dict(
306
+ fn=reset_textbox, inputs=[], outputs=[user_input], show_progress=True
307
+ )
308
+
309
+ keyTxt.submit(submit_key, keyTxt, [user_api_key, status_display])
310
+ keyTxt.change(submit_key, keyTxt, [user_api_key, status_display])
311
+ # Chatbot
312
+ cancelBtn.click(cancel_outputing, [], [])
313
+
314
+ user_input.submit(**start_outputing_args).then(
315
+ **chatgpt_predict_args
316
+ ).then(**reset_textbox_args).then(
317
+ **end_outputing_args
318
+ )
319
+ submitBtn.click(**start_outputing_args).then(
320
+ **chatgpt_predict_args
321
+ ).then(**reset_textbox_args).then(
322
+ **end_outputing_args
323
  )
 
324
 
325
  emptyBtn.click(
326
  reset_state,
327
  outputs=[chatbot, history, token_count, status_display],
328
  show_progress=True,
329
+ ).then(**reset_textbox_args)
330
 
331
+ retryBtn.click(**start_outputing_args).then(
332
  retry,
333
  [
334
  user_api_key,
 
343
  ],
344
  [chatbot, history, status_display, token_count],
345
  show_progress=True,
346
+ ).then(**end_outputing_args)
347
 
348
  delLastBtn.click(
349
  delete_last_conversation,
 
442
  if dockerflag:
443
  if authflag:
444
  demo.queue(concurrency_count=CONCURRENT_COUNT).launch(
445
+ server_name="0.0.0.0",
446
+ server_port=7860,
447
+ auth=(username, password),
448
+ favicon_path="./assets/favicon.png",
449
  )
450
  else:
451
+ demo.queue(concurrency_count=CONCURRENT_COUNT).launch(
452
+ server_name="0.0.0.0",
453
+ server_port=7860,
454
+ share=False,
455
+ favicon_path="./assets/favicon.png",
456
+ )
457
  # if not running in Docker
458
  else:
459
  if authflag:
460
+ demo.queue(concurrency_count=CONCURRENT_COUNT).launch(
461
+ share=False,
462
+ auth=(username, password),
463
+ favicon_path="./assets/favicon.png",
464
+ inbrowser=True,
465
+ )
466
  else:
467
+ demo.queue(concurrency_count=CONCURRENT_COUNT).launch(
468
+ share=False, favicon_path="./assets/favicon.ico", inbrowser=True
469
+ ) # 改为 share=True 可以创建公开分享链接
470
  # demo.queue(concurrency_count=CONCURRENT_COUNT).launch(server_name="0.0.0.0", server_port=7860, share=False) # 可自定义端口
471
  # demo.queue(concurrency_count=CONCURRENT_COUNT).launch(server_name="0.0.0.0", server_port=7860,auth=("在这里填写用户名", "在这里填写密码")) # 可设置用户名与密码
472
  # demo.queue(concurrency_count=CONCURRENT_COUNT).launch(auth=("在这里填写用户名", "在这里填写密码")) # 适合Nginx反向代理
custom.css → assets/custom.css RENAMED
File without changes
chat_func.py → modules/chat_func.py RENAMED
@@ -14,9 +14,10 @@ from duckduckgo_search import ddg
14
  import asyncio
15
  import aiohttp
16
 
17
- from presets import *
18
- from llama_func import *
19
- from utils import *
 
20
 
21
  # logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] [%(filename)s:%(lineno)d] %(message)s")
22
 
@@ -29,7 +30,6 @@ if TYPE_CHECKING:
29
 
30
 
31
  initial_prompt = "You are a helpful assistant."
32
- API_URL = "https://api.openai.com/v1/chat/completions"
33
  HISTORY_DIR = "history"
34
  TEMPLATES_DIR = "templates"
35
 
@@ -65,16 +65,18 @@ def get_response(
65
  # 如果存在代理设置,使用它们
66
  proxies = {}
67
  if http_proxy:
68
- logging.info(f"Using HTTP proxy: {http_proxy}")
69
  proxies["http"] = http_proxy
70
  if https_proxy:
71
- logging.info(f"Using HTTPS proxy: {https_proxy}")
72
  proxies["https"] = https_proxy
73
 
74
  # 如果有代理,使用代理发送请求,否则使用默认设置发送请求
 
 
75
  if proxies:
76
  response = requests.post(
77
- API_URL,
78
  headers=headers,
79
  json=payload,
80
  stream=True,
@@ -83,7 +85,7 @@ def get_response(
83
  )
84
  else:
85
  response = requests.post(
86
- API_URL,
87
  headers=headers,
88
  json=payload,
89
  stream=True,
@@ -268,10 +270,10 @@ def predict(
268
  if files:
269
  msg = "构建索引中……(这可能需要比较久的时间)"
270
  logging.info(msg)
271
- yield chatbot, history, msg, all_token_counts
272
  index = construct_index(openai_api_key, file_src=files)
273
  msg = "索引构建完成,获取回答中……"
274
- yield chatbot, history, msg, all_token_counts
275
  history, chatbot, status_text = chat_ai(openai_api_key, index, inputs, history, chatbot)
276
  yield chatbot, history, status_text, all_token_counts
277
  return
@@ -306,10 +308,15 @@ def predict(
306
  all_token_counts.append(0)
307
  else:
308
  history[-2] = construct_user(inputs)
309
- yield chatbot, history, status_text, all_token_counts
 
 
 
 
 
310
  return
311
 
312
- yield chatbot, history, "开始生成回答……", all_token_counts
313
 
314
  if stream:
315
  logging.info("使用流式传输")
@@ -327,6 +334,9 @@ def predict(
327
  display_append=link_references
328
  )
329
  for chatbot, history, status_text, all_token_counts in iter:
 
 
 
330
  yield chatbot, history, status_text, all_token_counts
331
  else:
332
  logging.info("不使用流式传输")
 
14
  import asyncio
15
  import aiohttp
16
 
17
+ from modules.presets import *
18
+ from modules.llama_func import *
19
+ from modules.utils import *
20
+ import modules.shared as shared
21
 
22
  # logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] [%(filename)s:%(lineno)d] %(message)s")
23
 
 
30
 
31
 
32
  initial_prompt = "You are a helpful assistant."
 
33
  HISTORY_DIR = "history"
34
  TEMPLATES_DIR = "templates"
35
 
 
65
  # 如果存在代理设置,使用它们
66
  proxies = {}
67
  if http_proxy:
68
+ logging.info(f"使用 HTTP 代理: {http_proxy}")
69
  proxies["http"] = http_proxy
70
  if https_proxy:
71
+ logging.info(f"使用 HTTPS 代理: {https_proxy}")
72
  proxies["https"] = https_proxy
73
 
74
  # 如果有代理,使用代理发送请求,否则使用默认设置发送请求
75
+ if shared.state.api_url != API_URL:
76
+ logging.info(f"使用自定义API URL: {shared.state.api_url}")
77
  if proxies:
78
  response = requests.post(
79
+ shared.state.api_url,
80
  headers=headers,
81
  json=payload,
82
  stream=True,
 
85
  )
86
  else:
87
  response = requests.post(
88
+ shared.state.api_url,
89
  headers=headers,
90
  json=payload,
91
  stream=True,
 
270
  if files:
271
  msg = "构建索引中……(这可能需要比较久的时间)"
272
  logging.info(msg)
273
+ yield chatbot+[(inputs, "")], history, msg, all_token_counts
274
  index = construct_index(openai_api_key, file_src=files)
275
  msg = "索引构建完成,获取回答中……"
276
+ yield chatbot+[(inputs, "")], history, msg, all_token_counts
277
  history, chatbot, status_text = chat_ai(openai_api_key, index, inputs, history, chatbot)
278
  yield chatbot, history, status_text, all_token_counts
279
  return
 
308
  all_token_counts.append(0)
309
  else:
310
  history[-2] = construct_user(inputs)
311
+ yield chatbot+[(inputs, "")], history, status_text, all_token_counts
312
+ return
313
+ elif len(inputs.strip()) == 0:
314
+ status_text = standard_error_msg + no_input_msg
315
+ logging.info(status_text)
316
+ yield chatbot+[(inputs, "")], history, status_text, all_token_counts
317
  return
318
 
319
+ yield chatbot+[(inputs, "")], history, "开始生成回答……", all_token_counts
320
 
321
  if stream:
322
  logging.info("使用流式传输")
 
334
  display_append=link_references
335
  )
336
  for chatbot, history, status_text, all_token_counts in iter:
337
+ if shared.state.interrupted:
338
+ shared.state.recover()
339
+ return
340
  yield chatbot, history, status_text, all_token_counts
341
  else:
342
  logging.info("不使用流式传输")
llama_func.py → modules/llama_func.py RENAMED
@@ -14,8 +14,8 @@ from langchain.llms import OpenAI
14
  import colorama
15
 
16
 
17
- from presets import *
18
- from utils import *
19
 
20
 
21
  def get_documents(file_src):
 
14
  import colorama
15
 
16
 
17
+ from modules.presets import *
18
+ from modules.utils import *
19
 
20
 
21
  def get_documents(file_src):
overwrites.py → modules/overwrites.py RENAMED
@@ -5,8 +5,8 @@ from llama_index import Prompt
5
  from typing import List, Tuple
6
  import mdtex2html
7
 
8
- from presets import *
9
- from llama_func import *
10
 
11
 
12
  def compact_text_chunks(self, prompt: Prompt, text_chunks: List[str]) -> List[str]:
@@ -51,5 +51,5 @@ def reload_javascript():
51
  return res
52
 
53
  gr.routes.templates.TemplateResponse = template_response
54
-
55
  GradioTemplateResponseOriginal = gr.routes.templates.TemplateResponse
 
5
  from typing import List, Tuple
6
  import mdtex2html
7
 
8
+ from modules.presets import *
9
+ from modules.llama_func import *
10
 
11
 
12
  def compact_text_chunks(self, prompt: Prompt, text_chunks: List[str]) -> List[str]:
 
51
  return res
52
 
53
  gr.routes.templates.TemplateResponse = template_response
54
+
55
  GradioTemplateResponseOriginal = gr.routes.templates.TemplateResponse
presets.py → modules/presets.py RENAMED
@@ -14,9 +14,10 @@ read_timeout_prompt = "读取超时,无法获取对话。" # 读取超时
14
  proxy_error_prompt = "代理错误,无法获取对话。" # 代理错误
15
  ssl_error_prompt = "SSL错误,无法获取对话。" # SSL 错误
16
  no_apikey_msg = "API key长度不是51位,请检查是否输入正确。" # API key 长度不足 51 位
 
17
 
18
  max_token_streaming = 3500 # 流式对话时的最大 token 数
19
- timeout_streaming = 30 # 流式对话时的超时时间
20
  max_token_all = 3500 # 非流式对话时的最大 token 数
21
  timeout_all = 200 # 非流式对话时的超时时间
22
  enable_streaming_option = True # 是否启用选择选择是否实时显示回答的勾选框
 
14
  proxy_error_prompt = "代理错误,无法获取对话。" # 代理错误
15
  ssl_error_prompt = "SSL错误,无法获取对话。" # SSL 错误
16
  no_apikey_msg = "API key长度不是51位,请检查是否输入正确。" # API key 长度不足 51 位
17
+ no_input_msg = "请输入对话内容。" # 未输入对话内容
18
 
19
  max_token_streaming = 3500 # 流式对话时的最大 token 数
20
+ timeout_streaming = 5 # 流式对话时的超时时间
21
  max_token_all = 3500 # 非流式对话时的最大 token 数
22
  timeout_all = 200 # 非流式对话时的超时时间
23
  enable_streaming_option = True # 是否启用选择选择是否实时显示回答的勾选框
modules/shared.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from modules.presets import API_URL
2
+
3
+ class State:
4
+ interrupted = False
5
+ api_url = API_URL
6
+
7
+ def interrupt(self):
8
+ self.interrupted = True
9
+
10
+ def recover(self):
11
+ self.interrupted = False
12
+
13
+ def set_api_url(self, api_url):
14
+ self.api_url = api_url
15
+
16
+ def reset_api_url(self):
17
+ self.api_url = API_URL
18
+ return self.api_url
19
+
20
+ def reset_all(self):
21
+ self.interrupted = False
22
+ self.api_url = API_URL
23
+
24
+ state = State()
utils.py → modules/utils.py RENAMED
@@ -19,9 +19,13 @@ from pygments import highlight
19
  from pygments.lexers import get_lexer_by_name
20
  from pygments.formatters import HtmlFormatter
21
 
22
- from presets import *
 
23
 
24
- # logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] [%(filename)s:%(lineno)d] %(message)s")
 
 
 
25
 
26
  if TYPE_CHECKING:
27
  from typing import TypedDict
@@ -107,10 +111,12 @@ def convert_mdtext(md_text):
107
  result = "".join(result)
108
  return result
109
 
 
110
  def convert_user(userinput):
111
  userinput = userinput.replace("\n", "<br>")
112
  return f"<pre>{userinput}</pre>"
113
 
 
114
  def detect_language(code):
115
  if code.startswith("\n"):
116
  first_line = ""
@@ -297,20 +303,19 @@ def reset_state():
297
 
298
 
299
  def reset_textbox():
 
300
  return gr.update(value="")
301
 
302
 
303
  def reset_default():
304
- global API_URL
305
- API_URL = "https://api.openai.com/v1/chat/completions"
306
  os.environ.pop("HTTPS_PROXY", None)
307
  os.environ.pop("https_proxy", None)
308
- return gr.update(value=API_URL), gr.update(value=""), "API URL 和代理已重置"
309
 
310
 
311
  def change_api_url(url):
312
- global API_URL
313
- API_URL = url
314
  msg = f"API地址更改为了{url}"
315
  logging.info(msg)
316
  return msg
@@ -384,13 +389,22 @@ def find_n(lst, max_num):
384
 
385
  for i in range(len(lst)):
386
  if total - lst[i] < max_num:
387
- return n - i -1
388
  total = total - lst[i]
389
  return 1
390
 
391
- def return_cancel_btn():
392
- return gr.Button.update(
393
- visible=False
394
- ), gr.Button.update(
395
- visible=True
 
 
 
 
396
  )
 
 
 
 
 
 
19
  from pygments.lexers import get_lexer_by_name
20
  from pygments.formatters import HtmlFormatter
21
 
22
+ from modules.presets import *
23
+ import modules.shared as shared
24
 
25
+ logging.basicConfig(
26
+ level=logging.INFO,
27
+ format="%(asctime)s [%(levelname)s] [%(filename)s:%(lineno)d] %(message)s",
28
+ )
29
 
30
  if TYPE_CHECKING:
31
  from typing import TypedDict
 
111
  result = "".join(result)
112
  return result
113
 
114
+
115
  def convert_user(userinput):
116
  userinput = userinput.replace("\n", "<br>")
117
  return f"<pre>{userinput}</pre>"
118
 
119
+
120
  def detect_language(code):
121
  if code.startswith("\n"):
122
  first_line = ""
 
303
 
304
 
305
  def reset_textbox():
306
+ logging.debug("重置文本框")
307
  return gr.update(value="")
308
 
309
 
310
  def reset_default():
311
+ newurl = shared.state.reset_all()
 
312
  os.environ.pop("HTTPS_PROXY", None)
313
  os.environ.pop("https_proxy", None)
314
+ return gr.update(value=newurl), gr.update(value=""), "API URL 和代理已重置"
315
 
316
 
317
  def change_api_url(url):
318
+ shared.state.set_api_url(url)
 
319
  msg = f"API地址更改为了{url}"
320
  logging.info(msg)
321
  return msg
 
389
 
390
  for i in range(len(lst)):
391
  if total - lst[i] < max_num:
392
+ return n - i - 1
393
  total = total - lst[i]
394
  return 1
395
 
396
+
397
+ def start_outputing():
398
+ logging.debug("显示取消按钮,隐藏发送按钮")
399
+ return gr.Button.update(visible=False), gr.Button.update(visible=True)
400
+
401
+ def end_outputing():
402
+ return (
403
+ gr.Button.update(visible=True),
404
+ gr.Button.update(visible=False),
405
  )
406
+
407
+
408
+ def cancel_outputing():
409
+ logging.info("中止输出……")
410
+ shared.state.interrupt()