|
import streamlit as st |
|
import plotly.graph_objects as go |
|
import numpy as np |
|
import pandas as pd |
|
|
|
|
|
fillcolor = "#FFD21E" |
|
line_color = "#FF9D00" |
|
|
|
|
|
opacity = 0.75 |
|
|
|
|
|
categories = ["ARC", "GSM8K", "TruthfulQA", "Winogrande", "HellaSwag", "MMLU"] |
|
|
|
columns = ["model_name", "ARC", "HellaSwag", "TruthfulQA", "Winogrande", "GSM8K", |
|
"MMLU", "Average"] |
|
|
|
|
|
@st.cache_data |
|
def plot_radar_chart_index(dataframe: pd.DataFrame, index: int, categories: list = categories, fillcolor: str = fillcolor, line_color:str = line_color): |
|
""" |
|
plot the index-th row of the dataframe |
|
|
|
Arguments: |
|
dataframe: a pandas DataFrame |
|
index: the index of the row we want to plot |
|
categories: the list of the metrics |
|
fillcolor: a string specifying the color to fill the area |
|
line_color: a string specifying the color of the lines in the graph |
|
""" |
|
fig = go.Figure() |
|
data = dataframe.loc[index,categories].to_numpy()*100 |
|
data = data.astype(float) |
|
|
|
data = data.round(decimals = 2) |
|
|
|
|
|
data = np.append(data, data[0]) |
|
categories_theta = categories.copy() |
|
categories_theta.append(categories[0]) |
|
model_name = dataframe.loc[index,"model_name"] |
|
|
|
|
|
fig.add_trace(go.Scatterpolar( |
|
r=data, |
|
theta=categories_theta, |
|
fill='toself', |
|
fillcolor = fillcolor, |
|
opacity = opacity, |
|
line=dict(color = line_color), |
|
name= model_name |
|
)) |
|
fig.update_layout( |
|
polar=dict( |
|
radialaxis=dict( |
|
visible=True, |
|
range=[0, 100.] |
|
)), |
|
showlegend=False |
|
) |
|
|
|
return fig |
|
|
|
@st.cache_data |
|
def plot_radar_chart_name(dataframe: pd.DataFrame, model_name: str, categories: list = categories, fillcolor: str = fillcolor, line_color:str = line_color): |
|
""" |
|
plot the results of the model named model_name row of the dataframe |
|
|
|
Arguments: |
|
dataframe: a pandas DataFrame |
|
model_name: a string stating the name of the model |
|
categories: the list of the metrics |
|
fillcolor: a string specifying the color to fill the area |
|
line_color: a string specifying the color of the lines in the graph |
|
""" |
|
fig = go.Figure() |
|
data = dataframe[dataframe["model_name"] == model_name][categories].to_numpy()*100 |
|
data = data.astype(float) |
|
|
|
data = data.round(decimals = 2) |
|
|
|
|
|
data = np.append(data, data[0]) |
|
categories_theta = categories.copy() |
|
categories_theta.append(categories[0]) |
|
model_name = model_name |
|
|
|
|
|
fig.add_trace(go.Scatterpolar( |
|
r=data, |
|
theta=categories_theta, |
|
fill='toself', |
|
fillcolor = fillcolor, |
|
opacity = opacity, |
|
line=dict(color = line_color), |
|
name= model_name |
|
)) |
|
fig.update_layout( |
|
polar=dict( |
|
radialaxis=dict( |
|
visible=True, |
|
range=[0, 100.] |
|
)), |
|
showlegend=False |
|
) |
|
|
|
return fig |
|
|
|
|
|
@st.cache_data |
|
def plot_radar_chart_index(dataframe: pd.DataFrame, index: int, categories: list = categories, fillcolor: str = fillcolor, line_color:str = line_color): |
|
""" |
|
plot the index-th row of the dataframe |
|
|
|
Arguments: |
|
dataframe: a pandas DataFrame |
|
index: the index of the row we want to plot |
|
categories: the list of the metrics |
|
fillcolor: a string specifying the color to fill the area |
|
line_color: a string specifying the color of the lines in the graph |
|
""" |
|
fig = go.Figure() |
|
data = dataframe.loc[index,categories].to_numpy()*100 |
|
data = data.astype(float) |
|
|
|
data = data.round(decimals = 2) |
|
|
|
|
|
data = np.append(data, data[0]) |
|
categories_theta = categories.copy() |
|
categories_theta.append(categories[0]) |
|
model_name = dataframe.loc[index,"model_name"] |
|
|
|
|
|
fig.add_trace(go.Scatterpolar( |
|
r=data, |
|
theta=categories_theta, |
|
fill='toself', |
|
fillcolor = fillcolor, |
|
opacity = opacity, |
|
line=dict(color = line_color), |
|
name= model_name |
|
)) |
|
fig.update_layout( |
|
polar=dict( |
|
radialaxis=dict( |
|
visible=True, |
|
range=[0, 100.] |
|
)), |
|
showlegend=False |
|
) |
|
|
|
return fig |
|
|
|
@st.cache_data |
|
def plot_radar_chart_rows(rows: object, columns:list = columns, categories: list = categories, fillcolor: str = fillcolor, line_color:str = line_color): |
|
""" |
|
plot the results of the model selected by the checkbox |
|
|
|
Arguments: |
|
rows: an iterable whose elements are dicts with columns as their keys |
|
columns: the list of the columns to use |
|
categories: the list of the metrics |
|
fillcolor: a string specifying the color to fill the area |
|
line_color: a string specifying the color of the lines in the graph |
|
""" |
|
fig = go.Figure() |
|
dataset = pd.DataFrame(rows, columns=columns) |
|
data = dataset[categories].to_numpy() |
|
data = data.astype(float) |
|
|
|
|
|
data = np.append(data, data[:,0].reshape((-1,1)), axis=1) |
|
categories_theta = categories.copy() |
|
categories_theta.append(categories[0]) |
|
|
|
|
|
for i in range(len(dataset)): |
|
|
|
fig.add_trace(go.Scatterpolar( |
|
r=data[i,:], |
|
theta=categories_theta, |
|
fill='toself', |
|
fillcolor = fillcolor, |
|
opacity = opacity, |
|
line=dict(color = line_color), |
|
name= dataset.loc[i,"model_name"] |
|
)) |
|
fig.update_layout( |
|
polar=dict( |
|
radialaxis=dict( |
|
visible=True, |
|
range=[0, 100.] |
|
)), |
|
showlegend=False |
|
) |
|
|
|
return fig |