Spaces:
Sleeping
Sleeping
arjunanand13
commited on
Commit
•
ababf21
1
Parent(s):
4ad946f
Update app.py
Browse files
app.py
CHANGED
@@ -47,10 +47,12 @@ class RAGEvaluator:
|
|
47 |
self.current_dataset = None
|
48 |
self.test_samples = []
|
49 |
|
50 |
-
def load_dataset(self, dataset_name: str, num_samples: int =
|
|
|
51 |
if dataset_name == "squad":
|
52 |
dataset = load_dataset("squad_v2", split="validation")
|
53 |
-
|
|
|
54 |
self.test_samples = [
|
55 |
{
|
56 |
"question": sample["question"],
|
@@ -62,7 +64,7 @@ class RAGEvaluator:
|
|
62 |
]
|
63 |
elif dataset_name == "msmarco":
|
64 |
dataset = load_dataset("ms_marco", "v2.1", split="train")
|
65 |
-
samples = dataset.select(range(
|
66 |
self.test_samples = [
|
67 |
{
|
68 |
"question": sample["query"],
|
@@ -76,40 +78,60 @@ class RAGEvaluator:
|
|
76 |
return self.test_samples
|
77 |
|
78 |
def evaluate_configuration(self, vector_db, qa_chain, splitting_strategy: str, chunk_size: str) -> Dict:
|
|
|
79 |
if not self.test_samples:
|
80 |
return {"error": "No dataset loaded"}
|
81 |
|
82 |
results = []
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
})
|
88 |
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
95 |
|
|
|
96 |
eval_dataset = Dataset.from_list(results)
|
97 |
metrics = [ContextRecall(), AnswerRelevancy(), Faithfulness(), ContextPrecision()]
|
98 |
-
scores = evaluate(eval_dataset, metrics=metrics)
|
99 |
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
scores['
|
108 |
-
scores['
|
109 |
-
scores['
|
110 |
-
|
111 |
-
|
112 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
113 |
|
114 |
# Text splitting and database functions
|
115 |
def get_text_splitter(strategy: str, chunk_size: int = 1024, chunk_overlap: int = 64):
|
|
|
47 |
self.current_dataset = None
|
48 |
self.test_samples = []
|
49 |
|
50 |
+
def load_dataset(self, dataset_name: str, num_samples: int = 5):
|
51 |
+
"""Load a smaller subset of questions"""
|
52 |
if dataset_name == "squad":
|
53 |
dataset = load_dataset("squad_v2", split="validation")
|
54 |
+
# Select diverse questions based on length and type
|
55 |
+
samples = dataset.select(range(0, 1000, 100))[:num_samples] # Take 10 spaced-out samples
|
56 |
self.test_samples = [
|
57 |
{
|
58 |
"question": sample["question"],
|
|
|
64 |
]
|
65 |
elif dataset_name == "msmarco":
|
66 |
dataset = load_dataset("ms_marco", "v2.1", split="train")
|
67 |
+
samples = dataset.select(range(0, 1000, 100))[:num_samples]
|
68 |
self.test_samples = [
|
69 |
{
|
70 |
"question": sample["query"],
|
|
|
78 |
return self.test_samples
|
79 |
|
80 |
def evaluate_configuration(self, vector_db, qa_chain, splitting_strategy: str, chunk_size: str) -> Dict:
|
81 |
+
"""Evaluate with progress tracking"""
|
82 |
if not self.test_samples:
|
83 |
return {"error": "No dataset loaded"}
|
84 |
|
85 |
results = []
|
86 |
+
total_questions = len(self.test_samples)
|
87 |
+
|
88 |
+
# Add progress tracking
|
89 |
+
for i, sample in enumerate(self.test_samples):
|
90 |
+
print(f"Evaluating question {i+1}/{total_questions}")
|
91 |
|
92 |
+
try:
|
93 |
+
response = qa_chain.invoke({
|
94 |
+
"question": sample["question"],
|
95 |
+
"chat_history": []
|
96 |
+
})
|
97 |
+
|
98 |
+
results.append({
|
99 |
+
"question": sample["question"],
|
100 |
+
"answer": response["answer"],
|
101 |
+
"contexts": [doc.page_content for doc in response["source_documents"]],
|
102 |
+
"ground_truths": [sample["ground_truth"]]
|
103 |
+
})
|
104 |
+
except Exception as e:
|
105 |
+
print(f"Error processing question {i+1}: {str(e)}")
|
106 |
+
continue
|
107 |
|
108 |
+
# Calculate RAGAS metrics
|
109 |
eval_dataset = Dataset.from_list(results)
|
110 |
metrics = [ContextRecall(), AnswerRelevancy(), Faithfulness(), ContextPrecision()]
|
|
|
111 |
|
112 |
+
try:
|
113 |
+
scores = evaluate(eval_dataset, metrics=metrics)
|
114 |
+
|
115 |
+
return {
|
116 |
+
"configuration": f"{splitting_strategy}_{chunk_size}",
|
117 |
+
"questions_evaluated": len(results),
|
118 |
+
"context_recall": float(scores['context_recall']),
|
119 |
+
"answer_relevancy": float(scores['answer_relevancy']),
|
120 |
+
"faithfulness": float(scores['faithfulness']),
|
121 |
+
"context_precision": float(scores['context_precision']),
|
122 |
+
"average_score": float(np.mean([
|
123 |
+
scores['context_recall'],
|
124 |
+
scores['answer_relevancy'],
|
125 |
+
scores['faithfulness'],
|
126 |
+
scores['context_precision']
|
127 |
+
]))
|
128 |
+
}
|
129 |
+
except Exception as e:
|
130 |
+
return {
|
131 |
+
"configuration": f"{splitting_strategy}_{chunk_size}",
|
132 |
+
"error": str(e),
|
133 |
+
"questions_evaluated": len(results)
|
134 |
+
}
|
135 |
|
136 |
# Text splitting and database functions
|
137 |
def get_text_splitter(strategy: str, chunk_size: int = 1024, chunk_overlap: int = 64):
|