Tuchuanhuhuhu commited on
Commit
a2dfe6a
·
1 Parent(s): b0a1d94

改进减少token逻辑

Browse files
Files changed (3) hide show
  1. ChuanhuChatbot.py +1 -1
  2. chat_func.py +15 -12
  3. utils.py +46 -26
ChuanhuChatbot.py CHANGED
@@ -359,7 +359,7 @@ with gr.Blocks(
359
  token_count,
360
  top_p,
361
  temperature,
362
- use_streaming_checkbox,
363
  model_select_dropdown,
364
  ],
365
  [chatbot, history, status_display, token_count],
 
359
  token_count,
360
  top_p,
361
  temperature,
362
+ gr.State(0),
363
  model_select_dropdown,
364
  ],
365
  [chatbot, history, status_display, token_count],
chat_func.py CHANGED
@@ -371,9 +371,8 @@ def predict(
371
  all_token_counts,
372
  top_p,
373
  temperature,
374
- stream=False,
375
  selected_model=selected_model,
376
- hidden=True,
377
  )
378
  for chatbot, history, status_text, all_token_counts in iter:
379
  status_text = f"Token 达到上限,已自动降低Token计数至 {status_text}"
@@ -410,9 +409,10 @@ def retry(
410
  stream=stream,
411
  selected_model=selected_model,
412
  )
413
- logging.info("重试完毕")
414
  for x in iter:
415
  yield x
 
416
 
417
 
418
  def reduce_token_size(
@@ -423,9 +423,8 @@ def reduce_token_size(
423
  token_count,
424
  top_p,
425
  temperature,
426
- stream=False,
427
  selected_model=MODELS[0],
428
- hidden=False,
429
  ):
430
  logging.info("开始减少token数量……")
431
  iter = predict(
@@ -437,17 +436,21 @@ def reduce_token_size(
437
  token_count,
438
  top_p,
439
  temperature,
440
- stream=stream,
441
  selected_model=selected_model,
442
  should_check_token_count=False,
443
  )
444
  logging.info(f"chatbot: {chatbot}")
 
445
  for chatbot, history, status_text, previous_token_count in iter:
446
- history = history[-2:]
447
- token_count = previous_token_count[-1:]
448
- if hidden:
449
- chatbot.pop()
450
- yield chatbot, history, construct_token_message(
451
- sum(token_count), stream=stream
 
 
 
452
  ), token_count
 
453
  logging.info("减少token数量完毕")
 
371
  all_token_counts,
372
  top_p,
373
  temperature,
374
+ max_token//2,
375
  selected_model=selected_model,
 
376
  )
377
  for chatbot, history, status_text, all_token_counts in iter:
378
  status_text = f"Token 达到上限,已自动降低Token计数至 {status_text}"
 
409
  stream=stream,
410
  selected_model=selected_model,
411
  )
412
+ logging.info("重试中……")
413
  for x in iter:
414
  yield x
415
+ logging.info("重试完毕")
416
 
417
 
418
  def reduce_token_size(
 
423
  token_count,
424
  top_p,
425
  temperature,
426
+ max_token_count,
427
  selected_model=MODELS[0],
 
428
  ):
429
  logging.info("开始减少token数量……")
430
  iter = predict(
 
436
  token_count,
437
  top_p,
438
  temperature,
 
439
  selected_model=selected_model,
440
  should_check_token_count=False,
441
  )
442
  logging.info(f"chatbot: {chatbot}")
443
+ flag = False
444
  for chatbot, history, status_text, previous_token_count in iter:
445
+ num_chat = find_n(previous_token_count, max_token_count)
446
+ if flag:
447
+ chatbot = chatbot[:-1]
448
+ flag = True
449
+ history = history[-2*num_chat:] if num_chat > 0 else []
450
+ token_count = previous_token_count[-num_chat:] if num_chat > 0 else []
451
+ msg = f"保留了最近{num_chat}轮对话"
452
+ yield chatbot, history, msg + "," + construct_token_message(
453
+ sum(token_count) if len(token_count) > 0 else 0,
454
  ), token_count
455
+ logging.info(msg)
456
  logging.info("减少token数量完毕")
utils.py CHANGED
@@ -37,9 +37,10 @@ def count_token(message):
37
  length = len(encoding.encode(input_str))
38
  return length
39
 
 
40
  def markdown_to_html_with_syntax_highlight(md_str):
