dimbyTa commited on
Commit
963c6da
1 Parent(s): 14bee3a

Adding caching and row plotting

Browse files
Files changed (3) hide show
  1. src/display.py +4 -7
  2. src/load_data.py +3 -0
  3. src/plot.py +98 -0
src/display.py CHANGED
@@ -5,7 +5,7 @@
5
  from st_aggrid import GridOptionsBuilder, AgGrid
6
  import streamlit as st
7
  from .load_data import load_dataframe, sort_by
8
- from .plot import plot_radar_chart_index, plot_radar_chart_name
9
 
10
 
11
  def display_app():
@@ -69,9 +69,8 @@ def display_app():
69
 
70
  with column2:
71
  if grid_response['selected_rows'] is not None and len(grid_response['selected_rows']) > 0:
72
- model_name = grid_response['selected_rows'][0]["model_name"]
73
- figure = plot_radar_chart_name(dataframe=dataframe, model_name=model_name)
74
- st.plotly_chart(figure, use_container_width=False)
75
  else:
76
  if len(subdata)>0:
77
  figure = plot_radar_chart_name(dataframe=subdata, model_name=model_name)
@@ -80,6 +79,4 @@ def display_app():
80
  if grid_response['selected_rows'] is not None and len(grid_response['selected_rows']) > 0:
81
  st.markdown("**Model name:** %s" % grid_response['selected_rows'][0]["model_name"])
82
  else:
83
- st.markdown("**Model name:** %s" % model_name)
84
-
85
-
 
5
  from st_aggrid import GridOptionsBuilder, AgGrid
6
  import streamlit as st
7
  from .load_data import load_dataframe, sort_by
8
+ from .plot import plot_radar_chart_name, plot_radar_chart_rows
9
 
10
 
11
  def display_app():
 
69
 
70
  with column2:
71
  if grid_response['selected_rows'] is not None and len(grid_response['selected_rows']) > 0:
72
+ figure = plot_radar_chart_rows(rows=grid_response['selected_rows'])
73
+ st.plotly_chart(figure, use_container_width=True)
 
74
  else:
75
  if len(subdata)>0:
76
  figure = plot_radar_chart_name(dataframe=subdata, model_name=model_name)
 
79
  if grid_response['selected_rows'] is not None and len(grid_response['selected_rows']) > 0:
80
  st.markdown("**Model name:** %s" % grid_response['selected_rows'][0]["model_name"])
81
  else:
82
+ st.markdown("**Model name:** %s" % model_name)
 
 
src/load_data.py CHANGED
@@ -1,5 +1,7 @@
 
1
  import pandas as pd
2
 
 
3
  def load_dataframe() -> pd.DataFrame:
4
  """
5
  Load dataframe from the csv file in public directory
@@ -11,6 +13,7 @@ def load_dataframe() -> pd.DataFrame:
11
  dataframe = dataframe.drop(columns = "Unnamed: 0")
12
  return dataframe
13
 
 
14
  def sort_by(dataframe: pd.DataFrame, column_name: str, ascending:bool = False) -> pd.DataFrame:
15
  """
16
  Sort the dataframe by column_name
 
1
+ import streamlit as st
2
  import pandas as pd
3
 
4
+ @st.cache_data
5
  def load_dataframe() -> pd.DataFrame:
6
  """
7
  Load dataframe from the csv file in public directory
 
13
  dataframe = dataframe.drop(columns = "Unnamed: 0")
14
  return dataframe
15
 
16
+ @st.cache_data
17
  def sort_by(dataframe: pd.DataFrame, column_name: str, ascending:bool = False) -> pd.DataFrame:
18
  """
19
  Sort the dataframe by column_name
src/plot.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import plotly.graph_objects as go
2
  import numpy as np
3
  import pandas as pd
@@ -11,7 +12,12 @@ opacity = 0.75
11
 
12
  # categories to show radar chart
13
  categories = ["ARC", "GSM8K", "TruthfulQA", "Winogrande", "HellaSwag", "MMLU"]
 
 
 
14
 
 
 
15
  def plot_radar_chart_index(dataframe: pd.DataFrame, index: int, categories: list = categories, fillcolor: str = fillcolor, line_color:str = line_color):
16
  """
17
  plot the index-th row of the dataframe
@@ -56,6 +62,7 @@ def plot_radar_chart_index(dataframe: pd.DataFrame, index: int, categories: list
56
 
57
  return fig
58
 
 
59
  def plot_radar_chart_name(dataframe: pd.DataFrame, model_name: str, categories: list = categories, fillcolor: str = fillcolor, line_color:str = line_color):
60
  """
61
  plot the results of the model named model_name row of the dataframe
