Commit
daf1d9d
1 Parent(s): e1559b3

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +146 -0
app.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import random
3
+ import numpy as np
4
+ import pandas as pd
5
+ from datasets import load_dataset
6
+ from sentence_transformers import CrossEncoder
7
+ from sklearn.metrics import average_precision_score
8
+ import matplotlib.pyplot as plt
9
+ import torch
10
+ import spaces
11
+
12
+ # Check for GPU support and configure appropriately
13
+ device = "cuda" if torch.cuda.is_available() else "cpu"
14
+ zero = torch.Tensor([0]).to(device)
15
+ print(f"Device being used: {zero.device}")
16
+
17
+ # Define evaluation metrics
18
+ def mean_reciprocal_rank(relevance_labels, scores):
19
+ sorted_indices = np.argsort(scores)[::-1]
20
+ for rank, idx in enumerate(sorted_indices, start=1):
21
+ if relevance_labels[idx] == 1:
22
+ return 1 / rank
23
+ return 0
24
+
25
+ def mean_average_precision(relevance_labels, scores):
26
+ return average_precision_score(relevance_labels, scores)
27
+
28
+ def ndcg_at_k(relevance_labels, scores, k=10):
29
+ sorted_indices = np.argsort(scores)[::-1]
30
+ relevance_sorted = np.take(relevance_labels, sorted_indices[:k])
31
+ dcg = sum(rel / np.log2(rank + 2) for rank, rel in enumerate(relevance_sorted))
32
+ idcg = sum(1 / np.log2(rank + 2) for rank in range(min(k, sum(relevance_labels))))
33
+ return dcg / idcg if idcg > 0 else 0
34
+
35
+ # Load datasets
36
+ datasets = {
37
+ "Relevance_Labels_Dataset": load_dataset("NAMAA-Space/Ar-Reranking-Eval")["train"],
38
+ "Positive_Negatives_Dataset": load_dataset("NAMAA-Space/Arabic-Reranking-Triplet-5-Eval")["train"]
39
+ }
40
+
41
+ @spaces.GPU
42
+ def evaluate_model_with_insights(model_name):
43
+ model = CrossEncoder(model_name, device=device)
44
+ results = []
45
+ sample_outputs = []
46
+
47
+ for dataset_name, dataset in datasets.items():
48
+ all_mrr, all_map, all_ndcg = [], [], []
49
+ dataset_samples = []
50
+
51
+ if 'candidate_document' in dataset.column_names:
52
+ grouped_data = dataset.to_pandas().groupby("query")
53
+ for query, group in grouped_data:
54
+ candidate_texts = group['candidate_document'].tolist()
55
+ relevance_labels = group['relevance_label'].tolist()
56
+ pairs = [(query, doc) for doc in candidate_texts]
57
+ scores = model.predict(pairs)
58
+
59
+ # Collecting top-5 results for display
60
+ sorted_indices = np.argsort(scores)[::-1]
61
+ top_docs = [(candidate_texts[i], scores[i], relevance_labels[i]) for i in sorted_indices[:5]]
62
+ dataset_samples.append({
63
+ "Query": query,
64
+ "Top 5 Candidates": top_docs
65
+ })
66
+
67
+ # Metrics
68
+ all_mrr.append(mean_reciprocal_rank(relevance_labels, scores))
69
+ all_map.append(mean_average_precision(relevance_labels, scores))
70
+ all_ndcg.append(ndcg_at_k(relevance_labels, scores, k=10))
71
+ else:
72
+ for entry in dataset:
73
+ query = entry['query']
74
+ candidate_texts = [entry['positive'], entry['negative1'], entry['negative2'], entry['negative3'], entry['negative4']]
75
+ relevance_labels = [1, 0, 0, 0, 0]
76
+ pairs = [(query, doc) for doc in candidate_texts]
77
+ scores = model.predict(pairs)
78
+
79
+ # Collecting top-5 results for display
80
+ sorted_indices = np.argsort(scores)[::-1]
81
+ top_docs = [(candidate_texts[i], scores[i], relevance_labels[i]) for i in sorted_indices[:5]]
82
+ dataset_samples.append({
83
+ "Query": query,
84
+ "Top 5 Candidates": top_docs
85
+ })
86
+
87
+ # Metrics
88
+ all_mrr.append(mean_reciprocal_rank(relevance_labels, scores))
89
+ all_map.append(mean_average_precision(relevance_labels, scores))
90
+ all_ndcg.append(ndcg_at_k(relevance_labels, scores, k=10))
91
+
92
+ # Metrics for this dataset
93
+ results.append({
94
+ "Dataset": dataset_name,
95
+ "MRR": np.mean(all_mrr),
96
+ "MAP": np.mean(all_map),
97
+ "nDCG@10": np.mean(all_ndcg)
98
+ })
99
+
100
+ # Collect sample outputs for inspection
101
+ sample_outputs.extend(dataset_samples)
102
+
103
+ results_df = pd.DataFrame(results)
104
+
105
+ # Plot results as a bar chart
106
+ fig, ax = plt.subplots(figsize=(8, 6))
107
+ results_df.plot(kind='bar', x='Dataset', y=['MRR', 'MAP', 'nDCG@10'], ax=ax)
108
+ ax.set_title(f"Evaluation Results for {model_name}")
109
+ ax.set_ylabel("Score")
110
+ plt.xticks(rotation=0)
111
+
112
+ return results_df, fig, sample_outputs
113
+
114
+ # Gradio app interface
115
+ def gradio_app_with_insights(model_name):
116
+ results_df, chart, samples = evaluate_model_with_insights(model_name)
117
+ sample_display = []
118
+ for sample in samples:
119
+ sample_display.append(f"Query: {sample['Query']}")
120
+ for doc, score, label in sample["Top 5 Candidates"]:
121
+ sample_display.append(f" Doc: {doc[:50]}... | Score: {score:.2f} | Relevance: {label}")
122
+ sample_display.append("\n")
123
+ return results_df, chart, "\n".join(sample_display)
124
+
125
+ interface = gr.Interface(
126
+ fn=gradio_app_with_insights,
127
+ inputs=gr.Textbox(label="Enter Model Name", placeholder="e.g., NAMAA-Space/GATE-Reranker-V1"),
128
+ outputs=[
129
+ gr.Dataframe(label="Evaluation Results"),
130
+ gr.Plot(label="Evaluation Metrics Chart"),
131
+ gr.Textbox(label="Sample Reranking Insights", lines=15)
132
+ ],
133
+ title="Arabic Reranking Model Evaluation and Insights",
134
+ description=(
135
+ "This app evaluates Arabic reranking models on two datasets:\n"
136
+ "1. **Relevance Labels Dataset**\n"
137
+ "2. **Positive-Negatives Dataset**\n\n"
138
+ "### Metrics Used:\n"
139
+ "- **MRR (Mean Reciprocal Rank)**: Measures how quickly the first relevant document appears.\n"
140
+ "- **MAP (Mean Average Precision)**: Reflects ranking quality across all relevant documents.\n"
141
+ "- **nDCG@10 (Normalized Discounted Cumulative Gain)**: Focuses on the ranking of relevant documents in the top-10.\n\n"
142
+ "Input a model name to evaluate its performance, view metrics, and examine sample reranking results."
143
+ )
144
+ )
145
+
146
+ interface.launch(debug=True)