from collections import Counter from itertools import product import numpy as np import pytest import torch from torch.nn import functional as F from chroma.layers.structure.potts import ( GraphPotts, compute_potts_energy, fold_symmetry, sample_potts, ) def test_graphpotts(): # Testing symmetry # Create non-symmetric Potts model and symmetrize using serial or not potts = GraphPotts(128, 128, 20, symmetric_J=False) node_h = torch.rand(1, 3, 128) edge_h = torch.rand(1, 3, 2, 128) edge_idx = torch.tensor([[[1, 2], [0, 2], [0, 1]]]) mask_i = torch.ones(1, 3) mask_ij = torch.ones(1, 3, 2) h, J = potts(node_h, edge_h, edge_idx, mask_i, mask_ij) assert ( potts._symmetrize_J(J, edge_idx, mask_ij) != potts._symmetrize_J_serial(J, edge_idx, mask_ij) ).sum().detach().numpy() == 0 mask_ij = torch.tensor([[[1, 1], [1, 0], [1, 0]]]) h, J = potts(node_h, edge_h, edge_idx, mask_i, mask_ij) assert ( potts._symmetrize_J(J, edge_idx, mask_ij) != potts._symmetrize_J_serial(J, edge_idx, mask_ij) ).sum().detach().numpy() == 0 def test_symmetry_folding(): N, Q = 12, 3 symmetry_order = 3 N_au = N // symmetry_order # Testing symmetry mask_i = torch.ones(1, N) mask_ij = (1.0 - torch.eye(N))[None, ...] h = torch.randn([1, N, Q]) J = torch.randn([1, N, N, Q, Q]) J = J + J.permute([0, 2, 1, 4, 3]) # J = torch.eye(Q)[None,None,None,...].expand([1, N, N, Q, Q]) J = J * mask_ij[..., None, None] edge_idx = torch.arange(N).long()[None, None, :].expand([1, N, N]) h_fold, J_fold, edge_idx_fold, mask_i_fold, mask_ij_fold = fold_symmetry( symmetry_order, h, J, edge_idx, mask_i, mask_ij, normalize=False ) # Validate dimensions assert tuple(h_fold.shape) == (1, N_au, Q) assert tuple(J_fold.shape) == (1, N_au, N_au, Q, Q) assert tuple(edge_idx_fold.shape) == (1, N_au, N_au) assert tuple(mask_i_fold.shape) == (1, N_au) assert tuple(mask_ij_fold.shape) == (1, N_au, N_au) # Does the folded Potts model return same energies as full? S_test_fold = torch.randint(high=Q, size=[1, N_au]) S_test = S_test_fold[:, None, :].expand([1, symmetry_order, N_au]).reshape([1, N]) U, U_i = compute_potts_energy(S_test, h, J, edge_idx) U_fold, U_i_fold = compute_potts_energy(S_test_fold, h_fold, J_fold, edge_idx_fold) assert torch.allclose(U, U_fold) @pytest.mark.parametrize("proposal", ["dlmc", "chromatic"]) def test_potts_mcmc(proposal, debug=False): """MCMC test for Chromatic Gibbs sampling.""" # Build a test, fully connected Potts model if debug: # Heavy duty sampling with large state space N = 5 q = 4 num_sweeps = 1000 num_chains = 1000 rtol = 0.05 else: # Quick and dirty small state space N = 3 q = 3 num_sweeps = 200 num_chains = 1000 rtol = 0.1 beta = 0.1 warmup_fraction = 0.1 torch.manual_seed(1) mask_i = torch.ones([1, N]).float() mask_ij = (1 - torch.eye(N))[None, ...].float() edge_idx = torch.arange(N)[None, None, :].expand([1, N, N]) h = beta * torch.randn([1, N, q]) J = beta * torch.randn([1, N, N, q, q]) J = mask_ij[..., None, None] * (J + J.permute([0, 2, 1, 4, 3])) / np.sqrt(2) # Enumerate all of sequence space alphabet = "ABCDEFGHIJK"[:q] sequences = ["".join(x) for x in product(alphabet, repeat=N)] S_exact = torch.Tensor( [[alphabet.index(s) for s in seq] for seq in sequences] ).long() print(f"Enumerated {len(sequences)} sequences") if torch.cuda.is_available(): device = "cuda" h = h.to(device) J = J.to(device) edge_idx = edge_idx.to(device) mask_i = mask_i.to(device) mask_ij = mask_ij.to(device) S_exact = S_exact.to(device) # Compute exact distribution over sequence space B = S_exact.shape[0] h_expand = h.expand([B, -1, -1]) J_expand = J.expand([B, -1, -1, -1, -1]) edge_idx_expand = edge_idx.expand([B, -1, -1]) mask_i_expand = mask_i.expand([B, -1]) mask_ij_expand = mask_ij.expand([B, -1, -1]) U, _ = compute_potts_energy(S_exact, h_expand, J_expand, edge_idx_expand) p_exact = F.softmax(-U, -1).tolist() # Estimate distribution from sampled sequences h_expand = h.expand([num_chains, -1, -1]) J_expand = J.expand([num_chains, -1, -1, -1, -1]) edge_idx_expand = edge_idx.expand([num_chains, -1, -1]) mask_i_expand = mask_i.expand([num_chains, -1]) mask_ij_expand = mask_ij.expand([num_chains, -1, -1]) S, U, S_trajectory, U_trajectory = sample_potts( h_expand, J_expand, edge_idx_expand, mask_i_expand, mask_ij_expand, num_sweeps=num_sweeps, proposal=proposal, rejection_step=True, verbose=True, return_trajectory=True, ) if warmup_fraction is not None: S_trajectory = S_trajectory[int(warmup_fraction * len(S_trajectory)) :] S_samples = torch.cat(S_trajectory, 0) U_trajectory = torch.stack(U_trajectory, 1).cpu().data.numpy() S_samples = S_samples.cpu().data.numpy() sample_counts = Counter(["".join([alphabet[c] for c in s]) for s in S_samples]) p_sample = [sample_counts[seq] / S_samples.shape[0] for seq in sequences] if debug: from matplotlib import pyplot as plt plt.figure(figsize=(6, 3)) plt.subplot(1, 2, 1) plt.plot(p_exact, p_sample, "k.") plt.grid() plt.axis("square") plt.xlabel("Probability, exact enumeration") plt.ylabel("Sampling frequencey (MCMC)") plt.title(f"Random Potts model over {q}^{N} sequences") plt.subplot(1, 2, 2) plt.plot(U_trajectory[0, :]) plt.xlabel("Iterations") plt.ylabel("Energy") plt.tight_layout() plt.show() # The frequencies of states visited via MCMC should reproduce their # exact probabilities (via enumeration) within rtol percent error assert np.allclose(p_sample, p_exact, rtol=rtol) def debug_potts_2D(): """Debug test for Potts model""" N = 100 q = 4 num_sites = N * N mask_i = torch.ones([1, N]).float() ix = torch.arange(num_sites).long() # Build 2D lattice topology edge_idx = torch.stack([ix + 1, ix - 1, ix + N, ix - N], -1) mask_ij = torch.ones_like(edge_idx).float()[None, :, :] edge_idx = torch.remainder(edge_idx, num_sites)[None, :, :].long() # Ferromagnetic parameters h = torch.zeros([1, num_sites, q]) h[:, :, 0] = h[:, :, 0] mask_J = mask_ij[:, :, :, None, None] * torch.eye(q)[None, None, None, :, :] if torch.cuda.is_available(): device = "cuda" h = h.to(device) edge_idx = edge_idx.to(device) mask_J = mask_J.to(device) mask_ij = mask_ij.to(device) import numpy as np from matplotlib import pyplot as plt from matplotlib.animation import FuncAnimation temp_range = (1.2, 0.8) plt.figure(figsize=(5, 5), dpi=600) _, _, S_trajectory, U_trajectory = sample_potts( h, -mask_J, edge_idx, mask_i, mask_ij, num_sweeps=10000, verbose=True, return_trajectory=True, S=None, annealing_fraction=1.0, temperature_init=1.2, temperature=0.8, ) # Define a function to update the plot for each frame num_frames = len(S_trajectory) temps = np.linspace(temp_range[0], temp_range[1], len(S_trajectory)) betas = 1.0 / temps def update(frame): plt.clf() # Clear the previous frame plt.pcolor(S_trajectory[frame].cpu().data.numpy().reshape([N, N]), cmap="tab10") plt.clim([0, 10]) plt.axis("square") plt.axis("off") plt.title(f"Beta = {betas[frame]:0.2f}") print(frame) # Create a figure and set the number of frames fig = plt.figure(figsize=(4, 4), dpi=300) animation = FuncAnimation(fig, update, frames=num_frames, interval=1000 / 60) animation.save("potts.mp4", writer="ffmpeg") return