demo / src /visualisation.py
Xmaster6y's picture
new repo structure
0d998a6 unverified
raw
history blame
2.4 kB
"""
Visualisation utils.
"""
import chess
import chess.svg
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import torch
COLOR_MAP = matplotlib.colormaps["RdYlBu_r"].resampled(1000)
ALPHA = 1.0
NORM = matplotlib.colors.Normalize(vmin=0, vmax=1, clip=False)
def render_heatmap(
board,
heatmap,
square=None,
vmin=None,
vmax=None,
arrows=None,
normalise="none",
):
"""
Render a heatmap on the board.
"""
if normalise == "abs":
a_max = heatmap.abs().max()
if a_max != 0:
heatmap = heatmap / a_max
vmin = -1
vmax = 1
if vmin is None:
vmin = heatmap.min()
if vmax is None:
vmax = heatmap.max()
norm = matplotlib.colors.Normalize(vmin=vmin, vmax=vmax, clip=False)
color_dict = {}
for square_index in range(64):
color = COLOR_MAP(norm(heatmap[square_index]))
color = (*color[:3], ALPHA)
color_dict[square_index] = matplotlib.colors.to_hex(color, keep_alpha=True)
fig = plt.figure(figsize=(6, 0.6))
ax = plt.gca()
ax.axis("off")
fig.colorbar(
matplotlib.cm.ScalarMappable(norm=norm, cmap=COLOR_MAP),
ax=ax,
orientation="horizontal",
fraction=1.0,
)
if square is not None:
try:
check = chess.parse_square(square)
except ValueError:
check = None
else:
check = None
if arrows is None:
arrows = []
plt.close()
return (
chess.svg.board(
board,
check=check,
fill=color_dict,
size=350,
arrows=arrows,
),
fig,
)
def render_policy_distribution(
policy,
legal_moves,
n_bins=20,
):
"""
Render the policy distribution histogram.
"""
legal_mask = torch.Tensor([move in legal_moves for move in range(1858)]).bool()
fig = plt.figure(figsize=(6, 6))
ax = plt.gca()
_, bins = np.histogram(policy, bins=n_bins)
ax.hist(
policy[~legal_mask],
bins=bins,
alpha=0.5,
density=True,
label="Illegal moves",
)
ax.hist(
policy[legal_mask],
bins=bins,
alpha=0.5,
density=True,
label="Legal moves",
)
plt.xlabel("Policy")
plt.ylabel("Density")
plt.legend()
plt.yscale("log")
return fig