Tuchuanhuhuhu commited on
Commit
811a42d
·
1 Parent(s): 2242318

修复自动重试的问题,现在采用截断策略实现无限对话

Browse files
Files changed (2) hide show
  1. ChuanhuChatbot.py +34 -19
  2. utils.py +0 -0
ChuanhuChatbot.py CHANGED
@@ -7,6 +7,7 @@ import traceback
7
  import requests
8
  # import markdown
9
  import csv
 
10
 
11
  my_api_key = "" # 在这里输入你的 API 密钥
12
  HIDE_MY_KEY = False # 如果你想在UI中隐藏你的 API 密钥,将此值设置为 True
@@ -48,16 +49,18 @@ def parse_text(text):
48
  count += 1
49
  items = line.split('`')
50
  if count % 2 == 1:
51
- lines[i] = f'<pre><code class="{items[-1]}">'
52
- firstline = True
53
  else:
54
- lines[i] = f'</code></pre>'
55
  else:
56
  if i > 0:
57
  if count % 2 == 1:
 
 
58
  line = line.replace("`", "\`")
59
- line = line.replace("\"", "`\"`")
60
- line = line.replace("\'", "`\'`")
 
61
  # line = line.replace("&", "&amp;")
62
  line = line.replace("<", "&lt;")
63
  line = line.replace(">", "&gt;")
@@ -74,10 +77,10 @@ def parse_text(text):
74
  text = "".join(lines)
75
  return text
76
 
77
- def predict(inputs, top_p, temperature, openai_api_key, chatbot=[], history=[], system_prompt=initial_prompt, retry=False, summary=False, summary_on_crash = False, stream = True): # repetition_penalty, top_k
78
 
79
- if summary:
80
- stream = False
81
 
