DHEIVER commited on
Commit
d853d5a
·
verified ·
1 Parent(s): 99a69cc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +159 -91
app.py CHANGED
@@ -1,29 +1,35 @@
1
  import gradio as gr
2
  import os
3
  import torch
4
- from langchain_community.vectorstores import FAISS
5
  from langchain_community.document_loaders import PyPDFLoader
6
  from langchain.text_splitter import RecursiveCharacterTextSplitter
7
- from langchain_community.vectorstores import Chroma
8
  from langchain.chains import ConversationalRetrievalChain
9
- from langchain_community.embeddings import HuggingFaceEmbeddings
10
- from langchain_community.llms import HuggingFacePipeline
11
- from langchain.chains import ConversationChain
12
- from langchain.memory import ConversationBufferMemory
13
  from langchain_community.llms import HuggingFaceEndpoint
 
 
 
 
 
 
 
14
 
15
  api_token = os.getenv("FirstToken")
16
 
17
  # Available LLM models
18
  list_llm = [
19
- "meta-llama/Meta-Llama-3-8B-Instruct",
20
  "mistralai/Mistral-7B-Instruct-v0.2",
21
  "deepseek-ai/deepseek-llm-7b-chat"
22
- ]
23
  list_llm_simple = [os.path.basename(llm) for llm in list_llm]
24
 
 
 
 
25
  def load_doc(list_file_path):
26
- """Load and split PDF documents into chunks"""
27
  loaders = [PyPDFLoader(x) for x in list_file_path]
28
  pages = []
29
  for loader in loaders:
@@ -35,14 +41,92 @@ def load_doc(list_file_path):
35
  doc_splits = text_splitter.split_documents(pages)
36
  return doc_splits
37
 
38
- def create_db(splits):
39
- """Create vector database from document splits"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
  embeddings = HuggingFaceEmbeddings()
41
- vectordb = FAISS.from_documents(splits, embeddings)
42
- return vectordb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
 
44
- def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vector_db, progress=gr.Progress()):
45
- """Initialize the language model chain"""
 
 
 
 
 
46
  llm = HuggingFaceEndpoint(
47
  repo_id=llm_model,
48
  huggingfacehub_api_token=api_token,
@@ -51,14 +135,13 @@ def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vector_db, pr
51
  top_k=top_k,
52
  task="text-generation"
53
  )
54
-
55
  memory = ConversationBufferMemory(
56
  memory_key="chat_history",
57
  output_key='answer',
58
  return_messages=True
59
  )
60
 
61
- retriever = vector_db.as_retriever()
62
  qa_chain = ConversationalRetrievalChain.from_llm(
63
  llm,
64
  retriever=retriever,
@@ -69,35 +152,50 @@ def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vector_db, pr
69
  )
70
  return qa_chain
71
 
72
- def initialize_database(list_file_obj, progress=gr.Progress()):
73
- """Initialize the document database"""
74
- list_file_path = [x.name for x in list_file_obj if x is not None]
75
- doc_splits = load_doc(list_file_path)
76
- vector_db = create_db(doc_splits)
77
- return vector_db, "Database created successfully!"
78
-
79
- def initialize_LLM(llm_option, llm_temperature, max_tokens, top_k, vector_db, progress=gr.Progress()):
80
- """Initialize the Language Model"""
81
  llm_name = list_llm[llm_option]
82
  print("Selected LLM model:", llm_name)
83
- qa_chain = initialize_llmchain(llm_name, llm_temperature, max_tokens, top_k, vector_db, progress)
84
  return qa_chain, "Analysis Assistant initialized and ready!"
85
 
 
 
 
86
  def format_chat_history(message, chat_history):
87
- """Format chat history for the model"""
88
  formatted_chat_history = []
89
  for user_message, bot_message in chat_history:
90
  formatted_chat_history.append(f"User: {user_message}")
91
  formatted_chat_history.append(f"Assistant: {bot_message}")
92
  return formatted_chat_history
93
 
94
- def conversation(qa_chain, message, history):
95
- """Handle conversation and document analysis"""
 
 
 
 
 
 
 
 
 
 
96
  formatted_chat_history = format_chat_history(message, history)
97
  response = qa_chain.invoke({"question": message, "chat_history": formatted_chat_history})
98
  response_answer = response["answer"]
 
 
 
 
 
99
  if response_answer.find("Helpful Answer:") != -1:
100
  response_answer = response_answer.split("Helpful Answer:")[-1]
 
101
  response_sources = response["source_documents"]
102
  response_source1 = response_sources[0].page_content.strip()
103
  response_source2 = response_sources[1].page_content.strip()
@@ -106,19 +204,20 @@ def conversation(qa_chain, message, history):
106
  response_source2_page = response_sources[1].metadata["page"] + 1
107
  response_source3_page = response_sources[2].metadata["page"] + 1
108
  new_history = history + [(message, response_answer)]
109
- return qa_chain, gr.update(value=""), new_history, response_source1, response_source1_page, response_source2, response_source2_page, response_source3, response_source3_page
110
 
 
111
 
112
- # ... (código anterior permanece igual)
113
-
 
114
  def demo():
115
- """Main demo application with enhanced layout"""
116
  theme = gr.themes.Default(
117
  primary_hue="indigo",
118
  secondary_hue="blue",
119
  neutral_hue="slate",
120
  )
121
-
122
  # Custom CSS for advanced layout
123
  custom_css = """
