aksell commited on
Commit
941854b
·
1 Parent(s): 4a93494

Fix labeling bug

Browse files
hexviz/pages/1_🗺️Identify_Interesting_Heads.py CHANGED
@@ -110,6 +110,11 @@ fig = plot_tiled_heatmap(attention, layer_sequence=layer_sequence, head_sequence
110
  st.pyplot(fig)
111
 
112
  st.subheader("Plot single head")
 
 
 
 
 
113
  left, mid, right = st.columns(3)
114
  with left:
115
  if "selected_layer" not in st.session_state:
@@ -134,10 +139,6 @@ with right:
134
  st.session_state.label_tokens = []
135
  tokens_to_label = st.multiselect("Label tokens", options=tokens, key="label_tokens")
136
 
137
- if selected_model.name == ModelType.PROT_T5:
138
- # Remove leading underscores from residue tokens
139
- tokens = [token[1:] if str(token) != "</s>" else token for token in tokens]
140
-
141
  if len(tokens_to_label) > 0:
142
  tokens = [token if token in tokens_to_label else "" for token in tokens]
143
 
 
110
  st.pyplot(fig)
111
 
112
  st.subheader("Plot single head")
113
+
114
+ if selected_model.name == ModelType.PROT_T5:
115
+ # Remove leading underscores from residue tokens
116
+ tokens = [token[1:] if str(token) != "</s>" else token for token in tokens]
117
+
118
  left, mid, right = st.columns(3)
119
  with left:
120
  if "selected_layer" not in st.session_state:
 
139
  st.session_state.label_tokens = []
140
  tokens_to_label = st.multiselect("Label tokens", options=tokens, key="label_tokens")
141
 
 
 
 
 
142
  if len(tokens_to_label) > 0:
143
  tokens = [token if token in tokens_to_label else "" for token in tokens]
144