Fecalisboa commited on
Commit
8dabaa3
·
verified ·
1 Parent(s): ce1efe0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +35 -22
app.py CHANGED
@@ -20,7 +20,7 @@ api_token = os.getenv("HF_TOKEN")
20
 
21
 
22
 
23
- list_llm = ["meta-llama/Meta-Llama-3-8B-Instruct", "mistralai/Mistral-7B-Instruct-v0.3","google/flan-t5-base"]
24
  list_llm_simple = [os.path.basename(llm) for llm in list_llm]
25
 
26
  # Load PDF document and create doc splits
@@ -34,15 +34,36 @@ def load_doc(list_file_path, chunk_size, chunk_overlap):
34
  return doc_splits
35
 
36
  # Create vector database
37
- def create_db(splits, collection_name):
38
  embedding = HuggingFaceEmbeddings()
39
- new_client = chromadb.EphemeralClient()
40
- vectordb = Chroma.from_documents(
41
- documents=splits,
42
- embedding=embedding,
43
- client=new_client,
44
- collection_name=collection_name,
45
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
  return vectordb
47
 
48
  # Load vector database
@@ -67,14 +88,7 @@ def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vector_db, pr
67
  max_new_tokens=max_tokens,
68
  top_k=top_k,
69
  )
70
- elif llm_model == "mistralai/Mistral-7B-Instruct-v0.3":
71
- llm = HuggingFaceEndpoint(
72
- repo_id=llm_model,
73
- huggingfacehub_api_token=api_token,
74
- temperature=temperature,
75
- max_new_tokens=max_tokens,
76
- top_k=top_k,
77
- )
78
  else:
79
 
80
  llm = HuggingFaceEndpoint(
@@ -122,14 +136,14 @@ def create_collection_name(filepath):
122
  return collection_name
123
 
124
  # Initialize database
125
- def initialize_database(list_file_obj, chunk_size, chunk_overlap, progress=gr.Progress()):
126
  list_file_path = [x.name for x in list_file_obj if x is not None]
127
  progress(0.1, desc="Creating collection name...")
128
  collection_name = create_collection_name(list_file_path[0])
129
  progress(0.25, desc="Loading document...")
130
  doc_splits = load_doc(list_file_path, chunk_size, chunk_overlap)
131
  progress(0.5, desc="Generating vector database...")
132
- vector_db = create_db(doc_splits, collection_name)
133
  progress(0.9, desc="Done!")
134
  return vector_db, collection_name, "Complete!"
135
 
@@ -190,7 +204,7 @@ def demo():
190
 
191
  with gr.Tab("Step 2 - Process document"):
192
  with gr.Row():
193
- db_btn = gr.Radio(["ChromaDB"], label="Vector database type", value="ChromaDB", type="index", info="Choose your vector database")
194
  with gr.Accordion("Advanced options - Document text splitter", open=False):
195
  with gr.Row():
196
  slider_chunk_size = gr.Slider(minimum=100, maximum=1000, value=600, step=20, label="Chunk size", info="Chunk size", interactive=True)
@@ -237,7 +251,7 @@ def demo():
237
 
238
  # Preprocessing events
239
  db_btn.click(initialize_database,
240
- inputs=[document, slider_chunk_size, slider_chunk_overlap],
241
  outputs=[vector_db, collection_name, db_progress])
242
  qachain_btn.click(initialize_LLM,
243
  inputs=[llm_btn, slider_temperature, slider_maxtokens, slider_topk, vector_db],
@@ -261,6 +275,5 @@ def demo():
261
  queue=False)
262
  demo.queue().launch(debug=True)
263
 
264
-
265
  if __name__ == "__main__":
266
  demo()
 
20
 
21
 
22
 
23
+ list_llm = ["meta-llama/Meta-Llama-3-8B-Instruct", "mistralai/Mistral-7B-Instruct-v0.3"]
24
  list_llm_simple = [os.path.basename(llm) for llm in list_llm]
25
 
26
  # Load PDF document and create doc splits
 
34
  return doc_splits
35
 
36
  # Create vector database
37
+ def create_db(splits, collection_name, db_type):
38
  embedding = HuggingFaceEmbeddings()
39
+
40
+ if db_type == "ChromaDB":
41
+ new_client = chromadb.EphemeralClient()
42
+ vectordb = Chroma.from_documents(
43
+ documents=splits,
44
+ embedding=embedding,
45
+ client=new_client,
46
+ collection_name=collection_name,
47
+ )
48
+ elif db_type == "FAISS":
49
+ vectordb = FAISS.from_documents(
50
+ documents=splits,
51
+ embedding=embedding
52
+ )
53
+ elif db_type == "ScaNN":
54
+ vectordb = ScaNN.from_documents(
55
+ documents=splits,
56
+ embedding=embedding
57
+ )
58
+ elif db_type == "Milvus":
59
+ vectordb = Milvus.from_documents(
60
+ documents=splits,
61
+ embedding=embedding,
62
+ collection_name=collection_name,
63
+ )
64
+ else:
65
+ raise ValueError(f"Unsupported vector database type: {db_type}")
66
+
67
  return vectordb
68
 
69
  # Load vector database
 
88
  max_new_tokens=max_tokens,
89
  top_k=top_k,
90
  )
91
+
 
 
 
 
 
 
 
92
  else:
93
 
94
  llm = HuggingFaceEndpoint(
 
136
  return collection_name
137
 
138
  # Initialize database
139
+ def initialize_database(list_file_obj, chunk_size, chunk_overlap, db_type, progress=gr.Progress()):
140
  list_file_path = [x.name for x in list_file_obj if x is not None]
141
  progress(0.1, desc="Creating collection name...")
142
  collection_name = create_collection_name(list_file_path[0])
143
  progress(0.25, desc="Loading document...")
144
  doc_splits = load_doc(list_file_path, chunk_size, chunk_overlap)
145
  progress(0.5, desc="Generating vector database...")
146
+ vector_db = create_db(doc_splits, collection_name, db_type)
147
  progress(0.9, desc="Done!")
148
  return vector_db, collection_name, "Complete!"
149
 
 
204
 
205
  with gr.Tab("Step 2 - Process document"):
206
  with gr.Row():
207
+ db_btn = gr.Radio(["ChromaDB", "FAISS", "ScaNN", "Milvus"], label="Vector database type", value="ChromaDB", type="index", info="Choose your vector database")
208
  with gr.Accordion("Advanced options - Document text splitter", open=False):
209
  with gr.Row():
210
  slider_chunk_size = gr.Slider(minimum=100, maximum=1000, value=600, step=20, label="Chunk size", info="Chunk size", interactive=True)
 
251
 
252
  # Preprocessing events
253
  db_btn.click(initialize_database,
254
+ inputs=[document, slider_chunk_size, slider_chunk_overlap, db_btn],
255
  outputs=[vector_db, collection_name, db_progress])
256
  qachain_btn.click(initialize_LLM,
257
  inputs=[llm_btn, slider_temperature, slider_maxtokens, slider_topk, vector_db],
 
275
  queue=False)
276
  demo.queue().launch(debug=True)
277
 
 
278
  if __name__ == "__main__":
279
  demo()