Spaces:
Sleeping
Sleeping
neoguojing
commited on
Commit
·
4d10a94
1
Parent(s):
494b300
finish rag
Browse files- .gitignore +3 -0
- app.py +51 -36
- llm.py +89 -0
- requirements.txt +2 -1
.gitignore
CHANGED
@@ -2,3 +2,6 @@
|
|
2 |
__pycache__/
|
3 |
*.bin
|
4 |
.vscode/
|
|
|
|
|
|
|
|
2 |
__pycache__/
|
3 |
*.bin
|
4 |
.vscode/
|
5 |
+
files/input/ir2023_ashare.pdf
|
6 |
+
knowledge_bases/中国移动.faiss
|
7 |
+
knowledge_bases/中国移动.pkl
|
app.py
CHANGED
@@ -133,23 +133,23 @@ def create_ui():
|
|
133 |
components["db_view"] = gr.Dataframe(
|
134 |
headers=["列表"],
|
135 |
datatype=["str"],
|
136 |
-
row_count=
|
137 |
col_count=(1, "fixed"),
|
138 |
interactive=False
|
139 |
)
|
140 |
with gr.Column(scale=2):
|
141 |
-
|
|
|
142 |
components["db_name"] = gr.Textbox(label="名称", info="请输入库名称", lines=1, value="")
|
143 |
-
|
144 |
components["db_submit_btn"] = gr.Button(value="提交")
|
|
|
145 |
with gr.Row():
|
146 |
with gr.Column(scale=2):
|
147 |
components["db_input"] = gr.Textbox(label="关键词", lines=1, value="")
|
148 |
-
|
149 |
with gr.Column(scale=1):
|
150 |
-
components["db_test_select"] = gr.Dropdown(
|
151 |
-
|
152 |
-
)
|
153 |
components["dbtest_submit_btn"] = gr.Button(value="检索")
|
154 |
with gr.Row():
|
155 |
with gr.Group():
|
@@ -157,16 +157,22 @@ def create_ui():
|
|
157 |
|
158 |
with gr.Tab("问答"):
|
159 |
with gr.Row():
|
160 |
-
with gr.Column():
|
|
|
|
|
|
|
|
|
|
|
|
|
161 |
with gr.Group():
|
162 |
components["chatbot"] = gr.Chatbot(
|
163 |
-
[(None,"
|
164 |
elem_id="chatbot",
|
165 |
bubble_full_width=False,
|
166 |
height=600
|
167 |
)
|
168 |
components["chat_input"] = gr.MultimodalTextbox(interactive=True, file_types=["image"], placeholder="Enter message or upload file...", show_label=False)
|
169 |
-
components["db_select"] = gr.CheckboxGroup(
|
170 |
create_event_handlers()
|
171 |
demo.load(init,None,gradio("db_view"))
|
172 |
return demo
|
@@ -236,6 +242,10 @@ def create_event_handlers():
|
|
236 |
do_search, gradio('db_test_select','db_input'), gradio('db_search_result')
|
237 |
)
|
238 |
|
|
|
|
|
|
|
|
|
239 |
def do_refernce(algo_type,input_image):
|
240 |
# def do_refernce():
|
241 |
print("input image",input_image)
|
@@ -307,9 +317,6 @@ def do_sam_everything(im):
|
|
307 |
|
308 |
return images
|
309 |
|
310 |
-
|
311 |
-
|
312 |
-
|
313 |
def point_to_mask(pil_image):
|
314 |
# 遍历每个像素
|
315 |
width, height = pil_image.size
|
@@ -337,11 +344,11 @@ def do_llm_request(history, message):
|
|
337 |
return history, gr.MultimodalTextbox(value=None, interactive=False)
|
338 |
|
339 |
def do_llm_response(history,selected_dbs):
|
|
|
340 |
user_input = history[-1][0]
|
341 |
prompt = ""
|
342 |
quote = ""
|
343 |
-
|
344 |
-
if selected_dbs is not None and len(selected_dbs) != 0:
|
345 |
knowledge = knowledgeBase.retrieve_documents(selected_dbs,user_input)
|
346 |
print("do_llm_response context:",knowledge)
|
347 |
prompt = f'''
|
@@ -349,8 +356,8 @@ def do_llm_response(history,selected_dbs):
|
|
349 |
背景2:{knowledge[1]["content"]}
|
350 |
背景3:{knowledge[2]["content"]}
|
351 |
基于以上事实回答问题:{user_input}
|
352 |
-
'''
|
353 |
-
|
354 |
quote = f'''
|
355 |
> 文档:{knowledge[0]["meta"]["source"]},页码:{knowledge[0]["meta"]["page"]}
|
356 |
> 文档:{knowledge[1]["meta"]["source"]},页码:{knowledge[1]["meta"]["page"]}
|
@@ -358,33 +365,41 @@ def do_llm_response(history,selected_dbs):
|
|
358 |
'''
|
359 |
else:
|
360 |
prompt = user_input
|
361 |
-
|
362 |
-
response = llm(prompt)
|
363 |
history[-1][1] = ""
|
364 |
-
|
365 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
366 |
for character in response:
|
367 |
history[-1][1] += character
|
368 |
time.sleep(0.01)
|
369 |
yield history
|
370 |
|
371 |
-
def llm(input):
|
372 |
-
import requests
|
373 |
-
API_URL = "https://api-inference.huggingface.co/models/mistralai/Mistral-7B-Instruct-v0.3"
|
374 |
-
headers = {"Authorization": "Bearer "}
|
375 |
|
376 |
-
|
377 |
-
|
378 |
-
|
379 |
-
|
380 |
-
|
381 |
-
|
382 |
-
|
383 |
-
|
384 |
-
|
385 |
-
|
386 |
-
|
|
|
|
|
|
|
|
|
|
|
387 |
|
|
|
388 |
|
389 |
|
390 |
def file_handler(file_objs,name):
|
|
|
133 |
components["db_view"] = gr.Dataframe(
|
134 |
headers=["列表"],
|
135 |
datatype=["str"],
|
136 |
+
row_count=2,
|
137 |
col_count=(1, "fixed"),
|
138 |
interactive=False
|
139 |
)
|
140 |
with gr.Column(scale=2):
|
141 |
+
with gr.Row():
|
142 |
+
with gr.Column(scale=2):
|
143 |
components["db_name"] = gr.Textbox(label="名称", info="请输入库名称", lines=1, value="")
|
144 |
+
with gr.Column(scale=2):
|
145 |
components["db_submit_btn"] = gr.Button(value="提交")
|
146 |
+
components["file_upload"] = gr.File(elem_id='file_upload',file_count='multiple',label='文档上传', file_types=[".pdf", ".doc", '.docx', '.json', '.csv'])
|
147 |
with gr.Row():
|
148 |
with gr.Column(scale=2):
|
149 |
components["db_input"] = gr.Textbox(label="关键词", lines=1, value="")
|
|
|
150 |
with gr.Column(scale=1):
|
151 |
+
components["db_test_select"] = gr.Dropdown(knowledgeBase.get_bases(),multiselect=True, label="知识库选择")
|
152 |
+
with gr.Column(scale=1):
|
|
|
153 |
components["dbtest_submit_btn"] = gr.Button(value="检索")
|
154 |
with gr.Row():
|
155 |
with gr.Group():
|
|
|
157 |
|
158 |
with gr.Tab("问答"):
|
159 |
with gr.Row():
|
160 |
+
with gr.Column(scale=1):
|
161 |
+
with gr.Group():
|
162 |
+
components["ak"] = gr.Textbox(label="appid")
|
163 |
+
components["sk"] = gr.Textbox(label="secret")
|
164 |
+
components["llm_client"] =gr.Radio(["Wenxin", "Tongyi","Huggingface"],value="Wenxin", label="llm")
|
165 |
+
components["llm_setting_btn"] = gr.Button(value="设置")
|
166 |
+
with gr.Column(scale=2):
|
167 |
with gr.Group():
|
168 |
components["chatbot"] = gr.Chatbot(
|
169 |
+
[(None,"你好,有什么需要帮助的?")],
|
170 |
elem_id="chatbot",
|
171 |
bubble_full_width=False,
|
172 |
height=600
|
173 |
)
|
174 |
components["chat_input"] = gr.MultimodalTextbox(interactive=True, file_types=["image"], placeholder="Enter message or upload file...", show_label=False)
|
175 |
+
components["db_select"] = gr.CheckboxGroup(knowledgeBase.get_bases(),label="知识库", info="可选择1个或多个知识库")
|
176 |
create_event_handlers()
|
177 |
demo.load(init,None,gradio("db_view"))
|
178 |
return demo
|
|
|
242 |
do_search, gradio('db_test_select','db_input'), gradio('db_search_result')
|
243 |
)
|
244 |
|
245 |
+
components['llm_setting_btn'].click(
|
246 |
+
llm, gradio('ak','sk','llm_client'), None
|
247 |
+
)
|
248 |
+
|
249 |
def do_refernce(algo_type,input_image):
|
250 |
# def do_refernce():
|
251 |
print("input image",input_image)
|
|
|
317 |
|
318 |
return images
|
319 |
|
|
|
|
|
|
|
320 |
def point_to_mask(pil_image):
|
321 |
# 遍历每个像素
|
322 |
width, height = pil_image.size
|
|
|
344 |
return history, gr.MultimodalTextbox(value=None, interactive=False)
|
345 |
|
346 |
def do_llm_response(history,selected_dbs):
|
347 |
+
print("do_llm_response:",history,selected_dbs)
|
348 |
user_input = history[-1][0]
|
349 |
prompt = ""
|
350 |
quote = ""
|
351 |
+
if len(selected_dbs) > 0:
|
|
|
352 |
knowledge = knowledgeBase.retrieve_documents(selected_dbs,user_input)
|
353 |
print("do_llm_response context:",knowledge)
|
354 |
prompt = f'''
|
|
|
356 |
背景2:{knowledge[1]["content"]}
|
357 |
背景3:{knowledge[2]["content"]}
|
358 |
基于以上事实回答问题:{user_input}
|
359 |
+
'''
|
360 |
+
|
361 |
quote = f'''
|
362 |
> 文档:{knowledge[0]["meta"]["source"]},页码:{knowledge[0]["meta"]["page"]}
|
363 |
> 文档:{knowledge[1]["meta"]["source"]},页码:{knowledge[1]["meta"]["page"]}
|
|
|
365 |
'''
|
366 |
else:
|
367 |
prompt = user_input
|
368 |
+
|
|
|
369 |
history[-1][1] = ""
|
370 |
+
if llm_client is None:
|
371 |
+
gr.Warning("请先设置大模型")
|
372 |
+
response = "模型参数未设置"
|
373 |
+
else:
|
374 |
+
print("do_llm_response prompt:",prompt)
|
375 |
+
response = llm_client(prompt)
|
376 |
+
response = response.removeprefix(prompt)
|
377 |
+
response += quote
|
378 |
+
|
379 |
for character in response:
|
380 |
history[-1][1] += character
|
381 |
time.sleep(0.01)
|
382 |
yield history
|
383 |
|
|
|
|
|
|
|
|
|
384 |
|
385 |
+
llm_client = None
|
386 |
+
def llm(ak,sk,client):
|
387 |
+
global llm_client
|
388 |
+
import llm
|
389 |
+
llm.init_param(ak,sk)
|
390 |
+
if client == "Wenxin":
|
391 |
+
llm_client = llm.baidu_client
|
392 |
+
elif client == "Tongyi":
|
393 |
+
llm_client = llm.qwen_agent_app
|
394 |
+
elif client == "Huggingface":
|
395 |
+
llm_client = llm.hg_client
|
396 |
+
|
397 |
+
if ak == "" and sk == "":
|
398 |
+
gr.Info("重置成功")
|
399 |
+
else:
|
400 |
+
gr.Info("设置成功")
|
401 |
|
402 |
+
return llm_client
|
403 |
|
404 |
|
405 |
def file_handler(file_objs,name):
|
llm.py
ADDED
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import requests
|
2 |
+
import json
|
3 |
+
from http import HTTPStatus
|
4 |
+
from dashscope import Application
|
5 |
+
|
6 |
+
ak = ""
|
7 |
+
sk = ""
|
8 |
+
|
9 |
+
def init_param(access_key,secret_key):
|
10 |
+
global ak, sk
|
11 |
+
ak = access_key
|
12 |
+
sk = secret_key
|
13 |
+
|
14 |
+
|
15 |
+
def baidu_client(input):
|
16 |
+
global ak, sk
|
17 |
+
if ak == "" or sk == "":
|
18 |
+
return ""
|
19 |
+
|
20 |
+
url = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-lite-8k?access_token=" + get_access_token()
|
21 |
+
|
22 |
+
payload = json.dumps({
|
23 |
+
"temperature": 0.95,
|
24 |
+
"top_p": 0.7,
|
25 |
+
"penalty_score": 1,
|
26 |
+
"messages": [
|
27 |
+
{
|
28 |
+
"role": "user",
|
29 |
+
"content": input
|
30 |
+
}
|
31 |
+
],
|
32 |
+
"system": ""
|
33 |
+
})
|
34 |
+
headers = {
|
35 |
+
'Content-Type': 'application/json'
|
36 |
+
}
|
37 |
+
|
38 |
+
response = requests.request("POST", url, headers=headers, data=payload)
|
39 |
+
|
40 |
+
print("baidu_client",response.text)
|
41 |
+
return response.json()["result"]
|
42 |
+
|
43 |
+
|
44 |
+
def get_access_token():
|
45 |
+
"""
|
46 |
+
使用 AK,SK 生成鉴权签名(Access Token)
|
47 |
+
:return: access_token,或是None(如果错误)
|
48 |
+
"""
|
49 |
+
url = "https://aip.baidubce.com/oauth/2.0/token"
|
50 |
+
params = {"grant_type": "client_credentials", "client_id": ak, "client_secret": sk}
|
51 |
+
return str(requests.post(url, params=params).json().get("access_token"))
|
52 |
+
|
53 |
+
|
54 |
+
def qwen_agent_app(input):
|
55 |
+
global ak, sk
|
56 |
+
if ak == "" or sk == "":
|
57 |
+
return ""
|
58 |
+
response = Application.call(app_id=ak,
|
59 |
+
prompt=input,
|
60 |
+
api_key=sk,
|
61 |
+
)
|
62 |
+
|
63 |
+
if response.status_code != HTTPStatus.OK:
|
64 |
+
print('request_id=%s, code=%s, message=%s\n' % (response.request_id, response.status_code, response.message))
|
65 |
+
return ""
|
66 |
+
else:
|
67 |
+
print('request_id=%s\n output=%s\n usage=%s\n' % (response.request_id, response.output, response.usage))
|
68 |
+
return response.output["text"]
|
69 |
+
|
70 |
+
|
71 |
+
def hg_client(input):
|
72 |
+
global ak, sk
|
73 |
+
if sk == "":
|
74 |
+
return ""
|
75 |
+
import requests
|
76 |
+
API_URL = "https://api-inference.huggingface.co/models/mistralai/Mistral-7B-Instruct-v0.3"
|
77 |
+
headers = {"Authorization": f"Bearer {sk}"}
|
78 |
+
|
79 |
+
def query(payload):
|
80 |
+
response = requests.post(API_URL, headers=headers, json=payload)
|
81 |
+
return response.json()
|
82 |
+
|
83 |
+
output = query({
|
84 |
+
"inputs": input,
|
85 |
+
})
|
86 |
+
print(output)
|
87 |
+
if len(output) >0:
|
88 |
+
return output[0]['generated_text']
|
89 |
+
return ""
|
requirements.txt
CHANGED
@@ -20,4 +20,5 @@ faiss-cpu==1.8.0
|
|
20 |
pypdf==4.2.0
|
21 |
langchain==0.2.5
|
22 |
langchain-community==0.2.5
|
23 |
-
transformers==4.32.1
|
|
|
|
20 |
pypdf==4.2.0
|
21 |
langchain==0.2.5
|
22 |
langchain-community==0.2.5
|
23 |
+
transformers==4.32.1
|
24 |
+
dashscope==1.20.0
|