124
  .container {background: #ffffff; padding: 1rem; border-radius: 8px; box-shadow: 0 1px 3px rgba(0,0,0,0.1);}
@@ -129,12 +228,12 @@ def demo():
129
  .control-panel {margin-bottom: 1rem;}
130
  .chat-area {background: white; padding: 1rem; border-radius: 8px;}
131
  """
132
-
133
  with gr.Blocks(theme=theme, css=custom_css) as demo:
134
- vector_db = gr.State()
135
  qa_chain = gr.State()
136
- language = gr.State(value="en") # Novo estado para controle de idioma
137
-
138
  # Header
139
  gr.HTML(
140
  """
@@ -144,12 +243,12 @@ def demo():
144
  </div>
145
  """
146
  )
147
-
148
  with gr.Row():
149
  # Left Column - Controls
150
  with gr.Column(scale=1):
151
  gr.Markdown("## Document Processing")
152
-
153
  # File Upload Section
154
  with gr.Column(elem_classes="section"):
155
  gr.Markdown("### 📄 Upload Documents")
@@ -163,7 +262,7 @@ def demo():
163
  value="Ready for documents",
164
  label="Processing Status"
165
  )
166
-
167
  # Model Selection Section
168
  with gr.Column(elem_classes="section"):
169
  gr.Markdown("### 🤖 Model Configuration")
@@ -173,15 +272,15 @@ def demo():
173
  value=list_llm_simple[0],
174
  type="index"
175
  )
176
-
177
- # Novo botão de seleção de idioma
178
  language_btn = gr.Radio(
179
  choices=["English", "Português"],
180
  label="Response Language",
181
  value="English",
182
  type="value"
183
  )
184
-
185
  with gr.Accordion("Advanced Settings", open=False):
186
  slider_temperature = gr.Slider(
187
  minimum=0.01,
@@ -204,17 +303,17 @@ def demo():
204
  step=1,
205
  label="Analysis Diversity"
206
  )
207
-
208
  qachain_btn = gr.Button("Initialize Assistant")
209
  llm_progress = gr.Textbox(
210
  value="Not initialized",
211
  label="Assistant Status"
212
  )
213
-
214
  # Right Column - Chat Interface
215
  with gr.Column(scale=2):
216
  gr.Markdown("## Interactive Analysis")
217
-
218
  # Features Section
219
  with gr.Row():
220
  with gr.Column():
@@ -235,7 +334,7 @@ def demo():
235
  - Specify standards
236
  """
237
  )
238
-
239
  # Chat Interface
240
  with gr.Column(elem_classes="chat-area"):
241
  chatbot = gr.Chatbot(
@@ -252,7 +351,7 @@ def demo():
252
  [msg, chatbot],
253
  value="Clear"
254
  )
255
-
256
  # References Section
