File size: 4,744 Bytes
963c6da
fbcd930
 
 
 
 
 
 
 
 
 
 
 
 
963c6da
 
 
fbcd930
963c6da
ea44f66
fbcd930
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ea44f66
fbcd930
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
963c6da
 
 
ea44f66
963c6da
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fbcd930
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
import streamlit as st
import plotly.graph_objects as go
import numpy as np
import pandas as pd

# Hugging Face Colors
fillcolor = "#FFD21E"
line_color = "#FF9D00"

# opacity of the plot
opacity = 0.75

# categories to show radar chart
categories = ["ARC", "GSM8K", "TruthfulQA", "Winogrande", "HellaSwag", "MMLU"]
# Dataset columns
columns = ["model_name", "ARC", "HellaSwag", "TruthfulQA", "Winogrande", "GSM8K",
                    "MMLU", "Average"]


#@st.cache_data
def plot_radar_chart_index(dataframe: pd.DataFrame, index: int, categories: list = categories, fillcolor: str = fillcolor, line_color:str = line_color):
    """
    plot the index-th row of the dataframe
    
    Arguments:
    dataframe: a pandas DataFrame 
    index: the index of the row we want to plot
    categories: the list of the metrics
    fillcolor: a string specifying the color to fill the area
    line_color: a string specifying the color of the lines in the graph
    """
    fig = go.Figure()
    data = dataframe.loc[index,categories].to_numpy()*100
    data  = data.astype(float)
    # rounding data
    data = data.round(decimals = 2)
    
    # add data to close the area of the radar chart
    data = np.append(data, data[0])
    categories_theta = categories.copy()
    categories_theta.append(categories[0])
    model_name = dataframe.loc[index,"model_name"] 
    #print("Printing data ", data, " for ", model_name)

    fig.add_trace(go.Scatterpolar(
          r=data,
          theta=categories_theta,
          fill='toself',
          fillcolor = fillcolor,
          opacity = opacity,
          line=dict(color = line_color),
          name= model_name
    ))
    fig.update_layout(
      polar=dict(
        radialaxis=dict(
          visible=True,
          range=[0, 100.]
        )),
      showlegend=False
    )

    return fig

#@st.cache_data
def plot_radar_chart_name(dataframe: pd.DataFrame, model_name: str, categories: list = categories, fillcolor: str = fillcolor, line_color:str = line_color):
    """
    plot the results of the model named model_name row of the dataframe
    
    Arguments:
    dataframe: a pandas DataFrame 
    model_name: a string stating the name of the model
    categories: the list of the metrics
    fillcolor: a string specifying the color to fill the area
    line_color: a string specifying the color of the lines in the graph
    """
    fig = go.Figure()
    data = dataframe[dataframe["model_name"] == model_name][categories].to_numpy()*100
    data  = data.astype(float)
    # rounding data
    data = data.round(decimals = 2)
    
    # add data to close the area of the radar chart
    data = np.append(data, data[0])
    categories_theta = categories.copy()
    categories_theta.append(categories[0])
    model_name = model_name 
    #print("Printing data ", data, " for ", model_name)

    fig.add_trace(go.Scatterpolar(
          r=data,
          theta=categories_theta,
          fill='toself',
          fillcolor = fillcolor,
          opacity = opacity,
          line=dict(color = line_color),
          name= model_name
    ))
    fig.update_layout(
      polar=dict(
        radialaxis=dict(
          visible=True,
          range=[0, 100.]
        )),
      showlegend=False
    )

    return fig


#@st.cache_data
def plot_radar_chart_rows(rows: object, columns:list = columns, categories: list = categories, fillcolor: str = fillcolor, line_color:str = line_color):
    """
    plot the results of the model selected by the checkbox
    
    Arguments:
    rows: an iterable whose elements are dicts with columns as their keys 
    columns: the list of the columns to use
    categories: the list of the metrics
    fillcolor: a string specifying the color to fill the area
    line_color: a string specifying the color of the lines in the graph
    """
    fig = go.Figure()
    dataset = pd.DataFrame(rows, columns=columns)
    data = dataset[categories].to_numpy()
    data  = data.astype(float)
    
    # add data to close the area of the radar chart
    data = np.append(data, data[:,0].reshape((-1,1)), axis=1)
    categories_theta = categories.copy()
    categories_theta.append(categories[0])
     
    #print("Printing data ", data, " for ", model_name)
    for i in range(len(dataset)):

      fig.add_trace(go.Scatterpolar(
            r=data[i,:],
            theta=categories_theta,
            fill='toself',
            fillcolor = fillcolor,
            opacity = opacity,
            line=dict(color = line_color),
            name= dataset.loc[i,"model_name"]
      ))
      fig.update_layout(
        polar=dict(
          radialaxis=dict(
            visible=True,
            range=[0, 100.]
          )),
        showlegend=False
      )

    return fig