hiwei commited on
Commit
83588c4
·
verified ·
1 Parent(s): a7c7b3c

modify demo layout

Browse files
Files changed (1) hide show
  1. app.py +29 -24
app.py CHANGED
@@ -4,7 +4,7 @@ from langchain.chains import RetrievalQA
4
  from langchain.text_splitter import SpacyTextSplitter
5
  from langchain_community.chat_models import ChatZhipuAI, ChatGooglePalm
6
  from langchain_community.document_loaders import PyPDFLoader
7
- from langchain_community.embeddings import HuggingFaceInferenceAPIEmbeddings
8
  from langchain_community.vectorstores import Chroma
9
  from langchain_core.prompts import PromptTemplate
10
 
@@ -44,6 +44,8 @@ class RAGDemo(object):
44
  return
45
  if 'glm' in embedding_model_name:
46
  gradio.Error("GLM is not supported yet.")
 
 
47
  else:
48
  self.embedding = HuggingFaceInferenceAPIEmbeddings(
49
  api_key=api_key, model_name=embedding_model_name
@@ -53,7 +55,6 @@ class RAGDemo(object):
53
  if not file_path:
54
  gradio.Error("Please enter vector database file path.")
55
  return
56
- gr.Info("Building vector database...")
57
  loader = PyPDFLoader(file_path)
58
  pages = loader.load()
59
 
@@ -63,7 +64,11 @@ class RAGDemo(object):
63
  self.vector_db = Chroma.from_documents(
64
  documents=docs, embedding=self.embedding
65
  )
66
- gr.Info("Vector database built.")
 
 
 
 
67
 
68
  def _retrieval_qa(self, input_text):
69
  basic_qa = RetrievalQA.from_chain_type(
@@ -77,37 +82,37 @@ class RAGDemo(object):
77
  with gr.Blocks() as demo:
78
  gr.Markdown("# RAG Demo\n\nbase on the [RAG learning note](https://www.jianshu.com/p/9792f1e6c3f9) and "
79
  "[rag-practice](https://github.com/hiwei93/rag-practice/tree/main)")
80
- with gr.Row():
81
- with gr.Column():
82
- input_text = gr.Textbox(placeholder="input your question...", label="input")
83
- submit_btn = gr.Button("submit")
84
- with gr.Accordion("model settings"):
85
- api_key = gr.Textbox(placeholder="your api key", label="api key")
86
  model_name = gr.Dropdown(
87
  choices=['glm-3-turbo', 'gemini-1.0-pro'],
88
  value='glm-3-turbo',
89
  label="model"
90
  )
91
- with gr.Accordion("knowledge base settigns"):
92
- embedding_api_key = gr.Textbox(placeholder="your api key", label="embedding api key")
93
  embedding_model = gr.Dropdown(
94
- choices=['glm-embedding-2', 'sentence-transformers/all-MiniLM-L6-v2',
95
  'intfloat/multilingual-e5-large'],
96
  value="sentence-transformers/all-MiniLM-L6-v2",
97
  label="embedding model"
98
  )
99
- data_file = gr.File(file_count='single', label="data pdf file")
100
- with gr.Column():
101
- output = gr.TextArea(label="answer")
102
- model_name.select(
103
- self._init_chat_model,
104
- inputs=[model_name, api_key]
105
- )
106
- embedding_model.select(
107
- self._init_embedding,
108
- inputs=[embedding_model, embedding_api_key]
 
 
 
 
109
  )
110
- data_file.upload(self._build_vector_db, inputs=data_file)
111
  submit_btn.click(
112
  self._retrieval_qa,
113
  inputs=input_text,
@@ -117,4 +122,4 @@ class RAGDemo(object):
117
 
118
 
119
  app = RAGDemo()
120
- app().launch()
 
4
  from langchain.text_splitter import SpacyTextSplitter
5
  from langchain_community.chat_models import ChatZhipuAI, ChatGooglePalm
6
  from langchain_community.document_loaders import PyPDFLoader
7
+ from langchain_community.embeddings import HuggingFaceInferenceAPIEmbeddings, GooglePalmEmbeddings
8
  from langchain_community.vectorstores import Chroma
9
  from langchain_core.prompts import PromptTemplate
10
 
 
44
  return
45
  if 'glm' in embedding_model_name:
46
  gradio.Error("GLM is not supported yet.")
47
+ elif 'gemini' in embedding_model_name:
48
+ self.embedding = GooglePalmEmbeddings(google_api_key=api_key, show_progress_bar=True)
49
  else:
50
  self.embedding = HuggingFaceInferenceAPIEmbeddings(
51
  api_key=api_key, model_name=embedding_model_name
 
55
  if not file_path:
56
  gradio.Error("Please enter vector database file path.")
57
  return
 
58
  loader = PyPDFLoader(file_path)
59
  pages = loader.load()
60
 
 
64
  self.vector_db = Chroma.from_documents(
65
  documents=docs, embedding=self.embedding
66
  )
67
+
68
+ def _init_settings(self, model_name, api_key, embedding_model, embedding_api_key, data_file):
69
+ self._init_chat_model(model_name, api_key)
70
+ self._init_embedding(embedding_model, embedding_api_key)
71
+ self._build_vector_db(data_file)
72
 
73
  def _retrieval_qa(self, input_text):
74
  basic_qa = RetrievalQA.from_chain_type(
 
82
  with gr.Blocks() as demo:
83
  gr.Markdown("# RAG Demo\n\nbase on the [RAG learning note](https://www.jianshu.com/p/9792f1e6c3f9) and "
84
  "[rag-practice](https://github.com/hiwei93/rag-practice/tree/main)")
85
+ with gr.Tab("Settings"):
86
+ with gr.Row():
87
+ with gr.Column():
 
 
 
88
  model_name = gr.Dropdown(
89
  choices=['glm-3-turbo', 'gemini-1.0-pro'],
90
  value='glm-3-turbo',
91
  label="model"
92
  )
93
+ api_key = gr.Textbox(placeholder="your api key for LLM", label="api key")
 
94
  embedding_model = gr.Dropdown(
95
+ choices=['glm-embedding-2', 'gemini-embedding', 'sentence-transformers/all-MiniLM-L6-v2',
96
  'intfloat/multilingual-e5-large'],
97
  value="sentence-transformers/all-MiniLM-L6-v2",
98
  label="embedding model"
99
  )
100
+ embedding_api_key = gr.Textbox(placeholder="your api key for embedding", label="embedding api key")
101
+ with gr.Column():
102
+ data_file = gr.File(file_count='single', label="pdf file")
103
+ initial_btn = gr.Button("submit")
104
+ with gr.Tab("RAG"):
105
+ with gr.Row():
106
+ with gr.Column():
107
+ input_text = gr.Textbox(placeholder="input your question...", label="input")
108
+ submit_btn = gr.Button("submit")
109
+ with gr.Column():
110
+ output = gr.TextArea(label="answer")
111
+ initial_btn.click(
112
+ self._init_settings,
113
+ inputs=[model_name, api_key, embedding_model, embedding_api_key, data_file]
114
  )
115
+
116
  submit_btn.click(
117
  self._retrieval_qa,
118
  inputs=input_text,
 
122
 
123
 
124
  app = RAGDemo()
125
+ app().launch(debug=True)