dimbyTa's picture
Adding caching and row plotting
963c6da
raw
history blame
6.12 kB
import streamlit as st
import plotly.graph_objects as go
import numpy as np
import pandas as pd
# Hugging Face Colors
fillcolor = "#FFD21E"
line_color = "#FF9D00"
# opacity of the plot
opacity = 0.75
# categories to show radar chart
categories = ["ARC", "GSM8K", "TruthfulQA", "Winogrande", "HellaSwag", "MMLU"]
# Dataset columns
columns = ["model_name", "ARC", "HellaSwag", "TruthfulQA", "Winogrande", "GSM8K",
"MMLU", "Average"]
@st.cache_data
def plot_radar_chart_index(dataframe: pd.DataFrame, index: int, categories: list = categories, fillcolor: str = fillcolor, line_color:str = line_color):
"""
plot the index-th row of the dataframe
Arguments:
dataframe: a pandas DataFrame
index: the index of the row we want to plot
categories: the list of the metrics
fillcolor: a string specifying the color to fill the area
line_color: a string specifying the color of the lines in the graph
"""
fig = go.Figure()
data = dataframe.loc[index,categories].to_numpy()*100
data = data.astype(float)
# rounding data
data = data.round(decimals = 2)
# add data to close the area of the radar chart
data = np.append(data, data[0])
categories_theta = categories.copy()
categories_theta.append(categories[0])
model_name = dataframe.loc[index,"model_name"]
#print("Printing data ", data, " for ", model_name)
fig.add_trace(go.Scatterpolar(
r=data,
theta=categories_theta,
fill='toself',
fillcolor = fillcolor,
opacity = opacity,
line=dict(color = line_color),
name= model_name
))
fig.update_layout(
polar=dict(
radialaxis=dict(
visible=True,
range=[0, 100.]
)),
showlegend=False
)
return fig
@st.cache_data
def plot_radar_chart_name(dataframe: pd.DataFrame, model_name: str, categories: list = categories, fillcolor: str = fillcolor, line_color:str = line_color):
"""
plot the results of the model named model_name row of the dataframe
Arguments:
dataframe: a pandas DataFrame
model_name: a string stating the name of the model
categories: the list of the metrics
fillcolor: a string specifying the color to fill the area
line_color: a string specifying the color of the lines in the graph
"""
fig = go.Figure()
data = dataframe[dataframe["model_name"] == model_name][categories].to_numpy()*100
data = data.astype(float)
# rounding data
data = data.round(decimals = 2)
# add data to close the area of the radar chart
data = np.append(data, data[0])
categories_theta = categories.copy()
categories_theta.append(categories[0])
model_name = model_name
#print("Printing data ", data, " for ", model_name)
fig.add_trace(go.Scatterpolar(
r=data,
theta=categories_theta,
fill='toself',
fillcolor = fillcolor,
opacity = opacity,
line=dict(color = line_color),
name= model_name
))
fig.update_layout(
polar=dict(
radialaxis=dict(
visible=True,
range=[0, 100.]
)),
showlegend=False
)
return fig
@st.cache_data
def plot_radar_chart_index(dataframe: pd.DataFrame, index: int, categories: list = categories, fillcolor: str = fillcolor, line_color:str = line_color):
"""
plot the index-th row of the dataframe
Arguments:
dataframe: a pandas DataFrame
index: the index of the row we want to plot
categories: the list of the metrics
fillcolor: a string specifying the color to fill the area
line_color: a string specifying the color of the lines in the graph
"""
fig = go.Figure()
data = dataframe.loc[index,categories].to_numpy()*100
data = data.astype(float)
# rounding data
data = data.round(decimals = 2)
# add data to close the area of the radar chart
data = np.append(data, data[0])
categories_theta = categories.copy()
categories_theta.append(categories[0])
model_name = dataframe.loc[index,"model_name"]
#print("Printing data ", data, " for ", model_name)
fig.add_trace(go.Scatterpolar(
r=data,
theta=categories_theta,
fill='toself',
fillcolor = fillcolor,
opacity = opacity,
line=dict(color = line_color),
name= model_name
))
fig.update_layout(
polar=dict(
radialaxis=dict(
visible=True,
range=[0, 100.]
)),
showlegend=False
)
return fig
@st.cache_data
def plot_radar_chart_rows(rows: object, columns:list = columns, categories: list = categories, fillcolor: str = fillcolor, line_color:str = line_color):
"""
plot the results of the model selected by the checkbox
Arguments:
rows: an iterable whose elements are dicts with columns as their keys
columns: the list of the columns to use
categories: the list of the metrics
fillcolor: a string specifying the color to fill the area
line_color: a string specifying the color of the lines in the graph
"""
fig = go.Figure()
dataset = pd.DataFrame(rows, columns=columns)
data = dataset[categories].to_numpy()
data = data.astype(float)
# add data to close the area of the radar chart
data = np.append(data, data[:,0].reshape((-1,1)), axis=1)
categories_theta = categories.copy()
categories_theta.append(categories[0])
#print("Printing data ", data, " for ", model_name)
for i in range(len(dataset)):
fig.add_trace(go.Scatterpolar(
r=data[i,:],
theta=categories_theta,
fill='toself',
fillcolor = fillcolor,
opacity = opacity,
line=dict(color = line_color),
name= dataset.loc[i,"model_name"]
))
fig.update_layout(
polar=dict(
radialaxis=dict(
visible=True,
range=[0, 100.]
)),
showlegend=False
)
return fig