Spaces:
Sleeping
Sleeping
Add plot of single head
Browse filesThis makes it possible to focus on on specific parts of a head
and gives you a link to the main view to look at the head on a structure.
- hexviz/pages/1_🗺️Identify_Interesting_Heads.py +38 -1
- hexviz/plot.py +74 -8
hexviz/pages/1_🗺️Identify_Interesting_Heads.py
CHANGED
@@ -2,7 +2,7 @@ import streamlit as st
|
|
2 |
|
3 |
from hexviz.attention import clean_and_validate_sequence, get_attention, get_sequence
|
4 |
from hexviz.models import Model, ModelType
|
5 |
-
from hexviz.plot import plot_tiled_heatmap
|
6 |
from hexviz.view import (
|
7 |
menu_items,
|
8 |
select_heads_and_layers,
|
@@ -75,4 +75,41 @@ st.write(attention.shape)
|
|
75 |
fig = plot_tiled_heatmap(
|
76 |
attention, layer_sequence=layer_sequence, head_sequence=head_sequence
|
77 |
)
|
|
|
|
|
78 |
st.pyplot(fig)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
|
3 |
from hexviz.attention import clean_and_validate_sequence, get_attention, get_sequence
|
4 |
from hexviz.models import Model, ModelType
|
5 |
+
from hexviz.plot import plot_single_heatmap, plot_tiled_heatmap
|
6 |
from hexviz.view import (
|
7 |
menu_items,
|
8 |
select_heads_and_layers,
|
|
|
75 |
fig = plot_tiled_heatmap(
|
76 |
attention, layer_sequence=layer_sequence, head_sequence=head_sequence
|
77 |
)
|
78 |
+
|
79 |
+
|
80 |
st.pyplot(fig)
|
81 |
+
|
82 |
+
st.subheader("Plot single head")
|
83 |
+
left, mid, right = st.columns(3)
|
84 |
+
with left:
|
85 |
+
if "selected_layer" not in st.session_state:
|
86 |
+
st.session_state["selected_layer"] = 5
|
87 |
+
layer_one = st.selectbox(
|
88 |
+
"Layer",
|
89 |
+
options=[i for i in range(1, selected_model.layers + 1)],
|
90 |
+
key="selected_layer",
|
91 |
+
)
|
92 |
+
layer = layer_one - 1
|
93 |
+
with mid:
|
94 |
+
if "selected_head" not in st.session_state:
|
95 |
+
st.session_state["selected_head"] = 1
|
96 |
+
head_one = st.selectbox(
|
97 |
+
"Head",
|
98 |
+
options=[i for i in range(1, selected_model.heads + 1)],
|
99 |
+
key="selected_head",
|
100 |
+
)
|
101 |
+
head = head_one - 1
|
102 |
+
with right:
|
103 |
+
st.markdown(
|
104 |
+
"""
|
105 |
+
|
106 |
+
|
107 |
+
### [🧬View attention from head on structure](Attention_Visualization)
|
108 |
+
"""
|
109 |
+
)
|
110 |
+
|
111 |
+
|
112 |
+
single_head_fig = plot_single_heatmap(
|
113 |
+
attention, layer, head, slice_start, slice_end, max_labels=10
|
114 |
+
)
|
115 |
+
st.pyplot(single_head_fig)
|
hexviz/plot.py
CHANGED
@@ -1,31 +1,97 @@
|
|
1 |
from typing import List
|
2 |
|
3 |
import matplotlib.pyplot as plt
|
|
|
|
|
|
|
4 |
|
5 |
|
6 |
def plot_tiled_heatmap(tensor, layer_sequence: List[int], head_sequence: List[int]):
|
7 |
-
tensor = tensor[layer_sequence, :][
|
|
|
|
|
8 |
num_layers = len(layer_sequence)
|
9 |
num_heads = len(head_sequence)
|
10 |
|
11 |
x_size = num_heads * 2
|
12 |
y_size = num_layers * 2
|
13 |
-
fig, axes = plt.subplots(
|
|
|
|
|
14 |
for i in range(num_layers):
|
15 |
for j in range(num_heads):
|
16 |
-
axes[i, j].imshow(
|
17 |
-
|
|
|
|
|
18 |
|
19 |
# Enumerate the axes
|
20 |
if i == 0:
|
21 |
-
axes[i, j].set_title(
|
|
|
|
|
22 |
|
23 |
# Calculate the row label offset based on the number of columns
|
24 |
offset = 0.02 + (12 - num_heads) * 0.0015
|
25 |
for i, ax_row in enumerate(axes):
|
26 |
row_label = f"{layer_sequence[i]+1}"
|
27 |
-
row_pos = ax_row[num_heads-1].get_position()
|
28 |
-
fig.text(
|
|
|
|
|
29 |
|
30 |
plt.subplots_adjust(wspace=0.1, hspace=0.1)
|
31 |
-
return fig
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
from typing import List
|
2 |
|
3 |
import matplotlib.pyplot as plt
|
4 |
+
import numpy as np
|
5 |
+
from matplotlib.ticker import MaxNLocator, MultipleLocator
|
6 |
+
from mpl_toolkits.axes_grid1 import make_axes_locatable
|
7 |
|
8 |
|
9 |
def plot_tiled_heatmap(tensor, layer_sequence: List[int], head_sequence: List[int]):
|
10 |
+
tensor = tensor[layer_sequence, :][
|
11 |
+
:, head_sequence, :, :
|
12 |
+
] # Slice the tensor according to the provided sequences and sequence_count
|
13 |
num_layers = len(layer_sequence)
|
14 |
num_heads = len(head_sequence)
|
15 |
|
16 |
x_size = num_heads * 2
|
17 |
y_size = num_layers * 2
|
18 |
+
fig, axes = plt.subplots(
|
19 |
+
num_layers, num_heads, figsize=(x_size, y_size), squeeze=False
|
20 |
+
)
|
21 |
for i in range(num_layers):
|
22 |
for j in range(num_heads):
|
23 |
+
axes[i, j].imshow(
|
24 |
+
tensor[i, j].detach().numpy(), cmap="viridis", aspect="equal"
|
25 |
+
)
|
26 |
+
axes[i, j].axis("off")
|
27 |
|
28 |
# Enumerate the axes
|
29 |
if i == 0:
|
30 |
+
axes[i, j].set_title(
|
31 |
+
f"Head {head_sequence[j] + 1}", fontsize=10, y=1.05
|
32 |
+
)
|
33 |
|
34 |
# Calculate the row label offset based on the number of columns
|
35 |
offset = 0.02 + (12 - num_heads) * 0.0015
|
36 |
for i, ax_row in enumerate(axes):
|
37 |
row_label = f"{layer_sequence[i]+1}"
|
38 |
+
row_pos = ax_row[num_heads - 1].get_position()
|
39 |
+
fig.text(
|
40 |
+
row_pos.x1 + offset, (row_pos.y1 + row_pos.y0) / 2, row_label, va="center"
|
41 |
+
)
|
42 |
|
43 |
plt.subplots_adjust(wspace=0.1, hspace=0.1)
|
44 |
+
return fig
|
45 |
+
|
46 |
+
|
47 |
+
def plot_single_heatmap(
|
48 |
+
tensor,
|
49 |
+
layer: int,
|
50 |
+
head: int,
|
51 |
+
slice_start: int,
|
52 |
+
slice_end: int,
|
53 |
+
max_labels: int = 40,
|
54 |
+
):
|
55 |
+
single_heatmap = tensor[layer, head, :, :].detach().numpy()
|
56 |
+
|
57 |
+
fig, ax = plt.subplots(figsize=(10, 10))
|
58 |
+
heatmap = ax.imshow(single_heatmap, cmap="viridis", aspect="equal")
|
59 |
+
|
60 |
+
# Set the x and y axis major ticks and labels
|
61 |
+
ax.xaxis.set_major_locator(
|
62 |
+
MaxNLocator(integer=True, steps=[1, 2, 5], prune="both", nbins=max_labels)
|
63 |
+
)
|
64 |
+
ax.yaxis.set_major_locator(
|
65 |
+
MaxNLocator(integer=True, steps=[1, 2, 5], prune="both", nbins=max_labels)
|
66 |
+
)
|
67 |
+
|
68 |
+
ax.set_xticklabels(
|
69 |
+
np.arange(slice_start, slice_end + 1)[ax.get_xticks().astype(int)], fontsize=8
|
70 |
+
)
|
71 |
+
ax.set_yticklabels(
|
72 |
+
np.arange(slice_start, slice_end + 1)[ax.get_yticks().astype(int)], fontsize=8
|
73 |
+
)
|
74 |
+
|
75 |
+
# Set the x and y axis minor ticks
|
76 |
+
ax.xaxis.set_minor_locator(MultipleLocator(1))
|
77 |
+
ax.yaxis.set_minor_locator(MultipleLocator(1))
|
78 |
+
|
79 |
+
# Set the axis labels
|
80 |
+
ax.set_xlabel("Residue Number")
|
81 |
+
ax.set_ylabel("Residue Number")
|
82 |
+
|
83 |
+
# Rotate the tick labels and set their alignment
|
84 |
+
plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor")
|
85 |
+
|
86 |
+
# Create custom colorbar axes with the desired dimensions
|
87 |
+
divider = make_axes_locatable(ax)
|
88 |
+
cax = divider.append_axes("right", size="5%", pad=0.1)
|
89 |
+
|
90 |
+
# Add a colorbar to show the scale
|
91 |
+
cbar = fig.colorbar(heatmap, cax=cax)
|
92 |
+
cbar.ax.set_ylabel("Attention Weight", rotation=-90, va="bottom")
|
93 |
+
|
94 |
+
# Set the title of the plot
|
95 |
+
ax.set_title(f"Layer {layer + 1} - Head {head + 1}")
|
96 |
+
|
97 |
+
return fig
|