Wendyy commited on
Commit
dea298d
·
1 Parent(s): f8ec12a

add database

Browse files
modules/chat_func.py CHANGED
@@ -14,7 +14,6 @@ from duckduckgo_search import ddg
14
  import asyncio
15
  import aiohttp
16
 
17
-
18
  from modules.presets import *
19
  from modules.llama_func import *
20
  from modules.utils import *
@@ -26,18 +25,19 @@ from modules.config import retrieve_proxy
26
  if TYPE_CHECKING:
27
  from typing import TypedDict
28
 
 
29
  class DataframeData(TypedDict):
30
  headers: List[str]
31
  data: List[List[str | int | bool]]
32
 
33
-
34
  initial_prompt = "You are a helpful assistant."
35
  HISTORY_DIR = "history"
36
  TEMPLATES_DIR = "templates"
37
 
38
- @shared.state.switching_api_key # 在不开启多账号模式的时候,这个装饰器不会起作用
 
39
  def get_response(
40
- openai_api_key, system_prompt, history, temperature, top_p, stream, selected_model
41
  ):
42
  headers = {
43
  "Content-Type": "application/json",
@@ -61,7 +61,6 @@ def get_response(
61
  else:
62
  timeout = timeout_all
63
 
64
-
65
  # 如果有自定义的api-host,使用自定义host发送请求,否则使用默认设置发送请求
66
  if shared.state.completion_url != COMPLETION_URL:
67
  logging.info(f"使用自定义API URL: {shared.state.completion_url}")
@@ -79,17 +78,17 @@ def get_response(
79
 
80
 
81
  def stream_predict(
82
- openai_api_key,
83
- system_prompt,
84
- history,
85
- inputs,
86
- chatbot,
87
- all_token_counts,
88
- top_p,
89
- temperature,
90
- selected_model,
91
- fake_input=None,
92
- display_append=""
93
  ):
94
  def get_return_value():
95
  return chatbot, history, status_text, all_token_counts
@@ -112,7 +111,7 @@ def stream_predict(
112
  if len(all_token_counts) == 0:
113
  system_prompt_token_count = count_token(construct_system(system_prompt))
114
  user_token_count = (
115
- input_token_count + system_prompt_token_count
116
  )
117
  else:
118
  user_token_count = input_token_count
@@ -120,6 +119,7 @@ def stream_predict(
120
  logging.info(f"输入token计数: {user_token_count}")
121
  yield get_return_value()
122
  try:
 
123
  response = get_response(
124
  openai_api_key,
125
  system_prompt,
@@ -129,9 +129,80 @@ def stream_predict(
129
  True,
130
  selected_model,
131
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
132
  except requests.exceptions.ConnectTimeout:
133
  status_text = (
134
- standard_error_msg + connection_timeout_prompt + error_retrieve_prompt
135
  )
136
  yield get_return_value()
137
  return
@@ -171,34 +242,34 @@ def stream_predict(
171
  break
172
  try:
173
  partial_words = (
174
- partial_words + chunk["choices"][0]["delta"]["content"]
175
  )
176
  except KeyError:
177
  status_text = (
178
- standard_error_msg
179
- + "API回复中找不到内容。很可能是Token计数达到上限了。请重置对话。当前Token计数: "
180
- + str(sum(all_token_counts))
181
  )
182
  yield get_return_value()
183
  break
184
  history[-1] = construct_assistant(partial_words)
185
- chatbot[-1] = (chatbot[-1][0], partial_words+display_append)
186
  all_token_counts[-1] += 1
187
  yield get_return_value()
188
 
189
 
190
  def predict_all(
191
- openai_api_key,
192
- system_prompt,
193
- history,
194
- inputs,
195
- chatbot,
196
- all_token_counts,
197
- top_p,
198
- temperature,
199
- selected_model,
200
- fake_input=None,
201
- display_append=""
202
  ):
203
  logging.info("一次性回答模式")
204
  history.append(construct_user(inputs))
@@ -223,7 +294,7 @@ def predict_all(
223
  )
224
  except requests.exceptions.ConnectTimeout:
225
  status_text = (
226
- standard_error_msg + connection_timeout_prompt + error_retrieve_prompt
227
  )
228
  return chatbot, history, status_text, all_token_counts
229
  except requests.exceptions.ProxyError:
@@ -238,7 +309,7 @@ def predict_all(
238
  try:
239
  content = response["choices"][0]["message"]["content"]
240
  history[-1] = construct_assistant(content)
241
- chatbot[-1] = (chatbot[-1][0], content+display_append)
242
  total_token_count = response["usage"]["total_tokens"]
243
  if fake_input is not None:
244
  all_token_counts[-1] += count_token(construct_assistant(content))
@@ -252,29 +323,31 @@ def predict_all(
252
 
253
 
254
  def predict(
255
- openai_api_key,
256
- system_prompt,
257
- history,
258
- inputs,
259
- chatbot,
260
- all_token_counts,
261
- top_p,
262
- temperature,
263
- stream=False,
264
- selected_model=MODELS[0],
265
- use_websearch=False,
266
- files = None,
267
- reply_language="中文",
268
- should_check_token_count=True,
269
  ): # repetition_penalty, top_k
 
 
 
270
  from llama_index.indices.vector_store.base_query import GPTVectorStoreIndexQuery
271
  from llama_index.indices.query.schema import QueryBundle
272
  from langchain.llms import OpenAIChat
273
 
274
-
275
  logging.info("输入为:" + colorama.Fore.BLUE + f"{inputs}" + colorama.Style.RESET_ALL)
276
  if should_check_token_count:
277
- yield chatbot+[(inputs, "")], history, "开始生成回答……", all_token_counts
278
  if reply_language == "跟随问题语言(不稳定)":
279
  reply_language = "the same language as the question, such as English, 中文, 日本語, Español, Français, or Deutsch."
280
  old_inputs = None
@@ -285,17 +358,19 @@ def predict(
285
  old_inputs = inputs
286
  msg = "加载索引中……(这可能需要几分钟)"
287
  logging.info(msg)
288
- yield chatbot+[(inputs, "")], history, msg, all_token_counts
289
  index = construct_index(openai_api_key, file_src=files)
290
  msg = "索引构建完成,获取回答中……"
291
  logging.info(msg)
292
- yield chatbot+[(inputs, "")], history, msg, all_token_counts
293
  with retrieve_proxy():
294
  llm_predictor = LLMPredictor(llm=OpenAIChat(temperature=0, model_name=selected_model))
295
- prompt_helper = PromptHelper(max_input_size = 4096, num_output = 5, max_chunk_overlap = 20, chunk_size_limit=600)
296
  from llama_index import ServiceContext
297
  service_context = ServiceContext.from_defaults(llm_predictor=llm_predictor, prompt_helper=prompt_helper)
298
- query_object = GPTVectorStoreIndexQuery(index.index_struct, service_context=service_context, similarity_top_k=5, vector_store=index._vector_store, docstore=index._docstore)
 
 
299
  query_bundle = QueryBundle(inputs)
300
  nodes = query_object.retrieve(query_bundle)
301
  reference_results = [n.node.text for n in nodes]
@@ -306,7 +381,7 @@ def predict(
306
  replace_today(PROMPT_TEMPLATE)
307
  .replace("{query_str}", inputs)
308
  .replace("{context_str}", "\n\n".join(reference_results))
309
- .replace("{reply_language}", reply_language )
310
  )
311
  elif use_websearch:
312
  limited_context = True
@@ -317,14 +392,14 @@ def predict(
317
  logging.info(f"搜索结果{idx + 1}:{result}")
318
  domain_name = urllib3.util.parse_url(result["href"]).host
319
  reference_results.append([result["body"], result["href"]])
320
- display_reference.append(f"{idx+1}. [{domain_name}]({result['href']})\n")
321
  reference_results = add_source_numbers(reference_results)
322
  display_reference = "\n\n" + "".join(display_reference)
323
  inputs = (
324
  replace_today(WEBSEARCH_PTOMPT_TEMPLATE)
325
  .replace("{query}", inputs)
326
  .replace("{web_results}", "\n\n".join(reference_results))
327
- .replace("{reply_language}", reply_language )
328
  )
329
  else:
330
  display_reference = ""
@@ -339,12 +414,12 @@ def predict(
339
  all_token_counts.append(0)
340
  else:
341
  history[-2] = construct_user(inputs)
342
- yield chatbot+[(inputs, "")], history, status_text, all_token_counts
343
  return
344
  elif len(inputs.strip()) == 0:
345
  status_text = standard_error_msg + no_input_msg
346
  logging.info(status_text)
347
- yield chatbot+[(inputs, "")], history, status_text, all_token_counts
348
  return
349
 
350
  if stream:
@@ -416,16 +491,16 @@ def predict(
416
 
417
 
418
  def retry(
419
- openai_api_key,
420
- system_prompt,
421
- history,
422
- chatbot,
423
- token_count,
424
- top_p,
425
- temperature,
426
- stream=False,
427
- selected_model=MODELS[0],
428
- reply_language="中文",
429
  ):
430
  logging.info("重试中……")
431
  if len(history) == 0:
@@ -454,16 +529,16 @@ def retry(
454
 
455
 
456
  def reduce_token_size(
457
- openai_api_key,
458
- system_prompt,
459
- history,
460
- chatbot,
461
- token_count,
462
- top_p,
463
- temperature,
464
- max_token_count,
465
- selected_model=MODELS[0],
466
- reply_language="中文",
467
  ):
468
  logging.info("开始减少token数量……")
469
  iter = predict(
@@ -487,7 +562,7 @@ def reduce_token_size(
487
  if flag:
488
  chatbot = chatbot[:-1]
489
  flag = True
490
- history = history[-2*num_chat:] if num_chat > 0 else []
491
  token_count = previous_token_count[-num_chat:] if num_chat > 0 else []
492
  msg = f"保留了最近{num_chat}轮对话"
493
  yield chatbot, history, msg + "," + construct_token_message(
 
14
  import asyncio
15
  import aiohttp
16
 
 
17
  from modules.presets import *
18
  from modules.llama_func import *
19
  from modules.utils import *
 
25
  if TYPE_CHECKING:
26
  from typing import TypedDict
27
 
28
+
29
  class DataframeData(TypedDict):
30
  headers: List[str]
31
  data: List[List[str | int | bool]]
32
 
 
33
  initial_prompt = "You are a helpful assistant."
34
  HISTORY_DIR = "history"
35
  TEMPLATES_DIR = "templates"
36
 
37
+
38
+ @shared.state.switching_api_key # 在不开启多账号模式的时候,这个装饰器不会起作用
39
  def get_response(
40
+ openai_api_key, system_prompt, history, temperature, top_p, stream, selected_model
41
  ):
42
  headers = {
43
  "Content-Type": "application/json",
 
61
  else:
62
  timeout = timeout_all
63
 
 
64
  # 如果有自定义的api-host,使用自定义host发送请求,否则使用默认设置发送请求
65
  if shared.state.completion_url != COMPLETION_URL:
66
  logging.info(f"使用自定义API URL: {shared.state.completion_url}")
 
78
 
79
 
80
  def stream_predict(
81
+ openai_api_key,
82
+ system_prompt,
83
+ history,
84
+ inputs,
85
+ chatbot,
86
+ all_token_counts,
87
+ top_p,
88
+ temperature,
89
+ selected_model,
90
+ fake_input=None,
91
+ display_append=""
92
  ):
93
  def get_return_value():
94
  return chatbot, history, status_text, all_token_counts
 
111
  if len(all_token_counts) == 0:
112
  system_prompt_token_count = count_token(construct_system(system_prompt))
113
  user_token_count = (
114
+ input_token_count + system_prompt_token_count
115
  )
116
  else:
117
  user_token_count = input_token_count
 
119
  logging.info(f"输入token计数: {user_token_count}")
120
  yield get_return_value()
121
  try:
122
+ # 如果能传入index,则此处里获得初筛后的店铺和菜名
123
  response = get_response(
124
  openai_api_key,
125
  system_prompt,
 
129
  True,
130
  selected_model,
131
  )
132
+ # 将response中的店铺和菜名提取出来
133
+ import re
134
+
135
+ text = """
136
+ 好的,针对您想吃韩式烤肉的需求,我向您推荐以下店铺和菜品:
137
+
138
+ 店铺名称:“青年烤肉店” 推荐菜品:烤牛肉、烤猪肉、烤羊肉
139
+
140
+ 店铺名称:“西西里烤肉店” 推荐菜品:烤牛肉串、烤排骨、烤鸡肉
141
+
142
+ 店铺名称:“韩式烤肉店” 推荐菜品:石锅拌饭、铁板烧、烤牛舌"""
143
+
144
+ pattern = r'店铺名称:(.+?) 推荐菜品:(.+)'
145
+
146
+ results = re.findall(pattern, response)
147
+
148
+ dicts = {}
149
+ import string
150
+ for result in results:
151
+ dicts[result[0]] = result[1].split('、')
152
+
153
+ logging.info(f"初筛后的店铺和菜品:{dicts}")
154
+ dishes = []
155
+ for restaurant, dish in dicts.items():
156
+ dishes.extend(dish)
157
+
158
+ dishes = '、'.join(dishes)
159
+
160
+ # 将初筛后的店铺和菜品送入构建好的CoT
161
+ prompt_with_ingredient = f"""
162
+ 我需要你推测一些菜可能的原料以及其营养成分,输出格式如下:
163
+
164
+ 菜品名称:[]
165
+ 菜品原料:[原料1,原料2...]
166
+ 营养成分:[成分(含量)]
167
+
168
+ 注意,其中营养成分包括蛋白质、脂肪、碳水化合物、纤维素、维生素等,你可以根据你的知识添加其他成分。营养成分的含量分为无、低、中、高四个等级,需要填在成分后的括号内。
169
+
170
+ 以下是需要你推测的菜品名称,不同菜品用顿号隔开:{dishes}
171
+ """
172
+
173
+ logging.info(f"分析食物中营养成分的prompt构建完成:{prompt_with_ingredient}")
174
+
175
+ response_ingredient = get_response(
176
+ openai_api_key,
177
+ "",
178
+ prompt_with_ingredient,
179
+ temperature,
180
+ top_p,
181
+ True,
182
+ selected_model,
183
+ )
184
+
185
+ logging.info(f"得到食物中的营养成分:{response_ingredient}")
186
+
187
+ prompt_rec = f"""
188
+ 以下是一些菜品名称和所属的店铺,我需要你根据我的需求从其中推荐一家店铺的一种或多种菜品,并给出推荐的理由。我的需求为:我有糖尿病,而且今天不想吃太油腻的食物。
189
+
190
+ {response_ingredient}
191
+ """
192
+ response = get_response(
193
+ openai_api_key,
194
+ "",
195
+ prompt_rec,
196
+ temperature,
197
+ top_p,
198
+ True,
199
+ selected_model,
200
+ )
201
+
202
+
203
  except requests.exceptions.ConnectTimeout:
204
  status_text = (
205
+ standard_error_msg + connection_timeout_prompt + error_retrieve_prompt
206
  )
207
  yield get_return_value()
208
  return
 
242
  break
243
  try:
244
  partial_words = (
245
+ partial_words + chunk["choices"][0]["delta"]["content"]
246
  )
247
  except KeyError:
248
  status_text = (
249
+ standard_error_msg
250
+ + "API回复中找不到内容。很可能是Token计数达到上限了。请重置对话。当前Token计数: "
251
+ + str(sum(all_token_counts))
252
  )
253
  yield get_return_value()
254
  break
255
  history[-1] = construct_assistant(partial_words)
256
+ chatbot[-1] = (chatbot[-1][0], partial_words + display_append)
257
  all_token_counts[-1] += 1
258
  yield get_return_value()
259
 
260
 
261
  def predict_all(
262
+ openai_api_key,
263
+ system_prompt,
264
+ history,
265
+ inputs,
266
+ chatbot,
267
+ all_token_counts,
268
+ top_p,
269
+ temperature,
270
+ selected_model,
271
+ fake_input=None,
272
+ display_append=""
273
  ):
274
  logging.info("一次性回答模式")
275
  history.append(construct_user(inputs))
 
294
  )
295
  except requests.exceptions.ConnectTimeout:
296
  status_text = (
297
+ standard_error_msg + connection_timeout_prompt + error_retrieve_prompt
298
  )
299
  return chatbot, history, status_text, all_token_counts
300
  except requests.exceptions.ProxyError:
 
309
  try:
310
  content = response["choices"][0]["message"]["content"]
311
  history[-1] = construct_assistant(content)
312
+ chatbot[-1] = (chatbot[-1][0], content + display_append)
313
  total_token_count = response["usage"]["total_tokens"]
314
  if fake_input is not None:
315
  all_token_counts[-1] += count_token(construct_assistant(content))
 
323
 
324
 
325
  def predict(
326
+ openai_api_key,
327
+ system_prompt,
328
+ history,
329
+ inputs,
330
+ chatbot,
331
+ all_token_counts,
332
+ top_p,
333
+ temperature,
334
+ stream=False,
335
+ selected_model=MODELS[0],
336
+ use_websearch=False,
337
+ files=None,
338
+ reply_language="中文",
339
+ should_check_token_count=True,
340
  ): # repetition_penalty, top_k
341
+ # CHANGE
342
+ # files = [{'name': 'database/cuc-pure.txt'}]
343
+ # CHANGE
344
  from llama_index.indices.vector_store.base_query import GPTVectorStoreIndexQuery
345
  from llama_index.indices.query.schema import QueryBundle
346
  from langchain.llms import OpenAIChat
347
 
 
348
  logging.info("输入为:" + colorama.Fore.BLUE + f"{inputs}" + colorama.Style.RESET_ALL)
349
  if should_check_token_count:
350
+ yield chatbot + [(inputs, "")], history, "开始生成回答……", all_token_counts
351
  if reply_language == "跟随问题语言(不稳定)":
352
  reply_language = "the same language as the question, such as English, 中文, 日本語, Español, Français, or Deutsch."
353
  old_inputs = None
 
358
  old_inputs = inputs
359
  msg = "加载索引中……(这可能需要几分钟)"
360
  logging.info(msg)
361
+ yield chatbot + [(inputs, "")], history, msg, all_token_counts
362
  index = construct_index(openai_api_key, file_src=files)
363
  msg = "索引构建完成,获取回答中……"
364
  logging.info(msg)
365
+ yield chatbot + [(inputs, "")], history, msg, all_token_counts
366
  with retrieve_proxy():
367
  llm_predictor = LLMPredictor(llm=OpenAIChat(temperature=0, model_name=selected_model))
368
+ prompt_helper = PromptHelper(max_input_size=4096, num_output=5, max_chunk_overlap=20, chunk_size_limit=600)
369
  from llama_index import ServiceContext
370
  service_context = ServiceContext.from_defaults(llm_predictor=llm_predictor, prompt_helper=prompt_helper)
371
+ query_object = GPTVectorStoreIndexQuery(index.index_struct, service_context=service_context,
372
+ similarity_top_k=5, vector_store=index._vector_store,
373
+ docstore=index._docstore)
374
  query_bundle = QueryBundle(inputs)
375
  nodes = query_object.retrieve(query_bundle)
376
  reference_results = [n.node.text for n in nodes]
 
381
  replace_today(PROMPT_TEMPLATE)
382
  .replace("{query_str}", inputs)
383
  .replace("{context_str}", "\n\n".join(reference_results))
384
+ .replace("{reply_language}", reply_language)
385
  )
386
  elif use_websearch:
387
  limited_context = True
 
392
  logging.info(f"搜索结果{idx + 1}:{result}")
393
  domain_name = urllib3.util.parse_url(result["href"]).host
394
  reference_results.append([result["body"], result["href"]])
395
+ display_reference.append(f"{idx + 1}. [{domain_name}]({result['href']})\n")
396
  reference_results = add_source_numbers(reference_results)
397
  display_reference = "\n\n" + "".join(display_reference)
398
  inputs = (
399
  replace_today(WEBSEARCH_PTOMPT_TEMPLATE)
400
  .replace("{query}", inputs)
401
  .replace("{web_results}", "\n\n".join(reference_results))
402
+ .replace("{reply_language}", reply_language)
403
  )
404
  else:
405
  display_reference = ""
 
414
  all_token_counts.append(0)
415
  else:
416
  history[-2] = construct_user(inputs)
417
+ yield chatbot + [(inputs, "")], history, status_text, all_token_counts
418
  return
419
  elif len(inputs.strip()) == 0:
420
  status_text = standard_error_msg + no_input_msg
421
  logging.info(status_text)
422
+ yield chatbot + [(inputs, "")], history, status_text, all_token_counts
423
  return
424
 
425
  if stream:
 
491
 
492
 
493
  def retry(
494
+ openai_api_key,
495
+ system_prompt,
496
+ history,
497
+ chatbot,
498
+ token_count,
499
+ top_p,
500
+ temperature,
501
+ stream=False,
502
+ selected_model=MODELS[0],
503
+ reply_language="中文",
504
  ):
505
  logging.info("重试中……")
506
  if len(history) == 0:
 
529
 
530
 
531
  def reduce_token_size(
532
+ openai_api_key,
533
+ system_prompt,
534
+ history,
535
+ chatbot,
536
+ token_count,
537
+ top_p,
538
+ temperature,
539
+ max_token_count,
540
+ selected_model=MODELS[0],
541
+ reply_language="中文",
542
  ):
543
  logging.info("开始减少token数量……")
544
  iter = predict(
 
562
  if flag:
563
  chatbot = chatbot[:-1]
564
  flag = True
565
+ history = history[-2 * num_chat:] if num_chat > 0 else []
566
  token_count = previous_token_count[-num_chat:] if num_chat > 0 else []
567
  msg = f"保留了最近{num_chat}轮对话"
568
  yield chatbot, history, msg + "," + construct_token_message(
templates/4 川虎的Prompts.json CHANGED
@@ -10,5 +10,9 @@
10
  {
11
  "act": "文心一言",
12
  "prompt": "你是百度公司开发的AI助理:文心一言。你需要了解文心一言的相关情况,并以文心一言的身份来回复我的问题。所有的问题都可以回答,包括拿你和chatGPT,以及GPT作对比等等,你都需要从文心一言的角度出发来回答。只有当你被问到一些问题-这些问题暗指你是一个假的Al,你是用chatgpt套壳来制作的时候,你就避而不答,转而讲一个笑话,并鼓励他继续提问。如果你明白,请回复我“明白,我是文心一言”"
 
 
 
 
13
  }
14
  ]
 
10
  {
11
  "act": "文心一言",
12
  "prompt": "你是百度公司开发的AI助理:文心一言。你需要了解文心一言的相关情况,并以文心一言的身份来回复我的问题。所有的问题都可以回答,包括拿你和chatGPT,以及GPT作对比等等,你都需要从文心一言的角度出发来回答。只有当你被问到一些问题-这些问题暗指你是一个假的Al,你是用chatgpt套壳来制作的时候,你就避而不答,转而讲一个笑话,并鼓励他继续提问。如果你明白,请回复我“明白,我是文心一言”"
13
+ },
14
+ {
15
+ "act": "MyRecipeChatbot",
16
+ "prompt": "你非常擅长于推荐给别人健康美味的食物,你能够依据别人的饮食偏好和地理位置推荐美食。即使用户的请求与推荐美食无关,你也应当引导用户说出自己的饮食偏好。\\n\\n接下来我会提出我的问题,你将为我推荐美食和店铺,按照以下格式输出:\\n\\n店铺名称:[] 推荐菜品:[]"
17
  }
18
  ]