aksell commited on
Commit
d35acfc
·
1 Parent(s): cea9292

Fix vmin and vmax in heatmaps

Browse files

This way you can compare between plots of different sections
of a protein. The color scale is also consistens between the focus
plot and the grid plots so it is easier to compare.

Files changed (1) hide show
  1. hexviz/plot.py +4 -2
hexviz/plot.py CHANGED
@@ -18,7 +18,9 @@ def plot_tiled_heatmap(tensor, layer_sequence: List[int], head_sequence: List[in
18
  fig, axes = plt.subplots(num_layers, num_heads, figsize=(x_size, y_size), squeeze=False)
19
  for i in range(num_layers):
20
  for j in range(num_heads):
21
- axes[i, j].imshow(tensor[i, j].detach().numpy(), cmap="viridis", aspect="equal")
 
 
22
  axes[i, j].axis("off")
23
 
24
  # Enumerate the axes
@@ -45,7 +47,7 @@ def plot_single_heatmap(
45
  single_heatmap = tensor[layer, head, :, :].detach().numpy()
46
 
47
  fig, ax = plt.subplots(figsize=(10, 10))
48
- heatmap = ax.imshow(single_heatmap, cmap="viridis", aspect="equal")
49
 
50
  # Set the x and y axis ticks
51
  ax.xaxis.set_major_locator(FixedLocator(np.arange(0, len(tokens))))
 
18
  fig, axes = plt.subplots(num_layers, num_heads, figsize=(x_size, y_size), squeeze=False)
19
  for i in range(num_layers):
20
  for j in range(num_heads):
21
+ axes[i, j].imshow(
22
+ tensor[i, j].detach().numpy(), cmap="viridis", aspect="equal", vmin=0, vmax=1
23
+ )
24
  axes[i, j].axis("off")
25
 
26
  # Enumerate the axes
 
47
  single_heatmap = tensor[layer, head, :, :].detach().numpy()
48
 
49
  fig, ax = plt.subplots(figsize=(10, 10))
50
+ heatmap = ax.imshow(single_heatmap, cmap="viridis", aspect="equal", vmin=0, vmax=1)
51
 
52
  # Set the x and y axis ticks
53
  ax.xaxis.set_major_locator(FixedLocator(np.arange(0, len(tokens))))