Spaces:
Sleeping
Sleeping
Fix vmin and vmax in heatmaps
Browse filesThis way you can compare between plots of different sections
of a protein. The color scale is also consistens between the focus
plot and the grid plots so it is easier to compare.
- hexviz/plot.py +4 -2
hexviz/plot.py
CHANGED
@@ -18,7 +18,9 @@ def plot_tiled_heatmap(tensor, layer_sequence: List[int], head_sequence: List[in
|
|
18 |
fig, axes = plt.subplots(num_layers, num_heads, figsize=(x_size, y_size), squeeze=False)
|
19 |
for i in range(num_layers):
|
20 |
for j in range(num_heads):
|
21 |
-
axes[i, j].imshow(
|
|
|
|
|
22 |
axes[i, j].axis("off")
|
23 |
|
24 |
# Enumerate the axes
|
@@ -45,7 +47,7 @@ def plot_single_heatmap(
|
|
45 |
single_heatmap = tensor[layer, head, :, :].detach().numpy()
|
46 |
|
47 |
fig, ax = plt.subplots(figsize=(10, 10))
|
48 |
-
heatmap = ax.imshow(single_heatmap, cmap="viridis", aspect="equal")
|
49 |
|
50 |
# Set the x and y axis ticks
|
51 |
ax.xaxis.set_major_locator(FixedLocator(np.arange(0, len(tokens))))
|
|
|
18 |
fig, axes = plt.subplots(num_layers, num_heads, figsize=(x_size, y_size), squeeze=False)
|
19 |
for i in range(num_layers):
|
20 |
for j in range(num_heads):
|
21 |
+
axes[i, j].imshow(
|
22 |
+
tensor[i, j].detach().numpy(), cmap="viridis", aspect="equal", vmin=0, vmax=1
|
23 |
+
)
|
24 |
axes[i, j].axis("off")
|
25 |
|
26 |
# Enumerate the axes
|
|
|
47 |
single_heatmap = tensor[layer, head, :, :].detach().numpy()
|
48 |
|
49 |
fig, ax = plt.subplots(figsize=(10, 10))
|
50 |
+
heatmap = ax.imshow(single_heatmap, cmap="viridis", aspect="equal", vmin=0, vmax=1)
|
51 |
|
52 |
# Set the x and y axis ticks
|
53 |
ax.xaxis.set_major_locator(FixedLocator(np.arange(0, len(tokens))))
|