Spaces:
Sleeping
Sleeping
File size: 4,974 Bytes
963c6da fbcd930 adbb181 fbcd930 963c6da fbcd930 963c6da ea44f66 fbcd930 ea44f66 fbcd930 963c6da ea44f66 adbb181 963c6da adbb181 963c6da adbb181 963c6da adbb181 963c6da adbb181 963c6da adbb181 963c6da adbb181 963c6da adbb181 963c6da fbcd930 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 |
import streamlit as st
import plotly.graph_objects as go
import numpy as np
import pandas as pd
# Hugging Face Colors
fillcolor = "#FFD21E"
line_color = "#FF9D00"
fill_color_list = [fillcolor, "#F05998", "#40BAF0"]
line_color_list = [line_color, "#5E233C", "#194A5E"]
# opacity of the plot
opacity = 0.75
# categories to show radar chart
categories = ["ARC", "GSM8K", "TruthfulQA", "Winogrande", "HellaSwag", "MMLU"]
# Dataset columns
columns = ["model_name", "ARC", "HellaSwag", "TruthfulQA", "Winogrande", "GSM8K",
"MMLU", "Average"]
#@st.cache_data
def plot_radar_chart_index(dataframe: pd.DataFrame, index: int, categories: list = categories, fillcolor: str = fillcolor, line_color:str = line_color):
"""
plot the index-th row of the dataframe
Arguments:
dataframe: a pandas DataFrame
index: the index of the row we want to plot
categories: the list of the metrics
fillcolor: a string specifying the color to fill the area
line_color: a string specifying the color of the lines in the graph
"""
fig = go.Figure()
data = dataframe.loc[index,categories].to_numpy()*100
data = data.astype(float)
# rounding data
data = data.round(decimals = 2)
# add data to close the area of the radar chart
data = np.append(data, data[0])
categories_theta = categories.copy()
categories_theta.append(categories[0])
model_name = dataframe.loc[index,"model_name"]
#print("Printing data ", data, " for ", model_name)
fig.add_trace(go.Scatterpolar(
r=data,
theta=categories_theta,
fill='toself',
fillcolor = fillcolor,
opacity = opacity,
line=dict(color = line_color),
name= model_name
))
fig.update_layout(
polar=dict(
radialaxis=dict(
visible=True,
range=[0, 100.]
)),
showlegend=False
)
return fig
#@st.cache_data
def plot_radar_chart_name(dataframe: pd.DataFrame, model_name: str, categories: list = categories, fillcolor: str = fillcolor, line_color:str = line_color):
"""
plot the results of the model named model_name row of the dataframe
Arguments:
dataframe: a pandas DataFrame
model_name: a string stating the name of the model
categories: the list of the metrics
fillcolor: a string specifying the color to fill the area
line_color: a string specifying the color of the lines in the graph
"""
fig = go.Figure()
data = dataframe[dataframe["model_name"] == model_name][categories].to_numpy()*100
data = data.astype(float)
# rounding data
data = data.round(decimals = 2)
# add data to close the area of the radar chart
data = np.append(data, data[0])
categories_theta = categories.copy()
categories_theta.append(categories[0])
model_name = model_name
#print("Printing data ", data, " for ", model_name)
fig.add_trace(go.Scatterpolar(
r=data,
theta=categories_theta,
fill='toself',
fillcolor = fillcolor,
opacity = opacity,
line=dict(color = line_color),
name= model_name
))
fig.update_layout(
polar=dict(
radialaxis=dict(
visible=True,
range=[0, 100.]
)),
showlegend=False
)
return fig
#@st.cache_data
def plot_radar_chart_rows(rows: object, columns:list = columns, categories: list = categories, fillcolor_list: str = fill_color_list, line_color_list:str = line_color_list):
"""
plot the results of the model selected by the checkbox
Arguments:
rows: an iterable whose elements are dicts with columns as their keys
columns: the list of the columns to use
categories: the list of the metrics
fillcolor: a string specifying the color to fill the area
line_color: a string specifying the color of the lines in the graph
"""
fig = go.Figure()
dataset = pd.DataFrame(rows, columns=columns)
data = dataset[categories].to_numpy()
data = data.astype(float)
showLegend = False
if len(rows) > 1:
showLegend = True
# add data to close the area of the radar chart
data = np.append(data, data[:,0].reshape((-1,1)), axis=1)
categories_theta = categories.copy()
categories_theta.append(categories[0])
opacity = 0.75
for i in range(len(dataset)):
colors = fillcolor_list[i]
fig.add_trace(go.Scatterpolar(
r=data[i,:],
theta=categories_theta,
fill='toself',
fillcolor = colors,
opacity = opacity,
line=dict(color = line_color_list[i]),
name= dataset.loc[i,"model_name"]
))
fig.update_layout(
polar=dict(
radialaxis=dict(
visible=True,
range=[0, 100.]
)),
showlegend=showLegend
)
opacity -= .2
return fig |