@@ -98,4 +105,95 @@ def plot_radar_chart_name(dataframe: pd.DataFrame, model_name: str, categories:
98
  showlegend=False
99
  )
100
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
101
  return fig
 
1
+ import streamlit as st
2
  import plotly.graph_objects as go
3
  import numpy as np
4
  import pandas as pd
 
12
 
13
  # categories to show radar chart
14
  categories = ["ARC", "GSM8K", "TruthfulQA", "Winogrande", "HellaSwag", "MMLU"]
15
+ # Dataset columns
16
+ columns = ["model_name", "ARC", "HellaSwag", "TruthfulQA", "Winogrande", "GSM8K",
17
+ "MMLU", "Average"]
18
 
19
+
20
+ @st.cache_data
21
  def plot_radar_chart_index(dataframe: pd.DataFrame, index: int, categories: list = categories, fillcolor: str = fillcolor, line_color:str = line_color):
22
  """
23
  plot the index-th row of the dataframe
 
62
 
63
  return fig
64
 
65
+ @st.cache_data
66
  def plot_radar_chart_name(dataframe: pd.DataFrame, model_name: str, categories: list = categories, fillcolor: str = fillcolor, line_color:str = line_color):
67
  """
68
  plot the results of the model named model_name row of the dataframe
 
105
  showlegend=False
106
  )
107
 
108
+ return fig
109
+
110
+
111
+ @st.cache_data
112
+ def plot_radar_chart_index(dataframe: pd.DataFrame, index: int, categories: list = categories, fillcolor: str = fillcolor, line_color:str = line_color):
113
+ """
114
+ plot the index-th row of the dataframe
115
+
116
+ Arguments:
117
+ dataframe: a pandas DataFrame
118
+ index: the index of the row we want to plot
119
+ categories: the list of the metrics
120
+ fillcolor: a string specifying the color to fill the area
121
+ line_color: a string specifying the color of the lines in the graph
122
+ """
123
+ fig = go.Figure()
124
+ data = dataframe.loc[index,categories].to_numpy()*100
125
+ data = data.astype(float)
126
+ # rounding data
127
+ data = data.round(decimals = 2)
128
+
129
+ # add data to close the area of the radar chart
130
+ data = np.append(data, data[0])
131
+ categories_theta = categories.copy()
132
+ categories_theta.append(categories[0])
133
+ model_name = dataframe.loc[index,"model_name"]
134
+ #print("Printing data ", data, " for ", model_name)
135
+
136
+ fig.add_trace(go.Scatterpolar(
137
+ r=data,
138
+ theta=categories_theta,
139
+ fill='toself',
140
+ fillcolor = fillcolor,
141
+ opacity = opacity,
142
+ line=dict(color = line_color),
143
+ name= model_name
144
+ ))
145
+ fig.update_layout(
146
+ polar=dict(
147
+ radialaxis=dict(
148
+ visible=True,
149
+ range=[0, 100.]
150
+ )),
151
+ showlegend=False
152
+ )
153
+
154
+ return fig
155
+
156
+ @st.cache_data
157
+ def plot_radar_chart_rows(rows: object, columns:list = columns, categories: list = categories, fillcolor: str = fillcolor, line_color:str = line_color):
158
+ """
159
+ plot the results of the model selected by the checkbox
160
+
161
+ Arguments:
162
+ rows: an iterable whose elements are dicts with columns as their keys
163
+ columns: the list of the columns to use
164
+ categories: the list of the metrics
165
+ fillcolor: a string specifying the color to fill the area
166
+ line_color: a string specifying the color of the lines in the graph
167
+ """
168
+ fig = go.Figure()
169
+ dataset = pd.DataFrame(rows, columns=columns)
170
+ data = dataset[categories].to_numpy()
171
+ data = data.astype(float)
172
+
173
+ # add data to close the area of the radar chart
174
+ data = np.append(data, data[:,0].reshape((-1,1)), axis=1)
175
+ categories_theta = categories.copy()
176
+ categories_theta.append(categories[0])
177
+
178
+ #print("Printing data ", data, " for ", model_name)
179
+ for i in range(len(dataset)):
180
+
181
+ fig.add_trace(go.Scatterpolar(
182
+ r=data[i,:],
183
+ theta=categories_theta,
184
+ fill='toself',
185
+ fillcolor = fillcolor,
186
+ opacity = opacity,
187
+ line=dict(color = line_color),
188
+ name= dataset.loc[i,"model_name"]
189
+ ))
190
+ fig.update_layout(
191
+ polar=dict(
192
+ radialaxis=dict(
193
+ visible=True,
194
+ range=[0, 100.]
195
+ )),
196
+ showlegend=False
197
+ )
198
+
199
  return fig