aksell commited on
Commit
6d3f484
·
1 Parent(s): d3faba1

Show EC number in single attention plot and add option to show special tokens

Browse files
hexviz/pages/1_🗺️Identify_Interesting_Heads.py CHANGED
@@ -1,6 +1,8 @@
 
 
1
  import streamlit as st
2
 
3
- from hexviz.attention import clean_and_validate_sequence, get_attention, get_sequence
4
  from hexviz.config import URL
5
  from hexviz.models import Model, ModelType
6
  from hexviz.plot import plot_single_heatmap, plot_tiled_heatmap
@@ -51,11 +53,36 @@ chain_selection = st.sidebar.selectbox(
51
  )
52
 
53
  selected_chain = next(chain for chain in chains if chain.id == chain_selection)
54
- sequence = get_sequence(selected_chain)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
 
56
  l = len(sequence)
57
  slice_start, slice_end = select_sequence_slice(l)
58
  truncated_sequence = sequence[slice_start - 1 : slice_end]
 
59
 
60
 
61
  layer_sequence, head_sequence = select_heads_and_layers(st.sidebar, selected_model)
@@ -68,11 +95,10 @@ st.markdown(
68
 
69
  # TODO: Decide if you should get attention for the full sequence or just the truncated sequence
70
  # Attention values will change depending on what we do.
71
- attention = get_attention(
72
  attention, tokens = get_attention(
73
  sequence=truncated_sequence,
74
  model_type=selected_model.name,
75
- remove_special_tokens=True,
76
  ec_number=ec_number,
77
  )
78
 
@@ -111,8 +137,5 @@ with right:
111
  unsafe_allow_html=True,
112
  )
113
 
114
-
115
- single_head_fig = plot_single_heatmap(
116
- attention, layer, head, slice_start, slice_end, max_labels=10
117
- )
118
  st.pyplot(single_head_fig)
 
1
+ import re
2
+
3
  import streamlit as st
4
 
5
+ from hexviz.attention import clean_and_validate_sequence, get_attention, res_to_1letter
6
  from hexviz.config import URL
7
  from hexviz.models import Model, ModelType
8
  from hexviz.plot import plot_single_heatmap, plot_tiled_heatmap
 
53
  )
54
 
55
  selected_chain = next(chain for chain in chains if chain.id == chain_selection)
56
+
57
+ ec_number = ""
58
+ if selected_model.name == ModelType.ZymCTRL:
59
+ st.sidebar.markdown(
60
+ """
61
+ ZymCTRL EC number
62
+ ---
63
+ """
64
+ )
65
+ try:
66
+ ec_number = structure.header["compound"]["1"]["ec"]
67
+ except KeyError:
68
+ pass
69
+ ec_number = st.sidebar.text_input("Enzyme Comission number (EC)", ec_number)
70
+
71
+ # Validate EC number
72
+ if not re.match(r"^\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}$", ec_number):
73
+ st.sidebar.error(
74
+ """Please enter a valid Enzyme Commission number in the format of 4
75
+ integers separated by periods (e.g., 1.2.3.21)"""
76
+ )
77
+
78
+
79
+ residues = [res for res in selected_chain.get_residues()]
80
+ sequence = res_to_1letter(residues)
81
 
82
  l = len(sequence)
83
  slice_start, slice_end = select_sequence_slice(l)
84
  truncated_sequence = sequence[slice_start - 1 : slice_end]
85
+ remove_special_tokens = st.sidebar.checkbox("Remove special tokens", key="remove_special_tokens")
86
 
87
 
88
  layer_sequence, head_sequence = select_heads_and_layers(st.sidebar, selected_model)
 
95
 
96
  # TODO: Decide if you should get attention for the full sequence or just the truncated sequence
97
  # Attention values will change depending on what we do.
 
98
  attention, tokens = get_attention(
99
  sequence=truncated_sequence,
100
  model_type=selected_model.name,
101
+ remove_special_tokens=remove_special_tokens,
102
  ec_number=ec_number,
103
  )