257
  with gr.Accordion("Document References", open=False):
258
  with gr.Row():
@@ -271,10 +370,10 @@ def demo():
271
  """
272
  ---
273
  ### About MetroAssist AI
274
-
275
- A specialized tool for metrology professionals, providing advanced analysis
276
  of calibration certificates, measurement data, and technical standards compliance.
277
-
278
  **Version 1.0** | © 2024 MetroAssist AI
279
  """
280
  )
@@ -285,16 +384,16 @@ def demo():
285
  inputs=language_btn,
286
  outputs=language
287
  )
288
-
289
  db_btn.click(
290
  initialize_database,
291
  inputs=[document],
292
- outputs=[vector_db, db_progress]
293
  )
294
-
295
  qachain_btn.click(
296
  initialize_LLM,
297
- inputs=[llm_btn, slider_temperature, slider_maxtokens, slider_topk, vector_db],
298
  outputs=[qa_chain, llm_progress]
299
  ).then(
300
  lambda: [None, "", 0, "", 0, "", 0],
@@ -309,14 +408,14 @@ def demo():
309
  outputs=[qa_chain, msg, chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page],
310
  queue=False
311
  )
312
-
313
  submit_btn.click(
314
  conversation,
315
  inputs=[qa_chain, msg, chatbot, language],
316
  outputs=[qa_chain, msg, chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page],
317
  queue=False
318
  )
