arjunanand13 commited on
Commit
338f585
·
verified ·
1 Parent(s): 184e87b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +75 -46
app.py CHANGED
@@ -11,8 +11,8 @@ from langchain.text_splitter import (
11
  from langchain_community.vectorstores import FAISS, Chroma, Qdrant
12
  from langchain_community.document_loaders import PyPDFLoader
13
  from langchain.chains import ConversationalRetrievalChain
14
- from langchain_community.embeddings import HuggingFaceEmbeddings
15
- from langchain_community.llms import HuggingFaceEndpoint
16
  from langchain.memory import ConversationBufferMemory
17
  from sentence_transformers import SentenceTransformer, util
18
  import torch
@@ -48,53 +48,46 @@ class RAGEvaluator:
48
  self.test_samples = []
49
 
50
  def load_dataset(self, dataset_name: str, num_samples: int = 10):
51
- """Load a smaller subset of questions with proper error handling"""
52
  try:
53
  if dataset_name == "squad":
54
  dataset = load_dataset("squad_v2", split="validation")
55
- # Select diverse questions
56
  samples = dataset.select(range(0, 1000, 100))[:num_samples]
57
 
58
  self.test_samples = []
59
  for sample in samples:
60
- # Check if answers exist and are not empty
61
- if sample.get("answers") and isinstance(sample["answers"], dict) and sample["answers"].get("text"):
 
62
  self.test_samples.append({
63
  "question": sample["question"],
64
- "ground_truth": sample["answers"]["text"][0],
65
  "context": sample["context"]
66
  })
67
 
68
  elif dataset_name == "msmarco":
69
- dataset = load_dataset("ms_marco", "v2.1", split="dev")
70
  samples = dataset.select(range(0, 1000, 100))[:num_samples]
71
 
72
  self.test_samples = []
73
  for sample in samples:
74
- # Check for valid answers
75
- if sample.get("answers") and sample["answers"]:
76
  self.test_samples.append({
77
  "question": sample["query"],
78
  "ground_truth": sample["answers"][0],
79
- "context": sample["passages"][0]["passage_text"]
80
- if isinstance(sample["passages"], list)
81
- else sample["passages"]["passage_text"][0]
82
  })
83
 
84
  self.current_dataset = dataset_name
85
-
86
- # Return dataset info
87
  return {
88
  "dataset": dataset_name,
89
- "num_samples": len(self.test_samples),
90
- "sample_questions": [s["question"] for s in self.test_samples[:3]],
91
- "status": "success"
92
  }
93
 
94
  except Exception as e:
95
  print(f"Error loading dataset: {str(e)}")
96
  return {
97
- "dataset": dataset_name,
98
  "error": str(e),
99
  "status": "failed"
100
  }
@@ -205,36 +198,58 @@ def create_db(splits, db_choice: str = "faiss"):
205
  return db_creators[db_choice]()
206
 
207
  def initialize_database(list_file_obj, splitting_strategy, chunk_size, db_choice, progress=gr.Progress()):
208
- list_file_path = [x.name for x in list_file_obj if x is not None]
209
- doc_splits = load_doc(list_file_path, splitting_strategy, chunk_size)
210
- vector_db = create_db(doc_splits, db_choice)
211
- return vector_db, f"Database created using {splitting_strategy} splitting and {db_choice} vector database!"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
212
 
213
  def initialize_llmchain(llm_choice, temperature, max_tokens, top_k, vector_db, progress=gr.Progress()):
214
- llm_model = list_llm[llm_choice]
215
-
216
- llm = HuggingFaceEndpoint(
217
- repo_id=llm_model,
218
- huggingfacehub_api_token=api_token,
219
- temperature=temperature,
220
- max_new_tokens=max_tokens,
221
- top_k=top_k
222
- )
223
-
224
- memory = ConversationBufferMemory(
225
- memory_key="chat_history",
226
- output_key='answer',
227
- return_messages=True
228
- )
 
 
 
 
 
229
 
230
- retriever = vector_db.as_retriever()
231
- qa_chain = ConversationalRetrievalChain.from_llm(
232
- llm,
233
- retriever=retriever,
234
- memory=memory,
235
- return_source_documents=True
236
- )
237
- return qa_chain, "LLM initialized successfully!"
 
 
 
238
 
239
  def conversation(qa_chain, message, history):
240
  """Fixed conversation function returning all required outputs"""
@@ -424,12 +439,26 @@ def demo():
424
  initialize_database,
425
  inputs=[document, splitting_strategy, chunk_size, db_choice],
426
  outputs=[vector_db, db_progress]
 
 
 
 
427
  )
428
 
429
  init_llm_btn.click(
430
  initialize_llmchain,
431
  inputs=[llm_choice, temperature, max_tokens, top_k, vector_db],
432
  outputs=[qa_chain, llm_progress]
 
 
 
 
 
 
 
 
 
 
433
  )
434
 