41
  def replacer(match):
42
- lang = match.group(1) or 'text'
43
  code = match.group(2)
44
 
45
  try:
@@ -50,60 +51,65 @@ def markdown_to_html_with_syntax_highlight(md_str):
50
  formatter = HtmlFormatter()
51
  highlighted_code = highlight(code, lexer, formatter)
52
 
53
- return f"<pre><code class=\"{lang}\">{highlighted_code}</code></pre>"
54
 
55
- code_block_pattern = r'```(\w+)?\n([\s\S]+?)\n```'
56
  md_str = re.sub(code_block_pattern, replacer, md_str, flags=re.MULTILINE)
57
 
58
  html_str = markdown(md_str)
59
  return html_str
60
 
 
61
  def normalize_markdown(md_text: str) -> str:
62
- lines = md_text.split('\n')
63
  normalized_lines = []
64
  inside_list = False
65
 
66
  for i, line in enumerate(lines):
67
- if re.match(r'^(\d+\.|-|\*|\+)\s', line.strip()):
68
- if not inside_list and i > 0 and lines[i - 1].strip() != '':
69
- normalized_lines.append('')
70
  inside_list = True
71
  normalized_lines.append(line)
72
- elif inside_list and line.strip() == '':
73
- if i < len(lines) - 1 and not re.match(r'^(\d+\.|-|\*|\+)\s', lines[i + 1].strip()):
 
 
74
  normalized_lines.append(line)
75
  continue
76
  else:
77
  inside_list = False
78
  normalized_lines.append(line)
79
 
80
- return '\n'.join(normalized_lines)
 
81
 
82
  def convert_mdtext(md_text):
83
- code_block_pattern = re.compile(r'```(.*?)(?:```|$)', re.DOTALL)
84
  code_blocks = code_block_pattern.findall(md_text)
85
  non_code_parts = code_block_pattern.split(md_text)[::2]
86
 
87
  result = []
88
- for non_code, code in zip(non_code_parts, code_blocks + ['']):
89
  if non_code.strip():
90
  non_code = normalize_markdown(non_code)
91
- result.append(mdtex2html.convert(non_code, extensions=['tables']))
92
  if code.strip():
93
- _, code = detect_language(code) # 暂时去除代码高亮功能,因为在大段代码的情况下会出现问题
94
  code = f"```{code}\n\n```"
95
  code = markdown_to_html_with_syntax_highlight(code)
96
  result.append(code)
97
  result = "".join(result)
98
  return result
99
 
 
100
  def detect_language(code):
101
  if code.startswith("\n"):
102
  first_line = ""
103
  else:
104
- first_line = code.strip().split('\n', 1)[0]
105
- language = first_line.lower() if first_line else ''
106
- code_without_language = code[len(first_line):].lstrip() if first_line else code
107
  return language, code_without_language
108
 
109
 
@@ -336,26 +342,40 @@ def replace_today(prompt):
336
  today = datetime.datetime.today().strftime("%Y-%m-%d")
337
  return prompt.replace("{current_date}", today)
338
 
 
339
  def get_geoip():
340
- response = requests.get('https://ipapi.co/json/', timeout=5)
341
  try:
342
  data = response.json()
343
  except:
344
- data = {
345
- "error": True,
346
- "reason" : "连接ipapi失败"
347
- }
348
  if "error" in data.keys():
349
  logging.warning(f"无法获取IP地址信息。\n{data}")
350
- if data['reason'] == "RateLimited":
351
- return f"获取IP地理位置失败,因为达到了检测IP的速率限制。聊天功能可能仍然可用,但请注意,如果您的IP地址在不受支持的地区,您可能会遇到问题。"
 
 
352
  else:
353
  return f"获取IP地理位置失败。原因:{data['reason']}。你仍然可以使用聊天功能。"
354
  else:
355
- country = data['country_name']
356
  if country == "China":
357
  text = "**您的IP区域:中国。请立即检查代理设置,在不受支持的地区使用API可能导致账号被封禁。**"
358
  else:
359
  text = f"您的IP区域:{country}。"
360
  logging.info(text)
361
- return text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
  length = len(encoding.encode(input_str))
38
  return length
39
 
40
+
41
  def markdown_to_html_with_syntax_highlight(md_str):
42
  def replacer(match):
