Spaces:
Sleeping
Sleeping
import gradio as gr | |
import plotly.express as px | |
import plotly.graph_objs as go | |
from collections import defaultdict | |
import json, math, gdown | |
import numpy as np | |
import pandas as pd | |
from Config import * | |
pd.options.display.float_format = '{:.2f}'.format | |
battles = np.linspace(0, 100, 100) | |
meta_topics = ['mmlu'] | |
def generate_plot(meta_index, topic_index): | |
""" | |
Bar plot of a specific dataset | |
""" | |
# battles = np.linspace(0, 100, 100) | |
meta_topic = meta_topics[meta_index] | |
print(meta_topic) | |
topic = TOPICS[meta_topic][topic_index] | |
data = pd.read_csv(f"data/{meta_topic}/response_rec.csv", sep=",") | |
topic_data = data[data['sub_topic'] == topic] | |
# Compute human and llm accuracy | |
topic_data['human_acc'] = topic_data['no_correct_human'] / topic_data['no_responses_human'].replace(0, np.nan) | |
topic_data['llm_acc'] = topic_data['no_correct_llm'] / topic_data['no_responses_llm'].replace(0, np.nan) | |
# Calculate mean and standard deviation for the sample data | |
mean_data = topic_data.groupby('model_name').mean().reset_index() | |
std_deviation = topic_data.groupby('model_name').std().reset_index() | |
# Prepare the plot data | |
plot_data = [] | |
# Define a consistent color scheme | |
colors = ['#FFA07A', '#20B2AA', '#778899'] # Light Salmon, Light Sea Green, Light Slate Gray | |
opacities = [0.7, 0.7, 0.7] # Opacity for average bars | |
# Add bars with error bars for the averages | |
for acc_type, color, opacity in zip(['oracle_acc', 'human_acc', 'llm_acc'], colors, opacities): | |
plot_data.append(go.Bar( | |
x=mean_data['model_name'], | |
y=mean_data[acc_type], | |
error_y=dict( | |
type='data', | |
array=std_deviation[acc_type], | |
visible=True | |
), | |
name=acc_type.split('_')[0].capitalize(), | |
marker=dict(color=color, opacity=opacity) | |
)) | |
# Layout | |
layout = go.Layout( | |
title=f"Accuracy for {meta_topic} ({topic})", | |
xaxis=dict(title='Model Name'), | |
yaxis=dict(title='Accuracy'), | |
showlegend=True, | |
legend=dict(title='Accuracy Type'), | |
barmode='group' | |
) | |
fig = go.Figure(data=plot_data, layout=layout) | |
return fig | |
# Gradio interface with grid layout | |
with gr.Blocks() as interface: | |
with gr.Row(): # Row 1 | |
plot1 = gr.Plot(generate_plot(0, 0)) | |
# plot1.update(inputs=[0, 0]) | |
plot2 = gr.Plot(generate_plot(0, 0)) | |
# plot2.update(inputs=[0, 1]) | |
with gr.Row(): # Row 2 | |
plot3 = gr.Plot(generate_plot(0, 0)) | |
# plot3.update(inputs=[1, 0]) | |
plot4 = gr.Plot(generate_plot(0, 0)) | |
# plot4.update(inputs=[1, 1]) | |
with gr.Row(): # Row 3 | |
plot5 = gr.Plot(generate_plot(0, 0)) | |
# plot5.update(inputs=[2, 0]) | |
plot6 = gr.Plot(generate_plot(0, 0)) | |
# plot6.update(inputs=[2, 1]) | |
interface.launch() | |