hiwei commited on
Commit
a591b90
·
verified ·
1 Parent(s): a9d03e7

modify gemini dependences

Browse files
Files changed (1) hide show
  1. app.py +12 -19
app.py CHANGED
@@ -2,15 +2,11 @@ import gradio
2
  import gradio as gr
3
  from langchain.chains import RetrievalQA
4
  from langchain.text_splitter import SpacyTextSplitter
5
- from langchain_community.chat_models import ChatZhipuAI, ChatGooglePalm
6
  from langchain_community.document_loaders import PyPDFLoader
7
- from langchain_community.embeddings import HuggingFaceInferenceAPIEmbeddings, GooglePalmEmbeddings
8
  from langchain_community.vectorstores import Chroma
9
  from langchain_core.prompts import PromptTemplate
10
-
11
- import spacy
12
-
13
- spacy.cli.download("en_core_web_sm")
14
 
15
  template = """Use the following pieces of context to answer the question at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer. Use three sentences maximum. Keep the answer as concise as possible. Always say "thanks for asking!" at the end of the answer.
16
  Tips: Make sure to cite your sources, and use the exact words from the context.
@@ -31,15 +27,11 @@ class RAGDemo(object):
31
  gradio.Error("Please enter model API key.")
32
  return
33
  if 'glm' in model_name:
34
- self.chat_model = ChatZhipuAI(
35
- temperature=0.5,
36
- api_key=api_key,
37
- model="glm-3-turbo",
38
- )
39
  elif 'gemini' in model_name:
40
- self.chat_model = ChatGooglePalm(
41
  google_api_key=api_key,
42
- model_name='gemini-pro'
43
  )
44
 
45
  def _init_embedding(self, embedding_model_name, api_key):
@@ -48,8 +40,6 @@ class RAGDemo(object):
48
  return
49
  if 'glm' in embedding_model_name:
50
  gradio.Error("GLM is not supported yet.")
51
- elif 'gemini' in embedding_model_name:
52
- self.embedding = GooglePalmEmbeddings(google_api_key=api_key, show_progress_bar=True)
53
  else:
54
  self.embedding = HuggingFaceInferenceAPIEmbeddings(
55
  api_key=api_key, model_name=embedding_model_name
@@ -59,6 +49,7 @@ class RAGDemo(object):
59
  if not file_path:
60
  gradio.Error("Please enter vector database file path.")
61
  return
 
62
  loader = PyPDFLoader(file_path)
63
  pages = loader.load()
64
 
@@ -68,6 +59,7 @@ class RAGDemo(object):
68
  self.vector_db = Chroma.from_documents(
69
  documents=docs, embedding=self.embedding
70
  )
 
71
 
72
  def _init_settings(self, model_name, api_key, embedding_model, embedding_api_key, data_file):
73
  self._init_chat_model(model_name, api_key)
@@ -78,7 +70,8 @@ class RAGDemo(object):
78
  basic_qa = RetrievalQA.from_chain_type(
79
  self.chat_model,
80
  retriever=self.vector_db.as_retriever(),
81
- chain_type_kwargs={"prompt": QA_CHAIN_PROMPT}
 
82
  )
83
  return basic_qa.invoke(input_text)
84
 
@@ -90,13 +83,13 @@ class RAGDemo(object):
90
  with gr.Row():
91
  with gr.Column():
92
  model_name = gr.Dropdown(
93
- choices=['glm-3-turbo', 'gemini-1.0-pro'],
94
  value='glm-3-turbo',
95
  label="model"
96
  )
97
  api_key = gr.Textbox(placeholder="your api key for LLM", label="api key")
98
  embedding_model = gr.Dropdown(
99
- choices=['glm-embedding-2', 'gemini-embedding', 'sentence-transformers/all-MiniLM-L6-v2',
100
  'intfloat/multilingual-e5-large'],
101
  value="sentence-transformers/all-MiniLM-L6-v2",
102
  label="embedding model"
@@ -122,7 +115,7 @@ class RAGDemo(object):
122
  inputs=input_text,
123
  outputs=output,
124
  )
125
- return demo
126
 
127
 
128
  app = RAGDemo()
 
2
  import gradio as gr
3
  from langchain.chains import RetrievalQA
4
  from langchain.text_splitter import SpacyTextSplitter
 
5
  from langchain_community.document_loaders import PyPDFLoader
6
+ from langchain_community.embeddings import HuggingFaceInferenceAPIEmbeddings
7
  from langchain_community.vectorstores import Chroma
8
  from langchain_core.prompts import PromptTemplate
9
+ from langchain_google_genai import ChatGoogleGenerativeAI
 
 
 
10
 
11
  template = """Use the following pieces of context to answer the question at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer. Use three sentences maximum. Keep the answer as concise as possible. Always say "thanks for asking!" at the end of the answer.
12
  Tips: Make sure to cite your sources, and use the exact words from the context.
 
27
  gradio.Error("Please enter model API key.")
28
  return
29
  if 'glm' in model_name:
30
+ gradio.Error("GLM is not supported yet.")
 
 
 
 
31
  elif 'gemini' in model_name:
32
+ self.chat_model = ChatGoogleGenerativeAI(
33
  google_api_key=api_key,
34
+ model='gemini-pro'
35
  )
36
 
37
  def _init_embedding(self, embedding_model_name, api_key):
 
40
  return
41
  if 'glm' in embedding_model_name:
42
  gradio.Error("GLM is not supported yet.")
 
 
43
  else:
44
  self.embedding = HuggingFaceInferenceAPIEmbeddings(
45
  api_key=api_key, model_name=embedding_model_name
 
49
  if not file_path:
50
  gradio.Error("Please enter vector database file path.")
51
  return
52
+ gr.Info("Building vector database...")
53
  loader = PyPDFLoader(file_path)
54
  pages = loader.load()
55
 
 
59
  self.vector_db = Chroma.from_documents(
60
  documents=docs, embedding=self.embedding
61
  )
62
+ gr.Info("Vector database built successfully.")
63
 
64
  def _init_settings(self, model_name, api_key, embedding_model, embedding_api_key, data_file):
65
  self._init_chat_model(model_name, api_key)
 
70
  basic_qa = RetrievalQA.from_chain_type(
71
  self.chat_model,
72
  retriever=self.vector_db.as_retriever(),
73
+ chain_type_kwargs={"prompt": QA_CHAIN_PROMPT},
74
+ verbose=True,
75
  )
76
  return basic_qa.invoke(input_text)
77
 
 
83
  with gr.Row():
84
  with gr.Column():
85
  model_name = gr.Dropdown(
86
+ choices=['gemini-1.0-pro'],
87
  value='glm-3-turbo',
88
  label="model"
89
  )
90
  api_key = gr.Textbox(placeholder="your api key for LLM", label="api key")
91
  embedding_model = gr.Dropdown(
92
+ choices=['sentence-transformers/all-MiniLM-L6-v2',
93
  'intfloat/multilingual-e5-large'],
94
  value="sentence-transformers/all-MiniLM-L6-v2",
95
  label="embedding model"
 
115
  inputs=input_text,
116
  outputs=output,
117
  )
118
+ return demo
119
 
120
 
121
  app = RAGDemo()