43
+ lang = match.group(1) or "text"
44
  code = match.group(2)
45
 
46
  try:
 
51
  formatter = HtmlFormatter()
52
  highlighted_code = highlight(code, lexer, formatter)
53
 
54
+ return f'<pre><code class="{lang}">{highlighted_code}</code></pre>'
55
 
56
+ code_block_pattern = r"```(\w+)?\n([\s\S]+?)\n```"
57
  md_str = re.sub(code_block_pattern, replacer, md_str, flags=re.MULTILINE)
58
 
59
  html_str = markdown(md_str)
60
  return html_str
61
 
62
+
63
  def normalize_markdown(md_text: str) -> str:
64
+ lines = md_text.split("\n")
65
  normalized_lines = []
66
  inside_list = False
67
 
68
  for i, line in enumerate(lines):
69
+ if re.match(r"^(\d+\.|-|\*|\+)\s", line.strip()):
70
+ if not inside_list and i > 0 and lines[i - 1].strip() != "":
71
+ normalized_lines.append("")
72
  inside_list = True
73
  normalized_lines.append(line)
74
+ elif inside_list and line.strip() == "":
75
+ if i < len(lines) - 1 and not re.match(
76
+ r"^(\d+\.|-|\*|\+)\s", lines[i + 1].strip()
77
+ ):
78
  normalized_lines.append(line)
79
  continue
80
  else:
81
  inside_list = False
82
  normalized_lines.append(line)
83
 
84
+ return "\n".join(normalized_lines)
85
+
86
 
87
  def convert_mdtext(md_text):
88
+ code_block_pattern = re.compile(r"```(.*?)(?:```|$)", re.DOTALL)
89
  code_blocks = code_block_pattern.findall(md_text)
90
  non_code_parts = code_block_pattern.split(md_text)[::2]
91
 
92
  result = []
93
+ for non_code, code in zip(non_code_parts, code_blocks + [""]):
94
  if non_code.strip():
95
  non_code = normalize_markdown(non_code)
96
+ result.append(mdtex2html.convert(non_code, extensions=["tables"]))
97
  if code.strip():
98
+ _, code = detect_language(code) # 暂时去除代码高亮功能,因为在大段代码的情况下会出现问题
99
  code = f"```{code}\n\n```"
100
  code = markdown_to_html_with_syntax_highlight(code)
101
  result.append(code)
102
  result = "".join(result)
103
  return result
104
 
105
+
106
  def detect_language(code):
107
  if code.startswith("\n"):
108
  first_line = ""
109
  else:
110
+ first_line = code.strip().split("\n", 1)[0]
111
+ language = first_line.lower() if first_line else ""
112
+ code_without_language = code[len(first_line) :].lstrip() if first_line else code
113
  return language, code_without_language
114
 
115
 
 
342
  today = datetime.datetime.today().strftime("%Y-%m-%d")
343
  return prompt.replace("{current_date}", today)
344
 
345
+
346
  def get_geoip():
347
+ response = requests.get("https://ipapi.co/json/", timeout=5)
348
  try:
349
  data = response.json()
350
  except:
351
+ data = {"error": True, "reason": "连接ipapi失败"}
 
 
 
352
  if "error" in data.keys():
353
  logging.warning(f"无法获取IP地址信息。\n{data}")
354
+ if data["reason"] == "RateLimited":
355
+ return (
356
+ f"获取IP地理位置失败,因为达到了检测IP的速率限制。聊天功能可能仍然可用,但请注意,如果您的IP地址在不受支持的地区,您可能会遇到问题。"
357
+ )
358
  else:
359
  return f"获取IP地理位置失败。原因:{data['reason']}。你仍然可以使用聊天功能。"
360
  else:
361
+ country = data["country_name"]
362
  if country == "China":
363
  text = "**您的IP区域:中国。请立即检查代理设置,在不受支持的地区使用API可能导致账号被封禁。**"
364
  else:
365
  text = f"您的IP区域:{country}。"
366
  logging.info(text)
367
+ return text
368
+
369
+
370
+ def find_n(lst, max_num):
371
+ n = len(lst)
372
+ total = sum(lst)
373
+
374
+ if total < max_num:
375
+ return n
376
+
377
+ for i in range(len(lst)):
378
+ if total - lst[i] < max_num:
379
+ return n - i -1
380
+ total = total - lst[i]
381
+ return 1