""" Visualisation utils. """ import chess import chess.svg import matplotlib import matplotlib.pyplot as plt import numpy as np import torch import torchviz from . import constants COLOR_MAP = matplotlib.colormaps["RdYlBu_r"].resampled(1000) ALPHA = 1.0 def render_heatmap( board, heatmap, square=None, vmin=None, vmax=None, arrows=None, normalise="none", ): """ Render a heatmap on the board. """ if normalise == "abs": a_max = heatmap.abs().max() if a_max != 0: heatmap = heatmap / a_max vmin = -1 vmax = 1 if vmin is None: vmin = heatmap.min() if vmax is None: vmax = heatmap.max() norm = matplotlib.colors.Normalize(vmin=vmin, vmax=vmax, clip=False) color_dict = {} for square_index in range(64): color = COLOR_MAP(norm(heatmap[square_index])) color = (*color[:3], ALPHA) color_dict[square_index] = matplotlib.colors.to_hex(color, keep_alpha=True) fig = plt.figure(figsize=(6, 0.6)) ax = plt.gca() ax.axis("off") fig.colorbar( matplotlib.cm.ScalarMappable(norm=norm, cmap=COLOR_MAP), ax=ax, orientation="horizontal", fraction=1.0, ) if square is not None: try: check = chess.parse_square(square) except ValueError: check = None else: check = None if arrows is None: arrows = [] plt.close() return ( chess.svg.board( board, check=check, fill=color_dict, size=350, arrows=arrows, ), fig, ) def render_architecture(model, name: str = "model", directory: str = ""): """ Render the architecture of the model. """ out = model(torch.zeros(1, 112, 8, 8)) if len(out) == 2: policy, outcome_probs = out value = torch.zeros(outcome_probs.shape[0], 1) else: policy, outcome_probs, value = out torchviz.make_dot(policy, params=dict(list(model.named_parameters()))).render( f"{directory}/{name}_policy", format="svg" ) torchviz.make_dot(outcome_probs, params=dict(list(model.named_parameters()))).render( f"{directory}/{name}_outcome_probs", format="svg" ) torchviz.make_dot(value, params=dict(list(model.named_parameters()))).render( f"{directory}/{name}_value", format="svg" ) def render_policy_distribution( policy, legal_moves, n_bins=20, ): """ Render the policy distribution histogram. """ legal_mask = torch.Tensor([move in legal_moves for move in range(1858)]).bool() fig = plt.figure(figsize=(6, 6)) ax = plt.gca() _, bins = np.histogram(policy, bins=n_bins) ax.hist( policy[~legal_mask], bins=bins, alpha=0.5, density=True, label="Illegal moves", ) ax.hist( policy[legal_mask], bins=bins, alpha=0.5, density=True, label="Legal moves", ) plt.xlabel("Policy") plt.ylabel("Density") plt.legend() plt.yscale("log") return fig def render_policy_statistics( statistics, ): """ Render the policy statistics. """ fig = plt.figure(figsize=(6, 6)) ax = plt.gca() move_indices = list(statistics["mean_legal_logits"].keys()) legal_means_avg = [np.mean(statistics["mean_legal_logits"][move_idx]) for move_idx in move_indices] illegal_means_avg = [np.mean(statistics["mean_illegal_logits"][move_idx]) for move_idx in move_indices] legal_means_std = [np.std(statistics["mean_legal_logits"][move_idx]) for move_idx in move_indices] illegal_means_std = [np.std(statistics["mean_illegal_logits"][move_idx]) for move_idx in move_indices] ax.errorbar( move_indices, legal_means_avg, yerr=legal_means_std, label="Legal moves", ) ax.errorbar( move_indices, illegal_means_avg, yerr=illegal_means_std, label="Illegal moves", ) plt.xlabel("Move index") plt.ylabel("Mean policy logits") plt.legend() return fig def render_relevance_proportion(statistics, scaled=True): """ Render the relevance proportion statistics. """ norm = matplotlib.colors.Normalize(vmin=0, vmax=1, clip=False) fig_hist = plt.figure(figsize=(6, 6)) ax = plt.gca() move_indices = list(statistics["planes_relevance_proportion"].keys()) for h in range(8): relevance_proportion_avg = [ np.mean([rel[13 * h : 13 * (h + 1)].sum() for rel in statistics["planes_relevance_proportion"][move_idx]]) for move_idx in move_indices ] relevance_proportion_std = [ np.std([rel[13 * h : 13 * (h + 1)].sum() for rel in statistics["planes_relevance_proportion"][move_idx]]) for move_idx in move_indices ] ax.errorbar( move_indices[h + 1 :], relevance_proportion_avg[h + 1 :], yerr=relevance_proportion_std[h + 1 :], label=f"History {h}", c=COLOR_MAP(norm(h / 9)), ) relevance_proportion_avg = [ np.mean([rel[104:108].sum() for rel in statistics["planes_relevance_proportion"][move_idx]]) for move_idx in move_indices ] relevance_proportion_std = [ np.std([rel[104:108].sum() for rel in statistics["planes_relevance_proportion"][move_idx]]) for move_idx in move_indices ] ax.errorbar( move_indices, relevance_proportion_avg, yerr=relevance_proportion_std, label="Castling rights", c=COLOR_MAP(norm(8 / 9)), ) relevance_proportion_avg = [ np.mean([rel[108:].sum() for rel in statistics["planes_relevance_proportion"][move_idx]]) for move_idx in move_indices ] relevance_proportion_std = [ np.std([rel[108:].sum() for rel in statistics["planes_relevance_proportion"][move_idx]]) for move_idx in move_indices ] ax.errorbar( move_indices, relevance_proportion_avg, yerr=relevance_proportion_std, label="Remaining planes", c=COLOR_MAP(norm(9 / 9)), ) plt.xlabel("Move index") plt.ylabel("Absolute relevance proportion") plt.yscale("log") plt.legend() if scaled: stat_key = "planes_relevance_proportion_scaled" else: stat_key = "planes_relevance_proportion" fig_planes = plt.figure(figsize=(6, 6)) ax = plt.gca() move_indices = list(statistics[stat_key].keys()) for p in range(13): relevance_proportion_avg = [ np.mean([rel[p].item() for rel in statistics[stat_key][move_idx]]) for move_idx in move_indices ] relevance_proportion_std = [ np.std([rel[p].item() for rel in statistics[stat_key][move_idx]]) for move_idx in move_indices ] ax.errorbar( move_indices, relevance_proportion_avg, yerr=relevance_proportion_std, label=constants.PLANE_NAMES[p], c=COLOR_MAP(norm(p / 12)), ) plt.xlabel("Move index") plt.ylabel("Absolute relevance proportion") plt.yscale("log") plt.legend() fig_pieces = plt.figure(figsize=(6, 6)) ax = plt.gca() for p in range(1, 13): stat_key = f"configuration_relevance_proportion_threatened_piece{p}" n_attackers = list(statistics[stat_key].keys()) relevance_proportion_avg = [ np.mean(statistics[f"configuration_relevance_proportion_threatened_piece{p}"][n]) for n in n_attackers ] relevance_proportion_std = [np.std(statistics[stat_key][n]) for n in n_attackers] ax.errorbar( n_attackers, relevance_proportion_avg, yerr=relevance_proportion_std, label="PNBRQKpnbrqk"[p - 1], c=COLOR_MAP(norm(p / 12)), ) plt.xlabel("Number of attackers") plt.ylabel("Absolute configuration relevance proportion") plt.yscale("log") plt.legend() return fig_hist, fig_planes, fig_pieces def render_probing_statistics( statistics, ): """ Render the probing statistics. """ fig = plt.figure(figsize=(6, 6)) ax = plt.gca() n_blocks = len(statistics["metrics"]) for metric in statistics["metrics"]["block0"]: avg = [] std = [] for block_idx in range(n_blocks): metrics = statistics["metrics"] block_data = metrics[f"block{block_idx}"] avg.append(np.mean(block_data[metric])) std.append(np.std(block_data[metric])) ax.errorbar( range(n_blocks), avg, yerr=std, label=metric, ) plt.xlabel("Block index") plt.ylabel("Metric") plt.yscale("log") plt.legend() return fig