# Copyright Generate Biomedicines, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Layers for computing sequence complexities. """ import numpy as np import torch import torch.nn.functional as F from chroma.constants import AA20 from chroma.layers.graph import collect_neighbors def compositions(S: torch.Tensor, C: torch.LongTensor, w: int = 30): """Compute local compositions per residue. Args: S (torch.Tensor): Sequence tensor with shape `(num_batch, num_residues)` (long) or `(num_batch, num_residues, num_alphabet)` (float). C (torch.LongTensor): Chain map with shape `(num_batch, num_residues)`. w (int, optional): Window size. Returns: P (torch.Tensor): Local compositions with shape `(num_batch, num_residues - w + 1, num_alphabet)`. N (torch.Tensor): Local counts with shape `(num_batch, num_residues - w + 1, num_alphabet)`. mask_P (torch.Tensor): Mask with shape `(num_batch, num_residues - w + 1)`. """ device = S.device Q = len(AA20) mask_i = (C > 0).float() if len(S.shape) == 2: S = F.one_hot(S, Q) # Build neighborhoods and masks S_onehot = mask_i[..., None] * S kx = torch.arange(w, device=S.device) - w // 2 edge_idx = ( torch.arange(S.shape[1], device=S.device)[None, :, None] + kx[None, None, :] ) mask_ij = (edge_idx > 0) & (edge_idx < S.shape[1]) edge_idx = edge_idx.clamp(min=0, max=S.shape[1] - 1) C_i = C[..., None] C_j = collect_neighbors(C_i, edge_idx)[..., 0] mask_ij = (mask_ij & C_j.eq(C_i) & (C_i > 0) & (C_j > 0)).float() # Sum neighborhood composition S_j = mask_ij[..., None] * collect_neighbors(S_onehot, edge_idx) N = S_j.sum(2) num_N = N.sum(-1, keepdims=True) P = N / (num_N + 1e-5) mask_i = ((num_N[..., 0] > 0) & (C > 0)).float() mask_ij = mask_i[..., None] * mask_ij return P, N, edge_idx, mask_i, mask_ij def complexity_lcp( S: torch.LongTensor, C: torch.LongTensor, w: int = 30, entropy_min: float = 2.32, method: str = "naive", differentiable=True, eps: float = 1e-5, min_coverage=0.9, # entropy_min: float = 2.52, # method = "chao-shen" ) -> torch.Tensor: """Compute the Local Composition Perplexity metric. Args: S (torch.Tensor): Sequence tensor with shape `(num_batch, num_residues)` (index tensor) or `(num_batch, num_residues, num_alphabet)`. C (torch.LongTensor): Chain map with shape `(num_batch, num_residues)`. w (int): Window size. grad_pseudocount (float): Pseudocount for stabilizing entropy gradients on backwards pass. eps (float): Small number for numerical stability in division and logarithms. Returns: U (torch.Tensor): Complexities with shape `(num_batch)`. """ # adjust window size based on sequence length if S.shape[1] < w: w = S.shape[1] P, N, edge_idx, mask_i, mask_ij = compositions(S, C, w) # Only count windows with `min_coverage` min_N = int(min_coverage * w) mask_coverage = N.sum(-1) > int(min_coverage * w) H = estimate_entropy(N, method=method) U = mask_coverage * (torch.exp(H) - np.exp(entropy_min)).clamp(max=0).square() # Compute entropy as a function of perturbed counts if differentiable and len(S.shape) == 3: # Compute how a mutation changes entropy for each neighbor N_neighbors = collect_neighbors(N, edge_idx) mask_coverage_j = collect_neighbors(mask_coverage[..., None], edge_idx) N_ij = (N_neighbors - S[:, :, None, :])[..., None, :] + torch.eye( N.shape[-1], device=N.device )[None, None, None, ...] N_ij = N_ij.clamp(min=0) H_ij = estimate_entropy(N_ij, method=method) U_ij = (torch.exp(H_ij) - np.exp(entropy_min)).clamp(max=0).square() U_ij = mask_ij[..., None] * mask_coverage_j * U_ij U_differentiable = (U_ij.detach() * S[:, :, None, :]).sum([-1, -2]) U = U.detach() + U_differentiable - U_differentiable.detach() U = (mask_i * U).sum(1) return U def complexity_scores_lcp_t( t, S: torch.LongTensor, C: torch.LongTensor, idx: torch.LongTensor, edge_idx_t: torch.LongTensor, mask_ij_t: torch.Tensor, w: int = 30, entropy_min: float = 2.515, eps: float = 1e-5, method: str = "chao-shen", ) -> torch.Tensor: """Compute local LCP scores for autoregressive decoding.""" Q = len(AA20) O = F.one_hot(S, Q) O_j = collect_neighbors(O, edge_idx_t) idx_i = idx[:, t, None] C_i = C[:, t, None] idx_j = collect_neighbors(idx[..., None], edge_idx_t)[..., 0] C_j = collect_neighbors(C[..., None], edge_idx_t)[..., 0] # Sum valid neighbor counts is_near = (idx_i - idx_j).abs() <= w / 2 same_chain = C_i == C_j valid_ij_t = (is_near * same_chain * (mask_ij_t > 0)).float()[..., None] N_k = (valid_ij_t * O_j).sum(-2) # Compute counts under all possible extensions N_k = N_k[:, :, None, :] + torch.eye(Q, device=N_k.device)[None, None, ...] H = estimate_entropy(N_k, method=method) U = -(torch.exp(H) - np.exp(entropy_min)).clamp(max=0).square() return U def estimate_entropy( N: torch.Tensor, method: str = "chao-shen", eps: float = 1e-11 ) -> torch.Tensor: """Estimate entropy from counts. See Chao, A., & Shen, T. J. (2003) for more details. Args: N (torch.Tensor): Tensor of counts with shape `(..., num_bins)`. Returns: H (torch.Tensor): Estimated entropy with shape `(...)`. """ N = N.float() N_total = N.sum(-1, keepdims=True) P = N / (N_total + eps) if method == "chao-shen": # Estimate coverage and adjusted frequencies singletons = N.long().eq(1).sum(-1, keepdims=True).float() C = 1.0 - singletons / (N_total + eps) P_adjust = C * P P_inclusion = (1.0 - (1.0 - P_adjust) ** N_total).clamp(min=eps) H = -(P_adjust * torch.log(P_adjust.clamp(min=eps)) / P_inclusion).sum(-1) elif method == "miller-maddow": bins = (N > 0).float().sum(-1) bias = (bins - 1) / (2 * N_total[..., 0] + eps) H = -(P * torch.log(P + eps)).sum(-1) + bias elif method == "laplace": N = N.float() + 1 / N.shape[-1] N_total = N.sum(-1, keepdims=True) P = N / (N_total + eps) H = -(P * torch.log(P)).sum(-1) else: H = -(P * torch.log(P + eps)).sum(-1) return H