Spaces:
Runtime error
Runtime error
Alexander Seifert
commited on
Commit
Β·
fb9cb6e
1
Parent(s):
17ba05a
update
Browse files- main.py +4 -17
- subpages/__init__.py +1 -1
- subpages/{embeddings.py β hidden_states.py} +2 -2
- subpages/home.py +5 -5
- utils.py +13 -12
main.py
CHANGED
@@ -17,8 +17,9 @@ from subpages import (
|
|
17 |
RawDataPage,
|
18 |
)
|
19 |
from subpages.attention import AttentionPage
|
20 |
-
from subpages.
|
21 |
from subpages.inspect import InspectPage
|
|
|
22 |
|
23 |
sts = st.sidebar
|
24 |
st.set_page_config(
|
@@ -54,25 +55,11 @@ def _write_color_legend(context):
|
|
54 |
def style(x):
|
55 |
return [f"background-color: {rgb}; opacity: 1;" for rgb in colors]
|
56 |
|
57 |
-
labelmap = {
|
58 |
-
"O": "O",
|
59 |
-
"person": "π",
|
60 |
-
"PER": "π",
|
61 |
-
"location": "π",
|
62 |
-
"LOC": "π",
|
63 |
-
"corporation": "π€",
|
64 |
-
"ORG": "π€",
|
65 |
-
"product": "π±",
|
66 |
-
"creative": "π·",
|
67 |
-
"group": "π·",
|
68 |
-
"MISC": "π·",
|
69 |
-
}
|
70 |
-
|
71 |
labels = list(set([lbl.split("-")[1] if "-" in lbl else lbl for lbl in context.labels]))
|
72 |
colors = [st.session_state.get(f"color_{lbl}", "#000000") for lbl in labels]
|
73 |
|
74 |
color_legend_df = pd.DataFrame(
|
75 |
-
[
|
76 |
).T
|
77 |
st.sidebar.write(
|
78 |
color_legend_df.T.style.apply(style, axis=0).set_properties(
|
@@ -85,7 +72,7 @@ def main():
|
|
85 |
pages: list[Page] = [
|
86 |
HomePage(),
|
87 |
AttentionPage(),
|
88 |
-
|
89 |
ProbingPage(),
|
90 |
MetricsPage(),
|
91 |
MisclassifiedPage(),
|
|
|
17 |
RawDataPage,
|
18 |
)
|
19 |
from subpages.attention import AttentionPage
|
20 |
+
from subpages.hidden_states import HiddenStatesPage
|
21 |
from subpages.inspect import InspectPage
|
22 |
+
from utils import classmap
|
23 |
|
24 |
sts = st.sidebar
|
25 |
st.set_page_config(
|
|
|
55 |
def style(x):
|
56 |
return [f"background-color: {rgb}; opacity: 1;" for rgb in colors]
|
57 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
58 |
labels = list(set([lbl.split("-")[1] if "-" in lbl else lbl for lbl in context.labels]))
|
59 |
colors = [st.session_state.get(f"color_{lbl}", "#000000") for lbl in labels]
|
60 |
|
61 |
color_legend_df = pd.DataFrame(
|
62 |
+
[classmap[l] for l in labels], columns=["label"], index=labels
|
63 |
).T
|
64 |
st.sidebar.write(
|
65 |
color_legend_df.T.style.apply(style, axis=0).set_properties(
|
|
|
72 |
pages: list[Page] = [
|
73 |
HomePage(),
|
74 |
AttentionPage(),
|
75 |
+
HiddenStatesPage(),
|
76 |
ProbingPage(),
|
77 |
MetricsPage(),
|
78 |
MisclassifiedPage(),
|
subpages/__init__.py
CHANGED
@@ -1,7 +1,7 @@
|
|
1 |
from subpages.attention import AttentionPage
|
2 |
from subpages.debug import DebugPage
|
3 |
-
from subpages.embeddings import EmbeddingsPage
|
4 |
from subpages.find_duplicates import FindDuplicatesPage
|
|
|
5 |
from subpages.home import HomePage
|
6 |
from subpages.inspect import InspectPage
|
7 |
from subpages.losses import LossesPage
|
|
|
1 |
from subpages.attention import AttentionPage
|
2 |
from subpages.debug import DebugPage
|
|
|
3 |
from subpages.find_duplicates import FindDuplicatesPage
|
4 |
+
from subpages.hidden_states import HiddenStatesPage
|
5 |
from subpages.home import HomePage
|
6 |
from subpages.inspect import InspectPage
|
7 |
from subpages.losses import LossesPage
|
subpages/{embeddings.py β hidden_states.py}
RENAMED
@@ -28,8 +28,8 @@ def reduce_dim_umap(X, n_neighbors=5, min_dist=0.1, metric="euclidean"):
|
|
28 |
return UMAP(n_neighbors=n_neighbors, min_dist=min_dist, metric=metric).fit_transform(X)
|
29 |
|
30 |
|
31 |
-
class
|
32 |
-
name = "
|
33 |
icon = "grid-3x3"
|
34 |
|
35 |
def get_widget_defaults(self):
|
|
|
28 |
return UMAP(n_neighbors=n_neighbors, min_dist=min_dist, metric=metric).fit_transform(X)
|
29 |
|
30 |
|
31 |
+
class HiddenStatesPage(Page):
|
32 |
+
name = "Hidden States"
|
33 |
icon = "grid-3x3"
|
34 |
|
35 |
def get_widget_defaults(self):
|
subpages/home.py
CHANGED
@@ -3,11 +3,10 @@ import random
|
|
3 |
from typing import Optional
|
4 |
|
5 |
import streamlit as st
|
6 |
-
from pandas import wide_to_long
|
7 |
|
8 |
from data import get_data
|
9 |
from subpages.page import Context, Page
|
10 |
-
from utils import color_map_color
|
11 |
|
12 |
_SENTENCE_ENCODER_MODEL = (
|
13 |
"sentence-transformers/all-MiniLM-L6-v2",
|
@@ -53,7 +52,7 @@ class HomePage(Page):
|
|
53 |
|
54 |
with st.expander("π‘", expanded=True):
|
55 |
st.write(
|
56 |
-
"**Error Analysis is an important but often overlooked part of the data science project lifecycle**, for which there is still very little tooling available. Practitioners tend to write throwaway code or, worse, skip this crucial step of understanding their models' errors altogether. This project tries to provide an **extensive toolkit to probe any NER model/dataset combination**, find labeling errors and understand the models' and datasets' limitations, leading the user on her way to further
|
57 |
)
|
58 |
|
59 |
col1, _, col2a, col2b = st.columns([1, 0.05, 0.15, 0.15])
|
@@ -91,7 +90,7 @@ class HomePage(Page):
|
|
91 |
st.text_input(
|
92 |
label="Encoder Model:",
|
93 |
key="encoder_model_name",
|
94 |
-
help="Path or name of the encoder to use",
|
95 |
)
|
96 |
ds_name = st.text_input(
|
97 |
label="Dataset:",
|
@@ -136,8 +135,9 @@ class HomePage(Page):
|
|
136 |
emojis = list(json.load(open("subpages/emoji-en-US.json")).keys())
|
137 |
for label in labels:
|
138 |
if f"icon_{label}" not in st.session_state:
|
139 |
-
st.session_state[f"icon_{label}"] =
|
140 |
st.selectbox(label, key=f"icon_{label}", options=emojis)
|
|
|
141 |
|
142 |
# if st.button("Reset to defaults"):
|
143 |
# st.session_state.update(**get_home_page_defaults())
|
|
|
3 |
from typing import Optional
|
4 |
|
5 |
import streamlit as st
|
|
|
6 |
|
7 |
from data import get_data
|
8 |
from subpages.page import Context, Page
|
9 |
+
from utils import classmap, color_map_color
|
10 |
|
11 |
_SENTENCE_ENCODER_MODEL = (
|
12 |
"sentence-transformers/all-MiniLM-L6-v2",
|
|
|
52 |
|
53 |
with st.expander("π‘", expanded=True):
|
54 |
st.write(
|
55 |
+
"**Error Analysis is an important but often overlooked part of the data science project lifecycle**, for which there is still very little tooling available. Practitioners tend to write throwaway code or, worse, skip this crucial step of understanding their models' errors altogether. This project tries to provide an **extensive toolkit to probe any NER model/dataset combination**, find labeling errors and understand the models' and datasets' limitations, leading the user on her way to further **improving both model AND dataset**."
|
56 |
)
|
57 |
|
58 |
col1, _, col2a, col2b = st.columns([1, 0.05, 0.15, 0.15])
|
|
|
90 |
st.text_input(
|
91 |
label="Encoder Model:",
|
92 |
key="encoder_model_name",
|
93 |
+
help="Path or name of the encoder to use for duplicate detection",
|
94 |
)
|
95 |
ds_name = st.text_input(
|
96 |
label="Dataset:",
|
|
|
135 |
emojis = list(json.load(open("subpages/emoji-en-US.json")).keys())
|
136 |
for label in labels:
|
137 |
if f"icon_{label}" not in st.session_state:
|
138 |
+
st.session_state[f"icon_{label}"] = classmap[label]
|
139 |
st.selectbox(label, key=f"icon_{label}", options=emojis)
|
140 |
+
classmap[label] = st.session_state[f"icon_{label}"]
|
141 |
|
142 |
# if st.button("Reset to defaults"):
|
143 |
# st.session_state.update(**get_home_page_defaults())
|
utils.py
CHANGED
@@ -14,6 +14,19 @@ tokenizer_hash_funcs = {
|
|
14 |
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu" if torch.has_mps else "cpu")
|
15 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
16 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
|
18 |
def aggrid_interactive_table(df: pd.DataFrame) -> dict:
|
19 |
"""Creates an st-aggrid interactive table based on a dataframe.
|
@@ -159,18 +172,6 @@ def colorize_classes(df: pd.DataFrame) -> pd.DataFrame:
|
|
159 |
|
160 |
def htmlify_labeled_example(example: pd.DataFrame) -> str:
|
161 |
html = []
|
162 |
-
classmap = {
|
163 |
-
"O": "O",
|
164 |
-
"PER": "π",
|
165 |
-
"person": "π",
|
166 |
-
"LOC": "π",
|
167 |
-
"location": "π",
|
168 |
-
"ORG": "π€",
|
169 |
-
"corporation": "π€",
|
170 |
-
"product": "π±",
|
171 |
-
"creative": "π·",
|
172 |
-
"MISC": "π·",
|
173 |
-
}
|
174 |
|
175 |
for _, row in example.iterrows():
|
176 |
pred = row.preds.split("-")[1] if "-" in row.preds else "O"
|
|
|
14 |
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu" if torch.has_mps else "cpu")
|
15 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
16 |
|
17 |
+
classmap = {
|
18 |
+
"O": "O",
|
19 |
+
"PER": "π",
|
20 |
+
"person": "π",
|
21 |
+
"LOC": "π",
|
22 |
+
"location": "π",
|
23 |
+
"ORG": "π€",
|
24 |
+
"corporation": "π€",
|
25 |
+
"product": "π±",
|
26 |
+
"creative": "π·",
|
27 |
+
"MISC": "π·",
|
28 |
+
}
|
29 |
+
|
30 |
|
31 |
def aggrid_interactive_table(df: pd.DataFrame) -> dict:
|
32 |
"""Creates an st-aggrid interactive table based on a dataframe.
|
|
|
172 |
|
173 |
def htmlify_labeled_example(example: pd.DataFrame) -> str:
|
174 |
html = []
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
175 |
|
176 |
for _, row in example.iterrows():
|
177 |
pred = row.preds.split("-")[1] if "-" in row.preds else "O"
|