rag / app.py
guoerjun
fix
61eb5cb
import gradio as gr
import numpy as np
import time
from pathlib import Path
from retriever import knowledgeBase
import llm
current_file_path = Path(__file__).resolve()
absolute_path = (current_file_path.parent / "files" / "input").resolve()
components = {}
params = {
"algo_type": None,
"input_image":None
}
def gradio(*keys):
if len(keys) == 1 and type(keys[0]) in [list, tuple]:
keys = keys[0]
return [components[k] for k in keys]
def create_ui():
with gr.Blocks() as demo:
with gr.Tab("知识库"):
with gr.Row():
with gr.Column(scale=1):
with gr.Group():
components["db_view"] = gr.Dataframe(
headers=["列表"],
datatype=["str"],
row_count=2,
col_count=(1, "fixed"),
interactive=False
)
components["file_expr"] = gr.FileExplorer(
scale=1,
value=[],
file_count="single",
root_dir=absolute_path,
# ignore_glob="**/__init__.py",
elem_id="file_expr",
)
with gr.Column(scale=2):
with gr.Row():
with gr.Column(scale=2):
components["db_name"] = gr.Textbox(label="名称", info="请输入库名称", lines=1, value="")
with gr.Column(scale=2):
components["db_submit_btn"] = gr.Button(value="提交")
components["file_upload"] = gr.File(elem_id='file_upload',file_count='multiple',label='文档上传', file_types=[".pdf", ".doc", '.docx', '.json', '.csv'])
with gr.Row():
with gr.Column(scale=2):
components["db_input"] = gr.Textbox(label="关键词", lines=1, value="")
with gr.Column(scale=1):
components["db_test_select"] = gr.Dropdown(knowledgeBase.get_bases(),multiselect=True, label="知识库选择")
with gr.Column(scale=1):
components["dbtest_submit_btn"] = gr.Button(value="检索")
with gr.Row():
with gr.Group():
components["db_search_result"] = gr.JSON(label="检索结果")
with gr.Tab("问答"):
with gr.Row():
with gr.Column(scale=2):
with gr.Group():
components["chatbot"] = gr.Chatbot(
[(None,"你好,有什么需要帮助的?")],
elem_id="chatbot",
bubble_full_width=False,
height=600
)
components["chat_input"] = gr.MultimodalTextbox(interactive=True, file_types=["image"], placeholder="Enter message or upload file...", show_label=False)
components["db_select"] = gr.CheckboxGroup(knowledgeBase.get_bases(),label="知识库", info="可选择1个或多个知识库")
create_event_handlers()
demo.load(init,None,gradio("db_view","db_select","db_test_select"))
return demo
def init():
db_list = knowledgeBase.get_bases()
db_df_list = knowledgeBase.get_df_bases()
return db_df_list,gr.CheckboxGroup(db_list,label="知识库", info="可选择1个或多个知识库"),gr.Dropdown(db_list,multiselect=True, label="知识库选择")
def create_event_handlers():
components["db_submit_btn"].click(
file_handler,gradio('file_upload','db_name'),gradio("db_view",'db_select',"db_test_select")
)
components["chat_input"].submit(
do_llm_request, gradio("chatbot", "chat_input"), gradio("chatbot", "chat_input")
).then(
do_llm_response, gradio("chatbot","db_select"), gradio("chatbot"), api_name="bot_response"
).then(
lambda: gr.MultimodalTextbox(interactive=True), None, gradio('chat_input')
)
# components["chatbot"].like(print_like_dislike, None, None)
components['dbtest_submit_btn'].click(
do_search, gradio('db_test_select','db_input'), gradio('db_search_result')
)
components['db_view'].select(
db_expr, gradio('db_view'), gradio('file_expr')
)
def print_like_dislike(x: gr.LikeData):
print(x.index, x.value, x.liked)
def do_llm_request(history, message):
for x in message["files"]:
history.append(((x,), None))
if message["text"] is not None:
history.append((message["text"], None))
return history, gr.MultimodalTextbox(value=None, interactive=False)
def do_llm_response(history,selected_dbs):
print("do_llm_response:",history,selected_dbs)
user_input = history[-1][0]
prompt = ""
quote = ""
if len(selected_dbs) > 0:
knowledge = knowledgeBase.retrieve_documents(selected_dbs,user_input)
print("do_llm_response context:",knowledge)
prompt = f'''
背景1:{knowledge[0]["content"]}
背景2:{knowledge[1]["content"]}
背景3:{knowledge[2]["content"]}
基于以上事实回答问题:{user_input}
'''
quote = f'''
> 文档:{knowledge[0]["meta"]["source"]},页码:{knowledge[0]["meta"]["page"]}
> 文档:{knowledge[1]["meta"]["source"]},页码:{knowledge[1]["meta"]["page"]}
> 文档:{knowledge[2]["meta"]["source"]},页码:{knowledge[2]["meta"]["page"]}
'''
else:
prompt = user_input
history[-1][1] = ""
if llm_client is None:
gr.Warning("请先设置大模型")
response = "模型参数未设置"
else:
print("do_llm_response prompt:",prompt)
response = llm_client(prompt)
response = response.removeprefix(prompt)
response += quote
for character in response:
history[-1][1] += character
time.sleep(0.01)
yield history
llm_client = llm.baidu_client
def file_handler(file_objs,name):
import shutil
import os
print("file_obj:",file_objs)
os.makedirs(os.path.dirname("./files/input/"), exist_ok=True)
for idx, file in enumerate(file_objs):
print(file)
file_path = "./files/input/" + os.path.basename(file.name)
if not os.path.exists(file_path):
shutil.move(file.name,"./files/input/")
knowledgeBase.add_documents_to_kb(name,[file_path])
dbs = knowledgeBase.get_bases()
dfs = knowledgeBase.get_df_bases()
return dfs,gr.CheckboxGroup(dbs,label="知识库", info="可选择1个或多个知识库"),gr.Dropdown(dbs,multiselect=True, label="知识库选择")
def db_expr(selected_index: gr.SelectData, dataframe_origin):
print("db_expr",selected_index.index)
dbname = dataframe_origin.iloc[selected_index.index[0],selected_index.index[1]]
print("db_expr",dbname)
return knowledgeBase.get_db_files(dbname)
def do_search(selected_dbs,user_input):
print("do_search:",selected_dbs,user_input)
context = knowledgeBase.retrieve_documents(selected_dbs,user_input)
return context
if __name__ == "__main__":
demo = create_ui()
# demo.launch(server_name="10.151.124.137")
demo.launch()