Hukuna's picture
Upload 221 files
ce7bf5b verified
from collections import Counter
from itertools import product
import numpy as np
import pytest
import torch
from torch.nn import functional as F
from chroma.layers.structure.potts import (
GraphPotts,
compute_potts_energy,
fold_symmetry,
sample_potts,
)
def test_graphpotts():
# Testing symmetry
# Create non-symmetric Potts model and symmetrize using serial or not
potts = GraphPotts(128, 128, 20, symmetric_J=False)
node_h = torch.rand(1, 3, 128)
edge_h = torch.rand(1, 3, 2, 128)
edge_idx = torch.tensor([[[1, 2], [0, 2], [0, 1]]])
mask_i = torch.ones(1, 3)
mask_ij = torch.ones(1, 3, 2)
h, J = potts(node_h, edge_h, edge_idx, mask_i, mask_ij)
assert (
potts._symmetrize_J(J, edge_idx, mask_ij)
!= potts._symmetrize_J_serial(J, edge_idx, mask_ij)
).sum().detach().numpy() == 0
mask_ij = torch.tensor([[[1, 1], [1, 0], [1, 0]]])
h, J = potts(node_h, edge_h, edge_idx, mask_i, mask_ij)
assert (
potts._symmetrize_J(J, edge_idx, mask_ij)
!= potts._symmetrize_J_serial(J, edge_idx, mask_ij)
).sum().detach().numpy() == 0
def test_symmetry_folding():
N, Q = 12, 3
symmetry_order = 3
N_au = N // symmetry_order
# Testing symmetry
mask_i = torch.ones(1, N)
mask_ij = (1.0 - torch.eye(N))[None, ...]
h = torch.randn([1, N, Q])
J = torch.randn([1, N, N, Q, Q])
J = J + J.permute([0, 2, 1, 4, 3])
# J = torch.eye(Q)[None,None,None,...].expand([1, N, N, Q, Q])
J = J * mask_ij[..., None, None]
edge_idx = torch.arange(N).long()[None, None, :].expand([1, N, N])
h_fold, J_fold, edge_idx_fold, mask_i_fold, mask_ij_fold = fold_symmetry(
symmetry_order, h, J, edge_idx, mask_i, mask_ij, normalize=False
)
# Validate dimensions
assert tuple(h_fold.shape) == (1, N_au, Q)
assert tuple(J_fold.shape) == (1, N_au, N_au, Q, Q)
assert tuple(edge_idx_fold.shape) == (1, N_au, N_au)
assert tuple(mask_i_fold.shape) == (1, N_au)
assert tuple(mask_ij_fold.shape) == (1, N_au, N_au)
# Does the folded Potts model return same energies as full?
S_test_fold = torch.randint(high=Q, size=[1, N_au])
S_test = S_test_fold[:, None, :].expand([1, symmetry_order, N_au]).reshape([1, N])
U, U_i = compute_potts_energy(S_test, h, J, edge_idx)
U_fold, U_i_fold = compute_potts_energy(S_test_fold, h_fold, J_fold, edge_idx_fold)
assert torch.allclose(U, U_fold)
@pytest.mark.parametrize("proposal", ["dlmc", "chromatic"])
def test_potts_mcmc(proposal, debug=False):
"""MCMC test for Chromatic Gibbs sampling."""
# Build a test, fully connected Potts model
if debug:
# Heavy duty sampling with large state space
N = 5
q = 4
num_sweeps = 1000
num_chains = 1000
rtol = 0.05
else:
# Quick and dirty small state space
N = 3
q = 3
num_sweeps = 200
num_chains = 1000
rtol = 0.1
beta = 0.1
warmup_fraction = 0.1
torch.manual_seed(1)
mask_i = torch.ones([1, N]).float()
mask_ij = (1 - torch.eye(N))[None, ...].float()
edge_idx = torch.arange(N)[None, None, :].expand([1, N, N])
h = beta * torch.randn([1, N, q])
J = beta * torch.randn([1, N, N, q, q])
J = mask_ij[..., None, None] * (J + J.permute([0, 2, 1, 4, 3])) / np.sqrt(2)
# Enumerate all of sequence space
alphabet = "ABCDEFGHIJK"[:q]
sequences = ["".join(x) for x in product(alphabet, repeat=N)]
S_exact = torch.Tensor(
[[alphabet.index(s) for s in seq] for seq in sequences]
).long()
print(f"Enumerated {len(sequences)} sequences")
if torch.cuda.is_available():
device = "cuda"
h = h.to(device)
J = J.to(device)
edge_idx = edge_idx.to(device)
mask_i = mask_i.to(device)
mask_ij = mask_ij.to(device)
S_exact = S_exact.to(device)
# Compute exact distribution over sequence space
B = S_exact.shape[0]
h_expand = h.expand([B, -1, -1])
J_expand = J.expand([B, -1, -1, -1, -1])
edge_idx_expand = edge_idx.expand([B, -1, -1])
mask_i_expand = mask_i.expand([B, -1])
mask_ij_expand = mask_ij.expand([B, -1, -1])
U, _ = compute_potts_energy(S_exact, h_expand, J_expand, edge_idx_expand)
p_exact = F.softmax(-U, -1).tolist()
# Estimate distribution from sampled sequences
h_expand = h.expand([num_chains, -1, -1])
J_expand = J.expand([num_chains, -1, -1, -1, -1])
edge_idx_expand = edge_idx.expand([num_chains, -1, -1])
mask_i_expand = mask_i.expand([num_chains, -1])
mask_ij_expand = mask_ij.expand([num_chains, -1, -1])
S, U, S_trajectory, U_trajectory = sample_potts(
h_expand,
J_expand,
edge_idx_expand,
mask_i_expand,
mask_ij_expand,
num_sweeps=num_sweeps,
proposal=proposal,
rejection_step=True,
verbose=True,
return_trajectory=True,
)
if warmup_fraction is not None:
S_trajectory = S_trajectory[int(warmup_fraction * len(S_trajectory)) :]
S_samples = torch.cat(S_trajectory, 0)
U_trajectory = torch.stack(U_trajectory, 1).cpu().data.numpy()
S_samples = S_samples.cpu().data.numpy()
sample_counts = Counter(["".join([alphabet[c] for c in s]) for s in S_samples])
p_sample = [sample_counts[seq] / S_samples.shape[0] for seq in sequences]
if debug:
from matplotlib import pyplot as plt
plt.figure(figsize=(6, 3))
plt.subplot(1, 2, 1)
plt.plot(p_exact, p_sample, "k.")
plt.grid()
plt.axis("square")
plt.xlabel("Probability, exact enumeration")
plt.ylabel("Sampling frequencey (MCMC)")
plt.title(f"Random Potts model over {q}^{N} sequences")
plt.subplot(1, 2, 2)
plt.plot(U_trajectory[0, :])
plt.xlabel("Iterations")
plt.ylabel("Energy")
plt.tight_layout()
plt.show()
# The frequencies of states visited via MCMC should reproduce their
# exact probabilities (via enumeration) within rtol percent error
assert np.allclose(p_sample, p_exact, rtol=rtol)
def debug_potts_2D():
"""Debug test for Potts model"""
N = 100
q = 4
num_sites = N * N
mask_i = torch.ones([1, N]).float()
ix = torch.arange(num_sites).long()
# Build 2D lattice topology
edge_idx = torch.stack([ix + 1, ix - 1, ix + N, ix - N], -1)
mask_ij = torch.ones_like(edge_idx).float()[None, :, :]
edge_idx = torch.remainder(edge_idx, num_sites)[None, :, :].long()
# Ferromagnetic parameters
h = torch.zeros([1, num_sites, q])
h[:, :, 0] = h[:, :, 0]
mask_J = mask_ij[:, :, :, None, None] * torch.eye(q)[None, None, None, :, :]
if torch.cuda.is_available():
device = "cuda"
h = h.to(device)
edge_idx = edge_idx.to(device)
mask_J = mask_J.to(device)
mask_ij = mask_ij.to(device)
import numpy as np
from matplotlib import pyplot as plt
from matplotlib.animation import FuncAnimation
temp_range = (1.2, 0.8)
plt.figure(figsize=(5, 5), dpi=600)
_, _, S_trajectory, U_trajectory = sample_potts(
h,
-mask_J,
edge_idx,
mask_i,
mask_ij,
num_sweeps=10000,
verbose=True,
return_trajectory=True,
S=None,
annealing_fraction=1.0,
temperature_init=1.2,
temperature=0.8,
)
# Define a function to update the plot for each frame
num_frames = len(S_trajectory)
temps = np.linspace(temp_range[0], temp_range[1], len(S_trajectory))
betas = 1.0 / temps
def update(frame):
plt.clf() # Clear the previous frame
plt.pcolor(S_trajectory[frame].cpu().data.numpy().reshape([N, N]), cmap="tab10")
plt.clim([0, 10])
plt.axis("square")
plt.axis("off")
plt.title(f"Beta = {betas[frame]:0.2f}")
print(frame)
# Create a figure and set the number of frames
fig = plt.figure(figsize=(4, 4), dpi=300)
animation = FuncAnimation(fig, update, frames=num_frames, interval=1000 / 60)
animation.save("potts.mp4", writer="ffmpeg")
return