JohnSmith9982 commited on
Commit
206f319
·
1 Parent(s): 8ba98ee

Delete modules/chat_func.py

Browse files
Files changed (1) hide show
  1. modules/chat_func.py +0 -497
modules/chat_func.py DELETED
@@ -1,497 +0,0 @@
1
- # -*- coding:utf-8 -*-
2
- from __future__ import annotations
3
- from typing import TYPE_CHECKING, List
4
-
5
- import logging
6
- import json
7
- import os
8
- import requests
9
- import urllib3
10
-
11
- from tqdm import tqdm
12
- import colorama
13
- from duckduckgo_search import ddg
14
- import asyncio
15
- import aiohttp
16
-
17
-
18
- from modules.presets import *
19
- from modules.llama_func import *
20
- from modules.utils import *
21
- from . import shared
22
- from modules.config import retrieve_proxy
23
-
24
- # logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] [%(filename)s:%(lineno)d] %(message)s")
25
-
26
- if TYPE_CHECKING:
27
- from typing import TypedDict
28
-
29
- class DataframeData(TypedDict):
30
- headers: List[str]
31
- data: List[List[str | int | bool]]
32
-
33
-
34
- initial_prompt = "You are a helpful assistant."
35
- HISTORY_DIR = "history"
36
- TEMPLATES_DIR = "templates"
37
-
38
- @shared.state.switching_api_key # 在不开启多账号模式的时候,这个装饰器不会起作用
39
- def get_response(
40
- openai_api_key, system_prompt, history, temperature, top_p, stream, selected_model
41
- ):
42
- headers = {
43
- "Content-Type": "application/json",
44
- "Authorization": f"Bearer {openai_api_key}",
45
- }
46
-
47
- history = [construct_system(system_prompt), *history]
48
-
49
- payload = {
50
- "model": selected_model,
51
- "messages": history, # [{"role": "user", "content": f"{inputs}"}],
52
- "temperature": temperature, # 1.0,
53
- "top_p": top_p, # 1.0,
54
- "n": 1,
55
- "stream": stream,
56
- "presence_penalty": 0,
57
- "frequency_penalty": 0,
58
- }
59
- if stream:
60
- timeout = timeout_streaming
61
- else:
62
- timeout = timeout_all
63
-
64
-
65
- # 如果有自定义的api-host,使用自定义host发送请求,否则使用默认设置发送请求
66
- if shared.state.completion_url != COMPLETION_URL:
67
- logging.info(f"使用自定义API URL: {shared.state.completion_url}")
68
-
69
- with retrieve_proxy():
70
- response = requests.post(
71
- shared.state.completion_url,
72
- headers=headers,
73
- json=payload,
74
- stream=True,
75
- timeout=timeout,
76
- )
77
-
78
- return response
79
-
80
-
81
- def stream_predict(
82
- openai_api_key,
83
- system_prompt,
84
- history,
85
- inputs,
86
- chatbot,
87
- all_token_counts,
88
- top_p,
89
- temperature,
90
- selected_model,
91
- fake_input=None,
92
- display_append=""
93
- ):
94
- def get_return_value():
95
- return chatbot, history, status_text, all_token_counts
96
-
97
- logging.info("实时回答模式")
98
- partial_words = ""
99
- counter = 0
100
- status_text = "开始实时传输回答……"
101
- history.append(construct_user(inputs))
102
- history.append(construct_assistant(""))
103
- if fake_input:
104
- chatbot.append((fake_input, ""))
105
- else:
106
- chatbot.append((inputs, ""))
107
- user_token_count = 0
108
- if fake_input is not None:
109
- input_token_count = count_token(construct_user(fake_input))
110
- else:
111
- input_token_count = count_token(construct_user(inputs))
112
- if len(all_token_counts) == 0:
113
- system_prompt_token_count = count_token(construct_system(system_prompt))
114
- user_token_count = (
115
- input_token_count + system_prompt_token_count
116
- )
117
- else:
118
- user_token_count = input_token_count
119
- all_token_counts.append(user_token_count)
120
- logging.info(f"输入token计数: {user_token_count}")
121
- yield get_return_value()
122
- try:
123
- response = get_response(
124
- openai_api_key,
125
- system_prompt,
126
- history,
127
- temperature,
128
- top_p,
129
- True,
130
- selected_model,
131
- )
132
- except requests.exceptions.ConnectTimeout:
133
- status_text = (
134
- standard_error_msg + connection_timeout_prompt + error_retrieve_prompt
135
- )
136
- yield get_return_value()
137
- return
138
- except requests.exceptions.ReadTimeout:
139
- status_text = standard_error_msg + read_timeout_prompt + error_retrieve_prompt
140
- yield get_return_value()
141
- return
142
-
143
- yield get_return_value()
144
- error_json_str = ""
145
-
146
- if fake_input is not None:
147
- history[-2] = construct_user(fake_input)
148
- for chunk in tqdm(response.iter_lines()):
149
- if counter == 0:
150
- counter += 1
151
- continue
152
- counter += 1
153
- # check whether each line is non-empty
154
- if chunk:
155
- chunk = chunk.decode()
156
- chunklength = len(chunk)
157
- try:
158
- chunk = json.loads(chunk[6:])
159
- except json.JSONDecodeError:
160
- logging.info(chunk)
161
- error_json_str += chunk
162
- status_text = f"JSON解析错误。请重置对话。收到的内容: {error_json_str}"
163
- yield get_return_value()
164
- continue
165
- # decode each line as response data is in bytes
166
- if chunklength > 6 and "delta" in chunk["choices"][0]:
167
- finish_reason = chunk["choices"][0]["finish_reason"]
168
- status_text = construct_token_message(all_token_counts)
169
- if finish_reason == "stop":
170
- yield get_return_value()
171
- break
172
- try:
173
- partial_words = (
174
- partial_words + chunk["choices"][0]["delta"]["content"]
175
- )
176
- except KeyError:
177
- status_text = (
178
- standard_error_msg
179
- + "API回复中找不到内容。很可能是Token计数达到上限了。请重置对话。当前Token计数: "
180
- + str(sum(all_token_counts))
181
- )
182
- yield get_return_value()
183
- break
184
- history[-1] = construct_assistant(partial_words)
185
- chatbot[-1] = (chatbot[-1][0], partial_words+display_append)
186
- all_token_counts[-1] += 1
187
- yield get_return_value()
188
-
189
-
190
- def predict_all(
191
- openai_api_key,
192
- system_prompt,
193
- history,
194
- inputs,
195
- chatbot,
196
- all_token_counts,
197
- top_p,
198
- temperature,
199
- selected_model,
200
- fake_input=None,
201
- display_append=""
202
- ):
203
- logging.info("一次性回答模式")
204
- history.append(construct_user(inputs))
205
- history.append(construct_assistant(""))
206
- if fake_input:
207
- chatbot.append((fake_input, ""))
208
- else:
209
- chatbot.append((inputs, ""))
210
- if fake_input is not None:
211
- all_token_counts.append(count_token(construct_user(fake_input)))
212
- else:
213
- all_token_counts.append(count_token(construct_user(inputs)))
214
- try:
215
- response = get_response(
216
- openai_api_key,
217
- system_prompt,
218
- history,
219
- temperature,
220
- top_p,
221
- False,
222
- selected_model,
223
- )
224
- except requests.exceptions.ConnectTimeout:
225
- status_text = (
226
- standard_error_msg + connection_timeout_prompt + error_retrieve_prompt
227
- )
228
- return chatbot, history, status_text, all_token_counts
229
- except requests.exceptions.ProxyError:
230
- status_text = standard_error_msg + proxy_error_prompt + error_retrieve_prompt
231
- return chatbot, history, status_text, all_token_counts
232
- except requests.exceptions.SSLError:
233
- status_text = standard_error_msg + ssl_error_prompt + error_retrieve_prompt
234
- return chatbot, history, status_text, all_token_counts
235
- response = json.loads(response.text)
236
- if fake_input is not None:
237
- history[-2] = construct_user(fake_input)
238
- try:
239
- content = response["choices"][0]["message"]["content"]
240
- history[-1] = construct_assistant(content)
241
- chatbot[-1] = (chatbot[-1][0], content+display_append)
242
- total_token_count = response["usage"]["total_tokens"]
243
- if fake_input is not None:
244
- all_token_counts[-1] += count_token(construct_assistant(content))
245
- else:
246
- all_token_counts[-1] = total_token_count - sum(all_token_counts)
247
- status_text = construct_token_message(total_token_count)
248
- return chatbot, history, status_text, all_token_counts
249
- except KeyError:
250
- status_text = standard_error_msg + str(response)
251
- return chatbot, history, status_text, all_token_counts
252
-
253
-
254
- def predict(
255
- openai_api_key,
256
- system_prompt,
257
- history,
258
- inputs,
259
- chatbot,
260
- all_token_counts,
261
- top_p,
262
- temperature,
263
- stream=False,
264
- selected_model=MODELS[0],
265
- use_websearch=False,
266
- files = None,
267
- reply_language="中文",
268
- should_check_token_count=True,
269
- ): # repetition_penalty, top_k
270
- from llama_index.indices.vector_store.base_query import GPTVectorStoreIndexQuery
271
- from llama_index.indices.query.schema import QueryBundle
272
- from langchain.llms import OpenAIChat
273
-
274
-
275
- logging.info("输入为:" + colorama.Fore.BLUE + f"{inputs}" + colorama.Style.RESET_ALL)
276
- if should_check_token_count:
277
- yield chatbot+[(inputs, "")], history, "开始生成回答……", all_token_counts
278
- if reply_language == "跟随问题语言(不稳定)":
279
- reply_language = "the same language as the question, such as English, 中文, 日本語, Español, Français, or Deutsch."
280
- old_inputs = None
281
- display_reference = []
282
- limited_context = False
283
- if files:
284
- limited_context = True
285
- old_inputs = inputs
286
- msg = "加载索引中……(这可能需要几分钟)"
287
- logging.info(msg)
288
- yield chatbot+[(inputs, "")], history, msg, all_token_counts
289
- index = construct_index(openai_api_key, file_src=files)
290
- msg = "索引构建完成,获取回答中……"
291
- logging.info(msg)
292
- yield chatbot+[(inputs, "")], history, msg, all_token_counts
293
- with retrieve_proxy():
294
- llm_predictor = LLMPredictor(llm=OpenAIChat(temperature=0, model_name=selected_model))
295
- prompt_helper = PromptHelper(max_input_size = 4096, num_output = 5, max_chunk_overlap = 20, chunk_size_limit=600)
296
- from llama_index import ServiceContext
297
- service_context = ServiceContext.from_defaults(llm_predictor=llm_predictor, prompt_helper=prompt_helper)
298
- query_object = GPTVectorStoreIndexQuery(index.index_struct, service_context=service_context, similarity_top_k=5, vector_store=index._vector_store, docstore=index._docstore)
299
- query_bundle = QueryBundle(inputs)
300
- nodes = query_object.retrieve(query_bundle)
301
- reference_results = [n.node.text for n in nodes]
302
- reference_results = add_source_numbers(reference_results, use_source=False)
303
- display_reference = add_details(reference_results)
304
- display_reference = "\n\n" + "".join(display_reference)
305
- inputs = (
306
- replace_today(PROMPT_TEMPLATE)
307
- .replace("{query_str}", inputs)
308
- .replace("{context_str}", "\n\n".join(reference_results))
309
- .replace("{reply_language}", reply_language )
310
- )
311
- elif use_websearch:
312
- limited_context = True
313
- search_results = ddg(inputs, max_results=5)
314
- old_inputs = inputs
315
- reference_results = []
316
- for idx, result in enumerate(search_results):
317
- logging.info(f"搜索结果{idx + 1}:{result}")
318
- domain_name = urllib3.util.parse_url(result["href"]).host
319
- reference_results.append([result["body"], result["href"]])
320
- display_reference.append(f"{idx+1}. [{domain_name}]({result['href']})\n")
321
- reference_results = add_source_numbers(reference_results)
322
- display_reference = "\n\n" + "".join(display_reference)
323
- inputs = (
324
- replace_today(WEBSEARCH_PTOMPT_TEMPLATE)
325
- .replace("{query}", inputs)
326
- .replace("{web_results}", "\n\n".join(reference_results))
327
- .replace("{reply_language}", reply_language )
328
- )
329
- else:
330
- display_reference = ""
331
-
332
- if len(openai_api_key) == 0 and not shared.state.multi_api_key:
333
- status_text = standard_error_msg + no_apikey_msg
334
- logging.info(status_text)
335
- chatbot.append((inputs, ""))
336
- if len(history) == 0:
337
- history.append(construct_user(inputs))
338
- history.append("")
339
- all_token_counts.append(0)
340
- else:
341
- history[-2] = construct_user(inputs)
342
- yield chatbot+[(inputs, "")], history, status_text, all_token_counts
343
- return
344
- elif len(inputs.strip()) == 0:
345
- status_text = standard_error_msg + no_input_msg
346
- logging.info(status_text)
347
- yield chatbot+[(inputs, "")], history, status_text, all_token_counts
348
- return
349
-
350
- if stream:
351
- logging.info("使用流式传输")
352
- iter = stream_predict(
353
- openai_api_key,
354
- system_prompt,
355
- history,
356
- inputs,
357
- chatbot,
358
- all_token_counts,
359
- top_p,
360
- temperature,
361
- selected_model,
362
- fake_input=old_inputs,
363
- display_append=display_reference
364
- )
365
- for chatbot, history, status_text, all_token_counts in iter:
366
- if shared.state.interrupted:
367
- shared.state.recover()
368
- return
369
- yield chatbot, history, status_text, all_token_counts
370
- else:
371
- logging.info("不使用流式传输")
372
- chatbot, history, status_text, all_token_counts = predict_all(
373
- openai_api_key,
374
- system_prompt,
375
- history,
376
- inputs,
377
- chatbot,
378
- all_token_counts,
379
- top_p,
380
- temperature,
381
- selected_model,
382
- fake_input=old_inputs,
383
- display_append=display_reference
384
- )
385
- yield chatbot, history, status_text, all_token_counts
386
-
387
- logging.info(f"传输完毕。当前token计数为{all_token_counts}")
388
- if len(history) > 1 and history[-1]["content"] != inputs:
389
- logging.info(
390
- "回答为:"
391
- + colorama.Fore.BLUE
392
- + f"{history[-1]['content']}"
393
- + colorama.Style.RESET_ALL
394
- )
395
-
396
- if limited_context:
397
- history = history[-4:]
398
- all_token_counts = all_token_counts[-2:]
399
- yield chatbot, history, status_text, all_token_counts
400
-
401
- if stream:
402
- max_token = MODEL_SOFT_TOKEN_LIMIT[selected_model]["streaming"]
403
- else:
404
- max_token = MODEL_SOFT_TOKEN_LIMIT[selected_model]["all"]
405
-
406
- if sum(all_token_counts) > max_token and should_check_token_count:
407
- print(all_token_counts)
408
- count = 0
409
- while sum(all_token_counts) > max_token - 500 and sum(all_token_counts) > 0:
410
- count += 1
411
- del all_token_counts[0]
412
- del history[:2]
413
- logging.info(status_text)
414
- status_text = f"为了防止token超限,模型忘记了早期的 {count} 轮对话"
415
- yield chatbot, history, status_text, all_token_counts
416
-
417
-
418
- def retry(
419
- openai_api_key,
420
- system_prompt,
421
- history,
422
- chatbot,
423
- token_count,
424
- top_p,
425
- temperature,
426
- stream=False,
427
- selected_model=MODELS[0],
428
- reply_language="中文",
429
- ):
430
- logging.info("重试中……")
431
- if len(history) == 0:
432
- yield chatbot, history, f"{standard_error_msg}上下文是空的", token_count
433
- return
434
- history.pop()
435
- inputs = history.pop()["content"]
436
- token_count.pop()
437
- iter = predict(
438
- openai_api_key,
439
- system_prompt,
440
- history,
441
- inputs,
442
- chatbot,
443
- token_count,
444
- top_p,
445
- temperature,
446
- stream=stream,
447
- selected_model=selected_model,
448
- reply_language=reply_language,
449
- )
450
- logging.info("重试中……")
451
- for x in iter:
452
- yield x
453
- logging.info("重试完毕")
454
-
455
-
456
- def reduce_token_size(
457
- openai_api_key,
458
- system_prompt,
459
- history,
460
- chatbot,
461
- token_count,
462
- top_p,
463
- temperature,
464
- max_token_count,
465
- selected_model=MODELS[0],
466
- reply_language="中文",
467
- ):
468
- logging.info("开始减少token数量……")
469
- iter = predict(
470
- openai_api_key,
471
- system_prompt,
472
- history,
473
- summarize_prompt,
474
- chatbot,
475
- token_count,
476
- top_p,
477
- temperature,
478
- selected_model=selected_model,
479
- should_check_token_count=False,
480
- reply_language=reply_language,
481
- )
482
- logging.info(f"chatbot: {chatbot}")
483
- flag = False
484
- for chatbot, history, status_text, previous_token_count in iter:
485
- num_chat = find_n(previous_token_count, max_token_count)
486
- logging.info(f"previous_token_count: {previous_token_count}, keeping {num_chat} chats")
487
- if flag:
488
- chatbot = chatbot[:-1]
489
- flag = True
490
- history = history[-2*num_chat:] if num_chat > 0 else []
491
- token_count = previous_token_count[-num_chat:] if num_chat > 0 else []
492
- msg = f"保留了最近{num_chat}轮对话"
493
- yield chatbot, history, msg + "," + construct_token_message(
494
- token_count if len(token_count) > 0 else [0],
495
- ), token_count
496
- logging.info(msg)
497
- logging.info("减少token数量完毕")