Spaces:
Build error
Build error
Commit
·
4c9761e
1
Parent(s):
2037c5b
Delete predict.py
Browse files- predict.py +0 -191
predict.py
DELETED
@@ -1,191 +0,0 @@
|
|
1 |
-
# 借鉴了 https://github.com/GaiZhenbiao/ChuanhuChatGPT 项目
|
2 |
-
|
3 |
-
import json
|
4 |
-
import gradio as gr
|
5 |
-
import logging
|
6 |
-
import traceback
|
7 |
-
import requests
|
8 |
-
import importlib
|
9 |
-
|
10 |
-
# config_private.py放自己的秘密如API和代理网址
|
11 |
-
# 读取时首先看是否存在私密的config_private配置文件(不受git管控),如果有,则覆盖原config文件
|
12 |
-
try: from config_private import proxies, API_URL, API_KEY, TIMEOUT_SECONDS, MAX_RETRY, LLM_MODEL
|
13 |
-
except: from config import proxies, API_URL, API_KEY, TIMEOUT_SECONDS, MAX_RETRY, LLM_MODEL
|
14 |
-
|
15 |
-
timeout_bot_msg = '[local] Request timeout, network error. please check proxy settings in config.py.'
|
16 |
-
|
17 |
-
def get_full_error(chunk, stream_response):
|
18 |
-
"""
|
19 |
-
获取完整的从Openai返回的报错
|
20 |
-
"""
|
21 |
-
while True:
|
22 |
-
try:
|
23 |
-
chunk += next(stream_response)
|
24 |
-
except:
|
25 |
-
break
|
26 |
-
return chunk
|
27 |
-
|
28 |
-
def predict_no_ui(inputs, top_p, temperature, history=[]):
|
29 |
-
"""
|
30 |
-
发送至chatGPT,等待回复,一次性完成,不显示中间过程。
|
31 |
-
predict函数的简化版。
|
32 |
-
用于payload比较大的情况,或者用于实现多线、带嵌套的复杂功能。
|
33 |
-
|
34 |
-
inputs 是本次问询的输入
|
35 |
-
top_p, temperature是chatGPT的内部调优参数
|
36 |
-
history 是之前的对话列表
|
37 |
-
(注意无论是inputs还是history,内容太长了都会触发token数量溢出的错误,然后raise ConnectionAbortedError)
|
38 |
-
"""
|
39 |
-
headers, payload = generate_payload(inputs, top_p, temperature, history, system_prompt="", stream=False)
|
40 |
-
|
41 |
-
retry = 0
|
42 |
-
while True:
|
43 |
-
try:
|
44 |
-
# make a POST request to the API endpoint, stream=False
|
45 |
-
response = requests.post(API_URL, headers=headers, proxies=proxies,
|
46 |
-
json=payload, stream=False, timeout=TIMEOUT_SECONDS*2); break
|
47 |
-
except requests.exceptions.ReadTimeout as e:
|
48 |
-
retry += 1
|
49 |
-
traceback.print_exc()
|
50 |
-
if MAX_RETRY!=0: print(f'请求超时,正在重试 ({retry}/{MAX_RETRY}) ……')
|
51 |
-
if retry > MAX_RETRY: raise TimeoutError
|
52 |
-
|
53 |
-
try:
|
54 |
-
result = json.loads(response.text)["choices"][0]["message"]["content"]
|
55 |
-
return result
|
56 |
-
except Exception as e:
|
57 |
-
if "choices" not in response.text: print(response.text)
|
58 |
-
raise ConnectionAbortedError("Json解析不合常规,可能是文本过长" + response.text)
|
59 |
-
|
60 |
-
|
61 |
-
def predict(inputs, top_p, temperature, chatbot=[], history=[], system_prompt='',
|
62 |
-
stream = True, additional_fn=None):
|
63 |
-
"""
|
64 |
-
发送至chatGPT,流式获取输出。
|
65 |
-
用于基础的对话功能。
|
66 |
-
inputs 是本次问询的输入
|
67 |
-
top_p, temperature是chatGPT的内部调优参数
|
68 |
-
history 是之前的对话列表(注意无论是inputs还是history,内容太长了都会触发token数量溢出的错误)
|
69 |
-
chatbot 为WebUI中显示的对话列表,修改它,然后yeild出去,可以直接修改对话界面内容
|
70 |
-
additional_fn代表点击的哪个按钮,按钮见functional.py
|
71 |
-
"""
|
72 |
-
if additional_fn is not None:
|
73 |
-
import functional
|
74 |
-
importlib.reload(functional)
|
75 |
-
functional = functional.get_functionals()
|
76 |
-
inputs = functional[additional_fn]["Prefix"] + inputs + functional[additional_fn]["Suffix"]
|
77 |
-
|
78 |
-
if stream:
|
79 |
-
raw_input = inputs
|
80 |
-
logging.info(f'[raw_input] {raw_input}')
|
81 |
-
chatbot.append((inputs, ""))
|
82 |
-
yield chatbot, history, "等待响应"
|
83 |
-
|
84 |
-
headers, payload = generate_payload(inputs, top_p, temperature, history, system_prompt, stream)
|
85 |
-
history.append(inputs); history.append(" ")
|
86 |
-
|
87 |
-
retry = 0
|
88 |
-
while True:
|
89 |
-
try:
|
90 |
-
# make a POST request to the API endpoint, stream=True
|
91 |
-
response = requests.post(API_URL, headers=headers, proxies=proxies,
|
92 |
-
json=payload, stream=True, timeout=TIMEOUT_SECONDS);break
|
93 |
-
except:
|
94 |
-
retry += 1
|
95 |
-
chatbot[-1] = ((chatbot[-1][0], timeout_bot_msg))
|
96 |
-
retry_msg = f",正在重试 ({retry}/{MAX_RETRY}) ……" if MAX_RETRY > 0 else ""
|
97 |
-
yield chatbot, history, "请求超时"+retry_msg
|
98 |
-
if retry > MAX_RETRY: raise TimeoutError
|
99 |
-
|
100 |
-
gpt_replying_buffer = ""
|
101 |
-
|
102 |
-
is_head_of_the_stream = True
|
103 |
-
if stream:
|
104 |
-
stream_response = response.iter_lines()
|
105 |
-
while True:
|
106 |
-
chunk = next(stream_response)
|
107 |
-
# print(chunk.decode()[6:])
|
108 |
-
if is_head_of_the_stream:
|
109 |
-
# 数据流的第一帧不携带content
|
110 |
-
is_head_of_the_stream = False; continue
|
111 |
-
|
112 |
-
if chunk:
|
113 |
-
try:
|
114 |
-
if len(json.loads(chunk.decode()[6:])['choices'][0]["delta"]) == 0:
|
115 |
-
# 判定为数据流的结束,gpt_replying_buffer也写完了
|
116 |
-
logging.info(f'[response] {gpt_replying_buffer}')
|
117 |
-
break
|
118 |
-
# 处理数据流的主体
|
119 |
-
chunkjson = json.loads(chunk.decode()[6:])
|
120 |
-
status_text = f"finish_reason: {chunkjson['choices'][0]['finish_reason']}"
|
121 |
-
# 如果这里抛出异常,一般是文本过长,详情见get_full_error的输出
|
122 |
-
gpt_replying_buffer = gpt_replying_buffer + json.loads(chunk.decode()[6:])['choices'][0]["delta"]["content"]
|
123 |
-
history[-1] = gpt_replying_buffer
|
124 |
-
chatbot[-1] = (history[-2], history[-1])
|
125 |
-
yield chatbot, history, status_text
|
126 |
-
|
127 |
-
except Exception as e:
|
128 |
-
traceback.print_exc()
|
129 |
-
yield chatbot, history, "Json解析不合常规"
|
130 |
-
chunk = get_full_error(chunk, stream_response)
|
131 |
-
error_msg = chunk.decode()
|
132 |
-
if "reduce the length" in error_msg:
|
133 |
-
chatbot[-1] = (chatbot[-1][0], "[Local Message] Input (or history) is too long, please reduce input or clear history by refleshing this page.")
|
134 |
-
history = []
|
135 |
-
elif "Incorrect API key" in error_msg:
|
136 |
-
chatbot[-1] = (chatbot[-1][0], "[Local Message] Incorrect API key provided.")
|
137 |
-
else:
|
138 |
-
from toolbox import regular_txt_to_markdown
|
139 |
-
tb_str = regular_txt_to_markdown(traceback.format_exc())
|
140 |
-
chatbot[-1] = (chatbot[-1][0], f"[Local Message] Json Error \n\n {tb_str} \n\n {regular_txt_to_markdown(chunk.decode()[4:])}")
|
141 |
-
yield chatbot, history, "Json解析不合常规" + error_msg
|
142 |
-
return
|
143 |
-
|
144 |
-
def generate_payload(inputs, top_p, temperature, history, system_prompt, stream):
|
145 |
-
"""
|
146 |
-
整合所有信息,选择LLM模型,生成http请求,为发送请求做准备
|
147 |
-
"""
|
148 |
-
headers = {
|
149 |
-
"Content-Type": "application/json",
|
150 |
-
"Authorization": f"Bearer {API_KEY}"
|
151 |
-
}
|
152 |
-
|
153 |
-
conversation_cnt = len(history) // 2
|
154 |
-
|
155 |
-
messages = [{"role": "system", "content": system_prompt}]
|
156 |
-
if conversation_cnt:
|
157 |
-
for index in range(0, 2*conversation_cnt, 2):
|
158 |
-
what_i_have_asked = {}
|
159 |
-
what_i_have_asked["role"] = "user"
|
160 |
-
what_i_have_asked["content"] = history[index]
|
161 |
-
what_gpt_answer = {}
|
162 |
-
what_gpt_answer["role"] = "assistant"
|
163 |
-
what_gpt_answer["content"] = history[index+1]
|
164 |
-
if what_i_have_asked["content"] != "":
|
165 |
-
if what_gpt_answer["content"] == "": continue
|
166 |
-
if what_gpt_answer["content"] == timeout_bot_msg: continue
|
167 |
-
messages.append(what_i_have_asked)
|
168 |
-
messages.append(what_gpt_answer)
|
169 |
-
else:
|
170 |
-
messages[-1]['content'] = what_gpt_answer['content']
|
171 |
-
|
172 |
-
what_i_ask_now = {}
|
173 |
-
what_i_ask_now["role"] = "user"
|
174 |
-
what_i_ask_now["content"] = inputs
|
175 |
-
messages.append(what_i_ask_now)
|
176 |
-
|
177 |
-
payload = {
|
178 |
-
"model": LLM_MODEL,
|
179 |
-
"messages": messages,
|
180 |
-
"temperature": temperature, # 1.0,
|
181 |
-
"top_p": top_p, # 1.0,
|
182 |
-
"n": 1,
|
183 |
-
"stream": stream,
|
184 |
-
"presence_penalty": 0,
|
185 |
-
"frequency_penalty": 0,
|
186 |
-
}
|
187 |
-
|
188 |
-
print(f" {LLM_MODEL} : {conversation_cnt} : {inputs}")
|
189 |
-
return headers,payload
|
190 |
-
|
191 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|