Spaces:
Runtime error
Runtime error
File size: 2,396 Bytes
0d998a6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 |
"""
Visualisation utils.
"""
import chess
import chess.svg
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import torch
COLOR_MAP = matplotlib.colormaps["RdYlBu_r"].resampled(1000)
ALPHA = 1.0
NORM = matplotlib.colors.Normalize(vmin=0, vmax=1, clip=False)
def render_heatmap(
board,
heatmap,
square=None,
vmin=None,
vmax=None,
arrows=None,
normalise="none",
):
"""
Render a heatmap on the board.
"""
if normalise == "abs":
a_max = heatmap.abs().max()
if a_max != 0:
heatmap = heatmap / a_max
vmin = -1
vmax = 1
if vmin is None:
vmin = heatmap.min()
if vmax is None:
vmax = heatmap.max()
norm = matplotlib.colors.Normalize(vmin=vmin, vmax=vmax, clip=False)
color_dict = {}
for square_index in range(64):
color = COLOR_MAP(norm(heatmap[square_index]))
color = (*color[:3], ALPHA)
color_dict[square_index] = matplotlib.colors.to_hex(color, keep_alpha=True)
fig = plt.figure(figsize=(6, 0.6))
ax = plt.gca()
ax.axis("off")
fig.colorbar(
matplotlib.cm.ScalarMappable(norm=norm, cmap=COLOR_MAP),
ax=ax,
orientation="horizontal",
fraction=1.0,
)
if square is not None:
try:
check = chess.parse_square(square)
except ValueError:
check = None
else:
check = None
if arrows is None:
arrows = []
plt.close()
return (
chess.svg.board(
board,
check=check,
fill=color_dict,
size=350,
arrows=arrows,
),
fig,
)
def render_policy_distribution(
policy,
legal_moves,
n_bins=20,
):
"""
Render the policy distribution histogram.
"""
legal_mask = torch.Tensor([move in legal_moves for move in range(1858)]).bool()
fig = plt.figure(figsize=(6, 6))
ax = plt.gca()
_, bins = np.histogram(policy, bins=n_bins)
ax.hist(
policy[~legal_mask],
bins=bins,
alpha=0.5,
density=True,
label="Illegal moves",
)
ax.hist(
policy[legal_mask],
bins=bins,
alpha=0.5,
density=True,
label="Legal moves",
)
plt.xlabel("Policy")
plt.ylabel("Density")
plt.legend()
plt.yscale("log")
return fig
|