Spaces:
Sleeping
Sleeping
Aksel Lenes
commited on
Commit
·
909a82d
1
Parent(s):
77c9ae7
Add fixed_scale parameter for grid plot view
Browse filesTo see patterns in longer sequences.
Added a warning that this means each subplot in the grid has an
individual scale, so it means you canot compare
attention attensities between grid cells.
hexviz/pages/1_🗺️Identify_Interesting_Heads.py
CHANGED
@@ -85,6 +85,9 @@ truncated_sequence = sequence[slice_start - 1 : slice_end]
|
|
85 |
remove_special_tokens = st.sidebar.checkbox(
|
86 |
"Hide attention to special tokens", key="remove_special_tokens"
|
87 |
)
|
|
|
|
|
|
|
88 |
|
89 |
|
90 |
layer_sequence, head_sequence = select_heads_and_layers(st.sidebar, selected_model)
|
@@ -104,7 +107,7 @@ attention, tokens = get_attention(
|
|
104 |
ec_number=ec_number,
|
105 |
)
|
106 |
|
107 |
-
fig = plot_tiled_heatmap(attention, layer_sequence=layer_sequence, head_sequence=head_sequence)
|
108 |
|
109 |
|
110 |
st.pyplot(fig)
|
@@ -143,5 +146,5 @@ if len(tokens_to_label) > 0:
|
|
143 |
tokens = [token if token in tokens_to_label else "" for token in tokens]
|
144 |
|
145 |
|
146 |
-
single_head_fig = plot_single_heatmap(attention, layer, head, tokens=tokens)
|
147 |
st.pyplot(single_head_fig)
|
|
|
85 |
remove_special_tokens = st.sidebar.checkbox(
|
86 |
"Hide attention to special tokens", key="remove_special_tokens"
|
87 |
)
|
88 |
+
if "fixed_scale" not in st.session_state:
|
89 |
+
st.session_state.fixed_scale = True
|
90 |
+
fixed_scale = st.sidebar.checkbox("Fixed scale", help="For long sequences the default fixed 0 to 1 scale can have very low contrast heatmaps, consider using a relative scale to increase the contrast between high attention and low attention areas. Note that each subplot will have separate color scales so don't compare colors between attention heads if using a non-fixed scale.", key="fixed_scale")
|
91 |
|
92 |
|
93 |
layer_sequence, head_sequence = select_heads_and_layers(st.sidebar, selected_model)
|
|
|
107 |
ec_number=ec_number,
|
108 |
)
|
109 |
|
110 |
+
fig = plot_tiled_heatmap(attention, layer_sequence=layer_sequence, head_sequence=head_sequence, fixed_scale=fixed_scale)
|
111 |
|
112 |
|
113 |
st.pyplot(fig)
|
|
|
146 |
tokens = [token if token in tokens_to_label else "" for token in tokens]
|
147 |
|
148 |
|
149 |
+
single_head_fig = plot_single_heatmap(attention, layer, head, tokens=tokens, fixed_scale=fixed_scale)
|
150 |
st.pyplot(single_head_fig)
|
hexviz/plot.py
CHANGED
@@ -6,7 +6,7 @@ from matplotlib.ticker import FixedLocator
|
|
6 |
from mpl_toolkits.axes_grid1 import make_axes_locatable
|
7 |
|
8 |
|
9 |
-
def plot_tiled_heatmap(tensor, layer_sequence: List[int], head_sequence: List[int]):
|
10 |
tensor = tensor[layer_sequence, :][
|
11 |
:, head_sequence, :, :
|
12 |
] # Slice the tensor according to the provided sequences and sequence_count
|
@@ -18,9 +18,14 @@ def plot_tiled_heatmap(tensor, layer_sequence: List[int], head_sequence: List[in
|
|
18 |
fig, axes = plt.subplots(num_layers, num_heads, figsize=(x_size, y_size), squeeze=False)
|
19 |
for i in range(num_layers):
|
20 |
for j in range(num_heads):
|
21 |
-
|
22 |
-
|
23 |
-
|
|
|
|
|
|
|
|
|
|
|
24 |
axes[i, j].axis("off")
|
25 |
|
26 |
# Enumerate the axes
|
@@ -33,7 +38,7 @@ def plot_tiled_heatmap(tensor, layer_sequence: List[int], head_sequence: List[in
|
|
33 |
row_label = f"{layer_sequence[i]+1}"
|
34 |
row_pos = ax_row[num_heads - 1].get_position()
|
35 |
fig.text(row_pos.x1 + offset, (row_pos.y1 + row_pos.y0) / 2, row_label, va="center")
|
36 |
-
|
37 |
plt.subplots_adjust(wspace=0.1, hspace=0.1)
|
38 |
return fig
|
39 |
|
@@ -43,11 +48,15 @@ def plot_single_heatmap(
|
|
43 |
layer: int,
|
44 |
head: int,
|
45 |
tokens: list[str],
|
|
|
46 |
):
|
47 |
single_heatmap = tensor[layer, head, :, :].detach().numpy()
|
48 |
|
49 |
fig, ax = plt.subplots(figsize=(10, 10))
|
50 |
-
|
|
|
|
|
|
|
51 |
|
52 |
# Function to adjust font size based on the number of labels
|
53 |
def get_font_size(labels):
|
|
|
6 |
from mpl_toolkits.axes_grid1 import make_axes_locatable
|
7 |
|
8 |
|
9 |
+
def plot_tiled_heatmap(tensor, layer_sequence: List[int], head_sequence: List[int], fixed_scale: bool = True):
|
10 |
tensor = tensor[layer_sequence, :][
|
11 |
:, head_sequence, :, :
|
12 |
] # Slice the tensor according to the provided sequences and sequence_count
|
|
|
18 |
fig, axes = plt.subplots(num_layers, num_heads, figsize=(x_size, y_size), squeeze=False)
|
19 |
for i in range(num_layers):
|
20 |
for j in range(num_heads):
|
21 |
+
if fixed_scale:
|
22 |
+
im = axes[i, j].imshow(
|
23 |
+
tensor[i, j].detach().numpy(), cmap="viridis", aspect="equal", vmin=0, vmax=1
|
24 |
+
)
|
25 |
+
else:
|
26 |
+
im = axes[i, j].imshow(
|
27 |
+
tensor[i, j].detach().numpy(), cmap="viridis", aspect="equal"
|
28 |
+
)
|
29 |
axes[i, j].axis("off")
|
30 |
|
31 |
# Enumerate the axes
|
|
|
38 |
row_label = f"{layer_sequence[i]+1}"
|
39 |
row_pos = ax_row[num_heads - 1].get_position()
|
40 |
fig.text(row_pos.x1 + offset, (row_pos.y1 + row_pos.y0) / 2, row_label, va="center")
|
41 |
+
|
42 |
plt.subplots_adjust(wspace=0.1, hspace=0.1)
|
43 |
return fig
|
44 |
|
|
|
48 |
layer: int,
|
49 |
head: int,
|
50 |
tokens: list[str],
|
51 |
+
fixed_scale : bool = True
|
52 |
):
|
53 |
single_heatmap = tensor[layer, head, :, :].detach().numpy()
|
54 |
|
55 |
fig, ax = plt.subplots(figsize=(10, 10))
|
56 |
+
if fixed_scale:
|
57 |
+
heatmap = ax.imshow(single_heatmap, cmap="viridis", aspect="equal", vmin=0, vmax=1)
|
58 |
+
else:
|
59 |
+
heatmap = ax.imshow(single_heatmap, cmap="viridis", aspect="equal")
|
60 |
|
61 |
# Function to adjust font size based on the number of labels
|
62 |
def get_font_size(labels):
|