319
-
320
  clear_btn.click(
321
  lambda: [None, "", 0, "", 0, "", 0],
322
  inputs=None,
@@ -326,36 +425,5 @@ def demo():
326
 
327
  demo.queue().launch(debug=True)
328
 
329
- # Modifique a função de conversão para incluir o idioma
330
- def conversation(qa_chain, message, history, lang):
331
- """Handle conversation and document analysis"""
332
- # Adicione instrução de idioma à mensagem
333
- if lang == "pt":
334
- message += " (Responda em Português)"
335
- else:
336
- message += " (Respond in English)"
337
-
338
- formatted_chat_history = format_chat_history(message, history)
339
- response = qa_chain.invoke({"question": message, "chat_history": formatted_chat_history})
340
- response_answer = response["answer"]
341
-
342
- # Remova a instrução de idioma do histórico do chat
343
- if "(Respond" in message:
344
- message = message.split(" (Respond")[0]
345
-
346
- if response_answer.find("Helpful Answer:") != -1:
347
- response_answer = response_answer.split("Helpful Answer:")[-1]
348
-
349
- response_sources = response["source_documents"]
350
- response_source1 = response_sources[0].page_content.strip()
351
- response_source2 = response_sources[1].page_content.strip()
352
- response_source3 = response_sources[2].page_content.strip()
353
- response_source1_page = response_sources[0].metadata["page"] + 1
354
- response_source2_page = response_sources[1].metadata["page"] + 1
355
- response_source3_page = response_sources[2].metadata["page"] + 1
356
- new_history = history + [(message, response_answer)]
357
-
358
- return qa_chain, gr.update(value=""), new_history, response_source1, response_source1_page, response_source2, response_source2_page, response_source3, response_source3_page
359
-
360
  if __name__ == "__main__":
361
- demo()
 
1
  import gradio as gr
2
  import os
3
  import torch
4
+ from langchain_community.vectorstores import FAISS, Chroma
5
  from langchain_community.document_loaders import PyPDFLoader
6
  from langchain.text_splitter import RecursiveCharacterTextSplitter
 
7
  from langchain.chains import ConversationalRetrievalChain
8
+ from langchain_community.embeddings import HuggingFaceEmbeddings
 
 
 
9
  from langchain_community.llms import HuggingFaceEndpoint
10
+ from langchain.memory import ConversationBufferMemory
11
+ from langchain.retrievers import BM25Retriever, EnsembleRetriever
12
+ from langchain.chains.query_constructor.base import AttributeInfo
13
+ from langchain.chains import create_query_chain
14
+ from langchain.retrievers.self_query.base import SelfQueryRetriever
15
+ from langchain.chains.query_constructor.schema import FieldInfo
16
+ from langchain.retrievers.multi_query import MultiQueryRetriever
17
 
18
  api_token = os.getenv("FirstToken")
19
 
20
  # Available LLM models
21
  list_llm = [
22
+ "meta-llama/Meta-Llama-3-8B-Instruct",
23
  "mistralai/Mistral-7B-Instruct-v0.2",
24
  "deepseek-ai/deepseek-llm-7b-chat"
25
+ ]
26
  list_llm_simple = [os.path.basename(llm) for llm in list_llm]
27
 
28
+ # -----------------------------------------------------------------------------
29
+ # Document Loading and Splitting
30
+ # -----------------------------------------------------------------------------
31
  def load_doc(list_file_path):
32
+ """Load and split PDF documents into chunks."""
33
  loaders = [PyPDFLoader(x) for x in list_file_path]
34
  pages = []
35
  for loader in loaders:
 
41
  doc_splits = text_splitter.split_documents(pages)
42
  return doc_splits
43
 
44
+ # -----------------------------------------------------------------------------
45
+ # Vector Database Creation (ChromaDB and FAISS)
46
+ # -----------------------------------------------------------------------------
47
+ def create_chromadb(splits, persist_directory="chroma_db"):
48
+ """Create ChromaDB vector database from document splits."""
49
+ embeddings = HuggingFaceEmbeddings()
50
+ chromadb = Chroma.from_documents(
51
+ documents=splits,
52
+ embedding=embeddings,
53
+ persist_directory=persist_directory
54
+ )
55
+ chromadb.persist() # Ensure data is written to disk
56
+ return chromadb
57
+
58
+ def create_faissdb(splits):
59
+ """Create FAISS vector database from document splits."""
60
  embeddings = HuggingFaceEmbeddings()
61
+ faissdb = FAISS.from_documents(splits, embeddings)
62
+ return faissdb
63
+
64
+ # -----------------------------------------------------------------------------
65
+ # BM25 Retriever
66
+ # -----------------------------------------------------------------------------
67
+ def create_bm25_retriever(splits):
68
+ """Create BM25 retriever from document splits."""
69
+ bm25_retriever = BM25Retriever.from_documents(splits)
70
+ bm25_retriever.k = 3 # Number of documents to retrieve
71
+ return bm25_retriever
72
+
73
+ # -----------------------------------------------------------------------------
74
+ # MultiQueryRetriever
75
+ # -----------------------------------------------------------------------------
76
+ def create_multi_query_retriever(llm, vector_db, num_queries=3):
77
+ """
78
+ Create a MultiQueryRetriever.
79
+
80
+ Args:
81
+ llm: The language model to use for query generation.
82
+ vector_db: The vector database to retrieve from.
83
+ num_queries: The number of diverse queries to generate.
84
+
85
+ Returns:
86
+ A MultiQueryRetriever instance.
87
+ """
88
+ retriever = MultiQueryRetriever.from_llm(
89
+ llm=llm, retriever=vector_db.as_retriever(),
90
+ output_key="answer",
91
+ memory_key="chat_history",
92
+ return_messages=True,
93
+ verbose=False
94
+ )
95
+ return retriever
96
+
97
+ # -----------------------------------------------------------------------------
98
+ # Ensemble Retriever (Combine VectorDB and BM25)
99
+ # -----------------------------------------------------------------------------
100
+ def create_ensemble_retriever(vector_db, bm25_retriever):
101
+ """Create an ensemble retriever combining ChromaDB and BM25."""
102
+ ensemble_retriever = EnsembleRetriever(
103
+ retrievers=[vector_db.as_retriever(), bm25_retriever],
104
+ weights=[0.7, 0.3] # Adjust weights as needed
105
+ )
106
+ return ensemble_retriever
107
+
108
+ # -----------------------------------------------------------------------------
109
+ # Initialize Database
110
+ # -----------------------------------------------------------------------------
111
+ def initialize_database(list_file_obj, progress=gr.Progress()):
112
+ """Initialize the document database."""
113
+ list_file_path = [x.name for x in list_file_obj if x is not None]
114
+ doc_splits = load_doc(list_file_path)
115
+
116
+ # Create vector databases and retrievers
117
+ chromadb = create_chromadb(doc_splits)
118
+ bm25_retriever = create_bm25_retriever(doc_splits)
119
+
120
+ # Create ensemble retriever
121
+ ensemble_retriever = create_ensemble_retriever(chromadb, bm25_retriever)
122
 
123
+ return ensemble_retriever, "Database created successfully!"
124
+
125
+ # -----------------------------------------------------------------------------
126
+ # Initialize LLM Chain
127
+ # -----------------------------------------------------------------------------
128
+ def initialize_llmchain(llm_model, temperature, max_tokens, top_k, retriever, progress=gr.Progress()):
129
+ """Initialize the language model chain."""
130
  llm = HuggingFaceEndpoint(
131
  repo_id=llm_model,
132
  huggingfacehub_api_token=api_token,
 
135
  top_k=top_k,
136
  task="text-generation"
137
  )
138
+
139
  memory = ConversationBufferMemory(
140
  memory_key="chat_history",
141
  output_key='answer',
142
  return_messages=True
143
  )
144
 
 
145
  qa_chain = ConversationalRetrievalChain.from_llm(
146
  llm,
147
  retriever=retriever,
 
152
  )
153
  return qa_chain
154
 
155
+ # -----------------------------------------------------------------------------
156
+ # Initialize LLM
157
+ # -----------------------------------------------------------------------------
158
+ def initialize_LLM(llm_option, llm_temperature, max_tokens, top_k, retriever, progress=gr.Progress()):
159
+ """Initialize the Language Model."""
 
 
 
 
160
  llm_name = list_llm[llm_option]
161
  print("Selected LLM model:", llm_name)
162
+ qa_chain = initialize_llmchain(llm_name, llm_temperature, max_tokens, top_k, retriever, progress)
163
  return qa_chain, "Analysis Assistant initialized and ready!"
164
 
165
+ # -----------------------------------------------------------------------------
166
+ # Chat History Formatting
167
+ # -----------------------------------------------------------------------------
168
  def format_chat_history(message, chat_history):
169
+ """Format chat history for the model."""
170
  formatted_chat_history = []
171
  for user_message, bot_message in chat_history:
172
  formatted_chat_history.append(f"User: {user_message}")
173
  formatted_chat_history.append(f"Assistant: {bot_message}")
174
  return formatted_chat_history
175
 
176
+ # -----------------------------------------------------------------------------
177
+ # Conversation Function
178
+ # -----------------------------------------------------------------------------
179
+ def conversation(qa_chain, message, history, lang):
180
+ """Handle conversation and document analysis."""
181
+
182
+ # Add language instruction to the message
183
+ if lang == "pt":
184
+ message += " (Responda em Português)"
185
+ else:
186
+ message += " (Respond in English)"
187
+
188
  formatted_chat_history = format_chat_history(message, history)
189
  response = qa_chain.invoke({"question": message, "chat_history": formatted_chat_history})
190
  response_answer = response["answer"]
191
+
192
+ # Remove the language instruction from the chat history
193
+ if "(Respond" in message:
194
+ message = message.split(" (Respond")[0]
195
+
196
  if response_answer.find("Helpful Answer:") != -1:
197
  response_answer = response_answer.split("Helpful Answer:")[-1]
198
+
199
  response_sources = response["source_documents"]
200
  response_source1 = response_sources[0].page_content.strip()
201
  response_source2 = response_sources[1].page_content.strip()
 
204
  response_source2_page = response_sources[1].metadata["page"] + 1
205
  response_source3_page = response_sources[2].metadata["page"] + 1
206
  new_history = history + [(message, response_answer)]
 
207
 
208
+ return qa_chain, gr.update(value=""), new_history, response_source1, response_source1_page, response_source2, response_source2_page, response_source3, response_source3_page
209
 
210
+ # -----------------------------------------------------------------------------
211
+ # Gradio Demo
212
+ # -----------------------------------------------------------------------------
213
  def demo():
214
+ """Main demo application with enhanced layout."""
215
  theme = gr.themes.Default(
216
  primary_hue="indigo",
217
  secondary_hue="blue",
218
  neutral_hue="slate",
219
  )
220
+
221
  # Custom CSS for advanced layout
222
  custom_css = """
223
  .container {background: #ffffff; padding: 1rem; border-radius: 8px; box-shadow: 0 1px 3px rgba(0,0,0,0.1);}
 
228
  .control-panel {margin-bottom: 1rem;}
229
  .chat-area {background: white; padding: 1rem; border-radius: 8px;}
230
  """
231
+
232
  with gr.Blocks(theme=theme, css=custom_css) as demo:
233
+ retriever = gr.State()
234
  qa_chain = gr.State()
235
+ language = gr.State(value="en") # State for language control
236
+
237
  # Header
238
  gr.HTML(
239
  """
 
243
  </div>
244
  """
245
  )
246
+
247
  with gr.Row():
248
  # Left Column - Controls
249
  with gr.Column(scale=1):
250
  gr.Markdown("## Document Processing")
251
+
252
  # File Upload Section
253
  with gr.Column(elem_classes="section"):
254
  gr.Markdown("### 📄 Upload Documents")
 
262
  value="Ready for documents",
263
  label="Processing Status"
264
  )