435
  msg.submit(
 
11
  from langchain_community.vectorstores import FAISS, Chroma, Qdrant
12
  from langchain_community.document_loaders import PyPDFLoader
13
  from langchain.chains import ConversationalRetrievalChain
14
+ from langchain_community.embeddings import HuggingFaceEmbeddings
15
+ from langchain_huggingface import HuggingFaceEndpoint
16
  from langchain.memory import ConversationBufferMemory
17
  from sentence_transformers import SentenceTransformer, util
18
  import torch
 
48
  self.test_samples = []
49
 
50
  def load_dataset(self, dataset_name: str, num_samples: int = 10):
51
+ """Load dataset with proper error handling"""
52
  try:
53
  if dataset_name == "squad":
54
  dataset = load_dataset("squad_v2", split="validation")
 
55
  samples = dataset.select(range(0, 1000, 100))[:num_samples]
56
 
57
  self.test_samples = []
58
  for sample in samples:
59
+ # Handle SQuAD format
60
+ answers = sample["answers"]
61
+ if answers["text"]: # Check if there are answers
62
  self.test_samples.append({
63
  "question": sample["question"],
64
+ "ground_truth": answers["text"][0],
65
  "context": sample["context"]
66
  })
67
 
68
  elif dataset_name == "msmarco":
69
+ dataset = load_dataset("ms_marco", "v2.1", split="test") # Changed from dev to test
70
  samples = dataset.select(range(0, 1000, 100))[:num_samples]
71
 
72
  self.test_samples = []
73
  for sample in samples:
74
+ if sample["answers"]: # Check if answers exist
 
75
  self.test_samples.append({
76
  "question": sample["query"],
77
  "ground_truth": sample["answers"][0],
78
+ "context": sample["passages"]["passage_text"][0]
 
 
79
  })
80
 
81
  self.current_dataset = dataset_name
 
 
82
  return {
83
  "dataset": dataset_name,
84
+ "samples_loaded": len(self.test_samples),
85
+ "example_questions": [s["question"] for s in self.test_samples[:3]]
 
86
  }
87
 
88
  except Exception as e:
89
  print(f"Error loading dataset: {str(e)}")
90
  return {
 
91
  "error": str(e),
92
  "status": "failed"
93
  }
 
198
  return db_creators[db_choice]()
199
 
200
  def initialize_database(list_file_obj, splitting_strategy, chunk_size, db_choice, progress=gr.Progress()):
201
+ """Initialize vector database with error handling"""
202
+ try:
203
+ if not list_file_obj:
204
+ return None, "No files uploaded. Please upload PDF documents first."
205
+
206
+ list_file_path = [x.name for x in list_file_obj if x is not None]
207
+ if not list_file_path:
208
+ return None, "No valid files found. Please upload PDF documents."
209
+
210
+ doc_splits = load_doc(list_file_path, splitting_strategy, chunk_size)
211
+ if not doc_splits:
212
+ return None, "No content extracted from documents."
213
+
214
+ vector_db = create_db(doc_splits, db_choice)
215
+ return vector_db, f"Database created successfully using {splitting_strategy} splitting and {db_choice} vector database!"
216
+
217
+ except Exception as e:
218
+ return None, f"Error creating database: {str(e)}"
219
 
220
  def initialize_llmchain(llm_choice, temperature, max_tokens, top_k, vector_db, progress=gr.Progress()):
221
+ """Initialize LLM chain with error handling"""
222
+ try:
223
+ if vector_db is None:
224
+ return None, "Please create vector database first."
225
+
226
+ llm_model = list_llm[llm_choice]
227
+
228
+ llm = HuggingFaceEndpoint(
229
+ repo_id=llm_model,
230
+ huggingfacehub_api_token=api_token,
231
+ temperature=temperature,
232
+ max_new_tokens=max_tokens,
233
+ top_k=top_k
234
+ )
235
+
236
+ memory = ConversationBufferMemory(
237
+ memory_key="chat_history",
238
+ output_key='answer',
239
+ return_messages=True
240
+ )
241
 
242
+ retriever = vector_db.as_retriever()
243
+ qa_chain = ConversationalRetrievalChain.from_llm(
244
+ llm,
245
+ retriever=retriever,
246
+ memory=memory,
247
+ return_source_documents=True
248
+ )
249
+ return qa_chain, "LLM initialized successfully!"
250
+
251
+ except Exception as e:
252
+ return None, f"Error initializing LLM: {str(e)}"
253
 
254
  def conversation(qa_chain, message, history):
255
  """Fixed conversation function returning all required outputs"""
 
439
  initialize_database,
440
  inputs=[document, splitting_strategy, chunk_size, db_choice],
441
  outputs=[vector_db, db_progress]
442
+ ).then(
443
+ lambda x: gr.update(interactive=True) if x[0] is not None else gr.update(interactive=False),
444
+ inputs=[vector_db],
445
+ outputs=[init_llm_btn]
446
  )
447
 
448
  init_llm_btn.click(
449
  initialize_llmchain,
450
  inputs=[llm_choice, temperature, max_tokens, top_k, vector_db],
451
  outputs=[qa_chain, llm_progress]
452
+ ).then(
453
+ lambda x: gr.update(interactive=True) if x[0] is not None else gr.update(interactive=False),
454
+ inputs=[qa_chain],
455
+ outputs=[msg]
456
+ )
457
+
458
+ load_dataset_btn.click(
459
+ lambda x: evaluator.load_dataset(x),
460
+ inputs=[dataset_choice],
461
+ outputs=[dataset_info]
462
  )
463
 
464
  msg.submit(