104
 
 
137
  unsafe_allow_html=True,
138
  )
139
 
140
+ single_head_fig = plot_single_heatmap(attention, layer, head, tokens=tokens)
 
 
 
141
  st.pyplot(single_head_fig)
hexviz/plot.py CHANGED
@@ -2,7 +2,7 @@ from typing import List
2
 
3
  import matplotlib.pyplot as plt
4
  import numpy as np
5
- from matplotlib.ticker import MaxNLocator, MultipleLocator
6
  from mpl_toolkits.axes_grid1 import make_axes_locatable
7
 
8
 
@@ -40,42 +40,24 @@ def plot_single_heatmap(
40
  tensor,
41
  layer: int,
42
  head: int,
43
- slice_start: int,
44
- slice_end: int,
45
- max_labels: int = 40,
46
  ):
47
  single_heatmap = tensor[layer, head, :, :].detach().numpy()
48
 
49
  fig, ax = plt.subplots(figsize=(10, 10))
50
  heatmap = ax.imshow(single_heatmap, cmap="viridis", aspect="equal")
51
 
52
- # Set the x and y axis major ticks and labels
53
- ax.xaxis.set_major_locator(
54
- MaxNLocator(integer=True, steps=[1, 2, 5], prune="both", nbins=max_labels)
55
- )
56
- ax.yaxis.set_major_locator(
57
- MaxNLocator(integer=True, steps=[1, 2, 5], prune="both", nbins=max_labels)
58
- )
59
-
60
- tick_indices_x = np.clip((ax.get_xticks()).astype(int), 0, slice_end - slice_start)
61
- tick_indices_y = np.clip((ax.get_yticks()).astype(int), 0, slice_end - slice_start)
62
- ax.set_xticklabels(
63
- np.arange(slice_start, slice_end + 1)[tick_indices_x], fontsize=8
64
- )
65
- ax.set_yticklabels(
66
- np.arange(slice_start, slice_end + 1)[tick_indices_y], fontsize=8
67
- )
68
-
69
- # Set the x and y axis minor ticks
70
- ax.xaxis.set_minor_locator(MultipleLocator(1))
71
- ax.yaxis.set_minor_locator(MultipleLocator(1))
72
 
73
- # Set the axis labels
74
- ax.set_xlabel("Residue Number")
75
- ax.set_ylabel("Residue Number")
76
 
77
- # Rotate the tick labels and set their alignment
78
- plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor")
 
79
 
80
  # Create custom colorbar axes with the desired dimensions
81
  divider = make_axes_locatable(ax)
 
2
 
3
  import matplotlib.pyplot as plt
4
  import numpy as np
5
+ from matplotlib.ticker import FixedLocator
6
  from mpl_toolkits.axes_grid1 import make_axes_locatable
7
 
8
 
 
40
  tensor,
41
  layer: int,
42
  head: int,
43
+ tokens: list[str],
 
 
44
  ):
45
  single_heatmap = tensor[layer, head, :, :].detach().numpy()
46
 
47
  fig, ax = plt.subplots(figsize=(10, 10))
48
  heatmap = ax.imshow(single_heatmap, cmap="viridis", aspect="equal")
49
 
50
+ # Set the x and y axis ticks
51
+ ax.xaxis.set_major_locator(FixedLocator(np.arange(0, len(tokens))))
52
+ ax.yaxis.set_major_locator(FixedLocator(np.arange(0, len(tokens))))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
 
54
+ # Set tick labels as sequence values
55
+ ax.set_xticklabels(tokens, fontsize=8, rotation=45, ha="right", rotation_mode="anchor")
56
+ ax.set_yticklabels(tokens, fontsize=8)
57
 
58
+ # Set the axis labels
59
+ ax.set_xlabel("Sequence tokens")
60
+ ax.set_ylabel("Sequence tokens")
61
 
62
  # Create custom colorbar axes with the desired dimensions
63
  divider = make_axes_locatable(ax)