Spaces:
Runtime error
Runtime error
adding parquets
Browse files
app.py
CHANGED
@@ -62,17 +62,17 @@ def down_samp(embedding):
|
|
62 |
|
63 |
|
64 |
def data_comparison(df):
|
65 |
-
selection = alt.selection_multi(fields=['cluster','label'])
|
66 |
-
color = alt.condition(alt.datum.slice == 'high-loss', alt.Color('cluster:
|
67 |
opacity = alt.condition(selection, alt.value(0.7), alt.value(0.25))
|
68 |
|
69 |
# basic chart
|
70 |
scatter = alt.Chart(df).mark_point(size=100, filled=True).encode(
|
71 |
-
x=alt.X('x', axis=None),
|
72 |
-
y=alt.Y('y', axis=None),
|
73 |
color=color,
|
74 |
-
shape=alt.Shape('label', scale=alt.Scale(range=['circle', 'diamond'])),
|
75 |
-
tooltip=['cluster','slice','content','label','pred'],
|
76 |
opacity=opacity
|
77 |
).properties(
|
78 |
width=1000,
|
@@ -80,31 +80,26 @@ def data_comparison(df):
|
|
80 |
).interactive()
|
81 |
|
82 |
legend = alt.Chart(df).mark_point(size=100, filled=True).encode(
|
83 |
-
x=alt.X("label"),
|
84 |
-
y=alt.Y('cluster:
|
85 |
-
shape=alt.Shape('label', scale=alt.Scale(
|
86 |
range=['circle', 'diamond']), legend=None),
|
87 |
color=color,
|
88 |
).add_selection(
|
89 |
selection
|
90 |
)
|
91 |
-
|
92 |
layered = scatter | legend
|
93 |
-
|
94 |
layered = layered.configure_axis(
|
95 |
grid=False
|
96 |
).configure_view(
|
97 |
strokeOpacity=0
|
98 |
)
|
99 |
-
|
100 |
return layered
|
101 |
|
102 |
def quant_panel(embedding_df):
|
103 |
""" Quantitative Panel Layout"""
|
104 |
-
|
105 |
all_metrics = {}
|
106 |
st.warning("**Error slice visualization**")
|
107 |
-
|
108 |
with st.expander("How to read this chart:"):
|
109 |
st.markdown("* Each **point** is an input example.")
|
110 |
st.markdown("* Gray points have low-loss and the colored have high-loss. High-loss instances are clustered using **kmeans** and each color represents a cluster.")
|
@@ -210,12 +205,14 @@ def topic_distribution(weights, smoothing=0.01):
|
|
210 |
|
211 |
if __name__ == "__main__":
|
212 |
### STREAMLIT APP CONGFIG ###
|
213 |
-
st.set_page_config(layout="wide", page_title="Error Analysis")
|
214 |
|
215 |
ut.init_style()
|
216 |
|
217 |
lcol, rcol = st.columns([2, 2])
|
218 |
# ******* loading the mode and the data
|
|
|
|
|
219 |
dataset = st.sidebar.selectbox(
|
220 |
"Dataset",
|
221 |
["amazon_polarity", "yelp_polarity"],
|
@@ -246,8 +243,6 @@ if __name__ == "__main__":
|
|
246 |
st.session_state["user_data"] = data_df
|
247 |
if "selected_slice" not in st.session_state:
|
248 |
st.session_state["selected_slice"] = None
|
249 |
-
if "embedding" not in st.session_state:
|
250 |
-
st.session_state["embedding"] = embedding_umap
|
251 |
|
252 |
data_df['loss'] = data_df['loss'].astype(float)
|
253 |
losses = data_df['loss']
|
@@ -258,13 +253,14 @@ if __name__ == "__main__":
|
|
258 |
if run_kmeans == 'True':
|
259 |
merged = kmeans(data_df,num_clusters=num_clusters)
|
260 |
with lcol:
|
261 |
-
st.markdown('<h3>Error Slices
|
262 |
dataframe = merged[['content', 'label', 'pred', 'loss', 'cluster']].sort_values(
|
263 |
by=['loss'], ascending=False)
|
264 |
table_html = dataframe.to_html(
|
265 |
columns=['content', 'label', 'pred', 'loss', 'cluster'], max_rows=50)
|
266 |
# table_html = table_html.replace("<th>", '<th align="left">') # left-align the headers
|
267 |
with st.expander("How to read the table:"):
|
|
|
268 |
st.markdown("* The table displays model error slices on the evaluation dataset, sorted by loss.")
|
269 |
st.markdown("* Each row is an input example that includes the label, model pred, loss, and error cluster.")
|
270 |
st.write(dataframe,width=900, height=300)
|
|
|
62 |
|
63 |
|
64 |
def data_comparison(df):
|
65 |
+
selection = alt.selection_multi(fields=['cluster:O','label:O'])
|
66 |
+
color = alt.condition(alt.datum.slice == 'high-loss', alt.Color('cluster:O', scale = alt.Scale(domain=df.cluster.unique().tolist())), alt.value("lightgray"))
|
67 |
opacity = alt.condition(selection, alt.value(0.7), alt.value(0.25))
|
68 |
|
69 |
# basic chart
|
70 |
scatter = alt.Chart(df).mark_point(size=100, filled=True).encode(
|
71 |
+
x=alt.X('x:Q', axis=None),
|
72 |
+
y=alt.Y('y:Q', axis=None),
|
73 |
color=color,
|
74 |
+
shape=alt.Shape('label:O', scale=alt.Scale(range=['circle', 'diamond'])),
|
75 |
+
tooltip=['cluster:O','slice:N','content:N','label:O','pred:O'],
|
76 |
opacity=opacity
|
77 |
).properties(
|
78 |
width=1000,
|
|
|
80 |
).interactive()
|
81 |
|
82 |
legend = alt.Chart(df).mark_point(size=100, filled=True).encode(
|
83 |
+
x=alt.X("label:O"),
|
84 |
+
y=alt.Y('cluster:O', axis=alt.Axis(orient='right'), title=""),
|
85 |
+
shape=alt.Shape('label:O', scale=alt.Scale(
|
86 |
range=['circle', 'diamond']), legend=None),
|
87 |
color=color,
|
88 |
).add_selection(
|
89 |
selection
|
90 |
)
|
|
|
91 |
layered = scatter | legend
|
|
|
92 |
layered = layered.configure_axis(
|
93 |
grid=False
|
94 |
).configure_view(
|
95 |
strokeOpacity=0
|
96 |
)
|
|
|
97 |
return layered
|
98 |
|
99 |
def quant_panel(embedding_df):
|
100 |
""" Quantitative Panel Layout"""
|
|
|
101 |
all_metrics = {}
|
102 |
st.warning("**Error slice visualization**")
|
|
|
103 |
with st.expander("How to read this chart:"):
|
104 |
st.markdown("* Each **point** is an input example.")
|
105 |
st.markdown("* Gray points have low-loss and the colored have high-loss. High-loss instances are clustered using **kmeans** and each color represents a cluster.")
|
|
|
205 |
|
206 |
if __name__ == "__main__":
|
207 |
### STREAMLIT APP CONGFIG ###
|
208 |
+
st.set_page_config(layout="wide", page_title="Interactive Error Analysis")
|
209 |
|
210 |
ut.init_style()
|
211 |
|
212 |
lcol, rcol = st.columns([2, 2])
|
213 |
# ******* loading the mode and the data
|
214 |
+
st.sidebar.mardown("<h4>Interactive Error Analysis</h4>", unsafe_allow_html=True)
|
215 |
+
|
216 |
dataset = st.sidebar.selectbox(
|
217 |
"Dataset",
|
218 |
["amazon_polarity", "yelp_polarity"],
|
|
|
243 |
st.session_state["user_data"] = data_df
|
244 |
if "selected_slice" not in st.session_state:
|
245 |
st.session_state["selected_slice"] = None
|
|
|
|
|
246 |
|
247 |
data_df['loss'] = data_df['loss'].astype(float)
|
248 |
losses = data_df['loss']
|
|
|
253 |
if run_kmeans == 'True':
|
254 |
merged = kmeans(data_df,num_clusters=num_clusters)
|
255 |
with lcol:
|
256 |
+
st.markdown('<h3>Error Slices</h3>',unsafe_allow_html=True)
|
257 |
dataframe = merged[['content', 'label', 'pred', 'loss', 'cluster']].sort_values(
|
258 |
by=['loss'], ascending=False)
|
259 |
table_html = dataframe.to_html(
|
260 |
columns=['content', 'label', 'pred', 'loss', 'cluster'], max_rows=50)
|
261 |
# table_html = table_html.replace("<th>", '<th align="left">') # left-align the headers
|
262 |
with st.expander("How to read the table:"):
|
263 |
+
st.markdown("* *Error slice* refers to the subset of evaluation dataset the model performs poorly on.")
|
264 |
st.markdown("* The table displays model error slices on the evaluation dataset, sorted by loss.")
|
265 |
st.markdown("* Each row is an input example that includes the label, model pred, loss, and error cluster.")
|
266 |
st.write(dataframe,width=900, height=300)
|