# Copyright Generate Biomedicines, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Layers for building Potts models. This module contains layers for parameterizing Potts models from graph embeddings. """ from typing import Callable, List, Literal, Optional, Tuple, Union import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from tqdm.auto import tqdm from chroma.layers import graph class GraphPotts(nn.Module): """Conditional Random Field (conditional Potts model) layer on a graph. Arguments: dim_nodes (int): Hidden dimension of node tensor. dim_edges (int): Hidden dimension of edge tensor. num_states (int): Size of the vocabulary. parameterization (str): Parameterization choice in `{'linear', 'factor', 'score', 'score_zsum', 'score_scale'}`, or any of those suffixed with `_beta`, which will add in a globally learnable temperature scaling parameter. symmetric_J (bool): If True enforce symmetry of Potts model i.e. `J_ij(s_i, s_j) = J_ji(s_j, s_i)`. init_scale (float): Scale factor for the weights and couplings at initialization. dropout (float): Probability of per-dimension dropout on `[0,1]`. label_smoothing (float): Label smoothing probability on for when per token likelihoods. num_factors (int): Number of factors to use for the `factor` parameterization mode. beta_init (float): Initial temperature scaling factor for parameterizations with the `_beta` suffix. Inputs: node_h (torch.Tensor): Node features with shape `(num_batch, num_nodes, dim_nodes)`. edge_h (torch.Tensor): Edge features with shape `(num_batch, num_nodes, num_neighbors, dim_edges)`. edge_idx (torch.LongTensor): Edge indices with shape `(num_batch, num_nodes, num_neighbors)`. mask_i (torch.Tensor): Node mask with shape `(num_batch, num_nodes)` mask_ij (torch.Tensor): Edge mask with shape `(num_batch, num_nodes, num_neighbors)` Outputs: h (torch.Tensor): Potts model fields :math:`h_i(s_i)` with shape `(num_batch, num_nodes, num_states)`. J (Tensor): Potts model couplings :math:`J_{ij}(s_i, s_j)` with shape `(num_batch, num_nodes, num_neighbors, num_states, num_states)`. """ def __init__( self, dim_nodes: int, dim_edges: int, num_states: int, parameterization: str = "score", symmetric_J: bool = True, init_scale: float = 0.1, dropout: float = 0.0, label_smoothing: float = 0.0, num_factors: Optional[int] = None, beta_init: float = 10.0, ): super(GraphPotts, self).__init__() self.dim_nodes = dim_nodes self.dim_edges = dim_edges self.num_states = num_states self.label_smoothing = label_smoothing # Beta parameterization support temperature learning self.scale_beta = False if parameterization.endswith("_beta"): parameterization = parameterization.split("_beta")[0] self.scale_beta = True self.log_beta = nn.Parameter(np.log(beta_init) * torch.ones(1)) self.init_scale = init_scale self.parameterization = parameterization self.symmetric_J = symmetric_J if self.parameterization == "linear": self.log_scale = nn.Parameter(np.log(init_scale) * torch.ones(1)) self.W_h = nn.Linear(self.dim_nodes, self.num_states, bias=True) self.W_J = nn.Linear(self.dim_edges, self.num_states ** 2, bias=True) elif self.parameterization == "factor": self.log_scale = nn.Parameter(np.log(init_scale) * torch.ones(1)) self.W_h = nn.Linear(self.dim_nodes, self.num_states, bias=True) self.W_J_left = nn.Linear(self.dim_edges, self.num_states ** 2, bias=True) self.W_J_right = nn.Linear(self.dim_edges, self.num_states ** 2, bias=True) elif self.parameterization == "score": if num_factors is None: num_factors = dim_edges self.num_factors = num_factors self.log_scale = nn.Parameter(np.log(init_scale) * torch.ones(1)) self.W_h_bg = nn.Linear(self.dim_nodes, 1) self.W_J_bg = nn.Linear(self.dim_edges, 1) self.W_h = nn.Linear(self.dim_nodes, self.num_states, bias=True) self.W_J_left = nn.Linear( self.dim_edges, self.num_states * num_factors, bias=True ) self.W_J_right = nn.Linear( self.dim_edges, self.num_states * num_factors, bias=True ) elif self.parameterization == "score_zsum": if num_factors is None: num_factors = dim_edges self.num_factors = num_factors self.log_scale = nn.Parameter(np.log(init_scale) * torch.ones(1)) self.W_h = nn.Linear(self.dim_nodes, self.num_states, bias=True) self.W_J_left = nn.Linear( self.dim_edges, self.num_states * num_factors, bias=True ) self.W_J_right = nn.Linear( self.dim_edges, self.num_states * num_factors, bias=True ) elif self.parameterization == "score_scale": if num_factors is None: num_factors = dim_edges self.num_factors = num_factors self.W_h_bg = nn.Linear(self.dim_nodes, 1) self.W_J_bg = nn.Linear(self.dim_edges, 1) self.W_h_log_scale = nn.Linear(self.dim_nodes, 1) self.W_J_log_scale = nn.Linear(self.dim_edges, 1) self.W_h = nn.Linear(self.dim_nodes, self.num_states) self.W_J_left = nn.Linear(self.dim_edges, self.num_states * num_factors) self.W_J_right = nn.Linear(self.dim_edges, self.num_states * num_factors) else: print(f"Unknown potts parameterization: {parameterization}") raise NotImplementedError self.dropout = nn.Dropout(dropout) def _mask_J(self, edge_idx, mask_i, mask_ij): # Remove self edges device = edge_idx.device ii = torch.arange(edge_idx.shape[1]).view((1, -1, 1)).to(device) not_self = torch.ne(edge_idx, ii).type(torch.float32) # Remove missing edges self_present = mask_i.unsqueeze(-1) neighbor_present = graph.collect_neighbors(self_present, edge_idx) neighbor_present = neighbor_present.squeeze(-1) mask_J = not_self * self_present * neighbor_present if mask_ij is not None: mask_J = mask_ij * mask_J return mask_J def forward( self, node_h: torch.Tensor, edge_h: torch.Tensor, edge_idx: torch.LongTensor, mask_i: torch.Tensor, mask_ij: torch.Tensor, ): mask_J = self._mask_J(edge_idx, mask_i, mask_ij) if self.parameterization == "linear": # Compute site params (h) from node embeddings # Compute coupling params (J) from edge embeddings scale = torch.exp(self.log_scale) h = scale * mask_i.unsqueeze(-1) * self.W_h(node_h) J = scale * mask_J.unsqueeze(-1) * self.W_J(edge_h) J = J.view(list(edge_h.size())[:3] + ([self.num_states] * 2)) elif self.parameterization == "factor": scale = torch.exp(self.log_scale) h = scale * mask_i.unsqueeze(-1) * self.W_h(node_h) mask_J = scale * mask_J.unsqueeze(-1) shape_J = list(edge_h.size())[:3] + ([self.num_states] * 2) J_left = (mask_J * self.W_J_left(edge_h)).view(shape_J) J_right = (mask_J * self.W_J_right(edge_h)).view(shape_J) J = torch.matmul(J_left, J_right) J = self.dropout(J) # Zero-sum gauge h = h - h.mean(-1, keepdim=True) J = ( J - J.mean(-1, keepdim=True) - J.mean(-2, keepdim=True) + J.mean(dim=[-1, -2], keepdim=True) ) elif self.parameterization == "score": node_h = self.dropout(node_h) edge_h = self.dropout(edge_h) scale = torch.exp(self.log_scale) mask_h = scale * mask_i.unsqueeze(-1) mask_J = scale * mask_J.unsqueeze(-1) h = mask_h * self.W_h(node_h) shape_J_prefix = list(edge_h.size())[:3] J_left = (mask_J * self.W_J_left(edge_h)).view( shape_J_prefix + [self.num_states, self.num_factors] ) J_right = (mask_J * self.W_J_right(edge_h)).view( shape_J_prefix + [self.num_factors, self.num_states] ) J = torch.matmul(J_left, J_right) # Zero-sum gauge h = h - h.mean(-1, keepdim=True) J = ( J - J.mean(-1, keepdim=True) - J.mean(-2, keepdim=True) + J.mean(dim=[-1, -2], keepdim=True) ) # Background components h = h + mask_h * self.W_h_bg(node_h) J = J + (mask_J * self.W_J_bg(edge_h)).unsqueeze(-1) elif self.parameterization == "score_zsum": node_h = self.dropout(node_h) edge_h = self.dropout(edge_h) scale = torch.exp(self.log_scale) mask_h_scale = scale * mask_i.unsqueeze(-1) mask_J_scale = scale * mask_J.unsqueeze(-1) h = mask_h_scale * self.W_h(node_h) shape_J_prefix = list(edge_h.size())[:3] J_left = (mask_J_scale * self.W_J_left(edge_h)).view( shape_J_prefix + [self.num_states, self.num_factors] ) J_right = (mask_J_scale * self.W_J_right(edge_h)).view( shape_J_prefix + [self.num_factors, self.num_states] ) J = torch.matmul(J_left, J_right) J = self.dropout(J) # Zero-sum gauge J = ( J - J.mean(-1, keepdim=True) - J.mean(-2, keepdim=True) + J.mean(dim=[-1, -2], keepdim=True) ) # Subtract off J background average mask_J = mask_J.view(list(mask_J.size()) + [1, 1]) J_i_avg = J.sum(dim=[1, 2], keepdim=True) / mask_J.sum([1, 2], keepdim=True) J = mask_J * (J - J_i_avg) elif self.parameterization == "score_scale": node_h = self.dropout(node_h) edge_h = self.dropout(edge_h) mask_h = mask_i.unsqueeze(-1) mask_J = mask_J.unsqueeze(-1) h = mask_h * self.W_h(node_h) shape_J_prefix = list(edge_h.size())[:3] J_left = (mask_J * self.W_J_left(edge_h)).view( shape_J_prefix + [self.num_states, self.num_factors] ) J_right = (mask_J * self.W_J_right(edge_h)).view( shape_J_prefix + [self.num_factors, self.num_states] ) J = torch.matmul(J_left, J_right) # Zero-sum gauge h = h - h.mean(-1, keepdim=True) J = ( J - J.mean(-1, keepdim=True) - J.mean(-2, keepdim=True) + J.mean(dim=[-1, -2], keepdim=True) ) # Background components log_scale = np.log(self.init_scale) h_scale = torch.exp(self.W_h_log_scale(node_h) + log_scale) J_scale = torch.exp(self.W_J_log_scale(edge_h) + 2 * log_scale).unsqueeze( -1 ) h_bg = mask_h * self.W_h_bg(node_h) J_bg = (mask_J * self.W_J_bg(edge_h)).unsqueeze(-1) h = h_scale * (h + h_bg) J = J_scale * (J + J_bg) if self.symmetric_J: J = self._symmetrize_J(J, edge_idx, mask_ij) if self.scale_beta: beta = torch.exp(self.log_beta) h = beta * h J = beta * J return h, J def _symmetrize_J_serial(self, J, edge_idx, mask_ij): """Enforce symmetry of J matrices, serial version.""" num_batch, num_residues, num_k, num_states, _ = list(J.size()) # Symmetrization based on raw indexing - extremely slow; for debugging import time _start = time.time() J_symm = torch.zeros_like(J) for b in range(J.size(0)): for i in range(J.size(1)): for k_i in range(J.size(2)): for k_j in range(J.size(2)): j = edge_idx[b, i, k_i] if edge_idx[b, j, k_j] == i: J_symm[b, i, k_i, :, :] = ( J[b, i, k_i, :, :] + J[b, j, k_j, :, :].transpose(-1, -2) ) / 2.0 speed = J.size(0) * J.size(1) / (time.time() - _start) print(f"symmetrized at {speed} residue/s") return J_symm def _symmetrize_J(self, J, edge_idx, mask_ij): """Enforce symmetry of J matrices via adding J_ij + J_ji^T""" num_batch, num_residues, num_k, num_states, _ = list(J.size()) # Flatten and gather J_ji matrices using transpose indexing J_flat = J.view(num_batch, num_residues, num_k, -1) J_flat_transpose, mask_ji = graph.collect_edges_transpose( J_flat, edge_idx, mask_ij ) J_transpose = J_flat_transpose.view( num_batch, num_residues, num_k, num_states, num_states ) # Transpose J_ji matrices to symmetrize as (J_ij + J_ji^T)/2 J_transpose = J_transpose.transpose(-2, -1) mask_ji = (0.5 * mask_ji).view(num_batch, num_residues, num_k, 1, 1) J_symm = mask_ji * (J + J_transpose) return J_symm def energy( self, S: torch.LongTensor, h: torch.Tensor, J: torch.Tensor, edge_idx: torch.LongTensor, ) -> torch.Tensor: """Compute Potts model energy from sequence. Inputs: S (torch.LongTensor): Sequence with shape `(num_batch, num_nodes)`. h (torch.Tensor): Potts model fields :math:`h_i(s_i)` with shape `(num_batch, num_nodes, num_states)`. J (Tensor): Potts model couplings :math:`J_{ij}(s_i, s_j)` with shape `(num_batch, num_nodes, num_neighbors, num_states, num_states)`. edge_idx (torch.LongTensor): Edge indices with shape `(num_batch, num_nodes, num_neighbors)`. Outputs: U (torch.Tensor): Potts total energies with shape `(num_batch)`. Lower energies are more favorable. """ # Gather J [Batch,i,j,A_i,A_j] => J_ij(:,A_j) [Batch,i,j,A_i] S_j = graph.collect_neighbors(S.unsqueeze(-1), edge_idx) S_j = S_j.unsqueeze(-1).expand(-1, -1, -1, self.num_states, -1) J_ij = torch.gather(J, -1, S_j).squeeze(-1) # Sum out J contributions J_i = J_ij.sum(2) / 2.0 r_i = h + J_i U_i = torch.gather(r_i, 2, S.unsqueeze(-1)) U = U_i.sum([1, 2]) return U def pseudolikelihood( self, S: torch.LongTensor, h: torch.Tensor, J: torch.Tensor, edge_idx: torch.LongTensor, ) -> torch.Tensor: """Compute Potts pseudolikelihood from sequence Inputs: S (torch.LongTensor): Sequence with shape `(num_batch, num_nodes)`. h (torch.Tensor): Potts model fields :math:`h_i(s_i)` with shape `(num_batch, num_nodes, num_states)`. J (Tensor): Potts model couplings :math:`J_{ij}(s_i, s_j)` with shape `(num_batch, num_nodes, num_neighbors, num_states, num_states)`. edge_idx (torch.LongTensor): Edge indices with shape `(num_batch, num_nodes, num_neighbors)`. Outputs: log_probs (torch.Tensor): Potts log-pseudolihoods with shape `(num_batch, num_nodes, num_states)`. """ # Gather J [Batch,i,j,A_i,A_j] => J_ij(:,A_j) [Batch,i,j,A_i] S_j = graph.collect_neighbors(S.unsqueeze(-1), edge_idx) S_j = S_j.unsqueeze(-1).expand(-1, -1, -1, self.num_states, -1) J_ij = torch.gather(J, -1, S_j).squeeze(-1) # Sum out J contributions J_i = J_ij.sum(2) logits = h + J_i log_probs = F.log_softmax(-logits, dim=-1) return log_probs def log_composite_likelihood( self, S: torch.LongTensor, h: torch.Tensor, J: torch.Tensor, edge_idx: torch.LongTensor, mask_i: torch.Tensor, mask_ij: torch.Tensor, smoothing_alpha: float = 0.0, ) -> Tuple[torch.Tensor, torch.Tensor]: """Compute Potts pairwise composite likelihoods from sequence. Inputs: S (torch.LongTensor): Sequence with shape `(num_batch, num_nodes)`. h (torch.Tensor): Potts model fields :math:`h_i(s_i)` with shape `(num_batch, num_nodes, num_states)`. J (Tensor): Potts model couplings :math:`J_{ij}(s_i, s_j)` with shape `(num_batch, num_nodes, num_neighbors, num_states, num_states)`. edge_idx (torch.LongTensor): Edge indices with shape `(num_batch, num_nodes, num_neighbors)`. mask_i (torch.Tensor): Node mask with shape `(num_batch, num_nodes)` mask_ij (torch.Tensor): Edge mask with shape `(num_batch, num_nodes, num_neighbors)`. smoothing_alpha (float): Label smoothing probability on `(0,1)`. Outputs: logp_ij (torch.Tensor): Potts pairwise composite likelihoods evaluated for the current sequence with shape `(num_batch, num_nodes, num_neighbors)`. mask_p_ij (torch.Tensor): Edge mask with shape `(num_batch, num_nodes, num_neighbors)`. """ num_batch, num_residues, num_k, num_states, _ = list(J.size()) # Gather J clamped at j # [Batch,i,j,A_i,A_j] => J_ij(:,A_j) [Batch,i,j,A_i] S_j = graph.collect_neighbors(S.unsqueeze(-1), edge_idx) S_j = S_j.unsqueeze(-1).expand(-1, -1, -1, num_states, -1) # (B,i,j,A_i) J_clamp_j = torch.gather(J, -1, S_j).squeeze(-1) # Gather J clamped at i S_i = S.view(num_batch, num_residues, 1, 1, 1) S_i = S_i.expand(-1, -1, num_k, num_states, num_states) # (B,i,j,1,A_j) J_clamp_i = torch.gather(J, -2, S_i) # Compute background per-site contributions that sum out J # (B,i,j,A_i) => (B,i,A_i) r_i = h + J_clamp_j.sum(2) r_j = graph.collect_neighbors(r_i, edge_idx) # Remove J_ij from the i contributions # (B,i,A_i) => (B,i,:,A_i,:) r_i = r_i.view([num_batch, num_residues, 1, num_states, 1]) r_i_minus_ij = r_i - J_clamp_j.unsqueeze(-1) # Remove J_ji from the j contributions # (B,j,A_j) => (B,:,j,:,A_j) r_j = r_j.view([num_batch, num_residues, num_k, 1, num_states]) r_j_minus_ji = r_j - J_clamp_i # Composite likelihood (B,i,j,A_i,A_j) logits_ij = r_i_minus_ij + r_j_minus_ji + J logits_ij = logits_ij.view([num_batch, num_residues, num_k, -1]) logp = F.log_softmax(-logits_ij, dim=-1) logp = logp.view([num_batch, num_residues, num_k, num_states, num_states]) # Score the current sequence under # (B,i,j,A_i,A_j) => (B,i,j,A_i) => (B,i,j) logp_j = torch.gather(logp, -1, S_j).squeeze(-1) S_i = S.view(num_batch, num_residues, 1, 1).expand(-1, -1, num_k, -1) logp_ij = torch.gather(logp_j, -1, S_i).squeeze(-1) # Optional label smoothing (scaled assuming per-token smoothing ) if smoothing_alpha > 0.0: # Foreground probability num_bins = num_states ** 2 prob_no_smooth = (1.0 - smoothing_alpha) ** 2 prob_background = (1.0 - prob_no_smooth) / float(num_bins - 1) # The second term corrects for double counting in background sum p_foreground = prob_no_smooth - prob_background logp_ij = p_foreground * logp_ij + prob_background * logp.sum([-2, -1]) mask_p_ij = self._mask_J(edge_idx, mask_i, mask_ij) logp_ij = mask_p_ij * logp_ij return logp_ij, mask_p_ij def loss( self, S: torch.LongTensor, node_h: torch.Tensor, edge_h: torch.Tensor, edge_idx: torch.LongTensor, mask_i: torch.Tensor, mask_ij: torch.Tensor, ) -> torch.Tensor: """Compute per-residue losses given a sequence. Inputs: S (torch.LongTensor): Sequence with shape `(num_batch, num_nodes)`. node_h (torch.Tensor): Node features with shape `(num_batch, num_nodes, dim_nodes)`. edge_h (torch.Tensor): Edge features with shape `(num_batch, num_nodes, num_neighbors, dim_edges)`. edge_idx (torch.LongTensor): Edge indices with shape `(num_batch, num_nodes, num_neighbors)`. mask_i (torch.Tensor): Node mask with shape `(num_batch, num_nodes)` mask_ij (torch.Tensor): Edge mask with shape `(num_batch, num_nodes, num_neighbors)` Outputs: logp_i (torch.Tensor): Potts per-residue normalized composite log likelihoods with shape`(num_batch, num_nodes)`. """ # Compute parameters h, J = self.forward(node_h, edge_h, edge_idx, mask_i, mask_ij) # Log composite likelihood logp_ij, mask_p_ij = self.log_composite_likelihood( S, h, J, edge_idx, mask_i, mask_ij, smoothing_alpha=self.label_smoothing if self.training else 0.0, ) # Map into approximate local likelihoods logp_i = ( mask_i * torch.sum(mask_p_ij * logp_ij, dim=-1) / (2.0 * torch.sum(mask_p_ij, dim=-1) + 1e-3) ) return logp_i def sample( self, node_h: torch.Tensor, edge_h: torch.Tensor, edge_idx: torch.LongTensor, mask_i: torch.Tensor, mask_ij: torch.Tensor, S: Optional[torch.LongTensor] = None, mask_sample: Optional[torch.Tensor] = None, num_sweeps: int = 100, temperature: float = 0.1, temperature_init: float = 1.0, penalty_func: Optional[Callable[[torch.LongTensor], torch.Tensor]] = None, differentiable_penalty: bool = True, rejection_step: bool = False, proposal: Literal["dlmc", "chromatic"] = "dlmc", verbose: bool = False, edge_idx_coloring: Optional[torch.LongTensor] = None, mask_ij_coloring: Optional[torch.Tensor] = None, symmetry_order: Optional[int] = None, ) -> Tuple[torch.LongTensor, torch.Tensor]: """Sample from Potts model with Chromatic Gibbs sampling. Args: node_h (torch.Tensor): Node features with shape `(num_batch, num_nodes, dim_nodes)`. edge_h (torch.Tensor): Edge features with shape `(num_batch, num_nodes, num_neighbors, dim_edges)`. edge_idx (torch.LongTensor): Edge indices with shape `(num_batch, num_nodes, num_neighbors)`. mask_i (torch.Tensor): Node mask with shape `(num_batch, num_nodes)`. mask_ij (torch.Tensor): Edge mask with shape `(num_batch, num_nodes, num_neighbors)`. S (torch.LongTensor, optional): Sequence for initialization with shape `(num_batch, num_nodes)`. mask_sample (torch.Tensor, optional): Binary sampling mask indicating positions which are free to change with shape `(num_batch, num_nodes)` or which tokens are acceptable at each position with shape `(num_batch, num_nodes, alphabet)`. num_sweeps (int): Number of sweeps of Chromatic Gibbs to perform, i.e. the depth of sampling as measured by the number of times every position has had an opportunity to update. temperature (float): Final sampling temperature. temperature_init (float): Initial sampling temperature, which will be linearly interpolated to `temperature` over the course of the burn in phase. penalty_func (Callable, optional): An optional penalty function which takes a sequence `S` and outputes a `(num_batch)` shaped tensor of energy adjustments, for example as regularization. differentiable_penalty (bool): If True, gradients of penalty function will be used to adjust the proposals. rejection_step (bool): If True, perform a Metropolis-Hastings rejection step. proposal (str): MCMC proposal for Potts sampling. Currently implemented proposals are `dlmc` for Discrete Langevin Monte Carlo [1] or `chromatic` for Gibbs sampling with graph coloring. [1] Sun et al. Discrete Langevin Sampler via Wasserstein Gradient Flow (2023). verbose (bool): If True print verbose output during sampling. edge_idx_coloring (torch.LongerTensor, optional): Alternative graph dependency structure that can be provided for the Chromatic Gibbs algorithm when it performs initial graph coloring. Has shape `(num_batch, num_nodes, num_neighbors_coloring)`. mask_ij_coloring (torch.Tensor): Edge mask for the alternative dependency structure with shape `(num_batch, num_nodes, num_neighbors_coloring)`. symmetry_order (int, optional): Optional integer argument to enable symmetric sequence decoding under `symmetry_order`-order symmetry. The first `(num_nodes // symmetry_order)` states will be free to move, and all consecutively tiled sets of states will be locked to these during decoding. Internally this is accomplished by summing the parameters Potts model under a symmetry constraint into this reduced sized system and then back imputing at the end. Returns: S (torch.LongTensor): Sampled sequences with shape `(num_batch, num_nodes)`. U (torch.Tensor): Sampled energies with shape `(num_batch)`. Lower is more favorable. """ B, N, _ = node_h.shape # Compute parameters h, J = self.forward(node_h, edge_h, edge_idx, mask_i, mask_ij) if symmetry_order is not None: h, J, edge_idx, mask_i, mask_ij = fold_symmetry( symmetry_order, h, J, edge_idx, mask_i, mask_ij ) S = S[:, : (N // symmetry_order)] if mask_sample is not None: mask_sample = mask_sample[:, : (N // symmetry_order)] S_sample, U_sample = sample_potts( h, J, edge_idx, mask_i, mask_ij, S=S, mask_sample=mask_sample, num_sweeps=num_sweeps, temperature=temperature, temperature_init=temperature_init, penalty_func=penalty_func, differentiable_penalty=differentiable_penalty, rejection_step=rejection_step, proposal=proposal, verbose=verbose, edge_idx_coloring=edge_idx_coloring, mask_ij_coloring=mask_ij_coloring, ) if symmetry_order is not None: assert N % symmetry_order == 0 S_sample = ( S_sample[:, None, :].expand([-1, symmetry_order, -1]).reshape([B, N]) ) return S_sample, U_sample def compute_potts_energy( S: torch.LongTensor, h: torch.Tensor, J: torch.Tensor, edge_idx: torch.LongTensor, ) -> Tuple[torch.Tensor, torch.Tensor]: """Compute Potts model energies from sequence. Args: S (torch.LongTensor): Sequence with shape `(num_batch, num_nodes)`. h (torch.Tensor): Potts model fields :math:`h_i(s_i)` with shape `(num_batch, num_nodes, num_states)`. J (Tensor): Potts model couplings :math:`J_{ij}(s_i, s_j)` with shape `(num_batch, num_nodes, num_neighbors, num_states, num_states)`. edge_idx (torch.LongTensor): Edge indices with shape `(num_batch, num_nodes, num_neighbors)`. Returns: U (torch.Tensor): Potts total energies with shape `(num_batch)`. Lower energies are more favorable. U_i (torch.Tensor): Potts local conditional energies with shape `(num_batch, num_nodes, num_states)`. """ S_j = graph.collect_neighbors(S.unsqueeze(-1), edge_idx) S_j = S_j.unsqueeze(-1).expand(-1, -1, -1, h.shape[-1], -1) J_ij = torch.gather(J, -1, S_j).squeeze(-1) # Sum out J contributions to yield local conditionals J_i = J_ij.sum(2) U_i = h + J_i # Correct for double counting in total energy S_expand = S[..., None] U = ( torch.gather(U_i, -1, S[..., None]) - 0.5 * torch.gather(J_i, -1, S[..., None]) ).sum((1, 2)) return U, U_i def fold_symmetry( symmetry_order: int, h: torch.Tensor, J: torch.Tensor, edge_idx: torch.LongTensor, mask_i: torch.Tensor, mask_ij: torch.Tensor, normalize=True, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """Fold Potts model symmetrically. Args: symmetry_order (int): The order of symmetry by which to fold the Potts model such that the first `(num_nodes // symmetry_order)` states represent the entire system and all fields and couplings to and among other copies of this base system are collected together in single reduced Potts model. h (torch.Tensor): Potts model fields :math:`h_i(s_i)` with shape `(num_batch, num_nodes, num_states)`. J (Tensor): Potts model couplings :math:`J_{ij}(s_i, s_j)` with shape `(num_batch, num_nodes, num_neighbors, num_states, num_states)`. edge_idx (torch.LongTensor): Edge indices with shape `(num_batch, num_nodes, num_neighbors)`. mask_i (torch.Tensor): Node mask with shape `(num_batch, num_nodes)`. mask_ij (torch.Tensor): Edge mask with shape `(num_batch, num_nodes, num_neighbors)`. normalize (bool): If True (default), aggregate the Potts model as an average energy across asymmetric units instead of as a sum. Returns: h_fold (torch.Tensor): Potts model fields :math:`h_i(s_i)` with shape `(num_batch, num_nodes_folded, num_states)`, where `num_nodes_folded = num_nodes // symmetry_order`. J_fold (Tensor): Potts model couplings :math:`J_{ij}(s_i, s_j)` with shape `(num_batch, num_nodes_folded, num_neighbors, num_states, num_states)`. edge_idx_fold (torch.LongTensor): Edge indices with shape `(num_batch, num_nodes_folded, num_neighbors)`. mask_i_fold (torch.Tensor): Node mask with shape `(num_batch, num_nodes_folded)`. mask_ij_fold (torch.Tensor): Edge mask with shape `(num_batch, num_nodes_folded, num_neighbors)`. """ B, N, K, Q, _ = J.shape device = h.device N_asymmetric = N // symmetry_order # Fold edges by densifying the assymetric unit and averaging edge_idx_au = torch.remainder(edge_idx, N_asymmetric).clamp(max=N_asymmetric - 1) def _pairwise_fold(_T): # Fold-sum along neighbor dimension shape = list(_T.shape) shape[2] = N_asymmetric _T_au_expand = torch.zeros(shape, device=device).float() extra_dims = len(_T.shape) - len(edge_idx_au.shape) edge_idx_au_expand = edge_idx_au.reshape( list(edge_idx_au.shape) + [1] * extra_dims ).expand([-1, -1, -1] + [Q] * extra_dims) _T_au_expand.scatter_add_(2, edge_idx_au_expand, _T.float()) # Fold-mean along self dimension shape_out = [shape[0], -1, N_asymmetric, N_asymmetric] + shape[3:] _T_au = _T_au_expand.reshape(shape_out).sum(1) return _T_au J_fold = _pairwise_fold(J) mask_ij_fold = (_pairwise_fold(mask_ij) > 0).float() edge_idx_fold = ( torch.arange(N_asymmetric, device=device) .long()[None, None, :] .expand(mask_ij_fold.shape) ) # Drop unused edges K_fold = mask_ij_fold.sum(2).max().item() _, sort_ix = torch.sort(mask_ij_fold, dim=2, descending=True) sort_ix_J = sort_ix[..., None, None].expand(list(sort_ix.shape) + [Q, Q]) edge_idx_fold = torch.gather(edge_idx_fold, 2, sort_ix) mask_ij_fold = torch.gather(mask_ij_fold, 2, sort_ix) J_fold = torch.gather(J_fold, 2, sort_ix_J) # Fold-mean along self dimension h_fold = h.reshape([B, -1, N_asymmetric, Q]).sum(1) mask_i_fold = (mask_i.reshape([B, -1, N_asymmetric]).sum(1) > 0).float() if normalize: h_fold = h_fold / symmetry_order J_fold = J_fold / symmetry_order return h_fold, J_fold, edge_idx_fold, mask_i_fold, mask_ij_fold @torch.no_grad() def _color_graph(edge_idx, mask_ij, max_iter=100): """Stochastic graph coloring.""" # Randomly assign initial colors B, N, K = edge_idx.shape # By Brooks we only need K + 1, but one extra color aids convergence num_colors = K + 2 S = torch.randint(0, num_colors, (B, N), device=edge_idx.device) # Ignore self-attachement ix = torch.arange(edge_idx.shape[1], device=edge_idx.device)[None, ..., None] mask_ij = (mask_ij * torch.ne(edge_idx, ix).float())[..., None] # Iteratively replace clashing sites with an available color i = 0 total_clashes = 1 while total_clashes > 0 and i < max_iter: # Tabulate available colors in neighborhood O_i = F.one_hot(S, num_colors).float() N_i = (mask_ij * graph.collect_neighbors(O_i, edge_idx)).sum(2) clashes = (O_i * N_i).sum(-1) N_i = torch.where(N_i > 0, -float("inf") * torch.ones_like(N_i), N_i) # Resample from this distribution where clashing S_new = torch.distributions.categorical.Categorical(logits=N_i).sample() S = torch.where(clashes > 0, S_new, S) i += 1 total_clashes = clashes.sum().item() return S @torch.no_grad() def sample_potts( h: torch.Tensor, J: torch.Tensor, edge_idx: torch.LongTensor, mask_i: torch.Tensor, mask_ij: torch.Tensor, S: Optional[torch.LongTensor] = None, mask_sample: Optional[torch.Tensor] = None, num_sweeps: int = 100, temperature: float = 1.0, temperature_init: float = 1.0, annealing_fraction: float = 0.8, penalty_func: Optional[Callable[[torch.LongTensor], torch.Tensor]] = None, differentiable_penalty: bool = True, rejection_step: bool = False, proposal: Literal["dlmc", "chromatic"] = "dlmc", verbose: bool = True, return_trajectory: bool = False, thin_sweeps: int = 3, edge_idx_coloring: Optional[torch.LongTensor] = None, mask_ij_coloring: Optional[torch.Tensor] = None, ) -> Union[ Tuple[torch.LongTensor, torch.Tensor], Tuple[torch.LongTensor, torch.Tensor, List[torch.LongTensor], List[torch.Tensor]], ]: """Sample from Potts model with Chromatic Gibbs sampling. Args: h (torch.Tensor): Potts model fields :math:`h_i(s_i)` with shape `(num_batch, num_nodes, num_states)`. J (Tensor): Potts model couplings :math:`J_{ij}(s_i, s_j)` with shape `(num_batch, num_nodes, num_neighbors, num_states, num_states)`. edge_idx (torch.LongTensor): Edge indices with shape `(num_batch, num_nodes, num_neighbors)`. mask_i (torch.Tensor): Node mask with shape `(num_batch, num_nodes)`. mask_ij (torch.Tensor): Edge mask with shape `(num_batch, num_nodes, num_neighbors)`. S (torch.LongTensor, optional): Sequence for initialization with shape `(num_batch, num_nodes)`. mask_sample (torch.Tensor, optional): Binary sampling mask indicating positions which are free to change with shape `(num_batch, num_nodes)` or which tokens are acceptable at each position with shape `(num_batch, num_nodes, alphabet)`. num_sweeps (int): Number of sweeps of Chromatic Gibbs to perform, i.e. the depth of sampling as measured by the number of times every position has had an opportunity to update. temperature (float): Final sampling temperature. temperature_init (float): Initial sampling temperature, which will be linearly interpolated to `temperature` over the course of the burn in phase. annealing_fraction (float): Fraction of the total sampling run during which temperature annealing occurs. penalty_func (Callable, optional): An optional penalty function which takes a sequence `S` and outputes a `(num_batch)` shaped tensor of energy adjustments, for example as regularization. differentiable_penalty (bool): If True, gradients of penalty function will be used to adjust the proposals. rejection_step (bool): If True, perform a Metropolis-Hastings rejection step. proposal (str): MCMC proposal for Potts sampling. Currently implemented proposals are `dlmc` for Discrete Langevin Monte Carlo [1] or `chromatic` for Gibbs sampling with graph coloring. [1] Sun et al. Discrete Langevin Sampler via Wasserstein Gradient Flow (2023). verbose (bool): If True print verbose output during sampling. return_trajectory (bool): If True, also output the sampling trajectories of `S` and `U`. thin_sweeps (int): When returning trajectories, only save every `thin_sweeps` state to reduce memory usage. edge_idx_coloring (torch.LongerTensor, optional): Alternative graph dependency structure that can be provided for the Chromatic Gibbs algorithm when it performs initial graph coloring. Has shape `(num_batch, num_nodes, num_neighbors_coloring)`. mask_ij_coloring (torch.Tensor): Edge mask for the alternative dependency structure with shape `(num_batch, num_nodes, num_neighbors_coloring)`. Returns: S (torch.LongTensor): Sampled sequences with shape `(num_batch, num_nodes)`. U (torch.Tensor): Sampled energies with shape `(num_batch)`. Lower is more favorable. S_trajectory (List[torch.LongTensor]): List of sampled sequences through time each with shape `(num_batch, num_nodes)`. U_trajectory (List[torch.Tensor]): List of sampled energies through time each with shape `(num_batch)`. """ # Initialize masked proposals and mask h mask_S, mask_mutatable, S = init_sampling_masks(-h, mask_sample, S) h_numerical_zero = h.max() + 1e3 * max(1.0, temperature) h = torch.where(mask_S > 0, h, h_numerical_zero * torch.ones_like(h)) # Block update schedule if proposal == "chromatic": if edge_idx_coloring is None: edge_idx_coloring = edge_idx if mask_ij_coloring is None: mask_ij_coloring = mask_ij schedule = _color_graph(edge_idx_coloring, mask_ij_coloring) num_colors = schedule.max() + 1 num_iterations = num_colors * num_sweeps else: num_iterations = num_sweeps num_iterations_annealing = int(annealing_fraction * num_iterations) temperatures = np.linspace( temperature_init, temperature, num_iterations_annealing ).tolist() + [temperature] * (num_iterations - num_iterations_annealing) if proposal == "chromatic": _energy_proposal = lambda _S, _T: _potts_proposal_gibbs( _S, h, J, edge_idx, T=_T, penalty_func=penalty_func, differentiable_penalty=differentiable_penalty, ) elif proposal == "dlmc": _energy_proposal = lambda _S, _T: _potts_proposal_dlmc( _S, h, J, edge_idx, T=_T, penalty_func=penalty_func, differentiable_penalty=differentiable_penalty, ) else: raise NotImplementedError cumulative_sweeps = 0 if return_trajectory: S_trajectory = [] U_trajectory = [] for i, T_i in enumerate(tqdm(temperatures, desc="Potts Sampling")): # Cycle through Gibbs updates random sites to the update with fixed prob if proposal == "chromatic": mask_update = schedule.eq(i % num_colors) else: mask_update = torch.ones_like(S) > 0 if mask_mutatable is not None: mask_update = mask_update * (mask_mutatable > 0) # Compute current energy and local conditionals U, logp = _energy_proposal(S, T_i) # Propose S_new = torch.distributions.categorical.Categorical(logits=logp).sample() S_new = torch.where(mask_update, S_new, S) # Metropolis-Hastings adjusment if rejection_step: def _flux(_U, _logp, _S): logp_transition = torch.gather(_logp, -1, _S[..., None]) _logp_ij = (mask_update.float() * logp_transition[..., 0]).sum(1) flux = -_U / T_i + _logp_ij return flux U_new, logp_new = _energy_proposal(S_new, T_i) _flux_backward = _flux(U_new, logp_new, S) _flux_forward = _flux(U, logp, S_new) acc_ratio = torch.exp((_flux_backward - _flux_forward)).clamp(max=1.0) if verbose: # and i % 100 == 0: print( f"{(U_new - U).mean().item():0.2f}" f"\t{(_flux_backward - _flux_forward).mean().item():0.2f}" f"\t{acc_ratio.mean().item():0.2f}" ) u = torch.bernoulli(acc_ratio)[..., None] S = torch.where(u > 0, S_new, S) cumulative_sweeps += (u * mask_update).sum(1).mean().item() / S.shape[1] else: S = S_new cumulative_sweeps += (mask_update).float().sum(1).mean().item() / S.shape[1] if return_trajectory and i % (thin_sweeps) == 0: S_trajectory.append(S) U_trajectory.append(U) U, _ = compute_potts_energy(S, h, J, edge_idx) if verbose: print(f"Effective number of sweeps: {cumulative_sweeps}") if return_trajectory: return S, U, S_trajectory, U_trajectory else: return S, U def init_sampling_masks( logits_init: torch.Tensor, mask_sample: Optional[torch.Tensor] = None, S: Optional[torch.LongTensor] = None, ban_S: Optional[List[int]] = None, ): """Parse sampling masks and an initial sequence. Args: logits_init (torch.Tensor): Logits for sequence initialization with shape `(num_batch, num_nodes, alphabet)`. mask_sample (torch.Tensor, optional): Binary sampling mask indicating which positions are free to change with shape `(num_batch, num_nodes)` or which tokens are valid at each position with shape `(num_batch, num_nodes, alphabet)`. In the latter case, `mask_sample` will take priority over `S` except for positions in which `mask_sample` is all zero. S (torch.LongTensor optional): Initial sequence with shape `(num_batch, num_nodes)`. ban_S (list of int, optional): Optional list of alphabet indices to ban from all positions during sampling. Returns: mask_sample (torch.Tensor): Finalized position specific mask with shape `(num_batch, num_nodes, alphabet)`. S (torch.Tensor): Self-consistent initial `S` with shape `(num_batch, num_nodes)`. """ if S is None and mask_sample is not None: raise Exception("To use masked sampling, please provide an initial S") if mask_sample is None: mask_S = torch.ones_like(logits_init) elif mask_sample.dim() == 2: # Position-restricted sampling mask_sample_expand = mask_sample[..., None].expand(logits_init.shape) O_init = F.one_hot(S, logits_init.shape[-1]).float() mask_S = mask_sample_expand + (1 - mask_sample_expand) * O_init elif mask_sample.dim() == 3: O_init = F.one_hot(S, logits_init.shape[-1]).float() # Mutation-restricted sampling mask_zero = (mask_sample.sum(-1, keepdim=True) == 0).float() mask_S = ((mask_zero * O_init + mask_sample) > 0).float() else: raise NotImplementedError if ban_S is not None: mask_S[:, :, ban_S] = 0.0 mask_S_1D = (mask_S.sum(-1) > 1).float() logits_init_masked = 1000 * mask_S + logits_init S = torch.distributions.categorical.Categorical(logits=logits_init_masked).sample() return mask_S, mask_S_1D, S def _potts_proposal_gibbs( S, h, J, edge_idx, T=1.0, penalty_func=None, differentiable_penalty=True ): U, U_i = compute_potts_energy(S, h, J, edge_idx) if penalty_func is not None: if differentiable_penalty: with torch.enable_grad(): S_onehot = F.one_hot(S, h.shape[0 - 1]).float() S_onehot.requires_grad = True U_penalty = penalty_func(S_onehot) U_i_adjustment = torch.autograd.grad(U_penalty.sum(), [S_onehot])[ 0 ].detach() U_penalty = U_penalty.detach() U_i = U_i + 0.5 * U_i_adjustment else: U_penalty = penalty_func(S_onehot) U = U + U_penalty logp_i = F.log_softmax(-U_i / T, dim=-1) return U, logp_i def _potts_proposal_dlmc( S, h, J, edge_idx, T=1.0, penalty_func=None, differentiable_penalty=True, dt=0.1, autoscale=True, balancing_func="sigmoid", ): # Compute energy gap U, U_i = compute_potts_energy(S, h, J, edge_idx) U_i = U_i if penalty_func is not None: O = F.one_hot(S, h.shape[0 - 1]).float() if differentiable_penalty: with torch.enable_grad(): O.requires_grad = True U_penalty = penalty_func(O) U_i_adjustment = torch.autograd.grad(U_penalty.sum(), [O])[0].detach() U_penalty = U_penalty.detach() U_i_adjustment = U_i_adjustment - torch.gather( U_i_adjustment, -1, S[..., None] ) U_i_mutate = U_i - torch.gather(U_i, -1, S[..., None]) U_i = U_i + U_i_adjustment else: U_penalty = penalty_func(O) U = U + U_penalty # Compute local equilibrium distribution logP_j = F.log_softmax(-U_i / T, dim=-1) # Compute transition log probabilities O = F.one_hot(S, h.shape[0 - 1]).float() logP_i = torch.gather(logP_j, -1, S[..., None]) if balancing_func == "sqrt": log_Q_ij = 0.5 * (logP_j - logP_i) elif balancing_func == "sigmoid": log_Q_ij = F.logsigmoid(logP_j - logP_i) else: raise NotImplementedError rate = torch.exp(log_Q_ij - logP_j) # Compute transition probability logP_ij = logP_j + (-(-dt * rate).expm1()).log() p_flip = ((1.0 - O) * logP_ij.exp()).sum(-1, keepdim=True) # DEBUG: # flux = ((1. - O) * torch.exp(log_Q_ij)).mean([1,2], keepdim=True) # print(f" ->Flux is {flux.item():0.2f}, FlipProb is {p_flip.mean():0.2f}") logP_ii = (1.0 - p_flip).clamp(1e-5).log() logP_ij = (1.0 - O) * logP_ij + O * logP_ii return U, logP_ij