DHEIVER commited on
Commit
914f0c8
·
verified ·
1 Parent(s): 3bf61c2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +268 -71
app.py CHANGED
@@ -1,7 +1,6 @@
1
  import gradio as gr
2
  import os
3
- api_token = os.getenv("HF_TOKEN")
4
-
5
  from langchain_community.vectorstores import FAISS
6
  from langchain_community.document_loaders import PyPDFLoader
7
  from langchain.text_splitter import RecursiveCharacterTextSplitter
@@ -12,9 +11,10 @@ from langchain_community.llms import HuggingFacePipeline
12
  from langchain.chains import ConversationChain
13
  from langchain.memory import ConversationBufferMemory
14
  from langchain_community.llms import HuggingFaceEndpoint
15
- import torch
16
 
17
- # Added Deepseek model to the list
 
 
18
  list_llm = [
19
  "meta-llama/Meta-Llama-3-8B-Instruct",
20
  "mistralai/Mistral-7B-Instruct-v0.2",
@@ -22,25 +22,33 @@ list_llm = [
22
  ]
23
  list_llm_simple = [os.path.basename(llm) for llm in list_llm]
24
 
25
- # Rest of the functions remain the same until demo()
26
  def load_doc(list_file_path):
 
 
 
27
  loaders = [PyPDFLoader(x) for x in list_file_path]
28
  pages = []
29
  for loader in loaders:
30
  pages.extend(loader.load())
31
  text_splitter = RecursiveCharacterTextSplitter(
32
- chunk_size = 1024,
33
- chunk_overlap = 64
34
- )
35
  doc_splits = text_splitter.split_documents(pages)
36
  return doc_splits
37
 
38
  def create_db(splits):
 
 
 
39
  embeddings = HuggingFaceEmbeddings()
40
  vectordb = FAISS.from_documents(splits, embeddings)
41
  return vectordb
42
 
43
  def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vector_db, progress=gr.Progress()):
 
 
 
