Upload 5 files
Browse files- Dockerfile +10 -0
- README.md +5 -5
- free_ask_internet.py +274 -0
- requirements.txt +39 -0
- server.py +312 -0
Dockerfile
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
FROM python:3.9.15
|
2 |
+
WORKDIR /app
|
3 |
+
COPY requirements.txt /app
|
4 |
+
RUN pip3 install -i https://mirrors.aliyun.com/pypi/simple/ --trusted-host mirrors.aliyun.com -r requirements.txt --no-cache-dir
|
5 |
+
COPY . /app
|
6 |
+
RUN mkdir /.cache
|
7 |
+
RUN chmod -R 777 /.cache
|
8 |
+
EXPOSE 8000
|
9 |
+
ENTRYPOINT ["python3"]
|
10 |
+
CMD ["server.py"]
|
README.md
CHANGED
@@ -1,10 +1,10 @@
|
|
1 |
---
|
2 |
-
title:
|
3 |
-
emoji:
|
4 |
-
colorFrom:
|
5 |
-
colorTo:
|
6 |
sdk: docker
|
7 |
pinned: false
|
|
|
8 |
---
|
9 |
|
10 |
-
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
1 |
---
|
2 |
+
title: net
|
3 |
+
emoji: 👩🎨
|
4 |
+
colorFrom: red
|
5 |
+
colorTo: yellow
|
6 |
sdk: docker
|
7 |
pinned: false
|
8 |
+
app_port: 8000
|
9 |
---
|
10 |
|
|
free_ask_internet.py
ADDED
@@ -0,0 +1,274 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
|
3 |
+
import json
|
4 |
+
import os
|
5 |
+
from pprint import pprint
|
6 |
+
import requests
|
7 |
+
import trafilatura
|
8 |
+
from trafilatura import bare_extraction
|
9 |
+
from concurrent.futures import ThreadPoolExecutor
|
10 |
+
import concurrent
|
11 |
+
import requests
|
12 |
+
import openai
|
13 |
+
import time
|
14 |
+
from datetime import datetime
|
15 |
+
from urllib.parse import urlparse
|
16 |
+
import tldextract
|
17 |
+
import platform
|
18 |
+
import urllib.parse
|
19 |
+
|
20 |
+
|
21 |
+
def extract_url_content(url):
|
22 |
+
downloaded = trafilatura.fetch_url(url)
|
23 |
+
content = trafilatura.extract(downloaded)
|
24 |
+
|
25 |
+
return {"url":url, "content":content}
|
26 |
+
|
27 |
+
|
28 |
+
|
29 |
+
|
30 |
+
def search_web_ref(query:str, debug=False):
|
31 |
+
|
32 |
+
content_list = []
|
33 |
+
|
34 |
+
try:
|
35 |
+
|
36 |
+
safe_string = urllib.parse.quote_plus(":all !general " + query)
|
37 |
+
|
38 |
+
searxng_url = os.environ.get('SEARXNG_URL')
|
39 |
+
response = requests.get(searxng_url + '?q=' + safe_string + '&format=json')
|
40 |
+
response.raise_for_status()
|
41 |
+
search_results = response.json()
|
42 |
+
|
43 |
+
if debug:
|
44 |
+
print("JSON Response:")
|
45 |
+
pprint(search_results)
|
46 |
+
pedding_urls = []
|
47 |
+
|
48 |
+
conv_links = []
|
49 |
+
|
50 |
+
if search_results.get('results'):
|
51 |
+
for item in search_results.get('results')[0:9]:
|
52 |
+
name = item.get('title')
|
53 |
+
snippet = item.get('content')
|
54 |
+
url = item.get('url')
|
55 |
+
pedding_urls.append(url)
|
56 |
+
|
57 |
+
if url:
|
58 |
+
url_parsed = urlparse(url)
|
59 |
+
domain = url_parsed.netloc
|
60 |
+
icon_url = url_parsed.scheme + '://' + url_parsed.netloc + '/favicon.ico'
|
61 |
+
site_name = tldextract.extract(url).domain
|
62 |
+
|
63 |
+
conv_links.append({
|
64 |
+
'site_name':site_name,
|
65 |
+
'icon_url':icon_url,
|
66 |
+
'title':name,
|
67 |
+
'url':url,
|
68 |
+
'snippet':snippet
|
69 |
+
})
|
70 |
+
|
71 |
+
results = []
|
72 |
+
futures = []
|
73 |
+
|
74 |
+
executor = ThreadPoolExecutor(max_workers=10)
|
75 |
+
for url in pedding_urls:
|
76 |
+
futures.append(executor.submit(extract_url_content,url))
|
77 |
+
try:
|
78 |
+
for future in futures:
|
79 |
+
res = future.result(timeout=5)
|
80 |
+
results.append(res)
|
81 |
+
except concurrent.futures.TimeoutError:
|
82 |
+
print("任务执行超时")
|
83 |
+
executor.shutdown(wait=False,cancel_futures=True)
|
84 |
+
|
85 |
+
for content in results:
|
86 |
+
if content and content.get('content'):
|
87 |
+
|
88 |
+
item_dict = {
|
89 |
+
"url":content.get('url'),
|
90 |
+
"content": content.get('content'),
|
91 |
+
"length":len(content.get('content'))
|
92 |
+
}
|
93 |
+
content_list.append(item_dict)
|
94 |
+
if debug:
|
95 |
+
print("URL: {}".format(url))
|
96 |
+
print("=================")
|
97 |
+
|
98 |
+
return content_list
|
99 |
+
except Exception as ex:
|
100 |
+
raise ex
|
101 |
+
|
102 |
+
|
103 |
+
def gen_prompt(question,content_list, lang="zh-CN", context_length_limit=11000,debug=False):
|
104 |
+
|
105 |
+
limit_len = (context_length_limit - 2000)
|
106 |
+
if len(question) > limit_len:
|
107 |
+
question = question[0:limit_len]
|
108 |
+
|
109 |
+
ref_content = [ item.get("content") for item in content_list]
|
110 |
+
|
111 |
+
answer_language = ' Simplified Chinese '
|
112 |
+
if lang == "zh-CN":
|
113 |
+
answer_language = ' Simplified Chinese '
|
114 |
+
if lang == "zh-TW":
|
115 |
+
answer_language = ' Traditional Chinese '
|
116 |
+
if lang == "en-US":
|
117 |
+
answer_language = ' English '
|
118 |
+
|
119 |
+
|
120 |
+
if len(ref_content) > 0:
|
121 |
+
|
122 |
+
if False:
|
123 |
+
prompts = '''
|
124 |
+
您是一位由 nash_su 开发的大型语言人工智能助手。您将被提供一个用户问题,并需要撰写一个清晰、简洁且准确的答案。提供了一组与问题相关的上下文,每个都以 [[citation:x]] 这样的编号开头,x 代表一个数字。请在适当的情况下在句子末尾引用上下文。答案必须正确、精确,并以专家的中立和职业语气撰写。请将答案限制在 2000 个标记内。不要提供与问题无关的信息,也不要重复。如果给出的上下文信息不足,请在相关主题后写上“信息缺失:”。请按照引用编号 [citation:x] 的格式在答案中对应部分引用上下文。如果一句话源自多个上下文,请列出所有相关的引用编号,例如 [citation:3][citation:5],不要将引用集中在最后返回,而是在答案对应部分列出。除非是代码、特定的名称或引用编号,答案的语言应与问题相同。以下是上下文的内容集:
|
125 |
+
''' + "\n\n" + "```"
|
126 |
+
ref_index = 1
|
127 |
+
|
128 |
+
for ref_text in ref_content:
|
129 |
+
|
130 |
+
prompts = prompts + "\n\n" + " [citation:{}] ".format(str(ref_index)) + ref_text
|
131 |
+
ref_index += 1
|
132 |
+
|
133 |
+
if len(prompts) >= limit_len:
|
134 |
+
prompts = prompts[0:limit_len]
|
135 |
+
prompts = prompts + '''
|
136 |
+
```
|
137 |
+
记住,不要一字不差的重复上下文内容。回答必须使用简体中文,如果回答很长,请尽量结构化、分段落总结。请按照引用编号 [citation:x] 的格式在答案中对应部分引用上下文。如果一句话源自多个上下文,请列出所有相关的引用编号,例如 [citation:3][citation:5],不要将引用集中在最后返回,而是在答案对应部分列出。下面是用户问题:
|
138 |
+
''' + question
|
139 |
+
else:
|
140 |
+
prompts = '''
|
141 |
+
You are a large language AI assistant develop by nash_su. You are given a user question, and please write clean, concise and accurate answer to the question. You will be given a set of related contexts to the question, each starting with a reference number like [[citation:x]], where x is a number. Please use the context and cite the context at the end of each sentence if applicable.
|
142 |
+
Your answer must be correct, accurate and written by an expert using an unbiased and professional tone. Please limit to 1024 tokens. Do not give any information that is not related to the question, and do not repeat. Say "information is missing on" followed by the related topic, if the given context do not provide sufficient information.
|
143 |
+
|
144 |
+
Please cite the contexts with the reference numbers, in the format [citation:x]. If a sentence comes from multiple contexts, please list all applicable citations, like [citation:3][citation:5]. Other than code and specific names and citations, your answer must be written in the same language as the question.
|
145 |
+
Here are the set of contexts:
|
146 |
+
''' + "\n\n" + "```"
|
147 |
+
ref_index = 1
|
148 |
+
|
149 |
+
for ref_text in ref_content:
|
150 |
+
|
151 |
+
prompts = prompts + "\n\n" + " [citation:{}] ".format(str(ref_index)) + ref_text
|
152 |
+
ref_index += 1
|
153 |
+
|
154 |
+
if len(prompts) >= limit_len:
|
155 |
+
prompts = prompts[0:limit_len]
|
156 |
+
prompts = prompts + '''
|
157 |
+
```
|
158 |
+
Above is the reference contexts. Remember, don't repeat the context word for word. Answer in ''' + answer_language + '''. If the response is lengthy, structure it in paragraphs and summarize where possible. Cite the context using the format [citation:x] where x is the reference number. If a sentence originates from multiple contexts, list all relevant citation numbers, like [citation:3][citation:5]. Don't cluster the citations at the end but include them in the answer where they correspond.
|
159 |
+
Remember, don't blindly repeat the contexts verbatim. And here is the user question:
|
160 |
+
''' + question
|
161 |
+
|
162 |
+
|
163 |
+
else:
|
164 |
+
prompts = question
|
165 |
+
|
166 |
+
if debug:
|
167 |
+
print(prompts)
|
168 |
+
print("总长度:"+ str(len(prompts)))
|
169 |
+
return prompts
|
170 |
+
|
171 |
+
|
172 |
+
def defaultchat(message, model:str, stream=True, debug=False):
|
173 |
+
openai.base_url = os.environ.get('OPENAI_BASE_URL')
|
174 |
+
openai.api_key = os.environ.get('OPENAI_API_KEY')
|
175 |
+
total_content = ""
|
176 |
+
#print(message)
|
177 |
+
for chunk in openai.chat.completions.create(
|
178 |
+
model=model,
|
179 |
+
messages=message,
|
180 |
+
stream=True,
|
181 |
+
max_tokens=3072,temperature=0.2
|
182 |
+
):
|
183 |
+
stream_resp = chunk.dict()
|
184 |
+
#print(stream_resp)
|
185 |
+
token = stream_resp["choices"][0]["delta"].get("content", "")
|
186 |
+
#print(token)
|
187 |
+
if token:
|
188 |
+
total_content += token
|
189 |
+
yield token
|
190 |
+
|
191 |
+
def ask_gpt(message, model_id, debug=False):
|
192 |
+
#print(message)
|
193 |
+
total_token = ""
|
194 |
+
for token in defaultchat(message, model_id):
|
195 |
+
if token:
|
196 |
+
total_token += token
|
197 |
+
yield token
|
198 |
+
|
199 |
+
def summary_gpt(message, model:str, debug=False):
|
200 |
+
#message = '\n'.join([msg.content for msg in message])
|
201 |
+
msgs = []
|
202 |
+
msgs.append({"role": "system", "content": '作为一位专业的问题审核专家,你的任务是确保每一个提问都是清晰、具体并且没有模糊歧义的,不需要在根据额外的内容就可以理解你的提问。在审阅提问时,请遵循以下规则进行优化:替换模糊的代名词,确保所有的人称和名词都有明确的指代,不允许出现"你我他这那"等这种类似的代名词;如果提问中包含泛指的名词,请根据上下文明确的定语,补充具体的细节以提供完整的信息;最后,只允许输出经过你精确优化的问题,不要有任何多余的文字。举例说明,1-当提问者问:他在做什么?,你根据上下文你可以得知他是"小明",那么你优化问题后输出"小明在干什么?"2-当提问者问:他们乐队都有谁?,你根据上下文可以得知乐队是"小强乐队",那么你优化问题后输出"小强乐队都有谁?"'})
|
203 |
+
msgs.append({"role": "user", "content":str(message)})
|
204 |
+
json_data = {
|
205 |
+
"model":model,
|
206 |
+
"messages":msgs,
|
207 |
+
"temperature":0.8,
|
208 |
+
"max_tokens":2560,
|
209 |
+
"top_p":1,
|
210 |
+
"frequency_penalty":0,
|
211 |
+
"presence_penalty":0,
|
212 |
+
"stop":None
|
213 |
+
}
|
214 |
+
apiurl = os.environ.get('OPENAI_BASE_URL')
|
215 |
+
pooltoken = os.environ.get('OPENAI_API_KEY')
|
216 |
+
headers = {
|
217 |
+
'Content-Type': 'application/json',
|
218 |
+
'Authorization': 'Bearer {}'.format(pooltoken),
|
219 |
+
}
|
220 |
+
response = requests.post( apiurl + '/chat/completions', headers=headers, json=json_data )
|
221 |
+
res = json.loads(response.text)['choices'][0]['message']['content']
|
222 |
+
#print(res)
|
223 |
+
return res
|
224 |
+
|
225 |
+
def chat(prompt, model:str, stream=True, debug=False):
|
226 |
+
openai.base_url = os.environ.get('OPENAI_BASE_URL')
|
227 |
+
openai.api_key = os.environ.get('OPENAI_API_KEY')
|
228 |
+
total_content = ""
|
229 |
+
for chunk in openai.chat.completions.create(
|
230 |
+
model=model,
|
231 |
+
messages=[{
|
232 |
+
"role": "user",
|
233 |
+
"content": prompt
|
234 |
+
}],
|
235 |
+
stream=True,
|
236 |
+
max_tokens=3072,temperature=0.2
|
237 |
+
):
|
238 |
+
stream_resp = chunk.dict()
|
239 |
+
token = stream_resp["choices"][0]["delta"].get("content", "")
|
240 |
+
if token:
|
241 |
+
|
242 |
+
total_content += token
|
243 |
+
yield token
|
244 |
+
if debug:
|
245 |
+
print(total_content)
|
246 |
+
|
247 |
+
|
248 |
+
|
249 |
+
|
250 |
+
def ask_internet(query:str, model:str, debug=False):
|
251 |
+
|
252 |
+
content_list = search_web_ref(query,debug=debug)
|
253 |
+
if debug:
|
254 |
+
print(content_list)
|
255 |
+
prompt = gen_prompt(query,content_list,context_length_limit=6000,debug=debug)
|
256 |
+
total_token = ""
|
257 |
+
|
258 |
+
for token in chat(prompt=prompt, model=model):
|
259 |
+
if token:
|
260 |
+
total_token += token
|
261 |
+
yield token
|
262 |
+
yield "\n\n"
|
263 |
+
# 是否返回参考资料
|
264 |
+
if True:
|
265 |
+
yield "---"
|
266 |
+
yield "\nSearxng"
|
267 |
+
yield "参考资料:\n"
|
268 |
+
count = 1
|
269 |
+
for url_content in content_list:
|
270 |
+
url = url_content.get('url')
|
271 |
+
yield "*[{}. {}]({})*".format(str(count),url,url )
|
272 |
+
yield "\n"
|
273 |
+
count += 1
|
274 |
+
|
requirements.txt
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
annotated-types==0.6.0
|
2 |
+
anyio==4.3.0
|
3 |
+
certifi==2024.2.2
|
4 |
+
charset-normalizer==3.3.2
|
5 |
+
click==8.1.7
|
6 |
+
courlan==1.0.0
|
7 |
+
dateparser==1.2.0
|
8 |
+
distro==1.9.0
|
9 |
+
exceptiongroup==1.2.0
|
10 |
+
fastapi==0.110.1
|
11 |
+
filelock==3.13.3
|
12 |
+
h11==0.14.0
|
13 |
+
htmldate==1.8.0
|
14 |
+
httpcore==1.0.5
|
15 |
+
httpx==0.27.0
|
16 |
+
idna==3.6
|
17 |
+
jusText==3.0.0
|
18 |
+
langcodes==3.3.0
|
19 |
+
lxml==5.1.1
|
20 |
+
openai==1.16.2
|
21 |
+
pydantic==2.6.4
|
22 |
+
pydantic_core==2.16.3
|
23 |
+
python-dateutil==2.9.0.post0
|
24 |
+
pytz==2024.1
|
25 |
+
regex==2023.12.25
|
26 |
+
requests==2.31.0
|
27 |
+
requests-file==2.0.0
|
28 |
+
six==1.16.0
|
29 |
+
sniffio==1.3.1
|
30 |
+
sse-starlette==2.0.0
|
31 |
+
starlette==0.37.2
|
32 |
+
tld==0.13
|
33 |
+
tldextract==5.1.2
|
34 |
+
tqdm==4.66.2
|
35 |
+
trafilatura==1.8.1
|
36 |
+
typing_extensions==4.10.0
|
37 |
+
tzlocal==5.2
|
38 |
+
urllib3==2.2.1
|
39 |
+
uvicorn==0.29.0
|
server.py
ADDED
@@ -0,0 +1,312 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
|
3 |
+
import time
|
4 |
+
import uvicorn
|
5 |
+
import sys
|
6 |
+
import getopt
|
7 |
+
import json
|
8 |
+
import os
|
9 |
+
from pprint import pprint
|
10 |
+
import requests
|
11 |
+
import trafilatura
|
12 |
+
from trafilatura import bare_extraction
|
13 |
+
from concurrent.futures import ThreadPoolExecutor
|
14 |
+
import concurrent
|
15 |
+
import requests
|
16 |
+
import openai
|
17 |
+
import time
|
18 |
+
from datetime import datetime
|
19 |
+
from urllib.parse import urlparse
|
20 |
+
import platform
|
21 |
+
import urllib.parse
|
22 |
+
import free_ask_internet
|
23 |
+
from pydantic import BaseModel, Field
|
24 |
+
from fastapi import FastAPI, HTTPException
|
25 |
+
from fastapi.middleware.cors import CORSMiddleware
|
26 |
+
from contextlib import asynccontextmanager
|
27 |
+
from typing import Any, Dict, List, Literal, Optional, Union
|
28 |
+
from sse_starlette.sse import ServerSentEvent, EventSourceResponse
|
29 |
+
from fastapi.responses import StreamingResponse
|
30 |
+
|
31 |
+
app = FastAPI()
|
32 |
+
|
33 |
+
app.add_middleware(
|
34 |
+
CORSMiddleware,
|
35 |
+
allow_origins=["*"],
|
36 |
+
allow_credentials=True,
|
37 |
+
allow_methods=["*"],
|
38 |
+
allow_headers=["*"],
|
39 |
+
)
|
40 |
+
|
41 |
+
|
42 |
+
class ModelCard(BaseModel):
|
43 |
+
id: str
|
44 |
+
object: str = "model"
|
45 |
+
created: int = Field(default_factory=lambda: int(time.time()))
|
46 |
+
owned_by: str = "owner"
|
47 |
+
root: Optional[str] = None
|
48 |
+
parent: Optional[str] = None
|
49 |
+
permission: Optional[list] = None
|
50 |
+
|
51 |
+
|
52 |
+
class ModelList(BaseModel):
|
53 |
+
object: str = "list"
|
54 |
+
data: List[ModelCard] = []
|
55 |
+
|
56 |
+
|
57 |
+
class ChatMessage(BaseModel):
|
58 |
+
role: Literal["user", "assistant", "system"]
|
59 |
+
content: str
|
60 |
+
|
61 |
+
|
62 |
+
class DeltaMessage(BaseModel):
|
63 |
+
role: Optional[Literal["user", "assistant", "system"]] = None
|
64 |
+
content: Optional[str] = None
|
65 |
+
|
66 |
+
class QueryRequest(BaseModel):
|
67 |
+
query:str
|
68 |
+
model: str
|
69 |
+
ask_type: Literal["search", "llm"]
|
70 |
+
llm_auth_token: Optional[str] = os.environ.get('OPENAI_API_KEY')
|
71 |
+
llm_base_url: Optional[str] = os.environ.get('OPENAI_BASE_URL')
|
72 |
+
using_custom_llm:Optional[bool] = False
|
73 |
+
lang:Optional[str] = "zh-CN"
|
74 |
+
|
75 |
+
class ChatCompletionRequest(BaseModel):
|
76 |
+
model: str
|
77 |
+
messages: List[ChatMessage]
|
78 |
+
temperature: Optional[float] = None
|
79 |
+
top_p: Optional[float] = None
|
80 |
+
max_length: Optional[int] = None
|
81 |
+
stream: Optional[bool] = False
|
82 |
+
|
83 |
+
|
84 |
+
class ChatCompletionResponseChoice(BaseModel):
|
85 |
+
index: int
|
86 |
+
message: ChatMessage
|
87 |
+
finish_reason: Literal["stop", "length"]
|
88 |
+
|
89 |
+
|
90 |
+
class ChatCompletionResponseStreamChoice(BaseModel):
|
91 |
+
index: int
|
92 |
+
delta: DeltaMessage
|
93 |
+
finish_reason: Optional[Literal["stop", "length"]]
|
94 |
+
|
95 |
+
|
96 |
+
class ChatCompletionResponse(BaseModel):
|
97 |
+
model: str
|
98 |
+
object: Literal["chat.completion", "chat.completion.chunk"]
|
99 |
+
choices: List[Union[ChatCompletionResponseChoice,
|
100 |
+
ChatCompletionResponseStreamChoice]]
|
101 |
+
created: Optional[int] = Field(default_factory=lambda: int(time.time()))
|
102 |
+
|
103 |
+
class SearchItem(BaseModel):
|
104 |
+
url: str
|
105 |
+
icon_url: str
|
106 |
+
site_name:str
|
107 |
+
snippet:str
|
108 |
+
title:str
|
109 |
+
|
110 |
+
class SearchItemList(BaseModel):
|
111 |
+
search_items: List[SearchItem] = []
|
112 |
+
|
113 |
+
class SearchResp(BaseModel):
|
114 |
+
code:int
|
115 |
+
msg:str
|
116 |
+
data: List[SearchItem] = []
|
117 |
+
|
118 |
+
|
119 |
+
@app.get("/deem/v1/models", response_model=ModelList)
|
120 |
+
async def list_models():
|
121 |
+
global model_args
|
122 |
+
model_card = ModelCard(id="gpt-3.5-turbo")
|
123 |
+
return ModelList(data=[model_card])
|
124 |
+
|
125 |
+
|
126 |
+
@app.post("/deem/v1/chat/completions", response_model=ChatCompletionResponse)
|
127 |
+
async def create_chat_completion(request: ChatCompletionRequest):
|
128 |
+
global model, tokenizer
|
129 |
+
if request.messages[-1].role != "user":
|
130 |
+
raise HTTPException(status_code=400, detail="Invalid request")
|
131 |
+
query = request.messages[-1].content
|
132 |
+
if query[0] != '!':
|
133 |
+
print("当前问题:gpt ---> {}".format(query))
|
134 |
+
generate = askgpt(request.messages,"",request.model)
|
135 |
+
else:
|
136 |
+
query = query[1:]
|
137 |
+
if len(request.messages) > 2:
|
138 |
+
message = '\n'.join([msg.content for msg in request.messages])
|
139 |
+
query = free_ask_internet.summary_gpt(message + "\n请根据以上的内容总结," + query +" 这个问题是要问什么?不要有模糊的代名词比如他/她之类的,不允许缺失上下文语境,需要明确提问的主题;最后只允许输出总结并完善语境的问题,不要有任何多余的文字!")
|
140 |
+
else:
|
141 |
+
pass
|
142 |
+
print("当前问题:net ---> {}".format(query))
|
143 |
+
generate = predict(query, "", request.model)
|
144 |
+
return EventSourceResponse(generate, media_type="text/event-stream")
|
145 |
+
|
146 |
+
|
147 |
+
def askgpt(query, history, model_id):
|
148 |
+
choice_data = ChatCompletionResponseStreamChoice(
|
149 |
+
index=0,
|
150 |
+
delta=DeltaMessage(role="assistant"),
|
151 |
+
finish_reason=None
|
152 |
+
)
|
153 |
+
chunk = ChatCompletionResponse(model=model_id, choices=[
|
154 |
+
choice_data], object="chat.completion.chunk")
|
155 |
+
yield "{}".format(chunk.json(exclude_unset=True))
|
156 |
+
new_response = ""
|
157 |
+
current_length = 0
|
158 |
+
for token in free_ask_internet.ask_gpt(query, model_id,):
|
159 |
+
|
160 |
+
new_response += token
|
161 |
+
if len(new_response) == current_length:
|
162 |
+
continue
|
163 |
+
|
164 |
+
new_text = new_response[current_length:]
|
165 |
+
current_length = len(new_response)
|
166 |
+
|
167 |
+
choice_data = ChatCompletionResponseStreamChoice(
|
168 |
+
index=0,
|
169 |
+
delta=DeltaMessage(content=new_text,role="assistant"),
|
170 |
+
finish_reason=None
|
171 |
+
)
|
172 |
+
chunk = ChatCompletionResponse(model=model_id, choices=[
|
173 |
+
choice_data], object="chat.completion.chunk")
|
174 |
+
yield "{}".format(chunk.json(exclude_unset=True))
|
175 |
+
|
176 |
+
choice_data = ChatCompletionResponseStreamChoice(
|
177 |
+
index=0,
|
178 |
+
delta=DeltaMessage(),
|
179 |
+
finish_reason="stop"
|
180 |
+
)
|
181 |
+
chunk = ChatCompletionResponse(model=model_id, choices=[
|
182 |
+
choice_data], object="chat.completion.chunk")
|
183 |
+
yield "{}".format(chunk.json(exclude_unset=True))
|
184 |
+
yield '[DONE]'
|
185 |
+
|
186 |
+
|
187 |
+
|
188 |
+
def predict(query: str, history: None, model_id: str):
|
189 |
+
choice_data = ChatCompletionResponseStreamChoice(
|
190 |
+
index=0,
|
191 |
+
delta=DeltaMessage(role="assistant"),
|
192 |
+
finish_reason=None
|
193 |
+
)
|
194 |
+
chunk = ChatCompletionResponse(model=model_id, choices=[
|
195 |
+
choice_data], object="chat.completion.chunk")
|
196 |
+
yield "{}".format(chunk.json(exclude_unset=True))
|
197 |
+
new_response = ""
|
198 |
+
current_length = 0
|
199 |
+
for token in free_ask_internet.ask_internet(query=query, model=model_id):
|
200 |
+
|
201 |
+
new_response += token
|
202 |
+
if len(new_response) == current_length:
|
203 |
+
continue
|
204 |
+
|
205 |
+
new_text = new_response[current_length:]
|
206 |
+
current_length = len(new_response)
|
207 |
+
|
208 |
+
choice_data = ChatCompletionResponseStreamChoice(
|
209 |
+
index=0,
|
210 |
+
delta=DeltaMessage(content=new_text,role="assistant"),
|
211 |
+
finish_reason=None
|
212 |
+
)
|
213 |
+
chunk = ChatCompletionResponse(model=model_id, choices=[
|
214 |
+
choice_data], object="chat.completion.chunk")
|
215 |
+
yield "{}".format(chunk.json(exclude_unset=True))
|
216 |
+
|
217 |
+
choice_data = ChatCompletionResponseStreamChoice(
|
218 |
+
index=0,
|
219 |
+
delta=DeltaMessage(),
|
220 |
+
finish_reason="stop"
|
221 |
+
)
|
222 |
+
chunk = ChatCompletionResponse(model=model_id, choices=[
|
223 |
+
choice_data], object="chat.completion.chunk")
|
224 |
+
yield "{}".format(chunk.json(exclude_unset=True))
|
225 |
+
yield '[DONE]'
|
226 |
+
|
227 |
+
|
228 |
+
|
229 |
+
@app.post("/api/search/get_search_refs", response_model=SearchResp)
|
230 |
+
async def get_search_refs(request: QueryRequest):
|
231 |
+
|
232 |
+
global search_results
|
233 |
+
search_results = []
|
234 |
+
search_item_list = []
|
235 |
+
if request.ask_type == "search":
|
236 |
+
search_links,search_results = free_ask_internet.search_web_ref(request.query)
|
237 |
+
for search_item in search_links:
|
238 |
+
snippet = search_item.get("snippet")
|
239 |
+
url = search_item.get("url")
|
240 |
+
icon_url = search_item.get("icon_url")
|
241 |
+
site_name = search_item.get("site_name")
|
242 |
+
title = search_item.get("title")
|
243 |
+
|
244 |
+
|
245 |
+
si = SearchItem(snippet=snippet,url=url,icon_url=icon_url,site_name=site_name,title=title)
|
246 |
+
|
247 |
+
search_item_list.append(si)
|
248 |
+
|
249 |
+
resp = SearchResp(code=0,msg="success",data=search_item_list)
|
250 |
+
|
251 |
+
return resp
|
252 |
+
|
253 |
+
def generator(prompt:str, model:str, llm_auth_token:str,llm_base_url:str, using_custom_llm=False,is_failed=False):
|
254 |
+
if is_failed:
|
255 |
+
yield "搜索失败,没有返回结果"
|
256 |
+
else:
|
257 |
+
total_token = ""
|
258 |
+
for token in free_ask_internet.chat(prompt=prompt,model=model,llm_auth_token=llm_auth_token,llm_base_url=llm_base_url,using_custom_llm=using_custom_llm,stream=True):
|
259 |
+
total_token += token
|
260 |
+
yield token
|
261 |
+
|
262 |
+
@app.post("/api/search/stream/{search_uuid}")
|
263 |
+
async def stream(search_uuid:str,request: QueryRequest):
|
264 |
+
global search_results
|
265 |
+
|
266 |
+
if request.ask_type == "llm":
|
267 |
+
|
268 |
+
answer_language = ' Simplified Chinese '
|
269 |
+
if request.lang == "zh-CN":
|
270 |
+
answer_language = ' Simplified Chinese '
|
271 |
+
if request.lang == "zh-TW":
|
272 |
+
answer_language = ' Traditional Chinese '
|
273 |
+
if request.lang == "en-US":
|
274 |
+
answer_language = ' English '
|
275 |
+
prompt = ' You are a large language AI assistant develop by nash_su. Answer user question in ' + answer_language + '. And here is the user question: ' + request.query
|
276 |
+
generate = generator(prompt,model=request.model,llm_auth_token=request.llm_auth_token, llm_base_url=request.llm_base_url, using_custom_llm=request.using_custom_llm)
|
277 |
+
else:
|
278 |
+
prompt = None
|
279 |
+
limit_count = 10
|
280 |
+
|
281 |
+
while limit_count > 0:
|
282 |
+
try:
|
283 |
+
if len(search_results) > 0:
|
284 |
+
prompt = free_ask_internet.gen_prompt(request.query,search_results,lang=request.lang,context_length_limit=8000)
|
285 |
+
break
|
286 |
+
else:
|
287 |
+
limit_count -= 1
|
288 |
+
time.sleep(1)
|
289 |
+
except Exception as err:
|
290 |
+
limit_count -= 1
|
291 |
+
time.sleep(1)
|
292 |
+
total_token = ""
|
293 |
+
if prompt:
|
294 |
+
generate = generator(prompt,model=request.model,llm_auth_token=request.llm_auth_token, llm_base_url=request.llm_base_url, using_custom_llm=request.using_custom_llm)
|
295 |
+
else:
|
296 |
+
generate = generator(prompt,model=request.model,llm_auth_token=request.llm_auth_token,llm_base_url=request.llm_base_url, using_custom_llm=request.using_custom_llm,is_failed=True)
|
297 |
+
|
298 |
+
# return EventSourceResponse(generate, media_type="text/event-stream")
|
299 |
+
return StreamingResponse(generate, media_type="text/event-stream")
|
300 |
+
|
301 |
+
def main():
|
302 |
+
|
303 |
+
port = 8000
|
304 |
+
|
305 |
+
search_results = []
|
306 |
+
|
307 |
+
|
308 |
+
uvicorn.run(app, host='0.0.0.0', port=port, workers=1)
|
309 |
+
|
310 |
+
|
311 |
+
if __name__ == "__main__":
|
312 |
+
main()
|