Spaces:
Runtime error
Runtime error
adding parquets
Browse files
app.py
CHANGED
@@ -224,15 +224,6 @@ if __name__ == "__main__":
|
|
224 |
["distilbert-base-uncased-finetuned-sst-2-english",
|
225 |
"albert-base-v2-yelp-polarity"],
|
226 |
)
|
227 |
-
|
228 |
-
loss_quantile = st.sidebar.slider(
|
229 |
-
"Loss Quantile", min_value=0.5, max_value=1.0,step=0.01,value=0.95
|
230 |
-
)
|
231 |
-
|
232 |
-
run_kmeans = st.sidebar.radio("Cluster error slice?", ('True', 'False'), index=0)
|
233 |
-
|
234 |
-
num_clusters = st.sidebar.slider("# clusters", min_value=1, max_value=20, step=1, value=3)
|
235 |
-
|
236 |
### LOAD DATA AND SESSION VARIABLES ###
|
237 |
data_df = pd.read_parquet('./assets/data/'+dataset+ '_'+ model+'.parquet')
|
238 |
if model == 'albert-base-v2-yelp-polarity':
|
@@ -243,13 +234,28 @@ if __name__ == "__main__":
|
|
243 |
st.session_state["user_data"] = data_df
|
244 |
if "selected_slice" not in st.session_state:
|
245 |
st.session_state["selected_slice"] = None
|
246 |
-
|
|
|
|
|
|
|
247 |
data_df['loss'] = data_df['loss'].astype(float)
|
248 |
losses = data_df['loss']
|
249 |
high_loss = losses.quantile(loss_quantile)
|
250 |
data_df['slice'] = 'high-loss'
|
251 |
data_df['slice'] = data_df['slice'].where(data_df['loss'] > high_loss, 'low-loss')
|
252 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
253 |
if run_kmeans == 'True':
|
254 |
merged = kmeans(data_df,num_clusters=num_clusters)
|
255 |
with lcol:
|
@@ -264,12 +270,5 @@ if __name__ == "__main__":
|
|
264 |
st.markdown("* The table displays model error slices on the evaluation dataset, sorted by loss.")
|
265 |
st.markdown("* Each row is an input example that includes the label, model pred, loss, and error cluster.")
|
266 |
st.write(dataframe,width=900, height=300)
|
267 |
-
|
268 |
-
with rcol:
|
269 |
-
with st.spinner(text='loading...'):
|
270 |
-
st.markdown('<h3>Word Distribution in Error Slice</h3>', unsafe_allow_html=True)
|
271 |
-
commontokens = frequent_tokens(merged, tokenizer, loss_quantile=loss_quantile)
|
272 |
-
with st.expander("How to read the table:"):
|
273 |
-
st.markdown("* The table displays the most frequent tokens in error slices, relative to their frequencies in the val set.")
|
274 |
-
st.write(commontokens)
|
275 |
quant_panel(merged)
|
|
|
224 |
["distilbert-base-uncased-finetuned-sst-2-english",
|
225 |
"albert-base-v2-yelp-polarity"],
|
226 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
227 |
### LOAD DATA AND SESSION VARIABLES ###
|
228 |
data_df = pd.read_parquet('./assets/data/'+dataset+ '_'+ model+'.parquet')
|
229 |
if model == 'albert-base-v2-yelp-polarity':
|
|
|
234 |
st.session_state["user_data"] = data_df
|
235 |
if "selected_slice" not in st.session_state:
|
236 |
st.session_state["selected_slice"] = None
|
237 |
+
|
238 |
+
loss_quantile = st.sidebar.slider(
|
239 |
+
"Loss Quantile", min_value=0.5, max_value=1.0,step=0.01,value=0.95
|
240 |
+
)
|
241 |
data_df['loss'] = data_df['loss'].astype(float)
|
242 |
losses = data_df['loss']
|
243 |
high_loss = losses.quantile(loss_quantile)
|
244 |
data_df['slice'] = 'high-loss'
|
245 |
data_df['slice'] = data_df['slice'].where(data_df['loss'] > high_loss, 'low-loss')
|
246 |
|
247 |
+
with rcol:
|
248 |
+
with st.spinner(text='loading...'):
|
249 |
+
st.markdown('<h3>Word Distribution in Error Slice</h3>', unsafe_allow_html=True)
|
250 |
+
commontokens = frequent_tokens(data_df, tokenizer, loss_quantile=loss_quantile)
|
251 |
+
with st.expander("How to read the table:"):
|
252 |
+
st.markdown("* The table displays the most frequent tokens in error slices, relative to their frequencies in the val set.")
|
253 |
+
st.write(commontokens)
|
254 |
+
|
255 |
+
run_kmeans = st.sidebar.radio("Cluster error slice?", ('True', 'False'), index=0)
|
256 |
+
|
257 |
+
num_clusters = st.sidebar.slider("# clusters", min_value=1, max_value=20, step=1, value=3)
|
258 |
+
|
259 |
if run_kmeans == 'True':
|
260 |
merged = kmeans(data_df,num_clusters=num_clusters)
|
261 |
with lcol:
|
|
|
270 |
st.markdown("* The table displays model error slices on the evaluation dataset, sorted by loss.")
|
271 |
st.markdown("* Each row is an input example that includes the label, model pred, loss, and error cluster.")
|
272 |
st.write(dataframe,width=900, height=300)
|
273 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
274 |
quant_panel(merged)
|