Spaces:
Sleeping
Sleeping
Show EC number in single attention plot and add option to show special tokens
Browse files- hexviz/pages/1_🗺️Identify_Interesting_Heads.py +31 -8
- hexviz/plot.py +11 -29
hexviz/pages/1_🗺️Identify_Interesting_Heads.py
CHANGED
@@ -1,6 +1,8 @@
|
|
|
|
|
|
1 |
import streamlit as st
|
2 |
|
3 |
-
from hexviz.attention import clean_and_validate_sequence, get_attention,
|
4 |
from hexviz.config import URL
|
5 |
from hexviz.models import Model, ModelType
|
6 |
from hexviz.plot import plot_single_heatmap, plot_tiled_heatmap
|
@@ -51,11 +53,36 @@ chain_selection = st.sidebar.selectbox(
|
|
51 |
)
|
52 |
|
53 |
selected_chain = next(chain for chain in chains if chain.id == chain_selection)
|
54 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
55 |
|
56 |
l = len(sequence)
|
57 |
slice_start, slice_end = select_sequence_slice(l)
|
58 |
truncated_sequence = sequence[slice_start - 1 : slice_end]
|
|
|
59 |
|
60 |
|
61 |
layer_sequence, head_sequence = select_heads_and_layers(st.sidebar, selected_model)
|
@@ -68,11 +95,10 @@ st.markdown(
|
|
68 |
|
69 |
# TODO: Decide if you should get attention for the full sequence or just the truncated sequence
|
70 |
# Attention values will change depending on what we do.
|
71 |
-
attention = get_attention(
|
72 |
attention, tokens = get_attention(
|
73 |
sequence=truncated_sequence,
|
74 |
model_type=selected_model.name,
|
75 |
-
remove_special_tokens=
|
76 |
ec_number=ec_number,
|
77 |
)
|
78 |
|
@@ -111,8 +137,5 @@ with right:
|
|
111 |
unsafe_allow_html=True,
|
112 |
)
|
113 |
|
114 |
-
|
115 |
-
single_head_fig = plot_single_heatmap(
|
116 |
-
attention, layer, head, slice_start, slice_end, max_labels=10
|
117 |
-
)
|
118 |
st.pyplot(single_head_fig)
|
|
|
1 |
+
import re
|
2 |
+
|
3 |
import streamlit as st
|
4 |
|
5 |
+
from hexviz.attention import clean_and_validate_sequence, get_attention, res_to_1letter
|
6 |
from hexviz.config import URL
|
7 |
from hexviz.models import Model, ModelType
|
8 |
from hexviz.plot import plot_single_heatmap, plot_tiled_heatmap
|
|
|
53 |
)
|
54 |
|
55 |
selected_chain = next(chain for chain in chains if chain.id == chain_selection)
|
56 |
+
|
57 |
+
ec_number = ""
|
58 |
+
if selected_model.name == ModelType.ZymCTRL:
|
59 |
+
st.sidebar.markdown(
|
60 |
+
"""
|
61 |
+
ZymCTRL EC number
|
62 |
+
---
|
63 |
+
"""
|
64 |
+
)
|
65 |
+
try:
|
66 |
+
ec_number = structure.header["compound"]["1"]["ec"]
|
67 |
+
except KeyError:
|
68 |
+
pass
|
69 |
+
ec_number = st.sidebar.text_input("Enzyme Comission number (EC)", ec_number)
|
70 |
+
|
71 |
+
# Validate EC number
|
72 |
+
if not re.match(r"^\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}$", ec_number):
|
73 |
+
st.sidebar.error(
|
74 |
+
"""Please enter a valid Enzyme Commission number in the format of 4
|
75 |
+
integers separated by periods (e.g., 1.2.3.21)"""
|
76 |
+
)
|
77 |
+
|
78 |
+
|
79 |
+
residues = [res for res in selected_chain.get_residues()]
|
80 |
+
sequence = res_to_1letter(residues)
|
81 |
|
82 |
l = len(sequence)
|
83 |
slice_start, slice_end = select_sequence_slice(l)
|
84 |
truncated_sequence = sequence[slice_start - 1 : slice_end]
|
85 |
+
remove_special_tokens = st.sidebar.checkbox("Remove special tokens", key="remove_special_tokens")
|
86 |
|
87 |
|
88 |
layer_sequence, head_sequence = select_heads_and_layers(st.sidebar, selected_model)
|
|
|
95 |
|
96 |
# TODO: Decide if you should get attention for the full sequence or just the truncated sequence
|
97 |
# Attention values will change depending on what we do.
|
|
|
98 |
attention, tokens = get_attention(
|
99 |
sequence=truncated_sequence,
|
100 |
model_type=selected_model.name,
|
101 |
+
remove_special_tokens=remove_special_tokens,
|
102 |
ec_number=ec_number,
|
103 |
)
|
104 |
|
|
|
137 |
unsafe_allow_html=True,
|
138 |
)
|
139 |
|
140 |
+
single_head_fig = plot_single_heatmap(attention, layer, head, tokens=tokens)
|
|
|
|
|
|
|
141 |
st.pyplot(single_head_fig)
|
hexviz/plot.py
CHANGED
@@ -2,7 +2,7 @@ from typing import List
|
|
2 |
|
3 |
import matplotlib.pyplot as plt
|
4 |
import numpy as np
|
5 |
-
from matplotlib.ticker import
|
6 |
from mpl_toolkits.axes_grid1 import make_axes_locatable
|
7 |
|
8 |
|
@@ -40,42 +40,24 @@ def plot_single_heatmap(
|
|
40 |
tensor,
|
41 |
layer: int,
|
42 |
head: int,
|
43 |
-
|
44 |
-
slice_end: int,
|
45 |
-
max_labels: int = 40,
|
46 |
):
|
47 |
single_heatmap = tensor[layer, head, :, :].detach().numpy()
|
48 |
|
49 |
fig, ax = plt.subplots(figsize=(10, 10))
|
50 |
heatmap = ax.imshow(single_heatmap, cmap="viridis", aspect="equal")
|
51 |
|
52 |
-
# Set the x and y axis
|
53 |
-
ax.xaxis.set_major_locator(
|
54 |
-
|
55 |
-
)
|
56 |
-
ax.yaxis.set_major_locator(
|
57 |
-
MaxNLocator(integer=True, steps=[1, 2, 5], prune="both", nbins=max_labels)
|
58 |
-
)
|
59 |
-
|
60 |
-
tick_indices_x = np.clip((ax.get_xticks()).astype(int), 0, slice_end - slice_start)
|
61 |
-
tick_indices_y = np.clip((ax.get_yticks()).astype(int), 0, slice_end - slice_start)
|
62 |
-
ax.set_xticklabels(
|
63 |
-
np.arange(slice_start, slice_end + 1)[tick_indices_x], fontsize=8
|
64 |
-
)
|
65 |
-
ax.set_yticklabels(
|
66 |
-
np.arange(slice_start, slice_end + 1)[tick_indices_y], fontsize=8
|
67 |
-
)
|
68 |
-
|
69 |
-
# Set the x and y axis minor ticks
|
70 |
-
ax.xaxis.set_minor_locator(MultipleLocator(1))
|
71 |
-
ax.yaxis.set_minor_locator(MultipleLocator(1))
|
72 |
|
73 |
-
# Set
|
74 |
-
ax.
|
75 |
-
ax.
|
76 |
|
77 |
-
#
|
78 |
-
|
|
|
79 |
|
80 |
# Create custom colorbar axes with the desired dimensions
|
81 |
divider = make_axes_locatable(ax)
|
|
|
2 |
|
3 |
import matplotlib.pyplot as plt
|
4 |
import numpy as np
|
5 |
+
from matplotlib.ticker import FixedLocator
|
6 |
from mpl_toolkits.axes_grid1 import make_axes_locatable
|
7 |
|
8 |
|
|
|
40 |
tensor,
|
41 |
layer: int,
|
42 |
head: int,
|
43 |
+
tokens: list[str],
|
|
|
|
|
44 |
):
|
45 |
single_heatmap = tensor[layer, head, :, :].detach().numpy()
|
46 |
|
47 |
fig, ax = plt.subplots(figsize=(10, 10))
|
48 |
heatmap = ax.imshow(single_heatmap, cmap="viridis", aspect="equal")
|
49 |
|
50 |
+
# Set the x and y axis ticks
|
51 |
+
ax.xaxis.set_major_locator(FixedLocator(np.arange(0, len(tokens))))
|
52 |
+
ax.yaxis.set_major_locator(FixedLocator(np.arange(0, len(tokens))))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
53 |
|
54 |
+
# Set tick labels as sequence values
|
55 |
+
ax.set_xticklabels(tokens, fontsize=8, rotation=45, ha="right", rotation_mode="anchor")
|
56 |
+
ax.set_yticklabels(tokens, fontsize=8)
|
57 |
|
58 |
+
# Set the axis labels
|
59 |
+
ax.set_xlabel("Sequence tokens")
|
60 |
+
ax.set_ylabel("Sequence tokens")
|
61 |
|
62 |
# Create custom colorbar axes with the desired dimensions
|
63 |
divider = make_axes_locatable(ax)
|