265
+
266
  # Model Selection Section
267
  with gr.Column(elem_classes="section"):
268
  gr.Markdown("### 🤖 Model Configuration")
 
272
  value=list_llm_simple[0],
273
  type="index"
274
  )
275
+
276
+ # Language selection button
277
  language_btn = gr.Radio(
278
  choices=["English", "Português"],
279
  label="Response Language",
280
  value="English",
281
  type="value"
282
  )
283
+
284
  with gr.Accordion("Advanced Settings", open=False):
285
  slider_temperature = gr.Slider(
286
  minimum=0.01,
 
303
  step=1,
304
  label="Analysis Diversity"
305
  )
306
+
307
  qachain_btn = gr.Button("Initialize Assistant")
308
  llm_progress = gr.Textbox(
309
  value="Not initialized",
310
  label="Assistant Status"
311
  )
312
+
313
  # Right Column - Chat Interface
314
  with gr.Column(scale=2):
315
  gr.Markdown("## Interactive Analysis")
316
+
317
  # Features Section
318
  with gr.Row():
319
  with gr.Column():
 
334
  - Specify standards
335
  """
336
  )
337
+
338
  # Chat Interface
339
  with gr.Column(elem_classes="chat-area"):
340
  chatbot = gr.Chatbot(
 
351
  [msg, chatbot],
352
  value="Clear"
353
  )
354
+
355
  # References Section
356
  with gr.Accordion("Document References", open=False):
357
  with gr.Row():
 
370
  """
