modify demo layout
Browse files
app.py
CHANGED
@@ -4,7 +4,7 @@ from langchain.chains import RetrievalQA
|
|
4 |
from langchain.text_splitter import SpacyTextSplitter
|
5 |
from langchain_community.chat_models import ChatZhipuAI, ChatGooglePalm
|
6 |
from langchain_community.document_loaders import PyPDFLoader
|
7 |
-
from langchain_community.embeddings import HuggingFaceInferenceAPIEmbeddings
|
8 |
from langchain_community.vectorstores import Chroma
|
9 |
from langchain_core.prompts import PromptTemplate
|
10 |
|
@@ -44,6 +44,8 @@ class RAGDemo(object):
|
|
44 |
return
|
45 |
if 'glm' in embedding_model_name:
|
46 |
gradio.Error("GLM is not supported yet.")
|
|
|
|
|
47 |
else:
|
48 |
self.embedding = HuggingFaceInferenceAPIEmbeddings(
|
49 |
api_key=api_key, model_name=embedding_model_name
|
@@ -53,7 +55,6 @@ class RAGDemo(object):
|
|
53 |
if not file_path:
|
54 |
gradio.Error("Please enter vector database file path.")
|
55 |
return
|
56 |
-
gr.Info("Building vector database...")
|
57 |
loader = PyPDFLoader(file_path)
|
58 |
pages = loader.load()
|
59 |
|
@@ -63,7 +64,11 @@ class RAGDemo(object):
|
|
63 |
self.vector_db = Chroma.from_documents(
|
64 |
documents=docs, embedding=self.embedding
|
65 |
)
|
66 |
-
|
|
|
|
|
|
|
|
|
67 |
|
68 |
def _retrieval_qa(self, input_text):
|
69 |
basic_qa = RetrievalQA.from_chain_type(
|
@@ -77,37 +82,37 @@ class RAGDemo(object):
|
|
77 |
with gr.Blocks() as demo:
|
78 |
gr.Markdown("# RAG Demo\n\nbase on the [RAG learning note](https://www.jianshu.com/p/9792f1e6c3f9) and "
|
79 |
"[rag-practice](https://github.com/hiwei93/rag-practice/tree/main)")
|
80 |
-
with gr.
|
81 |
-
with gr.
|
82 |
-
|
83 |
-
submit_btn = gr.Button("submit")
|
84 |
-
with gr.Accordion("model settings"):
|
85 |
-
api_key = gr.Textbox(placeholder="your api key", label="api key")
|
86 |
model_name = gr.Dropdown(
|
87 |
choices=['glm-3-turbo', 'gemini-1.0-pro'],
|
88 |
value='glm-3-turbo',
|
89 |
label="model"
|
90 |
)
|
91 |
-
|
92 |
-
embedding_api_key = gr.Textbox(placeholder="your api key", label="embedding api key")
|
93 |
embedding_model = gr.Dropdown(
|
94 |
-
choices=['glm-embedding-2', 'sentence-transformers/all-MiniLM-L6-v2',
|
95 |
'intfloat/multilingual-e5-large'],
|
96 |
value="sentence-transformers/all-MiniLM-L6-v2",
|
97 |
label="embedding model"
|
98 |
)
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
|
|
|
|
|
|
|
|
109 |
)
|
110 |
-
|
111 |
submit_btn.click(
|
112 |
self._retrieval_qa,
|
113 |
inputs=input_text,
|
@@ -117,4 +122,4 @@ class RAGDemo(object):
|
|
117 |
|
118 |
|
119 |
app = RAGDemo()
|
120 |
-
app().launch()
|
|
|
4 |
from langchain.text_splitter import SpacyTextSplitter
|
5 |
from langchain_community.chat_models import ChatZhipuAI, ChatGooglePalm
|
6 |
from langchain_community.document_loaders import PyPDFLoader
|
7 |
+
from langchain_community.embeddings import HuggingFaceInferenceAPIEmbeddings, GooglePalmEmbeddings
|
8 |
from langchain_community.vectorstores import Chroma
|
9 |
from langchain_core.prompts import PromptTemplate
|
10 |
|
|
|
44 |
return
|
45 |
if 'glm' in embedding_model_name:
|
46 |
gradio.Error("GLM is not supported yet.")
|
47 |
+
elif 'gemini' in embedding_model_name:
|
48 |
+
self.embedding = GooglePalmEmbeddings(google_api_key=api_key, show_progress_bar=True)
|
49 |
else:
|
50 |
self.embedding = HuggingFaceInferenceAPIEmbeddings(
|
51 |
api_key=api_key, model_name=embedding_model_name
|
|
|
55 |
if not file_path:
|
56 |
gradio.Error("Please enter vector database file path.")
|
57 |
return
|
|
|
58 |
loader = PyPDFLoader(file_path)
|
59 |
pages = loader.load()
|
60 |
|
|
|
64 |
self.vector_db = Chroma.from_documents(
|
65 |
documents=docs, embedding=self.embedding
|
66 |
)
|
67 |
+
|
68 |
+
def _init_settings(self, model_name, api_key, embedding_model, embedding_api_key, data_file):
|
69 |
+
self._init_chat_model(model_name, api_key)
|
70 |
+
self._init_embedding(embedding_model, embedding_api_key)
|
71 |
+
self._build_vector_db(data_file)
|
72 |
|
73 |
def _retrieval_qa(self, input_text):
|
74 |
basic_qa = RetrievalQA.from_chain_type(
|
|
|
82 |
with gr.Blocks() as demo:
|
83 |
gr.Markdown("# RAG Demo\n\nbase on the [RAG learning note](https://www.jianshu.com/p/9792f1e6c3f9) and "
|
84 |
"[rag-practice](https://github.com/hiwei93/rag-practice/tree/main)")
|
85 |
+
with gr.Tab("Settings"):
|
86 |
+
with gr.Row():
|
87 |
+
with gr.Column():
|
|
|
|
|
|
|
88 |
model_name = gr.Dropdown(
|
89 |
choices=['glm-3-turbo', 'gemini-1.0-pro'],
|
90 |
value='glm-3-turbo',
|
91 |
label="model"
|
92 |
)
|
93 |
+
api_key = gr.Textbox(placeholder="your api key for LLM", label="api key")
|
|
|
94 |
embedding_model = gr.Dropdown(
|
95 |
+
choices=['glm-embedding-2', 'gemini-embedding', 'sentence-transformers/all-MiniLM-L6-v2',
|
96 |
'intfloat/multilingual-e5-large'],
|
97 |
value="sentence-transformers/all-MiniLM-L6-v2",
|
98 |
label="embedding model"
|
99 |
)
|
100 |
+
embedding_api_key = gr.Textbox(placeholder="your api key for embedding", label="embedding api key")
|
101 |
+
with gr.Column():
|
102 |
+
data_file = gr.File(file_count='single', label="pdf file")
|
103 |
+
initial_btn = gr.Button("submit")
|
104 |
+
with gr.Tab("RAG"):
|
105 |
+
with gr.Row():
|
106 |
+
with gr.Column():
|
107 |
+
input_text = gr.Textbox(placeholder="input your question...", label="input")
|
108 |
+
submit_btn = gr.Button("submit")
|
109 |
+
with gr.Column():
|
110 |
+
output = gr.TextArea(label="answer")
|
111 |
+
initial_btn.click(
|
112 |
+
self._init_settings,
|
113 |
+
inputs=[model_name, api_key, embedding_model, embedding_api_key, data_file]
|
114 |
)
|
115 |
+
|
116 |
submit_btn.click(
|
117 |
self._retrieval_qa,
|
118 |
inputs=input_text,
|
|
|
122 |
|
123 |
|
124 |
app = RAGDemo()
|
125 |
+
app().launch(debug=True)
|