nazneen commited on
Commit
8dd2bc6
·
1 Parent(s): 76e2fde

adding parquets

Browse files
Files changed (1) hide show
  1. app.py +14 -18
app.py CHANGED
@@ -62,17 +62,17 @@ def down_samp(embedding):
62
 
63
 
64
  def data_comparison(df):
65
- selection = alt.selection_multi(fields=['cluster','label'])
66
- color = alt.condition(alt.datum.slice == 'high-loss', alt.Color('cluster:N', scale = alt.Scale(domain=df.cluster.unique().tolist())), alt.value("lightgray"))
67
  opacity = alt.condition(selection, alt.value(0.7), alt.value(0.25))
68
 
69
  # basic chart
70
  scatter = alt.Chart(df).mark_point(size=100, filled=True).encode(
71
- x=alt.X('x', axis=None),
72
- y=alt.Y('y', axis=None),
73
  color=color,
74
- shape=alt.Shape('label', scale=alt.Scale(range=['circle', 'diamond'])),
75
- tooltip=['cluster','slice','content','label','pred'],
76
  opacity=opacity
77
  ).properties(
78
  width=1000,
@@ -80,31 +80,26 @@ def data_comparison(df):
80
  ).interactive()
81
 
82
  legend = alt.Chart(df).mark_point(size=100, filled=True).encode(
83
- x=alt.X("label"),
84
- y=alt.Y('cluster:N', axis=alt.Axis(orient='right'), title=""),
85
- shape=alt.Shape('label', scale=alt.Scale(
86
  range=['circle', 'diamond']), legend=None),
87
  color=color,
88
  ).add_selection(
89
  selection
90
  )
91
-
92
  layered = scatter | legend
93
-
94
  layered = layered.configure_axis(
95
  grid=False
96
  ).configure_view(
97
  strokeOpacity=0
98
  )
99
-
100
  return layered
101
 
102
  def quant_panel(embedding_df):
103
  """ Quantitative Panel Layout"""
104
-
105
  all_metrics = {}
106
  st.warning("**Error slice visualization**")
107
-
108
  with st.expander("How to read this chart:"):
109
  st.markdown("* Each **point** is an input example.")
110
  st.markdown("* Gray points have low-loss and the colored have high-loss. High-loss instances are clustered using **kmeans** and each color represents a cluster.")
@@ -210,12 +205,14 @@ def topic_distribution(weights, smoothing=0.01):
210
 
211
  if __name__ == "__main__":
212
  ### STREAMLIT APP CONGFIG ###
213
- st.set_page_config(layout="wide", page_title="Error Analysis")
214
 
215
  ut.init_style()
216
 
217
  lcol, rcol = st.columns([2, 2])
218
  # ******* loading the mode and the data
 
 
