File size: 5,726 Bytes
6522b7b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
import os
import gradio as gr
from sqlalchemy import create_engine, Column, Integer, String, Float, DateTime
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker
from datetime import datetime
import pandas as pd
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score

testsets_root_path = "./testsets/"

# Function to load the dataset
def load_testsets(testsets_root_path: str) -> dict:
    datasets_dict = {}
    for ds in os.listdir(testsets_root_path):
        if ds.endswith(".csv"):  # Ensure only CSV files are processed
            csv_path = os.path.join(testsets_root_path, ds)
            df = pd.read_csv(csv_path)
            datasets_dict[ds.replace(".csv", "")] = df
    return datasets_dict

# Database setup
Base = declarative_base()

class Submission(Base):
    __tablename__ = 'submissions'
    id = Column(Integer, primary_key=True)
    dataset_name = Column(String)
    submission_name = Column(String)
    model_link = Column(String)
    person_name = Column(String)
    accuracy = Column(Float)
    precision = Column(Float)
    recall = Column(Float)
    f1 = Column(Float)
    submission_date = Column(DateTime, default=datetime.utcnow)

engine = create_engine('sqlite:///submissions.db')
Base.metadata.create_all(engine)
Session = sessionmaker(bind=engine)
session = Session()

# Function to fetch previous submissions for a selected dataset
def get_existing_submissions(dataset_name):
    existing_submissions = session.query(Submission).filter_by(dataset_name=dataset_name).order_by(
        Submission.submission_date.desc()).all()

    submissions_list = [{
        "Submission Name": sub.submission_name,
        "Model Link": sub.model_link,
        "Person Name": sub.person_name,
        "Accuracy": sub.accuracy,
        "Precision": sub.precision,
        "Recall": sub.recall,
        "F1": sub.f1,
        "Submission Date": sub.submission_date.strftime("%Y-%m-%d %H:%M:%S")
    } for sub in existing_submissions]

    return pd.DataFrame(submissions_list) if submissions_list else pd.DataFrame(columns=[
        "Submission Name", "Model Link", "Person Name", "Accuracy", "Precision", "Recall", "F1", "Submission Date"
    ])

# Evaluation function for text classification
def calculate_metrics(gs, pred):
    y_true = gs['label']
    y_pred = pred['label']
    try:
        accuracy = accuracy_score(y_true, y_pred)
        precision = precision_score(y_true, y_pred, average='weighted')
        recall = recall_score(y_true, y_pred, average='weighted')
        f1 = f1_score(y_true, y_pred, average='weighted')
        return accuracy, precision, recall, f1
    except:
        return None, None, None, None

def benchmark_interface(dataset_name, submission_file, submission_name, model_link, person_name):
    if not all([dataset_name, submission_file, submission_name, model_link, person_name]):
        return {"error": "All fields are required."}, pd.DataFrame()

    dataset_dict = load_testsets(testsets_root_path)
    df_gs = dataset_dict.get(dataset_name)
    if df_gs is None:
        return {"error": "Dataset not found."}, pd.DataFrame()

    # Parse the uploaded submission CSV
    submission_df = pd.read_csv(submission_file.name)

    # Ensure the columns are present
    if not all(col in submission_df.columns for col in ['file_name', 'label']):
        return {"error": "Submission file must contain 'file_name' and 'label' columns."}, pd.DataFrame()

    # Calculate metrics
    accuracy, precision, recall, f1 = calculate_metrics(gs=df_gs, pred=submission_df)
    metrics = {'Accuracy': accuracy, 'Precision': precision, 'Recall': recall, 'F1': f1}
    if f1 is not None:
        # Save submission to the database
        new_submission = Submission(
            dataset_name=dataset_name,
            submission_name=submission_name,
            model_link=model_link,
            person_name=person_name,
            accuracy=accuracy,
            precision=precision,
            recall=recall,
            f1=f1
        )
        session.add(new_submission)
        session.commit()

    # Fetch updated submissions
    submissions_df = get_existing_submissions(dataset_name)
    return metrics, submissions_df


def create_gradio_app():
    dataset_dict = load_testsets(testsets_root_path)
    dataset_names = list(dataset_dict.keys())

    with gr.Blocks() as demo:
        gr.Markdown("## Benchmarking Leaderboard for Text Classification")
        dataset_radio = gr.Radio(choices=dataset_names, label="Select Dataset")
        submission_file = gr.File(label="Upload Submission CSV")
        submission_name = gr.Textbox(label="Submission Name")
        model_link = gr.Textbox(label="Model Link on HuggingFace")
        person_name = gr.Textbox(label="Person Name")
        submit_button = gr.Button("Submit")
        metrics_output = gr.JSON(label="Evaluation Metrics")
        existing_submissions_output = gr.Dataframe(label="Existing Submissions")

        # When a dataset is selected, fetch previous submissions
        dataset_radio.change(
            fn=get_existing_submissions,
            inputs=[dataset_radio],
            outputs=[existing_submissions_output]
        )

        submit_button.click(
            fn=benchmark_interface,
            inputs=[dataset_radio, submission_file, submission_name, model_link, person_name],
            outputs=[metrics_output, existing_submissions_output]
        )
    return demo

def main():
    app = create_gradio_app()
    app.launch()

if __name__ == "__main__":
    main()