Spaces:
Sleeping
Sleeping
from collections import Counter | |
from itertools import product | |
import numpy as np | |
import pytest | |
import torch | |
from torch.nn import functional as F | |
from chroma.layers.structure.potts import ( | |
GraphPotts, | |
compute_potts_energy, | |
fold_symmetry, | |
sample_potts, | |
) | |
def test_graphpotts(): | |
# Testing symmetry | |
# Create non-symmetric Potts model and symmetrize using serial or not | |
potts = GraphPotts(128, 128, 20, symmetric_J=False) | |
node_h = torch.rand(1, 3, 128) | |
edge_h = torch.rand(1, 3, 2, 128) | |
edge_idx = torch.tensor([[[1, 2], [0, 2], [0, 1]]]) | |
mask_i = torch.ones(1, 3) | |
mask_ij = torch.ones(1, 3, 2) | |
h, J = potts(node_h, edge_h, edge_idx, mask_i, mask_ij) | |
assert ( | |
potts._symmetrize_J(J, edge_idx, mask_ij) | |
!= potts._symmetrize_J_serial(J, edge_idx, mask_ij) | |
).sum().detach().numpy() == 0 | |
mask_ij = torch.tensor([[[1, 1], [1, 0], [1, 0]]]) | |
h, J = potts(node_h, edge_h, edge_idx, mask_i, mask_ij) | |
assert ( | |
potts._symmetrize_J(J, edge_idx, mask_ij) | |
!= potts._symmetrize_J_serial(J, edge_idx, mask_ij) | |
).sum().detach().numpy() == 0 | |
def test_symmetry_folding(): | |
N, Q = 12, 3 | |
symmetry_order = 3 | |
N_au = N // symmetry_order | |
# Testing symmetry | |
mask_i = torch.ones(1, N) | |
mask_ij = (1.0 - torch.eye(N))[None, ...] | |
h = torch.randn([1, N, Q]) | |
J = torch.randn([1, N, N, Q, Q]) | |
J = J + J.permute([0, 2, 1, 4, 3]) | |
# J = torch.eye(Q)[None,None,None,...].expand([1, N, N, Q, Q]) | |
J = J * mask_ij[..., None, None] | |
edge_idx = torch.arange(N).long()[None, None, :].expand([1, N, N]) | |
h_fold, J_fold, edge_idx_fold, mask_i_fold, mask_ij_fold = fold_symmetry( | |
symmetry_order, h, J, edge_idx, mask_i, mask_ij, normalize=False | |
) | |
# Validate dimensions | |
assert tuple(h_fold.shape) == (1, N_au, Q) | |
assert tuple(J_fold.shape) == (1, N_au, N_au, Q, Q) | |
assert tuple(edge_idx_fold.shape) == (1, N_au, N_au) | |
assert tuple(mask_i_fold.shape) == (1, N_au) | |
assert tuple(mask_ij_fold.shape) == (1, N_au, N_au) | |
# Does the folded Potts model return same energies as full? | |
S_test_fold = torch.randint(high=Q, size=[1, N_au]) | |
S_test = S_test_fold[:, None, :].expand([1, symmetry_order, N_au]).reshape([1, N]) | |
U, U_i = compute_potts_energy(S_test, h, J, edge_idx) | |
U_fold, U_i_fold = compute_potts_energy(S_test_fold, h_fold, J_fold, edge_idx_fold) | |
assert torch.allclose(U, U_fold) | |
def test_potts_mcmc(proposal, debug=False): | |
"""MCMC test for Chromatic Gibbs sampling.""" | |
# Build a test, fully connected Potts model | |
if debug: | |
# Heavy duty sampling with large state space | |
N = 5 | |
q = 4 | |
num_sweeps = 1000 | |
num_chains = 1000 | |
rtol = 0.05 | |
else: | |
# Quick and dirty small state space | |
N = 3 | |
q = 3 | |
num_sweeps = 200 | |
num_chains = 1000 | |
rtol = 0.1 | |
beta = 0.1 | |
warmup_fraction = 0.1 | |
torch.manual_seed(1) | |
mask_i = torch.ones([1, N]).float() | |
mask_ij = (1 - torch.eye(N))[None, ...].float() | |
edge_idx = torch.arange(N)[None, None, :].expand([1, N, N]) | |
h = beta * torch.randn([1, N, q]) | |
J = beta * torch.randn([1, N, N, q, q]) | |
J = mask_ij[..., None, None] * (J + J.permute([0, 2, 1, 4, 3])) / np.sqrt(2) | |
# Enumerate all of sequence space | |
alphabet = "ABCDEFGHIJK"[:q] | |
sequences = ["".join(x) for x in product(alphabet, repeat=N)] | |
S_exact = torch.Tensor( | |
[[alphabet.index(s) for s in seq] for seq in sequences] | |
).long() | |
print(f"Enumerated {len(sequences)} sequences") | |
if torch.cuda.is_available(): | |
device = "cuda" | |
h = h.to(device) | |
J = J.to(device) | |
edge_idx = edge_idx.to(device) | |
mask_i = mask_i.to(device) | |
mask_ij = mask_ij.to(device) | |
S_exact = S_exact.to(device) | |
# Compute exact distribution over sequence space | |
B = S_exact.shape[0] | |
h_expand = h.expand([B, -1, -1]) | |
J_expand = J.expand([B, -1, -1, -1, -1]) | |
edge_idx_expand = edge_idx.expand([B, -1, -1]) | |
mask_i_expand = mask_i.expand([B, -1]) | |
mask_ij_expand = mask_ij.expand([B, -1, -1]) | |
U, _ = compute_potts_energy(S_exact, h_expand, J_expand, edge_idx_expand) | |
p_exact = F.softmax(-U, -1).tolist() | |
# Estimate distribution from sampled sequences | |
h_expand = h.expand([num_chains, -1, -1]) | |
J_expand = J.expand([num_chains, -1, -1, -1, -1]) | |
edge_idx_expand = edge_idx.expand([num_chains, -1, -1]) | |
mask_i_expand = mask_i.expand([num_chains, -1]) | |
mask_ij_expand = mask_ij.expand([num_chains, -1, -1]) | |
S, U, S_trajectory, U_trajectory = sample_potts( | |
h_expand, | |
J_expand, | |
edge_idx_expand, | |
mask_i_expand, | |
mask_ij_expand, | |
num_sweeps=num_sweeps, | |
proposal=proposal, | |
rejection_step=True, | |
verbose=True, | |
return_trajectory=True, | |
) | |
if warmup_fraction is not None: | |
S_trajectory = S_trajectory[int(warmup_fraction * len(S_trajectory)) :] | |
S_samples = torch.cat(S_trajectory, 0) | |
U_trajectory = torch.stack(U_trajectory, 1).cpu().data.numpy() | |
S_samples = S_samples.cpu().data.numpy() | |
sample_counts = Counter(["".join([alphabet[c] for c in s]) for s in S_samples]) | |
p_sample = [sample_counts[seq] / S_samples.shape[0] for seq in sequences] | |
if debug: | |
from matplotlib import pyplot as plt | |
plt.figure(figsize=(6, 3)) | |
plt.subplot(1, 2, 1) | |
plt.plot(p_exact, p_sample, "k.") | |
plt.grid() | |
plt.axis("square") | |
plt.xlabel("Probability, exact enumeration") | |
plt.ylabel("Sampling frequencey (MCMC)") | |
plt.title(f"Random Potts model over {q}^{N} sequences") | |
plt.subplot(1, 2, 2) | |
plt.plot(U_trajectory[0, :]) | |
plt.xlabel("Iterations") | |
plt.ylabel("Energy") | |
plt.tight_layout() | |
plt.show() | |
# The frequencies of states visited via MCMC should reproduce their | |
# exact probabilities (via enumeration) within rtol percent error | |
assert np.allclose(p_sample, p_exact, rtol=rtol) | |
def debug_potts_2D(): | |
"""Debug test for Potts model""" | |
N = 100 | |
q = 4 | |
num_sites = N * N | |
mask_i = torch.ones([1, N]).float() | |
ix = torch.arange(num_sites).long() | |
# Build 2D lattice topology | |
edge_idx = torch.stack([ix + 1, ix - 1, ix + N, ix - N], -1) | |
mask_ij = torch.ones_like(edge_idx).float()[None, :, :] | |
edge_idx = torch.remainder(edge_idx, num_sites)[None, :, :].long() | |
# Ferromagnetic parameters | |
h = torch.zeros([1, num_sites, q]) | |
h[:, :, 0] = h[:, :, 0] | |
mask_J = mask_ij[:, :, :, None, None] * torch.eye(q)[None, None, None, :, :] | |
if torch.cuda.is_available(): | |
device = "cuda" | |
h = h.to(device) | |
edge_idx = edge_idx.to(device) | |
mask_J = mask_J.to(device) | |
mask_ij = mask_ij.to(device) | |
import numpy as np | |
from matplotlib import pyplot as plt | |
from matplotlib.animation import FuncAnimation | |
temp_range = (1.2, 0.8) | |
plt.figure(figsize=(5, 5), dpi=600) | |
_, _, S_trajectory, U_trajectory = sample_potts( | |
h, | |
-mask_J, | |
edge_idx, | |
mask_i, | |
mask_ij, | |
num_sweeps=10000, | |
verbose=True, | |
return_trajectory=True, | |
S=None, | |
annealing_fraction=1.0, | |
temperature_init=1.2, | |
temperature=0.8, | |
) | |
# Define a function to update the plot for each frame | |
num_frames = len(S_trajectory) | |
temps = np.linspace(temp_range[0], temp_range[1], len(S_trajectory)) | |
betas = 1.0 / temps | |
def update(frame): | |
plt.clf() # Clear the previous frame | |
plt.pcolor(S_trajectory[frame].cpu().data.numpy().reshape([N, N]), cmap="tab10") | |
plt.clim([0, 10]) | |
plt.axis("square") | |
plt.axis("off") | |
plt.title(f"Beta = {betas[frame]:0.2f}") | |
print(frame) | |
# Create a figure and set the number of frames | |
fig = plt.figure(figsize=(4, 4), dpi=300) | |
animation = FuncAnimation(fig, update, frames=num_frames, interval=1000 / 60) | |
animation.save("potts.mp4", writer="ffmpeg") | |
return | |