JohnSmith9982 commited on
Commit
3e8b0ee
·
1 Parent(s): f082264

Upload 31 files

Browse files
app.py CHANGED
@@ -5,10 +5,11 @@ import sys
5
 
6
  import gradio as gr
7
 
8
- from utils import *
9
- from presets import *
10
- from overwrites import *
11
- from chat_func import *
 
12
 
13
  logging.basicConfig(
14
  level=logging.DEBUG,
@@ -44,7 +45,7 @@ else:
44
  with open("api_key.txt", "r") as f:
45
  my_api_key = f.read().strip()
46
  if os.path.exists("auth.json"):
47
- with open("auth.json", "r") as f:
48
  auth = json.load(f)
49
  username = auth["username"]
50
  password = auth["password"]
@@ -54,78 +55,16 @@ else:
54
  gr.Chatbot.postprocess = postprocess
55
  PromptHelper.compact_text_chunks = compact_text_chunks
56
 
57
- with open("custom.css", "r", encoding="utf-8") as f:
58
  customCSS = f.read()
59
 
60
- with gr.Blocks(
61
- css=customCSS,
62
- theme=gr.themes.Soft(
63
- primary_hue=gr.themes.Color(
64
- c50="#02C160",
65
- c100="rgba(2, 193, 96, 0.2)",
66
- c200="#02C160",
67
- c300="rgba(2, 193, 96, 0.32)",
68
- c400="rgba(2, 193, 96, 0.32)",
69
- c500="rgba(2, 193, 96, 1.0)",
70
- c600="rgba(2, 193, 96, 1.0)",
71
- c700="rgba(2, 193, 96, 0.32)",
72
- c800="rgba(2, 193, 96, 0.32)",
73
- c900="#02C160",
74
- c950="#02C160",
75
- ),
76
- secondary_hue=gr.themes.Color(
77
- c50="#576b95",
78
- c100="#576b95",
79
- c200="#576b95",
80
- c300="#576b95",
81
- c400="#576b95",
82
- c500="#576b95",
83
- c600="#576b95",
84
- c700="#576b95",
85
- c800="#576b95",
86
- c900="#576b95",
87
- c950="#576b95",
88
- ),
89
- neutral_hue=gr.themes.Color(
90
- name="gray",
91
- c50="#f9fafb",
92
- c100="#f3f4f6",
93
- c200="#e5e7eb",
94
- c300="#d1d5db",
95
- c400="#B2B2B2",
96
- c500="#808080",
97
- c600="#636363",
98
- c700="#515151",
99
- c800="#393939",
100
- c900="#272727",
101
- c950="#171717",
102
- ),
103
- radius_size=gr.themes.sizes.radius_sm,
104
- ).set(
105
- button_primary_background_fill="#06AE56",
106
- button_primary_background_fill_dark="#06AE56",
107
- button_primary_background_fill_hover="#07C863",
108
- button_primary_border_color="#06AE56",
109
- button_primary_border_color_dark="#06AE56",
110
- button_primary_text_color="#FFFFFF",
111
- button_primary_text_color_dark="#FFFFFF",
112
- button_secondary_background_fill="#F2F2F2",
113
- button_secondary_background_fill_dark="#2B2B2B",
114
- button_secondary_text_color="#393939",
115
- button_secondary_text_color_dark="#FFFFFF",
116
- # background_fill_primary="#F7F7F7",
117
- # background_fill_primary_dark="#1F1F1F",
118
- block_title_text_color="*primary_500",
119
- block_title_background_fill="*primary_100",
120
- input_background_fill="#F6F6F6",
121
- ),
122
- ) as demo:
123
  history = gr.State([])
124
  token_count = gr.State([])
125
  promptTemplates = gr.State(load_template(get_template_names(plain=True)[0], mode=2))
126
  user_api_key = gr.State(my_api_key)
127
- TRUECOMSTANT = gr.State(True)
128
- FALSECONSTANT = gr.State(False)
129
  topic = gr.State("未命名对话历史记录")
130
 
131
  with gr.Row():
@@ -139,10 +78,11 @@ with gr.Blocks(
139
  with gr.Row(scale=1):
140
  with gr.Column(scale=12):
141
  user_input = gr.Textbox(
142
- show_label=False, placeholder="在这里输入"
143
  ).style(container=False)
144
  with gr.Column(min_width=70, scale=1):
145
  submitBtn = gr.Button("发送", variant="primary")
 
146
  with gr.Row(scale=1):
147
  emptyBtn = gr.Button(
148
  "🧹 新的对话",
@@ -162,6 +102,7 @@ with gr.Blocks(
162
  visible=not HIDE_MY_KEY,
163
  label="API-Key",
164
  )
 
165
  model_select_dropdown = gr.Dropdown(
166
  label="选择模型", choices=MODELS, multiselect=False, value=MODELS[0]
167
  )
@@ -169,6 +110,12 @@ with gr.Blocks(
169
  label="实时传输回答", value=True, visible=enable_streaming_option
170
  )
171
  use_websearch_checkbox = gr.Checkbox(label="使用在线搜索", value=False)
 
 
 
 
 
 
172
  index_files = gr.Files(label="上传索引文件", type="file", multiple=True)
173
 
174
  with gr.Tab(label="Prompt"):
@@ -234,8 +181,8 @@ with gr.Blocks(
234
  downloadFile = gr.File(interactive=True)
235
 
236
  with gr.Tab(label="高级"):
237
- default_btn = gr.Button("🔙 恢复默认设置")
238
  gr.Markdown("# ⚠️ 务必谨慎更改 ⚠️\n\n如果无法使用请恢复默认设置")
 
239
 
240
  with gr.Accordion("参数", open=False):
241
  top_p = gr.Slider(
@@ -255,35 +202,33 @@ with gr.Blocks(
255
  label="Temperature",
256
  )
257
 
258
- apiurlTxt = gr.Textbox(
259
- show_label=True,
260
- placeholder=f"在这里输入API地址...",
261
- label="API地址",
262
- value="https://api.openai.com/v1/chat/completions",
263
- lines=2,
264
- )
265
- changeAPIURLBtn = gr.Button("🔄 切换API地址")
266
- proxyTxt = gr.Textbox(
267
- show_label=True,
268
- placeholder=f"在这里输入代理地址...",
269
- label="代理地址(示例:http://127.0.0.1:10809)",
270
- value="",
271
- lines=2,
272
- )
273
- changeProxyBtn = gr.Button("🔄 设置代理地址")
 
274
 
275
  gr.Markdown(description)
276
 
277
- keyTxt.submit(submit_key, keyTxt, [user_api_key, status_display])
278
- keyTxt.change(submit_key, keyTxt, [user_api_key, status_display])
279
- # Chatbot
280
- user_input.submit(
281
- predict,
282
- [
283
  user_api_key,
284
  systemPromptTxt,
285
  history,
286
- user_input,
287
  chatbot,
288
  token_count,
289
  top_p,
@@ -292,39 +237,52 @@ with gr.Blocks(
292
  model_select_dropdown,
293
  use_websearch_checkbox,
294
  index_files,
 
295
  ],
296
- [chatbot, history, status_display, token_count],
297
  show_progress=True,
298
  )
299
- user_input.submit(reset_textbox, [], [user_input])
300
 
301
- submitBtn.click(
302
- predict,
303
- [
304
- user_api_key,
305
- systemPromptTxt,
306
- history,
307
- user_input,
308
- chatbot,
309
- token_count,
310
- top_p,
311
- temperature,
312
- use_streaming_checkbox,
313
- model_select_dropdown,
314
- use_websearch_checkbox,
315
- index_files,
316
- ],
317
- [chatbot, history, status_display, token_count],
318
  show_progress=True,
319
  )
320
- submitBtn.click(reset_textbox, [], [user_input])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
321
 
322
  emptyBtn.click(
323
  reset_state,
324
  outputs=[chatbot, history, token_count, status_display],
325
  show_progress=True,
326
  )
 
327
 
 
328
  retryBtn.click(
329
  retry,
330
  [
@@ -337,10 +295,12 @@ with gr.Blocks(
337
  temperature,
338
  use_streaming_checkbox,
339
  model_select_dropdown,
 
340
  ],
341
  [chatbot, history, status_display, token_count],
342
  show_progress=True,
343
  )
 
344
 
345
  delLastBtn.click(
346
  delete_last_conversation,
@@ -361,10 +321,15 @@ with gr.Blocks(
361
  temperature,
362
  gr.State(0),
363
  model_select_dropdown,
 
364
  ],
365
  [chatbot, history, status_display, token_count],
366
  show_progress=True,
367
  )
 
 
 
 
368
 
369
  # Template
370
  templateRefreshBtn.click(get_template_names, None, [templateFileSelectDropdown])
@@ -438,18 +403,32 @@ if __name__ == "__main__":
438
  # if running in Docker
439
  if dockerflag:
440
  if authflag:
441
- demo.queue().launch(
442
- server_name="0.0.0.0", server_port=7860, auth=(username, password),
443
- favicon_path="./assets/favicon.png"
 
 
444
  )
445
  else:
446
- demo.queue().launch(server_name="0.0.0.0", server_port=7860, share=False, favicon_path="./assets/favicon.png")
 
 
 
 
 
447
  # if not running in Docker
448
  else:
449
  if authflag:
450
- demo.queue().launch(share=False, auth=(username, password), favicon_path="./assets/favicon.png", inbrowser=True)
 
 
 
 
 
451
  else:
452
- demo.queue().launch(share=False, favicon_path="./assets/favicon.ico", inbrowser=True) # 改为 share=True 可以创建公开分享链接
453
- # demo.queue().launch(server_name="0.0.0.0", server_port=7860, share=False) # 可自定义端口
454
- # demo.queue().launch(server_name="0.0.0.0", server_port=7860,auth=("在这里填写用户名", "在这里填写密码")) # 可设置用户名与密码
455
- # demo.queue().launch(auth=("在这里填写用户名", "在这里填写密码")) # 适合Nginx反向代理
 
 
 
5
 
6
  import gradio as gr
7
 
8
+ from modules.utils import *
9
+ from modules.presets import *
10
+ from modules.overwrites import *
11
+ from modules.chat_func import *
12
+ from modules.openai_func import get_usage
13
 
14
  logging.basicConfig(
15
  level=logging.DEBUG,
 
45
  with open("api_key.txt", "r") as f:
46
  my_api_key = f.read().strip()
47
  if os.path.exists("auth.json"):
48
+ with open("auth.json", "r", encoding='utf-8') as f:
49
  auth = json.load(f)
50
  username = auth["username"]
51
  password = auth["password"]
 
55
  gr.Chatbot.postprocess = postprocess
56
  PromptHelper.compact_text_chunks = compact_text_chunks
57
 
58
+ with open("assets/custom.css", "r", encoding="utf-8") as f:
59
  customCSS = f.read()
60
 
61
+ with gr.Blocks(css=customCSS, theme=small_and_beautiful_theme) as demo:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
  history = gr.State([])
63
  token_count = gr.State([])
64
  promptTemplates = gr.State(load_template(get_template_names(plain=True)[0], mode=2))
65
  user_api_key = gr.State(my_api_key)
66
+ user_question = gr.State("")
67
+ outputing = gr.State(False)
68
  topic = gr.State("未命名对话历史记录")
69
 
70
  with gr.Row():
 
78
  with gr.Row(scale=1):
79
  with gr.Column(scale=12):
80
  user_input = gr.Textbox(
81
+ show_label=False, placeholder="在这里输入", interactive=True
82
  ).style(container=False)
83
  with gr.Column(min_width=70, scale=1):
84
  submitBtn = gr.Button("发送", variant="primary")
85
+ cancelBtn = gr.Button("取消", variant="secondary", visible=False)
86
  with gr.Row(scale=1):
87
  emptyBtn = gr.Button(
88
  "🧹 新的对话",
 
102
  visible=not HIDE_MY_KEY,
103
  label="API-Key",
104
  )
105
+ usageTxt = gr.Markdown(get_usage(my_api_key), elem_id="usage_display")
106
  model_select_dropdown = gr.Dropdown(
107
  label="选择模型", choices=MODELS, multiselect=False, value=MODELS[0]
108
  )
 
110
  label="实时传输回答", value=True, visible=enable_streaming_option
111
  )
112
  use_websearch_checkbox = gr.Checkbox(label="使用在线搜索", value=False)
113
+ language_select_dropdown = gr.Dropdown(
114
+ label="选择回复语言(针对搜索&索引功能)",
115
+ choices=REPLY_LANGUAGES,
116
+ multiselect=False,
117
+ value=REPLY_LANGUAGES[0],
118
+ )
119
  index_files = gr.Files(label="上传索引文件", type="file", multiple=True)
120
 
121
  with gr.Tab(label="Prompt"):
 
181
  downloadFile = gr.File(interactive=True)
182
 
183
  with gr.Tab(label="高级"):
 
184
  gr.Markdown("# ⚠️ 务必谨慎更改 ⚠️\n\n如果无法使用请恢复默认设置")
185
+ default_btn = gr.Button("🔙 恢复默认设置")
186
 
187
  with gr.Accordion("参数", open=False):
188
  top_p = gr.Slider(
 
202
  label="Temperature",
203
  )
204
 
205
+ with gr.Accordion("网络设置", open=False):
206
+ apiurlTxt = gr.Textbox(
207
+ show_label=True,
208
+ placeholder=f"在这里输入API地址...",
209
+ label="API地址",
210
+ value="https://api.openai.com/v1/chat/completions",
211
+ lines=2,
212
+ )
213
+ changeAPIURLBtn = gr.Button("🔄 切换API地址")
214
+ proxyTxt = gr.Textbox(
215
+ show_label=True,
216
+ placeholder=f"在这里输入代理地址...",
217
+ label="代理地址(示例:http://127.0.0.1:10809)",
218
+ value="",
219
+ lines=2,
220
+ )
221
+ changeProxyBtn = gr.Button("🔄 设置代理地址")
222
 
223
  gr.Markdown(description)
224
 
225
+ chatgpt_predict_args = dict(
226
+ fn=predict,
227
+ inputs=[
 
 
 
228
  user_api_key,
229
  systemPromptTxt,
230
  history,
231
+ user_question,
232
  chatbot,
233
  token_count,
234
  top_p,
 
237
  model_select_dropdown,
238
  use_websearch_checkbox,
239
  index_files,
240
+ language_select_dropdown,
241
  ],
242
+ outputs=[chatbot, history, status_display, token_count],
243
  show_progress=True,
244
  )
 
245
 
246
+ start_outputing_args = dict(
247
+ fn=start_outputing,
248
+ inputs=[],
249
+ outputs=[submitBtn, cancelBtn],
 
 
 
 
 
 
 
 
 
 
 
 
 
250
  show_progress=True,
251
  )
252
+
253
+ end_outputing_args = dict(
254
+ fn=end_outputing, inputs=[], outputs=[submitBtn, cancelBtn]
255
+ )
256
+
257
+ reset_textbox_args = dict(
258
+ fn=reset_textbox, inputs=[], outputs=[user_input]
259
+ )
260
+
261
+ transfer_input_args = dict(
262
+ fn=transfer_input, inputs=[user_input], outputs=[user_question, user_input], show_progress=True
263
+ )
264
+
265
+ get_usage_args = dict(
266
+ fn=get_usage, inputs=[user_api_key], outputs=[usageTxt], show_progress=False
267
+ )
268
+
269
+ # Chatbot
270
+ cancelBtn.click(cancel_outputing, [], [])
271
+
272
+ user_input.submit(**transfer_input_args).then(**chatgpt_predict_args).then(**end_outputing_args)
273
+ user_input.submit(**get_usage_args)
274
+
275
+ submitBtn.click(**transfer_input_args).then(**chatgpt_predict_args).then(**end_outputing_args)
276
+ submitBtn.click(**get_usage_args)
277
 
278
  emptyBtn.click(
279
  reset_state,
280
  outputs=[chatbot, history, token_count, status_display],
281
  show_progress=True,
282
  )
283
+ emptyBtn.click(**reset_textbox_args)
284
 
285
+ retryBtn.click(**reset_textbox_args)
286
  retryBtn.click(
287
  retry,
288
  [
 
295
  temperature,
296
  use_streaming_checkbox,
297
  model_select_dropdown,
298
+ language_select_dropdown,
299
  ],
300
  [chatbot, history, status_display, token_count],
301
  show_progress=True,
302
  )
303
+ retryBtn.click(**get_usage_args)
304
 
305
  delLastBtn.click(
306
  delete_last_conversation,
 
321
  temperature,
322
  gr.State(0),
323
  model_select_dropdown,
324
+ language_select_dropdown,
325
  ],
326
  [chatbot, history, status_display, token_count],
327
  show_progress=True,
328
  )
329
+ reduceTokenBtn.click(**get_usage_args)
330
+
331
+ # ChatGPT
332
+ keyTxt.change(submit_key, keyTxt, [user_api_key, status_display]).then(**get_usage_args)
333
 
334
  # Template
335
  templateRefreshBtn.click(get_template_names, None, [templateFileSelectDropdown])
 
403
  # if running in Docker
404
  if dockerflag:
405
  if authflag:
406
+ demo.queue(concurrency_count=CONCURRENT_COUNT).launch(
407
+ server_name="0.0.0.0",
408
+ server_port=7860,
409
+ auth=(username, password),
410
+ favicon_path="./assets/favicon.ico",
411
  )
412
  else:
413
+ demo.queue(concurrency_count=CONCURRENT_COUNT).launch(
414
+ server_name="0.0.0.0",
415
+ server_port=7860,
416
+ share=False,
417
+ favicon_path="./assets/favicon.ico",
418
+ )
419
  # if not running in Docker
420
  else:
421
  if authflag:
422
+ demo.queue(concurrency_count=CONCURRENT_COUNT).launch(
423
+ share=False,
424
+ auth=(username, password),
425
+ favicon_path="./assets/favicon.ico",
426
+ inbrowser=True,
427
+ )
428
  else:
429
+ demo.queue(concurrency_count=CONCURRENT_COUNT).launch(
430
+ share=False, favicon_path="./assets/favicon.ico", inbrowser=True
431
+ ) # 改为 share=True 可以创建公开分享链接
432
+ # demo.queue(concurrency_count=CONCURRENT_COUNT).launch(server_name="0.0.0.0", server_port=7860, share=False) # 可自定义端口
433
+ # demo.queue(concurrency_count=CONCURRENT_COUNT).launch(server_name="0.0.0.0", server_port=7860,auth=("在这里填写用户名", "在这里填写密码")) # 可设置用户名与密码
434
+ # demo.queue(concurrency_count=CONCURRENT_COUNT).launch(auth=("在这里填写用户名", "在这里填写密码")) # 适合Nginx反向代理
assets/custom.css ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ :root {
2
+ --chatbot-color-light: #F3F3F3;
3
+ --chatbot-color-dark: #121111;
4
+ }
5
+
6
+ /* status_display */
7
+ #status_display {
8
+ display: flex;
9
+ min-height: 2.5em;
10
+ align-items: flex-end;
11
+ justify-content: flex-end;
12
+ }
13
+ #status_display p {
14
+ font-size: .85em;
15
+ font-family: monospace;
16
+ color: var(--body-text-color-subdued);
17
+ }
18
+
19
+ #chuanhu_chatbot, #status_display {
20
+ transition: all 0.6s;
21
+ }
22
+
23
+ /* usage_display */
24
+ #usage_display {
25
+ height: 1em;
26
+ }
27
+ #usage_display p{
28
+ padding: 0 1em;
29
+ font-size: .85em;
30
+ font-family: monospace;
31
+ color: var(--body-text-color-subdued);
32
+ }
33
+ /* list */
34
+ ol:not(.options), ul:not(.options) {
35
+ padding-inline-start: 2em !important;
36
+ }
37
+
38
+ /* 亮色 */
39
+ #chuanhu_chatbot {
40
+ background-color: var(--chatbot-color-light) !important;
41
+ }
42
+ [data-testid = "bot"] {
43
+ background-color: #FFFFFF !important;
44
+ }
45
+ [data-testid = "user"] {
46
+ background-color: #95EC69 !important;
47
+ }
48
+ /* 对话气泡 */
49
+ [class *= "message"] {
50
+ border-radius: var(--radius-xl) !important;
51
+ border: none;
52
+ padding: var(--spacing-xl) !important;
53
+ font-size: var(--text-md) !important;
54
+ line-height: var(--line-md) !important;
55
+ min-height: calc(var(--text-md)*var(--line-md) + 2*var(--spacing-xl));
56
+ min-width: calc(var(--text-md)*var(--line-md) + 2*var(--spacing-xl));
57
+ }
58
+ [data-testid = "bot"] {
59
+ max-width: 85%;
60
+ border-bottom-left-radius: 0 !important;
61
+ }
62
+ [data-testid = "user"] {
63
+ max-width: 85%;
64
+ width: auto !important;
65
+ border-bottom-right-radius: 0 !important;
66
+ }
67
+ /* 表格 */
68
+ table {
69
+ margin: 1em 0;
70
+ border-collapse: collapse;
71
+ empty-cells: show;
72
+ }
73
+ td,th {
74
+ border: 1.2px solid var(--border-color-primary) !important;
75
+ padding: 0.2em;
76
+ }
77
+ thead {
78
+ background-color: rgba(175,184,193,0.2);
79
+ }
80
+ thead th {
81
+ padding: .5em .2em;
82
+ }
83
+ /* 行内代码 */
84
+ code {
85
+ display: inline;
86
+ white-space: break-spaces;
87
+ border-radius: 6px;
88
+ margin: 0 2px 0 2px;
89
+ padding: .2em .4em .1em .4em;
90
+ background-color: rgba(175,184,193,0.2);
91
+ }
92
+ /* 代码块 */
93
+ pre code {
94
+ display: block;
95
+ overflow: auto;
96
+ white-space: pre;
97
+ background-color: hsla(0, 0%, 0%, 80%)!important;
98
+ border-radius: 10px;
99
+ padding: 1.4em 1.2em 0em 1.4em;
100
+ margin: 1.2em 2em 1.2em 0.5em;
101
+ color: #FFF;
102
+ box-shadow: 6px 6px 16px hsla(0, 0%, 0%, 0.2);
103
+ }
104
+ /* 代码高亮样式 */
105
+ .highlight .hll { background-color: #49483e }
106
+ .highlight .c { color: #75715e } /* Comment */
107
+ .highlight .err { color: #960050; background-color: #1e0010 } /* Error */
108
+ .highlight .k { color: #66d9ef } /* Keyword */
109
+ .highlight .l { color: #ae81ff } /* Literal */
110
+ .highlight .n { color: #f8f8f2 } /* Name */
111
+ .highlight .o { color: #f92672 } /* Operator */
112
+ .highlight .p { color: #f8f8f2 } /* Punctuation */
113
+ .highlight .ch { color: #75715e } /* Comment.Hashbang */
114
+ .highlight .cm { color: #75715e } /* Comment.Multiline */
115
+ .highlight .cp { color: #75715e } /* Comment.Preproc */
116
+ .highlight .cpf { color: #75715e } /* Comment.PreprocFile */
117
+ .highlight .c1 { color: #75715e } /* Comment.Single */
118
+ .highlight .cs { color: #75715e } /* Comment.Special */
119
+ .highlight .gd { color: #f92672 } /* Generic.Deleted */
120
+ .highlight .ge { font-style: italic } /* Generic.Emph */
121
+ .highlight .gi { color: #a6e22e } /* Generic.Inserted */
122
+ .highlight .gs { font-weight: bold } /* Generic.Strong */
123
+ .highlight .gu { color: #75715e } /* Generic.Subheading */
124
+ .highlight .kc { color: #66d9ef } /* Keyword.Constant */
125
+ .highlight .kd { color: #66d9ef } /* Keyword.Declaration */
126
+ .highlight .kn { color: #f92672 } /* Keyword.Namespace */
127
+ .highlight .kp { color: #66d9ef } /* Keyword.Pseudo */
128
+ .highlight .kr { color: #66d9ef } /* Keyword.Reserved */
129
+ .highlight .kt { color: #66d9ef } /* Keyword.Type */
130
+ .highlight .ld { color: #e6db74 } /* Literal.Date */
131
+ .highlight .m { color: #ae81ff } /* Literal.Number */
132
+ .highlight .s { color: #e6db74 } /* Literal.String */
133
+ .highlight .na { color: #a6e22e } /* Name.Attribute */
134
+ .highlight .nb { color: #f8f8f2 } /* Name.Builtin */
135
+ .highlight .nc { color: #a6e22e } /* Name.Class */
136
+ .highlight .no { color: #66d9ef } /* Name.Constant */
137
+ .highlight .nd { color: #a6e22e } /* Name.Decorator */
138
+ .highlight .ni { color: #f8f8f2 } /* Name.Entity */
139
+ .highlight .ne { color: #a6e22e } /* Name.Exception */
140
+ .highlight .nf { color: #a6e22e } /* Name.Function */
141
+ .highlight .nl { color: #f8f8f2 } /* Name.Label */
142
+ .highlight .nn { color: #f8f8f2 } /* Name.Namespace */
143
+ .highlight .nx { color: #a6e22e } /* Name.Other */
144
+ .highlight .py { color: #f8f8f2 } /* Name.Property */
145
+ .highlight .nt { color: #f92672 } /* Name.Tag */
146
+ .highlight .nv { color: #f8f8f2 } /* Name.Variable */
147
+ .highlight .ow { color: #f92672 } /* Operator.Word */
148
+ .highlight .w { color: #f8f8f2 } /* Text.Whitespace */
149
+ .highlight .mb { color: #ae81ff } /* Literal.Number.Bin */
150
+ .highlight .mf { color: #ae81ff } /* Literal.Number.Float */
151
+ .highlight .mh { color: #ae81ff } /* Literal.Number.Hex */
152
+ .highlight .mi { color: #ae81ff } /* Literal.Number.Integer */
153
+ .highlight .mo { color: #ae81ff } /* Literal.Number.Oct */
154
+ .highlight .sa { color: #e6db74 } /* Literal.String.Affix */
155
+ .highlight .sb { color: #e6db74 } /* Literal.String.Backtick */
156
+ .highlight .sc { color: #e6db74 } /* Literal.String.Char */
157
+ .highlight .dl { color: #e6db74 } /* Literal.String.Delimiter */
158
+ .highlight .sd { color: #e6db74 } /* Literal.String.Doc */
159
+ .highlight .s2 { color: #e6db74 } /* Literal.String.Double */
160
+ .highlight .se { color: #ae81ff } /* Literal.String.Escape */
161
+ .highlight .sh { color: #e6db74 } /* Literal.String.Heredoc */
162
+ .highlight .si { color: #e6db74 } /* Literal.String.Interpol */
163
+ .highlight .sx { color: #e6db74 } /* Literal.String.Other */
164
+ .highlight .sr { color: #e6db74 } /* Literal.String.Regex */
165
+ .highlight .s1 { color: #e6db74 } /* Literal.String.Single */
166
+ .highlight .ss { color: #e6db74 } /* Literal.String.Symbol */
167
+ .highlight .bp { color: #f8f8f2 } /* Name.Builtin.Pseudo */
168
+ .highlight .fm { color: #a6e22e } /* Name.Function.Magic */
169
+ .highlight .vc { color: #f8f8f2 } /* Name.Variable.Class */
170
+ .highlight .vg { color: #f8f8f2 } /* Name.Variable.Global */
171
+ .highlight .vi { color: #f8f8f2 } /* Name.Variable.Instance */
172
+ .highlight .vm { color: #f8f8f2 } /* Name.Variable.Magic */
173
+ .highlight .il { color: #ae81ff } /* Literal.Number.Integer.Long */
modules/__pycache__/chat_func.cpython-39.pyc ADDED
Binary file (8.82 kB). View file
 
modules/__pycache__/llama_func.cpython-39.pyc ADDED
Binary file (4.6 kB). View file
 
modules/__pycache__/openai_func.cpython-39.pyc ADDED
Binary file (1.79 kB). View file
 
modules/__pycache__/overwrites.cpython-39.pyc ADDED
Binary file (2.61 kB). View file
 
modules/__pycache__/presets.cpython-39.pyc ADDED
Binary file (4.72 kB). View file
 
modules/__pycache__/shared.cpython-39.pyc ADDED
Binary file (1.08 kB). View file
 
modules/__pycache__/utils.cpython-39.pyc ADDED
Binary file (14 kB). View file
 
modules/chat_func.py ADDED
@@ -0,0 +1,473 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding:utf-8 -*-
2
+ from __future__ import annotations
3
+ from typing import TYPE_CHECKING, List
4
+
5
+ import logging
6
+ import json
7
+ import os
8
+ import requests
9
+ import urllib3
10
+
11
+ from tqdm import tqdm
12
+ import colorama
13
+ from duckduckgo_search import ddg
14
+ import asyncio
15
+ import aiohttp
16
+
17
+ from modules.presets import *
18
+ from modules.llama_func import *
19
+ from modules.utils import *
20
+ import modules.shared as shared
21
+
22
+ # logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] [%(filename)s:%(lineno)d] %(message)s")
23
+
24
+ if TYPE_CHECKING:
25
+ from typing import TypedDict
26
+
27
+ class DataframeData(TypedDict):
28
+ headers: List[str]
29
+ data: List[List[str | int | bool]]
30
+
31
+
32
+ initial_prompt = "You are a helpful assistant."
33
+ HISTORY_DIR = "history"
34
+ TEMPLATES_DIR = "templates"
35
+
36
+ def get_response(
37
+ openai_api_key, system_prompt, history, temperature, top_p, stream, selected_model
38
+ ):
39
+ headers = {
40
+ "Content-Type": "application/json",
41
+ "Authorization": f"Bearer {openai_api_key}",
42
+ }
43
+
44
+ history = [construct_system(system_prompt), *history]
45
+
46
+ payload = {
47
+ "model": selected_model,
48
+ "messages": history, # [{"role": "user", "content": f"{inputs}"}],
49
+ "temperature": temperature, # 1.0,
50
+ "top_p": top_p, # 1.0,
51
+ "n": 1,
52
+ "stream": stream,
53
+ "presence_penalty": 0,
54
+ "frequency_penalty": 0,
55
+ }
56
+ if stream:
57
+ timeout = timeout_streaming
58
+ else:
59
+ timeout = timeout_all
60
+
61
+ # 获取环境变量中的代理设置
62
+ http_proxy = os.environ.get("HTTP_PROXY") or os.environ.get("http_proxy")
63
+ https_proxy = os.environ.get("HTTPS_PROXY") or os.environ.get("https_proxy")
64
+
65
+ # 如果存在代理设置,使用它们
66
+ proxies = {}
67
+ if http_proxy:
68
+ logging.info(f"使用 HTTP 代理: {http_proxy}")
69
+ proxies["http"] = http_proxy
70
+ if https_proxy:
71
+ logging.info(f"使用 HTTPS 代理: {https_proxy}")
72
+ proxies["https"] = https_proxy
73
+
74
+ # 如果有自定义的api-url,使用自定义url发送请求,否则使用默认设置发送请求
75
+ if shared.state.api_url != API_URL:
76
+ logging.info(f"使用自定义API URL: {shared.state.api_url}")
77
+ if proxies:
78
+ response = requests.post(
79
+ shared.state.api_url,
80
+ headers=headers,
81
+ json=payload,
82
+ stream=True,
83
+ timeout=timeout,
84
+ proxies=proxies,
85
+ )
86
+ else:
87
+ response = requests.post(
88
+ shared.state.api_url,
89
+ headers=headers,
90
+ json=payload,
91
+ stream=True,
92
+ timeout=timeout,
93
+ )
94
+ return response
95
+
96
+
97
+ def stream_predict(
98
+ openai_api_key,
99
+ system_prompt,
100
+ history,
101
+ inputs,
102
+ chatbot,
103
+ all_token_counts,
104
+ top_p,
105
+ temperature,
106
+ selected_model,
107
+ fake_input=None,
108
+ display_append=""
109
+ ):
110
+ def get_return_value():
111
+ return chatbot, history, status_text, all_token_counts
112
+
113
+ logging.info("实时回答模式")
114
+ partial_words = ""
115
+ counter = 0
116
+ status_text = "开始实时传输回答……"
117
+ history.append(construct_user(inputs))
118
+ history.append(construct_assistant(""))
119
+ if fake_input:
120
+ chatbot.append((fake_input, ""))
121
+ else:
122
+ chatbot.append((inputs, ""))
123
+ user_token_count = 0
124
+ if len(all_token_counts) == 0:
125
+ system_prompt_token_count = count_token(construct_system(system_prompt))
126
+ user_token_count = (
127
+ count_token(construct_user(inputs)) + system_prompt_token_count
128
+ )
129
+ else:
130
+ user_token_count = count_token(construct_user(inputs))
131
+ all_token_counts.append(user_token_count)
132
+ logging.info(f"输入token计数: {user_token_count}")
133
+ yield get_return_value()
134
+ try:
135
+ response = get_response(
136
+ openai_api_key,
137
+ system_prompt,
138
+ history,
139
+ temperature,
140
+ top_p,
141
+ True,
142
+ selected_model,
143
+ )
144
+ except requests.exceptions.ConnectTimeout:
145
+ status_text = (
146
+ standard_error_msg + connection_timeout_prompt + error_retrieve_prompt
147
+ )
148
+ yield get_return_value()
149
+ return
150
+ except requests.exceptions.ReadTimeout:
151
+ status_text = standard_error_msg + read_timeout_prompt + error_retrieve_prompt
152
+ yield get_return_value()
153
+ return
154
+
155
+ yield get_return_value()
156
+ error_json_str = ""
157
+
158
+ for chunk in response.iter_lines():
159
+ if counter == 0:
160
+ counter += 1
161
+ continue
162
+ counter += 1
163
+ # check whether each line is non-empty
164
+ if chunk:
165
+ chunk = chunk.decode()
166
+ chunklength = len(chunk)
167
+ try:
168
+ chunk = json.loads(chunk[6:])
169
+ except json.JSONDecodeError:
170
+ logging.info(chunk)
171
+ error_json_str += chunk
172
+ status_text = f"JSON解析错误。请重置对话。收到的内容: {error_json_str}"
173
+ yield get_return_value()
174
+ continue
175
+ # decode each line as response data is in bytes
176
+ if chunklength > 6 and "delta" in chunk["choices"][0]:
177
+ finish_reason = chunk["choices"][0]["finish_reason"]
178
+ status_text = construct_token_message(
179
+ sum(all_token_counts), stream=True
180
+ )
181
+ if finish_reason == "stop":
182
+ yield get_return_value()
183
+ break
184
+ try:
185
+ partial_words = (
186
+ partial_words + chunk["choices"][0]["delta"]["content"]
187
+ )
188
+ except KeyError:
189
+ status_text = (
190
+ standard_error_msg
191
+ + "API回复中找不到内容。很可能是Token计数达到上限了。请重置对话。当前Token计数: "
192
+ + str(sum(all_token_counts))
193
+ )
194
+ yield get_return_value()
195
+ break
196
+ history[-1] = construct_assistant(partial_words)
197
+ chatbot[-1] = (chatbot[-1][0], partial_words+display_append)
198
+ all_token_counts[-1] += 1
199
+ yield get_return_value()
200
+
201
+
202
+ def predict_all(
203
+ openai_api_key,
204
+ system_prompt,
205
+ history,
206
+ inputs,
207
+ chatbot,
208
+ all_token_counts,
209
+ top_p,
210
+ temperature,
211
+ selected_model,
212
+ fake_input=None,
213
+ display_append=""
214
+ ):
215
+ logging.info("一次性回答模式")
216
+ history.append(construct_user(inputs))
217
+ history.append(construct_assistant(""))
218
+ if fake_input:
219
+ chatbot.append((fake_input, ""))
220
+ else:
221
+ chatbot.append((inputs, ""))
222
+ all_token_counts.append(count_token(construct_user(inputs)))
223
+ try:
224
+ response = get_response(
225
+ openai_api_key,
226
+ system_prompt,
227
+ history,
228
+ temperature,
229
+ top_p,
230
+ False,
231
+ selected_model,
232
+ )
233
+ except requests.exceptions.ConnectTimeout:
234
+ status_text = (
235
+ standard_error_msg + connection_timeout_prompt + error_retrieve_prompt
236
+ )
237
+ return chatbot, history, status_text, all_token_counts
238
+ except requests.exceptions.ProxyError:
239
+ status_text = standard_error_msg + proxy_error_prompt + error_retrieve_prompt
240
+ return chatbot, history, status_text, all_token_counts
241
+ except requests.exceptions.SSLError:
242
+ status_text = standard_error_msg + ssl_error_prompt + error_retrieve_prompt
243
+ return chatbot, history, status_text, all_token_counts
244
+ response = json.loads(response.text)
245
+ content = response["choices"][0]["message"]["content"]
246
+ history[-1] = construct_assistant(content)
247
+ chatbot[-1] = (chatbot[-1][0], content+display_append)
248
+ total_token_count = response["usage"]["total_tokens"]
249
+ all_token_counts[-1] = total_token_count - sum(all_token_counts)
250
+ status_text = construct_token_message(total_token_count)
251
+ return chatbot, history, status_text, all_token_counts
252
+
253
+
254
+ def predict(
255
+ openai_api_key,
256
+ system_prompt,
257
+ history,
258
+ inputs,
259
+ chatbot,
260
+ all_token_counts,
261
+ top_p,
262
+ temperature,
263
+ stream=False,
264
+ selected_model=MODELS[0],
265
+ use_websearch=False,
266
+ files = None,
267
+ reply_language="中文",
268
+ should_check_token_count=True,
269
+ ): # repetition_penalty, top_k
270
+ logging.info("输入为:" + colorama.Fore.BLUE + f"{inputs}" + colorama.Style.RESET_ALL)
271
+ yield chatbot+[(inputs, "")], history, "开始生成回答……", all_token_counts
272
+ if reply_language == "跟随问题语言(不稳定)":
273
+ reply_language = "the same language as the question, such as English, 中文, 日本語, Español, Français, or Deutsch."
274
+ if files:
275
+ msg = "构建索引中……(这可能需要比较久的时间)"
276
+ logging.info(msg)
277
+ yield chatbot+[(inputs, "")], history, msg, all_token_counts
278
+ index = construct_index(openai_api_key, file_src=files)
279
+ msg = "索引构建完成,获取回答中……"
280
+ yield chatbot+[(inputs, "")], history, msg, all_token_counts
281
+ history, chatbot, status_text = chat_ai(openai_api_key, index, inputs, history, chatbot, reply_language)
282
+ yield chatbot, history, status_text, all_token_counts
283
+ return
284
+
285
+ old_inputs = ""
286
+ link_references = []
287
+ if use_websearch:
288
+ search_results = ddg(inputs, max_results=5)
289
+ old_inputs = inputs
290
+ web_results = []
291
+ for idx, result in enumerate(search_results):
292
+ logging.info(f"搜索结果{idx + 1}:{result}")
293
+ domain_name = urllib3.util.parse_url(result["href"]).host
294
+ web_results.append(f'[{idx+1}]"{result["body"]}"\nURL: {result["href"]}')
295
+ link_references.append(f"{idx+1}. [{domain_name}]({result['href']})\n")
296
+ link_references = "\n\n" + "".join(link_references)
297
+ inputs = (
298
+ replace_today(WEBSEARCH_PTOMPT_TEMPLATE)
299
+ .replace("{query}", inputs)
300
+ .replace("{web_results}", "\n\n".join(web_results))
301
+ .replace("{reply_language}", reply_language )
302
+ )
303
+ else:
304
+ link_references = ""
305
+
306
+ if len(openai_api_key) != 51:
307
+ status_text = standard_error_msg + no_apikey_msg
308
+ logging.info(status_text)
309
+ chatbot.append((inputs, ""))
310
+ if len(history) == 0:
311
+ history.append(construct_user(inputs))
312
+ history.append("")
313
+ all_token_counts.append(0)
314
+ else:
315
+ history[-2] = construct_user(inputs)
316
+ yield chatbot+[(inputs, "")], history, status_text, all_token_counts
317
+ return
318
+ elif len(inputs.strip()) == 0:
319
+ status_text = standard_error_msg + no_input_msg
320
+ logging.info(status_text)
321
+ yield chatbot+[(inputs, "")], history, status_text, all_token_counts
322
+ return
323
+
324
+ if stream:
325
+ logging.info("使用流式传输")
326
+ iter = stream_predict(
327
+ openai_api_key,
328
+ system_prompt,
329
+ history,
330
+ inputs,
331
+ chatbot,
332
+ all_token_counts,
333
+ top_p,
334
+ temperature,
335
+ selected_model,
336
+ fake_input=old_inputs,
337
+ display_append=link_references
338
+ )
339
+ for chatbot, history, status_text, all_token_counts in iter:
340
+ if shared.state.interrupted:
341
+ shared.state.recover()
342
+ return
343
+ yield chatbot, history, status_text, all_token_counts
344
+ else:
345
+ logging.info("不使用流式传输")
346
+ chatbot, history, status_text, all_token_counts = predict_all(
347
+ openai_api_key,
348
+ system_prompt,
349
+ history,
350
+ inputs,
351
+ chatbot,
352
+ all_token_counts,
353
+ top_p,
354
+ temperature,
355
+ selected_model,
356
+ fake_input=old_inputs,
357
+ display_append=link_references
358
+ )
359
+ yield chatbot, history, status_text, all_token_counts
360
+
361
+ logging.info(f"传输完毕。当前token计数为{all_token_counts}")
362
+ if len(history) > 1 and history[-1]["content"] != inputs:
363
+ logging.info(
364
+ "回答为:"
365
+ + colorama.Fore.BLUE
366
+ + f"{history[-1]['content']}"
367
+ + colorama.Style.RESET_ALL
368
+ )
369
+
370
+ if stream:
371
+ max_token = max_token_streaming
372
+ else:
373
+ max_token = max_token_all
374
+
375
+ if sum(all_token_counts) > max_token and should_check_token_count:
376
+ status_text = f"精简token中{all_token_counts}/{max_token}"
377
+ logging.info(status_text)
378
+ yield chatbot, history, status_text, all_token_counts
379
+ iter = reduce_token_size(
380
+ openai_api_key,
381
+ system_prompt,
382
+ history,
383
+ chatbot,
384
+ all_token_counts,
385
+ top_p,
386
+ temperature,
387
+ max_token//2,
388
+ selected_model=selected_model,
389
+ )
390
+ for chatbot, history, status_text, all_token_counts in iter:
391
+ status_text = f"Token 达到上限,已自动降低Token计数至 {status_text}"
392
+ yield chatbot, history, status_text, all_token_counts
393
+
394
+
395
+ def retry(
396
+ openai_api_key,
397
+ system_prompt,
398
+ history,
399
+ chatbot,
400
+ token_count,
401
+ top_p,
402
+ temperature,
403
+ stream=False,
404
+ selected_model=MODELS[0],
405
+ reply_language="中文",
406
+ ):
407
+ logging.info("重试中……")
408
+ if len(history) == 0:
409
+ yield chatbot, history, f"{standard_error_msg}上下文是空的", token_count
410
+ return
411
+ history.pop()
412
+ inputs = history.pop()["content"]
413
+ token_count.pop()
414
+ iter = predict(
415
+ openai_api_key,
416
+ system_prompt,
417
+ history,
418
+ inputs,
419
+ chatbot,
420
+ token_count,
421
+ top_p,
422
+ temperature,
423
+ stream=stream,
424
+ selected_model=selected_model,
425
+ reply_language=reply_language,
426
+ )
427
+ logging.info("重试中……")
428
+ for x in iter:
429
+ yield x
430
+ logging.info("重试完毕")
431
+
432
+
433
+ def reduce_token_size(
434
+ openai_api_key,
435
+ system_prompt,
436
+ history,
437
+ chatbot,
438
+ token_count,
439
+ top_p,
440
+ temperature,
441
+ max_token_count,
442
+ selected_model=MODELS[0],
443
+ reply_language="中文",
444
+ ):
445
+ logging.info("开始减少token数量……")
446
+ iter = predict(
447
+ openai_api_key,
448
+ system_prompt,
449
+ history,
450
+ summarize_prompt,
451
+ chatbot,
452
+ token_count,
453
+ top_p,
454
+ temperature,
455
+ selected_model=selected_model,
456
+ should_check_token_count=False,
457
+ reply_language=reply_language,
458
+ )
459
+ logging.info(f"chatbot: {chatbot}")
460
+ flag = False
461
+ for chatbot, history, status_text, previous_token_count in iter:
462
+ num_chat = find_n(previous_token_count, max_token_count)
463
+ if flag:
464
+ chatbot = chatbot[:-1]
465
+ flag = True
466
+ history = history[-2*num_chat:] if num_chat > 0 else []
467
+ token_count = previous_token_count[-num_chat:] if num_chat > 0 else []
468
+ msg = f"保留了最近{num_chat}轮对话"
469
+ yield chatbot, history, msg + "," + construct_token_message(
470
+ sum(token_count) if len(token_count) > 0 else 0,
471
+ ), token_count
472
+ logging.info(msg)
473
+ logging.info("减少token数量完毕")
modules/llama_func.py ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import logging
3
+
4
+ from llama_index import GPTSimpleVectorIndex
5
+ from llama_index import download_loader
6
+ from llama_index import (
7
+ Document,
8
+ LLMPredictor,
9
+ PromptHelper,
10
+ QuestionAnswerPrompt,
11
+ RefinePrompt,
12
+ )
13
+ from langchain.llms import OpenAI
14
+ import colorama
15
+
16
+
17
+ from modules.presets import *
18
+ from modules.utils import *
19
+
20
+
21
+ def get_documents(file_src):
22
+ documents = []
23
+ index_name = ""
24
+ logging.debug("Loading documents...")
25
+ logging.debug(f"file_src: {file_src}")
26
+ for file in file_src:
27
+ logging.debug(f"file: {file.name}")
28
+ index_name += file.name
29
+ if os.path.splitext(file.name)[1] == ".pdf":
30
+ logging.debug("Loading PDF...")
31
+ CJKPDFReader = download_loader("CJKPDFReader")
32
+ loader = CJKPDFReader()
33
+ documents += loader.load_data(file=file.name)
34
+ elif os.path.splitext(file.name)[1] == ".docx":
35
+ logging.debug("Loading DOCX...")
36
+ DocxReader = download_loader("DocxReader")
37
+ loader = DocxReader()
38
+ documents += loader.load_data(file=file.name)
39
+ elif os.path.splitext(file.name)[1] == ".epub":
40
+ logging.debug("Loading EPUB...")
41
+ EpubReader = download_loader("EpubReader")
42
+ loader = EpubReader()
43
+ documents += loader.load_data(file=file.name)
44
+ else:
45
+ logging.debug("Loading text file...")
46
+ with open(file.name, "r", encoding="utf-8") as f:
47
+ text = add_space(f.read())
48
+ documents += [Document(text)]
49
+ index_name = sha1sum(index_name)
50
+ return documents, index_name
51
+
52
+
53
+ def construct_index(
54
+ api_key,
55
+ file_src,
56
+ max_input_size=4096,
57
+ num_outputs=1,
58
+ max_chunk_overlap=20,
59
+ chunk_size_limit=600,
60
+ embedding_limit=None,
61
+ separator=" ",
62
+ num_children=10,
63
+ max_keywords_per_chunk=10,
64
+ ):
65
+ os.environ["OPENAI_API_KEY"] = api_key
66
+ chunk_size_limit = None if chunk_size_limit == 0 else chunk_size_limit
67
+ embedding_limit = None if embedding_limit == 0 else embedding_limit
68
+ separator = " " if separator == "" else separator
69
+
70
+ llm_predictor = LLMPredictor(
71
+ llm=OpenAI(model_name="gpt-3.5-turbo-0301", openai_api_key=api_key)
72
+ )
73
+ prompt_helper = PromptHelper(
74
+ max_input_size,
75
+ num_outputs,
76
+ max_chunk_overlap,
77
+ embedding_limit,
78
+ chunk_size_limit,
79
+ separator=separator,
80
+ )
81
+ documents, index_name = get_documents(file_src)
82
+ if os.path.exists(f"./index/{index_name}.json"):
83
+ logging.info("找到了缓存的索引文件,加载中……")
84
+ return GPTSimpleVectorIndex.load_from_disk(f"./index/{index_name}.json")
85
+ else:
86
+ try:
87
+ logging.debug("构建索引中……")
88
+ index = GPTSimpleVectorIndex(
89
+ documents, llm_predictor=llm_predictor, prompt_helper=prompt_helper
90
+ )
91
+ os.makedirs("./index", exist_ok=True)
92
+ index.save_to_disk(f"./index/{index_name}.json")
93
+ return index
94
+ except Exception as e:
95
+ print(e)
96
+ return None
97
+
98
+
99
+ def chat_ai(
100
+ api_key,
101
+ index,
102
+ question,
103
+ context,
104
+ chatbot,
105
+ reply_language,
106
+ ):
107
+ os.environ["OPENAI_API_KEY"] = api_key
108
+
109
+ logging.info(f"Question: {question}")
110
+
111
+ response, chatbot_display, status_text = ask_ai(
112
+ api_key,
113
+ index,
114
+ question,
115
+ replace_today(PROMPT_TEMPLATE),
116
+ REFINE_TEMPLATE,
117
+ SIM_K,
118
+ INDEX_QUERY_TEMPRATURE,
119
+ context,
120
+ reply_language,
121
+ )
122
+ if response is None:
123
+ status_text = "查询失败,请换个问法试试"
124
+ return context, chatbot
125
+ response = response
126
+
127
+ context.append({"role": "user", "content": question})
128
+ context.append({"role": "assistant", "content": response})
129
+ chatbot.append((question, chatbot_display))
130
+
131
+ os.environ["OPENAI_API_KEY"] = ""
132
+ return context, chatbot, status_text
133
+
134
+
135
+ def ask_ai(
136
+ api_key,
137
+ index,
138
+ question,
139
+ prompt_tmpl,
140
+ refine_tmpl,
141
+ sim_k=1,
142
+ temprature=0,
143
+ prefix_messages=[],
144
+ reply_language="中文",
145
+ ):
146
+ os.environ["OPENAI_API_KEY"] = api_key
147
+
148
+ logging.debug("Index file found")
149
+ logging.debug("Querying index...")
150
+ llm_predictor = LLMPredictor(
151
+ llm=OpenAI(
152
+ temperature=temprature,
153
+ model_name="gpt-3.5-turbo-0301",
154
+ prefix_messages=prefix_messages,
155
+ )
156
+ )
157
+
158
+ response = None # Initialize response variable to avoid UnboundLocalError
159
+ qa_prompt = QuestionAnswerPrompt(prompt_tmpl.replace("{reply_language}", reply_language))
160
+ rf_prompt = RefinePrompt(refine_tmpl.replace("{reply_language}", reply_language))
161
+ response = index.query(
162
+ question,
163
+ llm_predictor=llm_predictor,
164
+ similarity_top_k=sim_k,
165
+ text_qa_template=qa_prompt,
166
+ refine_template=rf_prompt,
167
+ response_mode="compact",
168
+ )
169
+
170
+ if response is not None:
171
+ logging.info(f"Response: {response}")
172
+ ret_text = response.response
173
+ nodes = []
174
+ for index, node in enumerate(response.source_nodes):
175
+ brief = node.source_text[:25].replace("\n", "")
176
+ nodes.append(
177
+ f"<details><summary>[{index+1}]\t{brief}...</summary><p>{node.source_text}</p></details>"
178
+ )
179
+ new_response = ret_text + "\n----------\n" + "\n\n".join(nodes)
180
+ logging.info(
181
+ f"Response: {colorama.Fore.BLUE}{ret_text}{colorama.Style.RESET_ALL}"
182
+ )
183
+ os.environ["OPENAI_API_KEY"] = ""
184
+ return ret_text, new_response, f"查询消耗了{llm_predictor.last_token_usage} tokens"
185
+ else:
186
+ logging.warning("No response found, returning None")
187
+ os.environ["OPENAI_API_KEY"] = ""
188
+ return None
189
+
190
+
191
+ def add_space(text):
192
+ punctuations = {",": ", ", "。": "。 ", "?": "? ", "!": "! ", ":": ": ", ";": "; "}
193
+ for cn_punc, en_punc in punctuations.items():
194
+ text = text.replace(cn_punc, en_punc)
195
+ return text
modules/openai_func.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests
2
+ import logging
3
+ from modules.presets import timeout_all, BALANCE_API_URL,standard_error_msg,connection_timeout_prompt,error_retrieve_prompt,read_timeout_prompt
4
+ from modules import shared
5
+ import os
6
+
7
+
8
+ def get_usage_response(openai_api_key):
9
+ headers = {
10
+ "Content-Type": "application/json",
11
+ "Authorization": f"Bearer {openai_api_key}",
12
+ }
13
+
14
+ timeout = timeout_all
15
+
16
+ # 获取环境变量中的代理设置
17
+ http_proxy = os.environ.get("HTTP_PROXY") or os.environ.get("http_proxy")
18
+ https_proxy = os.environ.get(
19
+ "HTTPS_PROXY") or os.environ.get("https_proxy")
20
+
21
+ # 如果存在代理设置,使用它们
22
+ proxies = {}
23
+ if http_proxy:
24
+ logging.info(f"使用 HTTP 代理: {http_proxy}")
25
+ proxies["http"] = http_proxy
26
+ if https_proxy:
27
+ logging.info(f"使用 HTTPS 代理: {https_proxy}")
28
+ proxies["https"] = https_proxy
29
+
30
+ # 如果有代理,使用代理发送请求,否则使用默认设置发送请求
31
+ """
32
+ 暂不支持修改
33
+ if shared.state.balance_api_url != BALANCE_API_URL:
34
+ logging.info(f"使用自定义BALANCE API URL: {shared.state.balance_api_url}")
35
+ """
36
+ if proxies:
37
+ response = requests.get(
38
+ BALANCE_API_URL,
39
+ headers=headers,
40
+ timeout=timeout,
41
+ proxies=proxies,
42
+ )
43
+ else:
44
+ response = requests.get(
45
+ BALANCE_API_URL,
46
+ headers=headers,
47
+ timeout=timeout,
48
+ )
49
+ return response
50
+
51
+ def get_usage(openai_api_key):
52
+ try:
53
+ response=get_usage_response(openai_api_key=openai_api_key)
54
+ logging.debug(response.json())
55
+ try:
56
+ balance = response.json().get("total_available") if response.json().get(
57
+ "total_available") else 0
58
+ total_used = response.json().get("total_used") if response.json().get(
59
+ "total_used") else 0
60
+ except Exception as e:
61
+ logging.error(f"API使用情况解析失败:"+str(e))
62
+ balance = 0
63
+ total_used=0
64
+ return f"**API使用情况**(已用/余额)\u3000{total_used}$ / {balance}$"
65
+ except requests.exceptions.ConnectTimeout:
66
+ status_text = standard_error_msg + connection_timeout_prompt + error_retrieve_prompt
67
+ return status_text
68
+ except requests.exceptions.ReadTimeout:
69
+ status_text = standard_error_msg + read_timeout_prompt + error_retrieve_prompt
70
+ return status_text
modules/overwrites.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+ import logging
3
+
4
+ from llama_index import Prompt
5
+ from typing import List, Tuple
6
+ import mdtex2html
7
+
8
+ from modules.presets import *
9
+ from modules.llama_func import *
10
+
11
+
12
+ def compact_text_chunks(self, prompt: Prompt, text_chunks: List[str]) -> List[str]:
13
+ logging.debug("Compacting text chunks...🚀🚀🚀")
14
+ combined_str = [c.strip() for c in text_chunks if c.strip()]
15
+ combined_str = [f"[{index+1}] {c}" for index, c in enumerate(combined_str)]
16
+ combined_str = "\n\n".join(combined_str)
17
+ # resplit based on self.max_chunk_overlap
18
+ text_splitter = self.get_text_splitter_given_prompt(prompt, 1, padding=1)
19
+ return text_splitter.split_text(combined_str)
20
+
21
+
22
+ def postprocess(
23
+ self, y: List[Tuple[str | None, str | None]]
24
+ ) -> List[Tuple[str | None, str | None]]:
25
+ """
26
+ Parameters:
27
+ y: List of tuples representing the message and response pairs. Each message and response should be a string, which may be in Markdown format.
28
+ Returns:
29
+ List of tuples representing the message and response. Each message and response will be a string of HTML.
30
+ """
31
+ if y is None or y == []:
32
+ return []
33
+ user, bot = y[-1]
34
+ if not detect_converted_mark(user):
35
+ user = convert_asis(user)
36
+ if not detect_converted_mark(bot):
37
+ bot = convert_mdtext(bot)
38
+ y[-1] = (user, bot)
39
+ return y
40
+
41
+ with open("./assets/custom.js", "r", encoding="utf-8") as f, open("./assets/Kelpy-Codos.js", "r", encoding="utf-8") as f2:
42
+ customJS = f.read()
43
+ kelpyCodos = f2.read()
44
+
45
+ def reload_javascript():
46
+ print("Reloading javascript...")
47
+ js = f'<script>{customJS}</script><script>{kelpyCodos}</script>'
48
+ def template_response(*args, **kwargs):
49
+ res = GradioTemplateResponseOriginal(*args, **kwargs)
50
+ res.body = res.body.replace(b'</html>', f'{js}</html>'.encode("utf8"))
51
+ res.init_headers()
52
+ return res
53
+
54
+ gr.routes.templates.TemplateResponse = template_response
55
+
56
+ GradioTemplateResponseOriginal = gr.routes.templates.TemplateResponse
modules/presets.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding:utf-8 -*-
2
+ import gradio as gr
3
+
4
+ # ChatGPT 设置
5
+ initial_prompt = "You are a helpful assistant."
6
+ API_URL = "https://api.openai.com/v1/chat/completions"
7
+ BALANCE_API_URL="https://api.openai.com/dashboard/billing/credit_grants"
8
+ HISTORY_DIR = "history"
9
+ TEMPLATES_DIR = "templates"
10
+
11
+ # 错误信息
12
+ standard_error_msg = "☹️发生了错误:" # 错误信息的标准前缀
13
+ error_retrieve_prompt = "请检查网络连接,或者API-Key是否有效。" # 获取对话时发生错误
14
+ connection_timeout_prompt = "连接超时,无法获取对话。" # 连接超时
15
+ read_timeout_prompt = "读取超时,无法获取对话。" # 读取超时
16
+ proxy_error_prompt = "代理错误,无法获取对话。" # 代理错误
17
+ ssl_error_prompt = "SSL错误,无法获取对话。" # SSL 错误
18
+ no_apikey_msg = "API key长度不是51位,请检查是否输入正确。" # API key 长度不足 51 位
19
+ no_input_msg = "请输入对话内容。" # 未输入对话内容
20
+
21
+ max_token_streaming = 3500 # 流式对话时的最大 token 数
22
+ timeout_streaming = 10 # 流式对话时的超时时间
23
+ max_token_all = 3500 # 非流式对话时的最大 token 数
24
+ timeout_all = 200 # 非流式对话时的超时时间
25
+ enable_streaming_option = True # 是否启用选择选择是否实时显示回答的勾选框
26
+ HIDE_MY_KEY = False # 如果你想在UI中隐藏你的 API 密钥,将此值设置为 True
27
+ CONCURRENT_COUNT = 100 # 允许同时使用的用户数量
28
+
29
+ SIM_K = 5
30
+ INDEX_QUERY_TEMPRATURE = 1.0
31
+
32
+ title = """<h1 align="left" style="min-width:200px; margin-top:0;">川虎ChatGPT 🚀</h1>"""
33
+ description = """\
34
+ <div align="center" style="margin:16px 0">
35
+
36
+ 由Bilibili [土川虎虎虎](https://space.bilibili.com/29125536) 和 [明昭MZhao](https://space.bilibili.com/24807452)开发
37
+
38
+ 访问川虎ChatGPT的 [GitHub项目](https://github.com/GaiZhenbiao/ChuanhuChatGPT) 下载最新版脚本
39
+
40
+ 此App使用 `gpt-3.5-turbo` 大语言模型
41
+ </div>
42
+ """
43
+
44
+ summarize_prompt = "你是谁?我们刚才聊了什么?" # 总结对话时的 prompt
45
+
46
+ MODELS = [
47
+ "gpt-3.5-turbo",
48
+ "gpt-3.5-turbo-0301",
49
+ "gpt-4",
50
+ "gpt-4-0314",
51
+ "gpt-4-32k",
52
+ "gpt-4-32k-0314",
53
+ ] # 可选的模型
54
+
55
+ REPLY_LANGUAGES = [
56
+ "中文",
57
+ "English",
58
+ "日本語",
59
+ "Español",
60
+ "Français",
61
+ "Deutsch",
62
+ "跟随问题语言(不稳定)"
63
+ ]
64
+
65
+
66
+ WEBSEARCH_PTOMPT_TEMPLATE = """\
67
+ Web search results:
68
+
69
+ {web_results}
70
+ Current date: {current_date}
71
+
72
+ Instructions: Using the provided web search results, write a comprehensive reply to the given query. Make sure to cite results using [[number](URL)] notation after the reference. If the provided search results refer to multiple subjects with the same name, write separate answers for each subject.
73
+ Query: {query}
74
+ Reply in {reply_language}
75
+ """
76
+
77
+ PROMPT_TEMPLATE = """\
78
+ Context information is below.
79
+ ---------------------
80
+ {context_str}
81
+ ---------------------
82
+ Current date: {current_date}.
83
+ Using the provided context information, write a comprehensive reply to the given query.
84
+ Make sure to cite results using [number] notation after the reference.
85
+ If the provided context information refer to multiple subjects with the same name, write separate answers for each subject.
86
+ Use prior knowledge only if the given context didn't provide enough information.
87
+ Answer the question: {query_str}
88
+ Reply in {reply_language}
89
+ """
90
+
91
+ REFINE_TEMPLATE = """\
92
+ The original question is as follows: {query_str}
93
+ We have provided an existing answer: {existing_answer}
94
+ We have the opportunity to refine the existing answer
95
+ (only if needed) with some more context below.
96
+ ------------
97
+ {context_msg}
98
+ ------------
99
+ Given the new context, refine the original answer to better
100
+ Reply in {reply_language}
101
+ If the context isn't useful, return the original answer.
102
+ """
103
+
104
+ ALREADY_CONVERTED_MARK = "<!-- ALREADY CONVERTED BY PARSER. -->"
105
+
106
+ small_and_beautiful_theme = gr.themes.Soft(
107
+ primary_hue=gr.themes.Color(
108
+ c50="#02C160",
109
+ c100="rgba(2, 193, 96, 0.2)",
110
+ c200="#02C160",
111
+ c300="rgba(2, 193, 96, 0.32)",
112
+ c400="rgba(2, 193, 96, 0.32)",
113
+ c500="rgba(2, 193, 96, 1.0)",
114
+ c600="rgba(2, 193, 96, 1.0)",
115
+ c700="rgba(2, 193, 96, 0.32)",
116
+ c800="rgba(2, 193, 96, 0.32)",
117
+ c900="#02C160",
118
+ c950="#02C160",
119
+ ),
120
+ secondary_hue=gr.themes.Color(
121
+ c50="#576b95",
122
+ c100="#576b95",
123
+ c200="#576b95",
124
+ c300="#576b95",
125
+ c400="#576b95",
126
+ c500="#576b95",
127
+ c600="#576b95",
128
+ c700="#576b95",
129
+ c800="#576b95",
130
+ c900="#576b95",
131
+ c950="#576b95",
132
+ ),
133
+ neutral_hue=gr.themes.Color(
134
+ name="gray",
135
+ c50="#f9fafb",
136
+ c100="#f3f4f6",
137
+ c200="#e5e7eb",
138
+ c300="#d1d5db",
139
+ c400="#B2B2B2",
140
+ c500="#808080",
141
+ c600="#636363",
142
+ c700="#515151",
143
+ c800="#393939",
144
+ c900="#272727",
145
+ c950="#171717",
146
+ ),
147
+ radius_size=gr.themes.sizes.radius_sm,
148
+ ).set(
149
+ button_primary_background_fill="#06AE56",
150
+ button_primary_background_fill_dark="#06AE56",
151
+ button_primary_background_fill_hover="#07C863",
152
+ button_primary_border_color="#06AE56",
153
+ button_primary_border_color_dark="#06AE56",
154
+ button_primary_text_color="#FFFFFF",
155
+ button_primary_text_color_dark="#FFFFFF",
156
+ button_secondary_background_fill="#F2F2F2",
157
+ button_secondary_background_fill_dark="#2B2B2B",
158
+ button_secondary_text_color="#393939",
159
+ button_secondary_text_color_dark="#FFFFFF",
160
+ # background_fill_primary="#F7F7F7",
161
+ # background_fill_primary_dark="#1F1F1F",
162
+ block_title_text_color="*primary_500",
163
+ block_title_background_fill="*primary_100",
164
+ input_background_fill="#F6F6F6",
165
+ )
modules/shared.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from modules.presets import API_URL
2
+
3
+ class State:
4
+ interrupted = False
5
+ api_url = API_URL
6
+
7
+ def interrupt(self):
8
+ self.interrupted = True
9
+
10
+ def recover(self):
11
+ self.interrupted = False
12
+
13
+ def set_api_url(self, api_url):
14
+ self.api_url = api_url
15
+
16
+ def reset_api_url(self):
17
+ self.api_url = API_URL
18
+ return self.api_url
19
+
20
+ def reset_all(self):
21
+ self.interrupted = False
22
+ self.api_url = API_URL
23
+
24
+ state = State()
modules/utils.py ADDED
@@ -0,0 +1,424 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding:utf-8 -*-
2
+ from __future__ import annotations
3
+ from typing import TYPE_CHECKING, Any, Callable, Dict, List, Tuple, Type
4
+ import logging
5
+ import json
6
+ import os
7
+ import datetime
8
+ import hashlib
9
+ import csv
10
+ import requests
11
+ import re
12
+ import html
13
+
14
+ import gradio as gr
15
+ from pypinyin import lazy_pinyin
16
+ import tiktoken
17
+ import mdtex2html
18
+ from markdown import markdown
19
+ from pygments import highlight
20
+ from pygments.lexers import get_lexer_by_name
21
+ from pygments.formatters import HtmlFormatter
22
+
23
+ from modules.presets import *
24
+ import modules.shared as shared
25
+
26
+ logging.basicConfig(
27
+ level=logging.INFO,
28
+ format="%(asctime)s [%(levelname)s] [%(filename)s:%(lineno)d] %(message)s",
29
+ )
30
+
31
+ if TYPE_CHECKING:
32
+ from typing import TypedDict
33
+
34
+ class DataframeData(TypedDict):
35
+ headers: List[str]
36
+ data: List[List[str | int | bool]]
37
+
38
+
39
+ def count_token(message):
40
+ encoding = tiktoken.get_encoding("cl100k_base")
41
+ input_str = f"role: {message['role']}, content: {message['content']}"
42
+ length = len(encoding.encode(input_str))
43
+ return length
44
+
45
+
46
+ def markdown_to_html_with_syntax_highlight(md_str):
47
+ def replacer(match):
48
+ lang = match.group(1) or "text"
49
+ code = match.group(2)
50
+
51
+ try:
52
+ lexer = get_lexer_by_name(lang, stripall=True)
53
+ except ValueError:
54
+ lexer = get_lexer_by_name("text", stripall=True)
55
+
56
+ formatter = HtmlFormatter()
57
+ highlighted_code = highlight(code, lexer, formatter)
58
+
59
+ return f'<pre><code class="{lang}">{highlighted_code}</code></pre>'
60
+
61
+ code_block_pattern = r"```(\w+)?\n([\s\S]+?)\n```"
62
+ md_str = re.sub(code_block_pattern, replacer, md_str, flags=re.MULTILINE)
63
+
64
+ html_str = markdown(md_str)
65
+ return html_str
66
+
67
+
68
+ def normalize_markdown(md_text: str) -> str:
69
+ lines = md_text.split("\n")
70
+ normalized_lines = []
71
+ inside_list = False
72
+
73
+ for i, line in enumerate(lines):
74
+ if re.match(r"^(\d+\.|-|\*|\+)\s", line.strip()):
75
+ if not inside_list and i > 0 and lines[i - 1].strip() != "":
76
+ normalized_lines.append("")
77
+ inside_list = True
78
+ normalized_lines.append(line)
79
+ elif inside_list and line.strip() == "":
80
+ if i < len(lines) - 1 and not re.match(
81
+ r"^(\d+\.|-|\*|\+)\s", lines[i + 1].strip()
82
+ ):
83
+ normalized_lines.append(line)
84
+ continue
85
+ else:
86
+ inside_list = False
87
+ normalized_lines.append(line)
88
+
89
+ return "\n".join(normalized_lines)
90
+
91
+
92
+ def convert_mdtext(md_text):
93
+ code_block_pattern = re.compile(r"```(.*?)(?:```|$)", re.DOTALL)
94
+ inline_code_pattern = re.compile(r"`(.*?)`", re.DOTALL)
95
+ code_blocks = code_block_pattern.findall(md_text)
96
+ non_code_parts = code_block_pattern.split(md_text)[::2]
97
+
98
+ result = []
99
+ for non_code, code in zip(non_code_parts, code_blocks + [""]):
100
+ if non_code.strip():
101
+ non_code = normalize_markdown(non_code)
102
+ if inline_code_pattern.search(non_code):
103
+ result.append(markdown(non_code, extensions=["tables"]))
104
+ else:
105
+ result.append(mdtex2html.convert(non_code, extensions=["tables"]))
106
+ if code.strip():
107
+ # _, code = detect_language(code) # 暂时去除代码高亮功能,因为在大段代码的情况下会出现问题
108
+ # code = code.replace("\n\n", "\n") # 暂时去除代码中的空行,因为在大段代码的情况下会出现问题
109
+ code = f"\n```{code}\n\n```"
110
+ code = markdown_to_html_with_syntax_highlight(code)
111
+ result.append(code)
112
+ result = "".join(result)
113
+ result += ALREADY_CONVERTED_MARK
114
+ return result
115
+
116
+
117
+ def convert_asis(userinput):
118
+ return f"<p style=\"white-space:pre-wrap;\">{html.escape(userinput)}</p>"+ALREADY_CONVERTED_MARK
119
+
120
+ def detect_converted_mark(userinput):
121
+ if userinput.endswith(ALREADY_CONVERTED_MARK):
122
+ return True
123
+ else:
124
+ return False
125
+
126
+
127
+ def detect_language(code):
128
+ if code.startswith("\n"):
129
+ first_line = ""
130
+ else:
131
+ first_line = code.strip().split("\n", 1)[0]
132
+ language = first_line.lower() if first_line else ""
133
+ code_without_language = code[len(first_line) :].lstrip() if first_line else code
134
+ return language, code_without_language
135
+
136
+
137
+ def construct_text(role, text):
138
+ return {"role": role, "content": text}
139
+
140
+
141
+ def construct_user(text):
142
+ return construct_text("user", text)
143
+
144
+
145
+ def construct_system(text):
146
+ return construct_text("system", text)
147
+
148
+
149
+ def construct_assistant(text):
150
+ return construct_text("assistant", text)
151
+
152
+
153
+ def construct_token_message(token, stream=False):
154
+ return f"Token 计数: {token}"
155
+
156
+
157
+ def delete_last_conversation(chatbot, history, previous_token_count):
158
+ if len(chatbot) > 0 and standard_error_msg in chatbot[-1][1]:
159
+ logging.info("由于包含报错信息,只删除chatbot记录")
160
+ chatbot.pop()
161
+ return chatbot, history
162
+ if len(history) > 0:
163
+ logging.info("删除了一组对话历史")
164
+ history.pop()
165
+ history.pop()
166
+ if len(chatbot) > 0:
167
+ logging.info("删除了一组chatbot对话")
168
+ chatbot.pop()
169
+ if len(previous_token_count) > 0:
170
+ logging.info("删除了一组对话的token计数记录")
171
+ previous_token_count.pop()
172
+ return (
173
+ chatbot,
174
+ history,
175
+ previous_token_count,
176
+ construct_token_message(sum(previous_token_count)),
177
+ )
178
+
179
+
180
+ def save_file(filename, system, history, chatbot):
181
+ logging.info("保存对话历史中……")
182
+ os.makedirs(HISTORY_DIR, exist_ok=True)
183
+ if filename.endswith(".json"):
184
+ json_s = {"system": system, "history": history, "chatbot": chatbot}
185
+ print(json_s)
186
+ with open(os.path.join(HISTORY_DIR, filename), "w") as f:
187
+ json.dump(json_s, f)
188
+ elif filename.endswith(".md"):
189
+ md_s = f"system: \n- {system} \n"
190
+ for data in history:
191
+ md_s += f"\n{data['role']}: \n- {data['content']} \n"
192
+ with open(os.path.join(HISTORY_DIR, filename), "w", encoding="utf8") as f:
193
+ f.write(md_s)
194
+ logging.info("保存对话历史完毕")
195
+ return os.path.join(HISTORY_DIR, filename)
196
+
197
+
198
+ def save_chat_history(filename, system, history, chatbot):
199
+ if filename == "":
200
+ return
201
+ if not filename.endswith(".json"):
202
+ filename += ".json"
203
+ return save_file(filename, system, history, chatbot)
204
+
205
+
206
+ def export_markdown(filename, system, history, chatbot):
207
+ if filename == "":
208
+ return
209
+ if not filename.endswith(".md"):
210
+ filename += ".md"
211
+ return save_file(filename, system, history, chatbot)
212
+
213
+
214
+ def load_chat_history(filename, system, history, chatbot):
215
+ logging.info("加载对话历史中……")
216
+ if type(filename) != str:
217
+ filename = filename.name
218
+ try:
219
+ with open(os.path.join(HISTORY_DIR, filename), "r") as f:
220
+ json_s = json.load(f)
221
+ try:
222
+ if type(json_s["history"][0]) == str:
223
+ logging.info("历史记录格式为旧版,正在转换……")
224
+ new_history = []
225
+ for index, item in enumerate(json_s["history"]):
226
+ if index % 2 == 0:
227
+ new_history.append(construct_user(item))
228
+ else:
229
+ new_history.append(construct_assistant(item))
230
+ json_s["history"] = new_history
231
+ logging.info(new_history)
232
+ except:
233
+ # 没有对话历史
234
+ pass
235
+ logging.info("加载对话历史完毕")
236
+ return filename, json_s["system"], json_s["history"], json_s["chatbot"]
237
+ except FileNotFoundError:
238
+ logging.info("没有找到对话历史文件,不执行任何操作")
239
+ return filename, system, history, chatbot
240
+
241
+
242
+ def sorted_by_pinyin(list):
243
+ return sorted(list, key=lambda char: lazy_pinyin(char)[0][0])
244
+
245
+
246
+ def get_file_names(dir, plain=False, filetypes=[".json"]):
247
+ logging.info(f"获取文件名列表,目录为{dir},文件类型为{filetypes},是否为纯文本列表{plain}")
248
+ files = []
249
+ try:
250
+ for type in filetypes:
251
+ files += [f for f in os.listdir(dir) if f.endswith(type)]
252
+ except FileNotFoundError:
253
+ files = []
254
+ files = sorted_by_pinyin(files)
255
+ if files == []:
256
+ files = [""]
257
+ if plain:
258
+ return files
259
+ else:
260
+ return gr.Dropdown.update(choices=files)
261
+
262
+
263
+ def get_history_names(plain=False):
264
+ logging.info("获取历史记录文件名列表")
265
+ return get_file_names(HISTORY_DIR, plain)
266
+
267
+
268
+ def load_template(filename, mode=0):
269
+ logging.info(f"加载模板文件{filename},模式为{mode}(0为返回字典和下拉菜单,1为返回下拉菜单,2为返回字典)")
270
+ lines = []
271
+ logging.info("Loading template...")
272
+ if filename.endswith(".json"):
273
+ with open(os.path.join(TEMPLATES_DIR, filename), "r", encoding="utf8") as f:
274
+ lines = json.load(f)
275
+ lines = [[i["act"], i["prompt"]] for i in lines]
276
+ else:
277
+ with open(
278
+ os.path.join(TEMPLATES_DIR, filename), "r", encoding="utf8"
279
+ ) as csvfile:
280
+ reader = csv.reader(csvfile)
281
+ lines = list(reader)
282
+ lines = lines[1:]
283
+ if mode == 1:
284
+ return sorted_by_pinyin([row[0] for row in lines])
285
+ elif mode == 2:
286
+ return {row[0]: row[1] for row in lines}
287
+ else:
288
+ choices = sorted_by_pinyin([row[0] for row in lines])
289
+ return {row[0]: row[1] for row in lines}, gr.Dropdown.update(
290
+ choices=choices, value=choices[0]
291
+ )
292
+
293
+
294
+ def get_template_names(plain=False):
295
+ logging.info("获取模板文件名列表")
296
+ return get_file_names(TEMPLATES_DIR, plain, filetypes=[".csv", "json"])
297
+
298
+
299
+ def get_template_content(templates, selection, original_system_prompt):
300
+ logging.info(f"应用模板中,选择为{selection},原始系统提示为{original_system_prompt}")
301
+ try:
302
+ return templates[selection]
303
+ except:
304
+ return original_system_prompt
305
+
306
+
307
+ def reset_state():
308
+ logging.info("重置状态")
309
+ return [], [], [], construct_token_message(0)
310
+
311
+
312
+ def reset_textbox():
313
+ logging.debug("重置文本框")
314
+ return gr.update(value="")
315
+
316
+
317
+ def reset_default():
318
+ newurl = shared.state.reset_api_url()
319
+ os.environ.pop("HTTPS_PROXY", None)
320
+ os.environ.pop("https_proxy", None)
321
+ return gr.update(value=newurl), gr.update(value=""), "API URL 和代理已重置"
322
+
323
+
324
+ def change_api_url(url):
325
+ shared.state.set_api_url(url)
326
+ msg = f"API地址更改为了{url}"
327
+ logging.info(msg)
328
+ return msg
329
+
330
+
331
+ def change_proxy(proxy):
332
+ os.environ["HTTPS_PROXY"] = proxy
333
+ msg = f"代理更改为了{proxy}"
334
+ logging.info(msg)
335
+ return msg
336
+
337
+
338
+ def hide_middle_chars(s):
339
+ if len(s) <= 8:
340
+ return s
341
+ else:
342
+ head = s[:4]
343
+ tail = s[-4:]
344
+ hidden = "*" * (len(s) - 8)
345
+ return head + hidden + tail
346
+
347
+
348
+ def submit_key(key):
349
+ key = key.strip()
350
+ msg = f"API密钥更改为了{hide_middle_chars(key)}"
351
+ logging.info(msg)
352
+ return key, msg
353
+
354
+
355
+ def sha1sum(filename):
356
+ sha1 = hashlib.sha1()
357
+ sha1.update(filename.encode("utf-8"))
358
+ return sha1.hexdigest()
359
+
360
+
361
+ def replace_today(prompt):
362
+ today = datetime.datetime.today().strftime("%Y-%m-%d")
363
+ return prompt.replace("{current_date}", today)
364
+
365
+
366
+ def get_geoip():
367
+ response = requests.get("https://ipapi.co/json/", timeout=5)
368
+ try:
369
+ data = response.json()
370
+ except:
371
+ data = {"error": True, "reason": "连接ipapi失败"}
372
+ if "error" in data.keys():
373
+ logging.warning(f"无法获取IP地址信息。\n{data}")
374
+ if data["reason"] == "RateLimited":
375
+ return (
376
+ f"获取IP地理位置失败,因为达到了检测IP的速率限制。聊天功能可能仍然可用,但请注意,如果您的IP地址在不受支持的地区,您可能会遇到问题。"
377
+ )
378
+ else:
379
+ return f"获取IP地理位置失败。原因:{data['reason']}。你仍然可以使用聊天功能。"
380
+ else:
381
+ country = data["country_name"]
382
+ if country == "China":
383
+ text = "**您的IP区域:中国。请立即检查代理设置,在不受支持的地区使用API可能导致账号被封禁。**"
384
+ else:
385
+ text = f"您的IP区域:{country}。"
386
+ logging.info(text)
387
+ return text
388
+
389
+
390
+ def find_n(lst, max_num):
391
+ n = len(lst)
392
+ total = sum(lst)
393
+
394
+ if total < max_num:
395
+ return n
396
+
397
+ for i in range(len(lst)):
398
+ if total - lst[i] < max_num:
399
+ return n - i - 1
400
+ total = total - lst[i]
401
+ return 1
402
+
403
+
404
+ def start_outputing():
405
+ logging.debug("显示取消按钮,隐藏发送按钮")
406
+ return gr.Button.update(visible=False), gr.Button.update(visible=True)
407
+
408
+
409
+ def end_outputing():
410
+ return (
411
+ gr.Button.update(visible=True),
412
+ gr.Button.update(visible=False),
413
+ )
414
+
415
+
416
+ def cancel_outputing():
417
+ logging.info("中止输出……")
418
+ shared.state.interrupt()
419
+
420
+ def transfer_input(inputs):
421
+ # 一次性返回,降低延迟
422
+ textbox = reset_textbox()
423
+ outputing = start_outputing()
424
+ return inputs, gr.update(value="")