82
  headers = {
83
  "Content-Type": "application/json",
@@ -88,7 +91,7 @@ def predict(inputs, top_p, temperature, openai_api_key, chatbot=[], history=[],
88
 
89
  print(f"chat_counter - {chat_counter}")
90
 
91
- messages = [compose_system(system_prompt)]
92
  if chat_counter:
93
  for index in range(0, 2*chat_counter, 2):
94
  temp1 = {}
@@ -104,6 +107,8 @@ def predict(inputs, top_p, temperature, openai_api_key, chatbot=[], history=[],
104
  else:
105
  messages[-1]['content'] = temp2['content']
106
  if retry and chat_counter:
 
 
107
  messages.pop()
108
  elif summary:
109
  history = [*[i["content"] for i in messages[-2:]], "我们刚刚聊了什么?"]
@@ -115,6 +120,7 @@ def predict(inputs, top_p, temperature, openai_api_key, chatbot=[], history=[],
115
  temp3["content"] = inputs
116
  messages.append(temp3)
117
  chat_counter += 1
 
118
  # messages
119
  payload = {
120
  "model": "gpt-3.5-turbo",
@@ -131,9 +137,16 @@ def predict(inputs, top_p, temperature, openai_api_key, chatbot=[], history=[],
131
  history.append(inputs)
132
  else:
133
  print("精简中...")
 
 
134
  # make a POST request to the API endpoint using the requests.post method, passing in stream=True
135
- response = requests.post(API_URL, headers=headers,
136
- json=payload, stream=True)
 
 
 
 
 
137
 
138
  token_counter = 0
139
  partial_words = ""
@@ -157,15 +170,17 @@ def predict(inputs, top_p, temperature, openai_api_key, chatbot=[], history=[],
157
  break
158
  except Exception as e:
159
  traceback.print_exc()
160
- print("Context 过长,正在尝试精简……")
161
- chatbot.pop()
162
- chatbot, history, status_text = next(predict(inputs, top_p, temperature, openai_api_key, chatbot, history, system_prompt, retry, summary=True, summary_on_crash=True, stream=False))
163
- yield chatbot, history, status_text
164
- if not "ERROR" in status_text:
165
- print("精简完成,正在尝试重新生成……")
166
- yield next(predict(inputs, top_p, temperature, openai_api_key, chatbot, history, system_prompt, retry, summary=False, summary_on_crash=True, stream=False))
167
  else:
168
- print("精简出错了,可能是网络原因。")
 
 
 
 
169
  break
170
  chunkjson = json.loads(chunk.decode()[6:])
171
  status_text = f"id: {chunkjson['id']}, finish_reason: {chunkjson['choices'][0]['finish_reason']}"
 
7
  import requests
8
  # import markdown
9
  import csv
10
+ from utils import ChuanhuChatbot
11
 
12
  my_api_key = "" # 在这里输入你的 API 密钥
13
  HIDE_MY_KEY = False # 如果你想在UI中隐藏你的 API 密钥,将此值设置为 True
 
49
  count += 1
50
  items = line.split('`')
51
  if count % 2 == 1:
52
+ lines[i] = f'<pre><code class="language-{items[-1]}">'
 
53
  else:
54
+ lines[i] = f'<br></code></pre>'
55
  else:
56
  if i > 0:
57
  if count % 2 == 1:
58
+ # line = line.replace("‘", "'")
59
+ # line = line.replace("“", '"')
60
  line = line.replace("`", "\`")
61
+ # line = line.replace("\"", "`\"`")
62
+ # line = line.replace("\'", "`\'`")
63
+ # line = line.replace("'``'", "''")
64
  # line = line.replace("&", "&amp;")
65
  line = line.replace("<", "&lt;")
66
  line = line.replace(">", "&gt;")
 
77
  text = "".join(lines)
78
  return text
79
 
80
+ def predict(inputs, top_p, temperature, openai_api_key, chatbot=[], history=[], system_prompt=initial_prompt, retry=False, summary=False, retry_on_crash = False, stream = True): # repetition_penalty, top_k
81
 
82
+ if retry_on_crash:
83
+ retry = True
84
 
85
  headers = {
86
  "Content-Type": "application/json",
 
91
 
92
  print(f"chat_counter - {chat_counter}")
93
 
94
+ messages = []
95
  if chat_counter:
96
  for index in range(0, 2*chat_counter, 2):
97
  temp1 = {}
 
107
  else:
108
  messages[-1]['content'] = temp2['content']
109
  if retry and chat_counter:
110
+ if retry_on_crash:
111
+ messages = messages[-6:]
112
  messages.pop()
113
  elif summary:
114
  history = [*[i["content"] for i in messages[-2:]], "我们刚刚聊了什么?"]
 
120
  temp3["content"] = inputs
121
  messages.append(temp3)
122
  chat_counter += 1
123
+ messages = [compose_system(system_prompt), *messages]
124
  # messages
125
  payload = {
126
  "model": "gpt-3.5-turbo",
 
137
  history.append(inputs)
138
  else:
139
  print("精简中...")
140
+
141
+ print(f"payload: {payload}")
142
  # make a POST request to the API endpoint using the requests.post method, passing in stream=True
143
+ try:
144
+ response = requests.post(API_URL, headers=headers, json=payload, stream=True)
145
+ except:
146
+ history.append("")
147
+ chatbot.append(inputs, "")
148
+ yield history, chatbot, f"出现了网络错误"
149
+ return
150
 
151
  token_counter = 0
152
  partial_words = ""
 
170
  break
171
  except Exception as e:
172
  traceback.print_exc()
173
+ if not retry_on_crash:
174
+ print("正在尝试使用缩短的context重新生成……")
175
+ chatbot.pop()
176
+ history.append("")
177
+ yield next(predict(inputs, top_p, temperature, openai_api_key, chatbot, history, system_prompt, retry, summary=False, retry_on_crash=True, stream=False))
 
 
178
  else:
179
+ msg = "☹️发生了错误:生成失败,请检查网络"
180
+ print(msg)
181
+ history.append(inputs, "")
182
+ chatbot.append(inputs, msg)
183
+ yield chatbot, history, "status: ERROR"
184
  break
185
  chunkjson = json.loads(chunk.decode()[6:])
186
  status_text = f"id: {chunkjson['id']}, finish_reason: {chunkjson['choices'][0]['finish_reason']}"
utils.py ADDED
The diff for this file is too large to render. See raw diff