Spaces:
Sleeping
Sleeping
Fix labeling bug
Browse files
hexviz/pages/1_🗺️Identify_Interesting_Heads.py
CHANGED
@@ -110,6 +110,11 @@ fig = plot_tiled_heatmap(attention, layer_sequence=layer_sequence, head_sequence
|
|
110 |
st.pyplot(fig)
|
111 |
|
112 |
st.subheader("Plot single head")
|
|
|
|
|
|
|
|
|
|
|
113 |
left, mid, right = st.columns(3)
|
114 |
with left:
|
115 |
if "selected_layer" not in st.session_state:
|
@@ -134,10 +139,6 @@ with right:
|
|
134 |
st.session_state.label_tokens = []
|
135 |
tokens_to_label = st.multiselect("Label tokens", options=tokens, key="label_tokens")
|
136 |
|
137 |
-
if selected_model.name == ModelType.PROT_T5:
|
138 |
-
# Remove leading underscores from residue tokens
|
139 |
-
tokens = [token[1:] if str(token) != "</s>" else token for token in tokens]
|
140 |
-
|
141 |
if len(tokens_to_label) > 0:
|
142 |
tokens = [token if token in tokens_to_label else "" for token in tokens]
|
143 |
|
|
|
110 |
st.pyplot(fig)
|
111 |
|
112 |
st.subheader("Plot single head")
|
113 |
+
|
114 |
+
if selected_model.name == ModelType.PROT_T5:
|
115 |
+
# Remove leading underscores from residue tokens
|
116 |
+
tokens = [token[1:] if str(token) != "</s>" else token for token in tokens]
|
117 |
+
|
118 |
left, mid, right = st.columns(3)
|
119 |
with left:
|
120 |
if "selected_layer" not in st.session_state:
|
|
|
139 |
st.session_state.label_tokens = []
|
140 |
tokens_to_label = st.multiselect("Label tokens", options=tokens, key="label_tokens")
|
141 |
|
|
|
|
|
|
|
|
|
142 |
if len(tokens_to_label) > 0:
|
143 |
tokens = [token if token in tokens_to_label else "" for token in tokens]
|
144 |
|