arjunanand13 commited on
Commit
ababf21
1 Parent(s): 4ad946f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +50 -28
app.py CHANGED
@@ -47,10 +47,12 @@ class RAGEvaluator:
47
  self.current_dataset = None
48
  self.test_samples = []
49
 
50
- def load_dataset(self, dataset_name: str, num_samples: int = 50):
 
51
  if dataset_name == "squad":
52
  dataset = load_dataset("squad_v2", split="validation")
53
- samples = dataset.select(range(num_samples))
 
54
  self.test_samples = [
55
  {
56
  "question": sample["question"],
@@ -62,7 +64,7 @@ class RAGEvaluator:
62
  ]
63
  elif dataset_name == "msmarco":
64
  dataset = load_dataset("ms_marco", "v2.1", split="train")
65
- samples = dataset.select(range(num_samples))
66
  self.test_samples = [
67
  {
68
  "question": sample["query"],
@@ -76,40 +78,60 @@ class RAGEvaluator:
76
  return self.test_samples
77
 
78
  def evaluate_configuration(self, vector_db, qa_chain, splitting_strategy: str, chunk_size: str) -> Dict:
 
79
  if not self.test_samples:
80
  return {"error": "No dataset loaded"}
81
 
82
  results = []
83
- for sample in self.test_samples:
84
- response = qa_chain.invoke({
85
- "question": sample["question"],
86
- "chat_history": []
87
- })
88
 
89
- results.append({
90
- "question": sample["question"],
91
- "answer": response["answer"],
92
- "contexts": [doc.page_content for doc in response["source_documents"]],
93
- "ground_truths": [sample["ground_truth"]]
94
- })
 
 
 
 
 
 
 
 
 
95
 
 
96
  eval_dataset = Dataset.from_list(results)
97
  metrics = [ContextRecall(), AnswerRelevancy(), Faithfulness(), ContextPrecision()]
98
- scores = evaluate(eval_dataset, metrics=metrics)
99
 
100
- return {
101
- "configuration": f"{splitting_strategy}_{chunk_size}",
102
- "context_recall": float(scores['context_recall']),
103
- "answer_relevancy": float(scores['answer_relevancy']),
104
- "faithfulness": float(scores['faithfulness']),
105
- "context_precision": float(scores['context_precision']),
106
- "average_score": float(np.mean([
107
- scores['context_recall'],
108
- scores['answer_relevancy'],
109
- scores['faithfulness'],
110
- scores['context_precision']
111
- ]))
112
- }
 
 
 
 
 
 
 
 
 
 
113
 
114
  # Text splitting and database functions
115
  def get_text_splitter(strategy: str, chunk_size: int = 1024, chunk_overlap: int = 64):
 
47
  self.current_dataset = None
48
  self.test_samples = []
49
 
50
+ def load_dataset(self, dataset_name: str, num_samples: int = 5):
51
+ """Load a smaller subset of questions"""
52
  if dataset_name == "squad":
53
  dataset = load_dataset("squad_v2", split="validation")
54
+ # Select diverse questions based on length and type
55
+ samples = dataset.select(range(0, 1000, 100))[:num_samples] # Take 10 spaced-out samples
56
  self.test_samples = [
57
  {
58
  "question": sample["question"],
 
64
  ]
65
  elif dataset_name == "msmarco":
66
  dataset = load_dataset("ms_marco", "v2.1", split="train")
67
+ samples = dataset.select(range(0, 1000, 100))[:num_samples]
68
  self.test_samples = [
69
  {
70
  "question": sample["query"],
 
78
  return self.test_samples
79
 
80
  def evaluate_configuration(self, vector_db, qa_chain, splitting_strategy: str, chunk_size: str) -> Dict:
81
+ """Evaluate with progress tracking"""
82
  if not self.test_samples:
83
  return {"error": "No dataset loaded"}
84
 
85
  results = []
86
+ total_questions = len(self.test_samples)
87
+
88
+ # Add progress tracking
89
+ for i, sample in enumerate(self.test_samples):
90
+ print(f"Evaluating question {i+1}/{total_questions}")
91
 
92
+ try:
93
+ response = qa_chain.invoke({
94
+ "question": sample["question"],
95
+ "chat_history": []
96
+ })
97
+
98
+ results.append({
99
+ "question": sample["question"],
100
+ "answer": response["answer"],
101
+ "contexts": [doc.page_content for doc in response["source_documents"]],
102
+ "ground_truths": [sample["ground_truth"]]
103
+ })
104
+ except Exception as e:
105
+ print(f"Error processing question {i+1}: {str(e)}")
106
+ continue
107
 
108
+ # Calculate RAGAS metrics
109
  eval_dataset = Dataset.from_list(results)
110
  metrics = [ContextRecall(), AnswerRelevancy(), Faithfulness(), ContextPrecision()]
 
111
 
112
+ try:
113
+ scores = evaluate(eval_dataset, metrics=metrics)
114
+
115
+ return {
116
+ "configuration": f"{splitting_strategy}_{chunk_size}",
117
+ "questions_evaluated": len(results),
118
+ "context_recall": float(scores['context_recall']),
119
+ "answer_relevancy": float(scores['answer_relevancy']),
120
+ "faithfulness": float(scores['faithfulness']),
121
+ "context_precision": float(scores['context_precision']),
122
+ "average_score": float(np.mean([
123
+ scores['context_recall'],
124
+ scores['answer_relevancy'],
125
+ scores['faithfulness'],
126
+ scores['context_precision']
127
+ ]))
128
+ }
129
+ except Exception as e:
130
+ return {
131
+ "configuration": f"{splitting_strategy}_{chunk_size}",
132
+ "error": str(e),
133
+ "questions_evaluated": len(results)
134
+ }
135
 
136
  # Text splitting and database functions
137
  def get_text_splitter(strategy: str, chunk_size: int = 1024, chunk_overlap: int = 64):