219
  dataset = st.sidebar.selectbox(
220
  "Dataset",
221
  ["amazon_polarity", "yelp_polarity"],
@@ -246,8 +243,6 @@ if __name__ == "__main__":
246
  st.session_state["user_data"] = data_df
247
  if "selected_slice" not in st.session_state:
248
  st.session_state["selected_slice"] = None
249
- if "embedding" not in st.session_state:
250
- st.session_state["embedding"] = embedding_umap
251
 
252
  data_df['loss'] = data_df['loss'].astype(float)
253
  losses = data_df['loss']
@@ -258,13 +253,14 @@ if __name__ == "__main__":
258
  if run_kmeans == 'True':
259
  merged = kmeans(data_df,num_clusters=num_clusters)
260
  with lcol:
261
- st.markdown('<h3>Error Slices (Subset of evaluation dataset the model performs poorly)</h3>',unsafe_allow_html=True)
262
  dataframe = merged[['content', 'label', 'pred', 'loss', 'cluster']].sort_values(
263
  by=['loss'], ascending=False)
264
  table_html = dataframe.to_html(
265
  columns=['content', 'label', 'pred', 'loss', 'cluster'], max_rows=50)
266
  # table_html = table_html.replace("<th>", '<th align="left">') # left-align the headers
267
  with st.expander("How to read the table:"):
 
268
  st.markdown("* The table displays model error slices on the evaluation dataset, sorted by loss.")
269
  st.markdown("* Each row is an input example that includes the label, model pred, loss, and error cluster.")
270
  st.write(dataframe,width=900, height=300)
 
62
 
63
 
64
  def data_comparison(df):
65
+ selection = alt.selection_multi(fields=['cluster:O','label:O'])
66
+ color = alt.condition(alt.datum.slice == 'high-loss', alt.Color('cluster:O', scale = alt.Scale(domain=df.cluster.unique().tolist())), alt.value("lightgray"))
67
  opacity = alt.condition(selection, alt.value(0.7), alt.value(0.25))
68
 
69
  # basic chart
70
  scatter = alt.Chart(df).mark_point(size=100, filled=True).encode(
71
+ x=alt.X('x:Q', axis=None),
72
+ y=alt.Y('y:Q', axis=None),
73
  color=color,
74
+ shape=alt.Shape('label:O', scale=alt.Scale(range=['circle', 'diamond'])),
75
+ tooltip=['cluster:O','slice:N','content:N','label:O','pred:O'],
76
  opacity=opacity
77
  ).properties(
78
  width=1000,
 
80
  ).interactive()
81
 
82
  legend = alt.Chart(df).mark_point(size=100, filled=True).encode(
83
+ x=alt.X("label:O"),
84
+ y=alt.Y('cluster:O', axis=alt.Axis(orient='right'), title=""),
85
+ shape=alt.Shape('label:O', scale=alt.Scale(
86
  range=['circle', 'diamond']), legend=None),
87
  color=color,
88
  ).add_selection(
89
  selection
90
  )
 
91
  layered = scatter | legend
 
92
  layered = layered.configure_axis(
93
  grid=False
94
  ).configure_view(
95
  strokeOpacity=0
96
  )
 
97
  return layered
98
 
99
  def quant_panel(embedding_df):
100
  """ Quantitative Panel Layout"""
 
101
  all_metrics = {}
102
  st.warning("**Error slice visualization**")
 
103
  with st.expander("How to read this chart:"):
104
  st.markdown("* Each **point** is an input example.")
105
  st.markdown("* Gray points have low-loss and the colored have high-loss. High-loss instances are clustered using **kmeans** and each color represents a cluster.")
 
205
 
206
  if __name__ == "__main__":
207
  ### STREAMLIT APP CONGFIG ###
208
+ st.set_page_config(layout="wide", page_title="Interactive Error Analysis")
209
 
210
  ut.init_style()
211
 
212
  lcol, rcol = st.columns([2, 2])
213
  # ******* loading the mode and the data
214
+ st.sidebar.mardown("<h4>Interactive Error Analysis</h4>", unsafe_allow_html=True)
215
+
216
  dataset = st.sidebar.selectbox(
217
  "Dataset",
218
  ["amazon_polarity", "yelp_polarity"],
 
243
  st.session_state["user_data"] = data_df
244
  if "selected_slice" not in st.session_state:
245
  st.session_state["selected_slice"] = None
 
 
246
 
247
  data_df['loss'] = data_df['loss'].astype(float)
248
  losses = data_df['loss']
 
253
  if run_kmeans == 'True':
254
  merged = kmeans(data_df,num_clusters=num_clusters)
255
  with lcol:
256
+ st.markdown('<h3>Error Slices</h3>',unsafe_allow_html=True)
257
  dataframe = merged[['content', 'label', 'pred', 'loss', 'cluster']].sort_values(
258
  by=['loss'], ascending=False)
259
  table_html = dataframe.to_html(
260
  columns=['content', 'label', 'pred', 'loss', 'cluster'], max_rows=50)
261
  # table_html = table_html.replace("<th>", '<th align="left">') # left-align the headers
262
  with st.expander("How to read the table:"):
263
+ st.markdown("* *Error slice* refers to the subset of evaluation dataset the model performs poorly on.")
264
  st.markdown("* The table displays model error slices on the evaluation dataset, sorted by loss.")
265
  st.markdown("* Each row is an input example that includes the label, model pred, loss, and error cluster.")
266
  st.write(dataframe,width=900, height=300)