aksell commited on
Commit
214c50a
·
1 Parent(s): 1804445

Add plot of single head

Browse files

This makes it possible to focus on on specific parts of a head
and gives you a link to the main view to look at the head on a structure.

hexviz/pages/1_🗺️Identify_Interesting_Heads.py CHANGED
@@ -2,7 +2,7 @@ import streamlit as st
2
 
3
  from hexviz.attention import clean_and_validate_sequence, get_attention, get_sequence
4
  from hexviz.models import Model, ModelType
5
- from hexviz.plot import plot_tiled_heatmap
6
  from hexviz.view import (
7
  menu_items,
8
  select_heads_and_layers,
@@ -75,4 +75,41 @@ st.write(attention.shape)
75
  fig = plot_tiled_heatmap(
76
  attention, layer_sequence=layer_sequence, head_sequence=head_sequence
77
  )
 
 
78
  st.pyplot(fig)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
  from hexviz.attention import clean_and_validate_sequence, get_attention, get_sequence
4
  from hexviz.models import Model, ModelType
5
+ from hexviz.plot import plot_single_heatmap, plot_tiled_heatmap
6
  from hexviz.view import (
7
  menu_items,
8
  select_heads_and_layers,
 
75
  fig = plot_tiled_heatmap(
76
  attention, layer_sequence=layer_sequence, head_sequence=head_sequence
77
  )
78
+
79
+
80
  st.pyplot(fig)
81
+
82
+ st.subheader("Plot single head")
83
+ left, mid, right = st.columns(3)
84
+ with left:
85
+ if "selected_layer" not in st.session_state:
86
+ st.session_state["selected_layer"] = 5
87
+ layer_one = st.selectbox(
88
+ "Layer",
89
+ options=[i for i in range(1, selected_model.layers + 1)],
90
+ key="selected_layer",
91
+ )
92
+ layer = layer_one - 1
93
+ with mid:
94
+ if "selected_head" not in st.session_state:
95
+ st.session_state["selected_head"] = 1
96
+ head_one = st.selectbox(
97
+ "Head",
98
+ options=[i for i in range(1, selected_model.heads + 1)],
99
+ key="selected_head",
100
+ )
101
+ head = head_one - 1
102
+ with right:
103
+ st.markdown(
104
+ """
105
+
106
+
107
+ ### [🧬View attention from head on structure](Attention_Visualization)
108
+ """
109
+ )
110
+
111
+
112
+ single_head_fig = plot_single_heatmap(
113
+ attention, layer, head, slice_start, slice_end, max_labels=10
114
+ )
115
+ st.pyplot(single_head_fig)
hexviz/plot.py CHANGED
@@ -1,31 +1,97 @@
1
  from typing import List
2
 
3
  import matplotlib.pyplot as plt
 
 
 
4
 
5
 
6
  def plot_tiled_heatmap(tensor, layer_sequence: List[int], head_sequence: List[int]):
7
- tensor = tensor[layer_sequence, :][:, head_sequence, :, :] # Slice the tensor according to the provided sequences and sequence_count
 
 
8
  num_layers = len(layer_sequence)
9
  num_heads = len(head_sequence)
10
 
11
  x_size = num_heads * 2
12
  y_size = num_layers * 2
13
- fig, axes = plt.subplots(num_layers, num_heads, figsize=(x_size, y_size), squeeze=False)
 
 
14
  for i in range(num_layers):
15
  for j in range(num_heads):
16
- axes[i, j].imshow(tensor[i, j].detach().numpy(), cmap='viridis', aspect='equal')
17
- axes[i, j].axis('off')
 
 
18
 
19
  # Enumerate the axes
20
  if i == 0:
21
- axes[i, j].set_title(f'Head {head_sequence[j] + 1}', fontsize=10, y=1.05)
 
 
22
 
23
  # Calculate the row label offset based on the number of columns
24
  offset = 0.02 + (12 - num_heads) * 0.0015
25
  for i, ax_row in enumerate(axes):
26
  row_label = f"{layer_sequence[i]+1}"
27
- row_pos = ax_row[num_heads-1].get_position()
28
- fig.text(row_pos.x1+offset, (row_pos.y1+row_pos.y0)/2, row_label, va='center')
 
 
29
 
30
  plt.subplots_adjust(wspace=0.1, hspace=0.1)
31
- return fig
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from typing import List
2
 
3
  import matplotlib.pyplot as plt
4
+ import numpy as np
5
+ from matplotlib.ticker import MaxNLocator, MultipleLocator
6
+ from mpl_toolkits.axes_grid1 import make_axes_locatable
7
 
8
 
9
  def plot_tiled_heatmap(tensor, layer_sequence: List[int], head_sequence: List[int]):
10
+ tensor = tensor[layer_sequence, :][
11
+ :, head_sequence, :, :
12
+ ] # Slice the tensor according to the provided sequences and sequence_count
13
  num_layers = len(layer_sequence)
14
  num_heads = len(head_sequence)
15
 
16
  x_size = num_heads * 2
17
  y_size = num_layers * 2
18
+ fig, axes = plt.subplots(
19
+ num_layers, num_heads, figsize=(x_size, y_size), squeeze=False
20
+ )
21
  for i in range(num_layers):
22
  for j in range(num_heads):
23
+ axes[i, j].imshow(
24
+ tensor[i, j].detach().numpy(), cmap="viridis", aspect="equal"
25
+ )
26
+ axes[i, j].axis("off")
27
 
28
  # Enumerate the axes
29
  if i == 0:
30
+ axes[i, j].set_title(
31
+ f"Head {head_sequence[j] + 1}", fontsize=10, y=1.05
32
+ )
33
 
34
  # Calculate the row label offset based on the number of columns
35
  offset = 0.02 + (12 - num_heads) * 0.0015
36
  for i, ax_row in enumerate(axes):
37
  row_label = f"{layer_sequence[i]+1}"
38
+ row_pos = ax_row[num_heads - 1].get_position()
39
+ fig.text(
40
+ row_pos.x1 + offset, (row_pos.y1 + row_pos.y0) / 2, row_label, va="center"
41
+ )
42
 
43
  plt.subplots_adjust(wspace=0.1, hspace=0.1)
44
+ return fig
45
+
46
+
47
+ def plot_single_heatmap(
48
+ tensor,
49
+ layer: int,
50
+ head: int,
51
+ slice_start: int,
52
+ slice_end: int,
53
+ max_labels: int = 40,
54
+ ):
55
+ single_heatmap = tensor[layer, head, :, :].detach().numpy()
56
+
57
+ fig, ax = plt.subplots(figsize=(10, 10))
58
+ heatmap = ax.imshow(single_heatmap, cmap="viridis", aspect="equal")
59
+
60
+ # Set the x and y axis major ticks and labels
61
+ ax.xaxis.set_major_locator(
62
+ MaxNLocator(integer=True, steps=[1, 2, 5], prune="both", nbins=max_labels)
63
+ )
64
+ ax.yaxis.set_major_locator(
65
+ MaxNLocator(integer=True, steps=[1, 2, 5], prune="both", nbins=max_labels)
66
+ )
67
+
68
+ ax.set_xticklabels(
69
+ np.arange(slice_start, slice_end + 1)[ax.get_xticks().astype(int)], fontsize=8
70
+ )
71
+ ax.set_yticklabels(
72
+ np.arange(slice_start, slice_end + 1)[ax.get_yticks().astype(int)], fontsize=8
73
+ )
74
+
75
+ # Set the x and y axis minor ticks
76
+ ax.xaxis.set_minor_locator(MultipleLocator(1))
77
+ ax.yaxis.set_minor_locator(MultipleLocator(1))
78
+
79
+ # Set the axis labels
80
+ ax.set_xlabel("Residue Number")
81
+ ax.set_ylabel("Residue Number")
82
+
83
+ # Rotate the tick labels and set their alignment
84
+ plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor")
85
+
86
+ # Create custom colorbar axes with the desired dimensions
87
+ divider = make_axes_locatable(ax)
88
+ cax = divider.append_axes("right", size="5%", pad=0.1)
89
+
90
+ # Add a colorbar to show the scale
91
+ cbar = fig.colorbar(heatmap, cax=cax)
92
+ cbar.ax.set_ylabel("Attention Weight", rotation=-90, va="bottom")
93
+
94
+ # Set the title of the plot
95
+ ax.set_title(f"Layer {layer + 1} - Head {head + 1}")
96
+
97
+ return fig