hiwei commited on
Commit
a7c7b3c
·
verified ·
1 Parent(s): d157382

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +120 -0
app.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio
2
+ import gradio as gr
3
+ from langchain.chains import RetrievalQA
4
+ from langchain.text_splitter import SpacyTextSplitter
5
+ from langchain_community.chat_models import ChatZhipuAI, ChatGooglePalm
6
+ from langchain_community.document_loaders import PyPDFLoader
7
+ from langchain_community.embeddings import HuggingFaceInferenceAPIEmbeddings
8
+ from langchain_community.vectorstores import Chroma
9
+ from langchain_core.prompts import PromptTemplate
10
+
11
+ template = """Use the following pieces of context to answer the question at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer. Use three sentences maximum. Keep the answer as concise as possible. Always say "thanks for asking!" at the end of the answer.
12
+ Tips: Make sure to cite your sources, and use the exact words from the context.
13
+ {context}
14
+ Question: {question}
15
+ Helpful Answer:"""
16
+ QA_CHAIN_PROMPT = PromptTemplate.from_template(template)
17
+
18
+
19
+ class RAGDemo(object):
20
+ def __init__(self):
21
+ self.embedding = None
22
+ self.vector_db = None
23
+ self.chat_model = None
24
+
25
+ def _init_chat_model(self, model_name, api_key):
26
+ if not api_key:
27
+ gradio.Error("Please enter model API key.")
28
+ return
29
+ if 'glm' in model_name:
30
+ self.chat_model = ChatZhipuAI(
31
+ temperature=0.5,
32
+ api_key=api_key,
33
+ model="glm-3-turbo",
34
+ )
35
+ elif 'gemini' in model_name:
36
+ self.chat_model = ChatGooglePalm(
37
+ google_api_key=api_key,
38
+ model_name='gemini-pro'
39
+ )
40
+
41
+ def _init_embedding(self, embedding_model_name, api_key):
42
+ if not api_key:
43
+ gradio.Error("Please enter embedding API key.")
44
+ return
45
+ if 'glm' in embedding_model_name:
46
+ gradio.Error("GLM is not supported yet.")
47
+ else:
48
+ self.embedding = HuggingFaceInferenceAPIEmbeddings(
49
+ api_key=api_key, model_name=embedding_model_name
50
+ )
51
+
52
+ def _build_vector_db(self, file_path):
53
+ if not file_path:
54
+ gradio.Error("Please enter vector database file path.")
55
+ return
56
+ gr.Info("Building vector database...")
57
+ loader = PyPDFLoader(file_path)
58
+ pages = loader.load()
59
+
60
+ text_splitter = SpacyTextSplitter(chunk_size=500, chunk_overlap=50)
61
+ docs = text_splitter.split_documents(pages)
62
+
63
+ self.vector_db = Chroma.from_documents(
64
+ documents=docs, embedding=self.embedding
65
+ )
66
+ gr.Info("Vector database built.")
67
+
68
+ def _retrieval_qa(self, input_text):
69
+ basic_qa = RetrievalQA.from_chain_type(
70
+ self.chat_model,
71
+ retriever=self.vector_db.as_retriever(),
72
+ chain_type_kwargs={"prompt": QA_CHAIN_PROMPT}
73
+ )
74
+ return basic_qa.invoke(input_text)
75
+
76
+ def __call__(self):
77
+ with gr.Blocks() as demo:
78
+ gr.Markdown("# RAG Demo\n\nbase on the [RAG learning note](https://www.jianshu.com/p/9792f1e6c3f9) and "
79
+ "[rag-practice](https://github.com/hiwei93/rag-practice/tree/main)")
80
+ with gr.Row():
81
+ with gr.Column():
82
+ input_text = gr.Textbox(placeholder="input your question...", label="input")
83
+ submit_btn = gr.Button("submit")
84
+ with gr.Accordion("model settings"):
85
+ api_key = gr.Textbox(placeholder="your api key", label="api key")
86
+ model_name = gr.Dropdown(
87
+ choices=['glm-3-turbo', 'gemini-1.0-pro'],
88
+ value='glm-3-turbo',
89
+ label="model"
90
+ )
91
+ with gr.Accordion("knowledge base settigns"):
92
+ embedding_api_key = gr.Textbox(placeholder="your api key", label="embedding api key")
93
+ embedding_model = gr.Dropdown(
94
+ choices=['glm-embedding-2', 'sentence-transformers/all-MiniLM-L6-v2',
95
+ 'intfloat/multilingual-e5-large'],
96
+ value="sentence-transformers/all-MiniLM-L6-v2",
97
+ label="embedding model"
98
+ )
99
+ data_file = gr.File(file_count='single', label="data pdf file")
100
+ with gr.Column():
101
+ output = gr.TextArea(label="answer")
102
+ model_name.select(
103
+ self._init_chat_model,
104
+ inputs=[model_name, api_key]
105
+ )
106
+ embedding_model.select(
107
+ self._init_embedding,
108
+ inputs=[embedding_model, embedding_api_key]
109
+ )
110
+ data_file.upload(self._build_vector_db, inputs=data_file)
111
+ submit_btn.click(
112
+ self._retrieval_qa,
113
+ inputs=input_text,
114
+ outputs=output,
115
+ )
116
+ return demo
117
+
118
+
119
+ app = RAGDemo()
120
+ app().launch()