Spaces:
Running
on
Zero
Running
on
Zero
import gradio as gr | |
import spaces | |
import torch | |
import pandas as pd | |
import matplotlib.pyplot as plt | |
from datasets import load_dataset | |
from sentence_transformers import SentenceTransformer | |
from sentence_transformers.evaluation import InformationRetrievalEvaluator, SequentialEvaluator | |
from sentence_transformers.util import cos_sim | |
# Check for GPU support and configure appropriately | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
zero = torch.Tensor([0]).to(device) | |
print(f"Device being used: {zero.device}") | |
def evaluate_model(model_id): | |
model = SentenceTransformer(model_id, device=device) | |
matryoshka_dimensions = [768, 512, 256, 128, 64] | |
# Prepare datasets | |
datasets_info = [ | |
{ | |
"name": "Arabic Financial Dataset (Financial Evaluation)", | |
"dataset_id": "Omartificial-Intelligence-Space/Arabic-finanical-rag-embedding-dataset", | |
"split": "train", | |
"size": 7000, | |
"columns": ("question", "context"), | |
"sample_size": 500 | |
}, | |
{ | |
"name": "MLQA Arabic (Long Context Evaluation)", | |
"dataset_id": "google/xtreme", | |
"subset": "MLQA.ar.ar", | |
"split": "validation", | |
"size": 500, | |
"columns": ("question", "context"), | |
"sample_size": 500 | |
}, | |
{ | |
"name": "ARCD (Short Context Evaluation)", | |
"dataset_id": "hsseinmz/arcd", | |
"split": "train", | |
"size": None, | |
"columns": ("question", "context"), | |
"sample_size": 500, | |
"last_rows": True # Take the last 500 rows | |
} | |
] | |
evaluation_results = [] | |
scores_by_dataset = {} | |
for dataset_info in datasets_info: | |
# Load the dataset with subset if available | |
if "subset" in dataset_info: | |
dataset = load_dataset(dataset_info["dataset_id"], dataset_info["subset"], split=dataset_info["split"]) | |
else: | |
dataset = load_dataset(dataset_info["dataset_id"], split=dataset_info["split"]) | |
# Take last 500 rows if specified | |
if dataset_info.get("last_rows"): | |
dataset = dataset.select(range(len(dataset) - dataset_info["sample_size"], len(dataset))) | |
else: | |
dataset = dataset.select(range(min(dataset_info["sample_size"], len(dataset)))) | |
# Rename columns | |
dataset = dataset.rename_column(dataset_info["columns"][0], "anchor") | |
dataset = dataset.rename_column(dataset_info["columns"][1], "positive") | |
# Check if "id" column already exists before adding it | |
if "id" not in dataset.column_names: | |
dataset = dataset.add_column("id", range(len(dataset))) | |
# Prepare queries and corpus | |
corpus = dict(zip(dataset["id"], dataset["positive"])) | |
queries = dict(zip(dataset["id"], dataset["anchor"])) | |
# Create a mapping of relevant documents (1 in our case) for each query | |
relevant_docs = {q_id: [q_id] for q_id in queries} | |
matryoshka_evaluators = [] | |
for dim in matryoshka_dimensions: | |
ir_evaluator = InformationRetrievalEvaluator( | |
queries=queries, | |
corpus=corpus, | |
relevant_docs=relevant_docs, | |
name=f"dim_{dim}", | |
truncate_dim=dim, | |
score_functions={"cosine": cos_sim}, | |
) | |
matryoshka_evaluators.append(ir_evaluator) | |
evaluator = SequentialEvaluator(matryoshka_evaluators) | |
results = evaluator(model) | |
scores = [] | |
for dim in matryoshka_dimensions: | |
key = f"dim_{dim}_cosine_ndcg@10" | |
score = results[key] if key in results else None | |
evaluation_results.append({ | |
"Dataset": dataset_info["name"], | |
"Dimension": dim, | |
"Score": score | |
}) | |
scores.append(score) | |
# Store scores by dataset for bar chart creation | |
scores_by_dataset[dataset_info["name"]] = scores | |
# Convert results to DataFrame for display | |
result_df = pd.DataFrame(evaluation_results) | |
# Generate bar charts for each dataset | |
charts = [] | |
colors = ['#FF5733', '#33FF57', '#3357FF', '#FF33C4', '#F3FF33'] # Creative color palette | |
for dataset_name, scores in scores_by_dataset.items(): | |
fig, ax = plt.subplots() | |
ax.bar([str(dim) for dim in matryoshka_dimensions], scores, color=colors) | |
ax.set_title(f"{dataset_name} Evaluation Scores", fontsize=16, color='darkblue') | |
ax.set_xlabel("Embedding Dimension", fontsize=12) | |
ax.set_ylabel("NDCG@10 Score", fontsize=12) | |
ax.spines['top'].set_visible(False) | |
ax.spines['right'].set_visible(False) | |
plt.tight_layout() | |
charts.append(fig) | |
return result_df, charts[0], charts[1], charts[2] | |
# Define the Gradio interface | |
def display_results(model_name): | |
result_df, chart1, chart2, chart3 = evaluate_model(model_name) | |
return result_df, chart1, chart2, chart3 | |
demo = gr.Interface( | |
fn=display_results, | |
inputs=gr.Textbox(label="Enter Your Embedding Model ID", placeholder="e.g., Omartificial-Intelligence-Space/GATE-AraBert-v1"), | |
outputs=[ | |
gr.Dataframe(label="Evaluation Results"), | |
gr.Plot(label="Arabic Financial Dataset (Financial Evaluation)"), | |
gr.Plot(label="MLQA Arabic (Long Context Evaluation)"), | |
gr.Plot(label="ARCD (Short Context Evaluation)") | |
], | |
title="Evaluation of Arabic Matryoshka Embedding Models on Retreival Tasks ", | |
description=( | |
"Evaluate your Sentence Transformer model's performance on **context and question retrieval** for Arabic datasets for enhancing Arabic RAG.\n" | |
"- **ARCD** evaluates short context retrieval performance.\n" | |
"- **MLQA Arabic** evaluates long context retrieval performance.\n" | |
"- **Arabic Financial Dataset** focuses on financial context retrieval.\n\n" | |
"**Evaluation Metric:**\n" | |
"The evaluation uses **NDCG@10** (Normalized Discounted Cumulative Gain), which measures how well the retrieved documents (contexts) match the query relevance.\n" | |
"Higher scores indicate better performance. Embedding dimensions are reduced from 768 to 64, evaluating how well the model performs with fewer dimensions." | |
), | |
theme="default", | |
live=False, | |
css="footer {visibility: hidden;}" | |
) | |
demo.launch(share=True) | |
demo.launch(share=True) | |
# Add the footer | |
print("\nCreated by Omar Najar | Omartificial Intelligence Space") | |