371
  ---
372
  ### About MetroAssist AI
373
+
374
+ A specialized tool for metrology professionals, providing advanced analysis
375
  of calibration certificates, measurement data, and technical standards compliance.
376
+
377
  **Version 1.0** | © 2024 MetroAssist AI
378
  """
379
  )
 
384
  inputs=language_btn,
385
  outputs=language
386
  )
387
+
388
  db_btn.click(
389
  initialize_database,
390
  inputs=[document],
391
+ outputs=[retriever, db_progress]
392
  )
393
+
394
  qachain_btn.click(
395
  initialize_LLM,
396
+ inputs=[llm_btn, slider_temperature, slider_maxtokens, slider_topk, retriever],
397
  outputs=[qa_chain, llm_progress]
398
  ).then(
399
  lambda: [None, "", 0, "", 0, "", 0],
 
408
  outputs=[qa_chain, msg, chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page],
409
  queue=False
410
  )
411
+
412
  submit_btn.click(
413
  conversation,
414
  inputs=[qa_chain, msg, chatbot, language],
415
  outputs=[qa_chain, msg, chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page],
416
  queue=False
417
  )
418
+
419
  clear_btn.click(
420
  lambda: [None, "", 0, "", 0, "", 0],
421
  inputs=None,
 
425
 
426
  demo.queue().launch(debug=True)
427
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
428
  if __name__ == "__main__":
429
+ demo()