44
  llm = HuggingFaceEndpoint(
45
  repo_id=llm_model,
46
  huggingfacehub_api_token=api_token,
@@ -60,7 +68,7 @@ def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vector_db, pr
60
  qa_chain = ConversationalRetrievalChain.from_llm(
61
  llm,
62
  retriever=retriever,
63
- chain_type="stuff",
64
  memory=memory,
65
  return_source_documents=True,
66
  verbose=False,
@@ -68,18 +76,27 @@ def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vector_db, pr
68
  return qa_chain
69
 
70
  def initialize_database(list_file_obj, progress=gr.Progress()):
 
 
 
71
  list_file_path = [x.name for x in list_file_obj if x is not None]
72
  doc_splits = load_doc(list_file_path)
73
  vector_db = create_db(doc_splits)
74
- return vector_db, "Database created!"
75
 
76
  def initialize_LLM(llm_option, llm_temperature, max_tokens, top_k, vector_db, progress=gr.Progress()):
 
 
 
77
  llm_name = list_llm[llm_option]
78
- print("llm_name: ",llm_name)
79
  qa_chain = initialize_llmchain(llm_name, llm_temperature, max_tokens, top_k, vector_db, progress)
80
- return qa_chain, "QA chain initialized. Chatbot is ready!"
81
 
82
  def format_chat_history(message, chat_history):
 
 
 
83
  formatted_chat_history = []
84
  for user_message, bot_message in chat_history:
85
  formatted_chat_history.append(f"User: {user_message}")
@@ -87,6 +104,9 @@ def format_chat_history(message, chat_history):
87
  return formatted_chat_history
88
 
89
  def conversation(qa_chain, message, history):
 
 
 
90
  formatted_chat_history = format_chat_history(message, history)
91
  response = qa_chain.invoke({"question": message, "chat_history": formatted_chat_history})
92
  response_answer = response["answer"]
@@ -102,95 +122,272 @@ def conversation(qa_chain, message, history):
102
  new_history = history + [(message, response_answer)]
103
  return qa_chain, gr.update(value=""), new_history, response_source1, response_source1_page, response_source2, response_source2_page, response_source3, response_source3_page
104
 
105
- def upload_file(file_obj):
106
- list_file_path = []
107
- for idx, file in enumerate(file_obj):
108
- file_path = file_obj.name
109
- list_file_path.append(file_path)
110
- return list_file_path
111
-
112
  def demo():
113
- # Modified theme to use dark blue colors
 
 
 
114
  theme = gr.themes.Default(
115
  primary_hue="indigo",
116
  secondary_hue="blue",
117
- neutral_hue="slate"
 
118
  )
119
 
120
- with gr.Blocks(theme=theme) as demo:
 
 
 
 
 
 
 
 
121
  vector_db = gr.State()
122
  qa_chain = gr.State()
123
- gr.HTML("<center><h1>RAG PDF chatbot</h1><center>")
124
- gr.Markdown("""<b>Query your PDF documents!</b> This AI agent is designed to perform retrieval augmented generation (RAG) on PDF documents. The app is hosted on Hugging Face Hub for the sole purpose of demonstration. \
125
- <b>Please do not upload confidential documents.</b>
126
- """)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
127
  with gr.Row():
128
- with gr.Column(scale = 86):
129
- gr.Markdown("<b>Step 1 - Upload PDF documents and Initialize RAG pipeline</b>")
 
 
 
 
 
130
  with gr.Row():
131
- document = gr.Files(height=300, file_count="multiple", file_types=["pdf"], interactive=True, label="Upload PDF documents")
 
 
 
 
 
 
 
132
  with gr.Row():
133
- db_btn = gr.Button("Create vector database")
 
 
 
 
134
  with gr.Row():
135
- db_progress = gr.Textbox(value="Not initialized", show_label=False)
136
- gr.Markdown("<style>body { font-size: 16px; }</style><b>Select Large Language Model (LLM) and input parameters</b>")
 
 
 
 
 
 
 
 
 
 
137
  with gr.Row():
138
- llm_btn = gr.Radio(list_llm_simple, label="Available LLMs", value = list_llm_simple[0], type="index")
 
 
 
 
 
 
 
139
  with gr.Row():
140
- with gr.Accordion("LLM input parameters", open=False):
141
  with gr.Row():
142
- slider_temperature = gr.Slider(minimum = 0.01, maximum = 1.0, value=0.5, step=0.1, label="Temperature", info="Controls randomness in token generation", interactive=True)
 
 
 
 
 
 
 
 
143
  with gr.Row():
144
- slider_maxtokens = gr.Slider(minimum = 128, maximum = 9192, value=4096, step=128, label="Max New Tokens", info="Maximum number of tokens to be generated",interactive=True)
 
 
 
 
 
 
 
 
145
  with gr.Row():
146
- slider_topk = gr.Slider(minimum = 1, maximum = 10, value=3, step=1, label="top-k", info="Number of tokens to select the next token from", interactive=True)
 
 
 
 
 
 
 
 
 
147
  with gr.Row():
148
- qachain_btn = gr.Button("Initialize Question Answering Chatbot")
 
 
 
 
149
  with gr.Row():
150
- llm_progress = gr.Textbox(value="Not initialized", show_label=False)
 
 
 
151
 
152
- with gr.Column(scale = 200):
153
- gr.Markdown("<b>Step 2 - Chat with your Document</b>")
154
- chatbot = gr.Chatbot(height=505)
155
- with gr.Accordion("Relevent context from the source document", open=False):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
156
  with gr.Row():
157
- doc_source1 = gr.Textbox(label="Reference 1", lines=2, container=True, scale=20)
 
 
 
 
 
158
  source1_page = gr.Number(label="Page", scale=1)
159
  with gr.Row():
160
- doc_source2 = gr.Textbox(label="Reference 2", lines=2, container=True, scale=20)
 
 
 
 
 
161
  source2_page = gr.Number(label="Page", scale=1)
162
  with gr.Row():
163
- doc_source3 = gr.Textbox(label="Reference 3", lines=2, container=True, scale=20)
 
 
 
 
 
164
  source3_page = gr.Number(label="Page", scale=1)
 
165
  with gr.Row():
166
- msg = gr.Textbox(placeholder="Ask a question", container=True)
 
 
 
 
167
  with gr.Row():
168
- submit_btn = gr.Button("Submit")
169
- clear_btn = gr.ClearButton([msg, chatbot], value="Clear")
 
 
 
 
 
 
 
 
 
 
 
 
 
170
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
171
  # Event handlers
172
- db_btn.click(initialize_database,
173
- inputs=[document],
174
- outputs=[vector_db, db_progress])
175
- qachain_btn.click(initialize_LLM,
176
- inputs=[llm_btn, slider_temperature, slider_maxtokens, slider_topk, vector_db],
177
- outputs=[qa_chain, llm_progress]).then(lambda:[None,"",0,"",0,"",0],
178
- inputs=None,
179
- outputs=[chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page],
180
- queue=False)
181
-
182
- msg.submit(conversation,
183
- inputs=[qa_chain, msg, chatbot],
184
- outputs=[qa_chain, msg, chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page],
185
- queue=False)
186
- submit_btn.click(conversation,
187
- inputs=[qa_chain, msg, chatbot],
188
- outputs=[qa_chain, msg, chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page],
189
- queue=False)
190
- clear_btn.click(lambda:[None,"",0,"",0,"",0],
191
- inputs=None,
192
- outputs=[chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page],
193
- queue=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
194
 
195
  demo.queue().launch(debug=True)
196
 
 
1
  import gradio as gr
2
  import os
3
+ import torch
 
4
  from langchain_community.vectorstores import FAISS
5
  from langchain_community.document_loaders import PyPDFLoader
6
  from langchain.text_splitter import RecursiveCharacterTextSplitter
 
11
  from langchain.chains import ConversationChain
12
  from langchain.memory import ConversationBufferMemory
13
  from langchain_community.llms import HuggingFaceEndpoint
 
14
 
15
+ api_token = os.getenv("HF_TOKEN")
16
+
17
+ # Available LLM models
18
  list_llm = [
19
  "meta-llama/Meta-Llama-3-8B-Instruct",
20
  "mistralai/Mistral-7B-Instruct-v0.2",
 
22
  ]
23
  list_llm_simple = [os.path.basename(llm) for llm in list_llm]
24
 
 
25
  def load_doc(list_file_path):
26
+ """
27
+ Load and split PDF documents into chunks
28
+ """
29
  loaders = [PyPDFLoader(x) for x in list_file_path]
30
  pages = []
31
  for loader in loaders:
32
  pages.extend(loader.load())
33
  text_splitter = RecursiveCharacterTextSplitter(
34
+ chunk_size=1024,
35
+ chunk_overlap=64
36
+ )
37
  doc_splits = text_splitter.split_documents(pages)
38
  return doc_splits
39
 
40
  def create_db(splits):
41
+ """
42
+ Create vector database from document splits
43
+ """
44
  embeddings = HuggingFaceEmbeddings()
45
  vectordb = FAISS.from_documents(splits, embeddings)
46
  return vectordb
47
 
48
  def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vector_db, progress=gr.Progress()):
49
+ """
50
+ Initialize the language model chain
51
+ """
52
  llm = HuggingFaceEndpoint(
53
  repo_id=llm_model,
54
  huggingfacehub_api_token=api_token,
 
68
  qa_chain = ConversationalRetrievalChain.from_llm(
69
  llm,
70
  retriever=retriever,
71
+ chain_type="stuff",
72
  memory=memory,
73
  return_source_documents=True,
74
  verbose=False,
 
76
  return qa_chain
77
 
78
  def initialize_database(list_file_obj, progress=gr.Progress()):
79
+ """
80
+ Initialize the document database
81
+ """
82
  list_file_path = [x.name for x in list_file_obj if x is not None]
83
  doc_splits = load_doc(list_file_path)
84
  vector_db = create_db(doc_splits)
85
+ return vector_db, "Database created successfully!"
86
 
87
  def initialize_LLM(llm_option, llm_temperature, max_tokens, top_k, vector_db, progress=gr.Progress()):
88
+ """
89
+ Initialize the Language Model
90
+ """
91
  llm_name = list_llm[llm_option]
92
+ print("Selected LLM model:", llm_name)
93
  qa_chain = initialize_llmchain(llm_name, llm_temperature, max_tokens, top_k, vector_db, progress)
94
+ return qa_chain, "Analysis Assistant initialized and ready!"
95
 
96
  def format_chat_history(message, chat_history):
97
+ """
98
+ Format chat history for the model
99
+ """
100
  formatted_chat_history = []
101
  for user_message, bot_message in chat_history:
102
  formatted_chat_history.append(f"User: {user_message}")
 
104
  return formatted_chat_history
105
 
106
  def conversation(qa_chain, message, history):
107
+ """
108
+ Handle conversation and document analysis
109
+ """
110
  formatted_chat_history = format_chat_history(message, history)
111
  response = qa_chain.invoke({"question": message, "chat_history": formatted_chat_history})
112
  response_answer = response["answer"]
 
122
  new_history = history + [(message, response_answer)]
123
  return qa_chain, gr.update(value=""), new_history, response_source1, response_source1_page, response_source2, response_source2_page, response_source3, response_source3_page
124
 
 
 
 
 
 
 
 
125
  def demo():
126
+ """
127
+ Main demo application
128
+ """
129
+ # Enhanced theme with professional colors
130
  theme = gr.themes.Default(
131
  primary_hue="indigo",
132
  secondary_hue="blue",
133
+ neutral_hue="slate",
134
+ font=[gr.themes.GoogleFont("Roboto"), "system-ui", "sans-serif"]
135
  )
136
 
137
+ css = """
138
+ .container { max-width: 1200px; margin: auto; }
139
+ .metadata { font-size: 0.9em; color: #666; }
140
+ .highlight { background-color: #f0f7ff; padding: 1em; border-radius: 8px; }
141
+ .warning { color: #e53e3e; }
142
+ .success { color: #38a169; }
143
+ """
144
+
145
+ with gr.Blocks(theme=theme, css=css) as demo:
146
  vector_db = gr.State()
147
  qa_chain = gr.State()
148
+
149
+ # Enhanced header
150
+ gr.HTML(
151
+ """
152
+ <div style='text-align: center; padding: 20px;'>
153
+ <h1 style='color: #1a365d; margin-bottom: 10px;'>MetroAssist AI - Expert in Metrology Report Analysis</h1>
154
+ <p style='color: #4a5568; font-size: 1.2em;'>Your intelligent assistant for advanced analysis of metrological documents</p>
155
+ </div>
156
+ """
157
+ )
158
+
159
+ # Marketing and feature description
160
+ gr.Markdown(
161
+ """
162
+ ### 🔍 Specialized Metrology Analysis
163
+
164
+ MetroAssist AI is a specialized assistant designed to revolutionize metrology report analysis.
165
+ Powered by cutting-edge AI technology, it offers:
166
+
167
+ * **Precise Analysis**: Detailed interpretation of measurements, calibrations, and compliance
168
+ * **Intelligent Contextualization**: Deep understanding of metrological standards and norms
169
+ * **Advanced Technical Support**: Assistance in complex instrument and measurement analyses
170
+ * **Rapid Processing**: Efficient analysis of multiple technical documents
171
+
172
+ ⚠️ **Security Note**: Your documents are processed with complete security. We do not permanently store confidential data.
173
+ """
174
+ )
175
+
176
  with gr.Row():
177
+ with gr.Column(scale=86):
178
+ gr.Markdown(
179
+ """
180
+ ### 📥 Step 1: Document Loading and Preparation
181
+ Upload your metrology reports for expert analysis.
182
+ """
183
+ )
184
  with gr.Row():
185
+ document = gr.Files(
186
+ height=300,
187
+ file_count="multiple",
188
+ file_types=["pdf"],
189
+ interactive=True,
190
+ label="Upload Metrology Reports (PDF)",
191
+ info="Accepts multiple PDF files"
192
+ )
193
  with gr.Row():
194
+ db_btn = gr.Button(
195
+ "Process Documents",
196
+ variant="primary",
197
+ size="lg"
198
+ )
199
  with gr.Row():
200
+ db_progress = gr.Textbox(
201
+ value="Waiting for documents...",
202
+ show_label=False,
203
+ container=False
204
+ )
205
+
206
+ gr.Markdown(
207
+ """
208
+ ### 🤖 Analysis Engine Configuration
209
+ Select and configure the AI model to best meet your needs.
210
+ """
211
+ )
212
  with gr.Row():
213
+ llm_btn = gr.Radio(
214
+ list_llm_simple,
215
+ label="Available AI Models",
216
+ value=list_llm_simple[0],
217
+ type="index",
218
+ info="Choose the most suitable model for your analysis"
219
+ )
220
+
221
  with gr.Row():
222
+ with gr.Accordion("Advanced Analysis Parameters", open=False):
223
  with gr.Row():
224
+ slider_temperature = gr.Slider(
225
+ minimum=0.01,
226
+ maximum=1.0,
227
+ value=0.5,
228
+ step=0.1,
229
+ label="Analysis Precision",
230
+ info="Controls the balance between precision and creativity in analysis",
231
+ interactive=True
232
+ )
233
  with gr.Row():
234
+ slider_maxtokens = gr.Slider(
235
+ minimum=128,
236
+ maximum=9192,
237
+ value=4096,
238
+ step=128,
239
+ label="Response Length",
240
+ info="Defines the level of detail in analyses",
241
+ interactive=True
242
+ )
243
  with gr.Row():
244
+ slider_topk = gr.Slider(
245
+ minimum=1,
246
+ maximum=10,
247
+ value=3,
248
+ step=1,
249
+ label="Analysis Diversity",
250
+ info="Controls the variety of perspectives in analysis",
251
+ interactive=True
252
+ )
253
+
254
  with gr.Row():
255
+ qachain_btn = gr.Button(
256
+ "Initialize Analysis Assistant",
257
+ variant="primary",
258
+ size="lg"
259
+ )
260
  with gr.Row():
261
+ llm_progress = gr.Textbox(
262
+ value="Waiting for initialization...",
263
+ show_label=False
264
+ )
265
 
266
+ with gr.Column(scale=200):
267
+ gr.Markdown(
268
+ """
269
+ ### 💬 Step 2: Expert Consultation and Analysis
270
+ Ask questions about your metrology reports. The assistant will provide detailed technical analyses.
271
+
272
+ **Suggested questions:**
273
+ - Analyze the calibration results of this instrument
274
+ - Verify compliance with technical standards
275
+ - Identify critical points in measurements
276
+ - Compare results with specified limits
277
+ - Evaluate measurement uncertainty
278
+ - Assess calibration intervals
279
+ """
280
+ )
281
+ chatbot = gr.Chatbot(
282
+ height=505,
283
+ show_label=True,
284
+ container=True,
285
+ label="Metrology Analysis"
286
+ )
287
+
288
+ with gr.Accordion("Source Document References", open=False):
289
  with gr.Row():
290
+ doc_source1 = gr.Textbox(
291
+ label="Technical Reference 1",
292
+ lines=2,
293
+ container=True,
294
+ scale=20
295
+ )
296
  source1_page = gr.Number(label="Page", scale=1)
297
  with gr.Row():
298
+ doc_source2 = gr.Textbox(
299
+ label="Technical Reference 2",
300
+ lines=2,
301
+ container=True,
302
+ scale=20
303
+ )
304
  source2_page = gr.Number(label="Page", scale=1)
305
  with gr.Row():
306
+ doc_source3 = gr.Textbox(
307
+ label="Technical Reference 3",
308
+ lines=2,
309
+ container=True,
310
+ scale=20
311
+ )
312
  source3_page = gr.Number(label="Page", scale=1)
313
+
314
  with gr.Row():
315
+ msg = gr.Textbox(
316
+ placeholder="Enter your question about the metrology report...",
317
+ container=True,
318
+ label="Your Query"
319
+ )
320
  with gr.Row():
321
+ submit_btn = gr.Button(
322
+ "Submit Query",
323
+ variant="primary"
324
+ )
325
+ clear_btn = gr.ClearButton(
326
+ [msg, chatbot],
327
+ value="Clear Conversation",
328
+ variant="secondary"
329
+ )
330
+
331
+ # Footer
332
+ gr.Markdown(
333
+ """
334
+ ---
335
+ ### ℹ️ About MetroAssist AI
336
 
337
+ Developed for metrology professionals, engineers, and technicians who need precise
338
+ and reliable analysis of technical documents. Our tool uses advanced AI technology
339
+ to provide valuable insights and support decision-making in metrology.
340
+
341
+ **Specialized Features:**
342
+ - Detailed analysis of calibration certificates
343
+ - Interpretation of complex metrological data
344
+ - Verification of compliance with technical standards
345
+ - Decision support in metrological processes
346
+ - Uncertainty analysis and measurement traceability
347
+ - Quality control and measurement system analysis
348
+
349
+ *Version 1.0 - Updated 2024*
350
+ """
351
+ )
352
+
353
  # Event handlers
354
+ db_btn.click(
355
+ initialize_database,
356
+ inputs=[document],
357
+ outputs=[vector_db, db_progress]
358
+ )
359
+
360
+ qachain_btn.click(
361
+ initialize_LLM,
362
+ inputs=[llm_btn, slider_temperature, slider_maxtokens, slider_topk, vector_db],
363
+ outputs=[qa_chain, llm_progress]
364
+ ).then(
365
+ lambda: [None, "", 0, "", 0, "", 0],
366
+ inputs=None,
367
+ outputs=[chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page],
368
+ queue=False
369
+ )
370
+
371
+ msg.submit(
372
+ conversation,
373
+ inputs=[qa_chain, msg, chatbot],
374
+ outputs=[qa_chain, msg, chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page],
375
+ queue=False
376
+ )
377
+
378
+ submit_btn.click(
379
+ conversation,
380
+ inputs=[qa_chain, msg, chatbot],
381
+ outputs=[qa_chain, msg, chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page],
382
+ queue=False
383
+ )
384
+
385
+ clear_btn.click(
386
+ lambda: [None, "", 0, "", 0, "", 0],
387
+ inputs=None,
388
+ outputs=[chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page],
389
+ queue=False
390
+ )
391
 
392
  demo.queue().launch(debug=True)
393