Hukuna's picture
Upload 221 files
ce7bf5b verified
# Copyright Generate Biomedicines, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Layers for computing sequence complexities.
"""
import numpy as np
import torch
import torch.nn.functional as F
from chroma.constants import AA20
from chroma.layers.graph import collect_neighbors
def compositions(S: torch.Tensor, C: torch.LongTensor, w: int = 30):
"""Compute local compositions per residue.
Args:
S (torch.Tensor): Sequence tensor with shape `(num_batch, num_residues)`
(long) or `(num_batch, num_residues, num_alphabet)` (float).
C (torch.LongTensor): Chain map with shape `(num_batch, num_residues)`.
w (int, optional): Window size.
Returns:
P (torch.Tensor): Local compositions with shape
`(num_batch, num_residues - w + 1, num_alphabet)`.
N (torch.Tensor): Local counts with shape
`(num_batch, num_residues - w + 1, num_alphabet)`.
mask_P (torch.Tensor): Mask with shape
`(num_batch, num_residues - w + 1)`.
"""
device = S.device
Q = len(AA20)
mask_i = (C > 0).float()
if len(S.shape) == 2:
S = F.one_hot(S, Q)
# Build neighborhoods and masks
S_onehot = mask_i[..., None] * S
kx = torch.arange(w, device=S.device) - w // 2
edge_idx = (
torch.arange(S.shape[1], device=S.device)[None, :, None] + kx[None, None, :]
)
mask_ij = (edge_idx > 0) & (edge_idx < S.shape[1])
edge_idx = edge_idx.clamp(min=0, max=S.shape[1] - 1)
C_i = C[..., None]
C_j = collect_neighbors(C_i, edge_idx)[..., 0]
mask_ij = (mask_ij & C_j.eq(C_i) & (C_i > 0) & (C_j > 0)).float()
# Sum neighborhood composition
S_j = mask_ij[..., None] * collect_neighbors(S_onehot, edge_idx)
N = S_j.sum(2)
num_N = N.sum(-1, keepdims=True)
P = N / (num_N + 1e-5)
mask_i = ((num_N[..., 0] > 0) & (C > 0)).float()
mask_ij = mask_i[..., None] * mask_ij
return P, N, edge_idx, mask_i, mask_ij
def complexity_lcp(
S: torch.LongTensor,
C: torch.LongTensor,
w: int = 30,
entropy_min: float = 2.32,
method: str = "naive",
differentiable=True,
eps: float = 1e-5,
min_coverage=0.9,
# entropy_min: float = 2.52,
# method = "chao-shen"
) -> torch.Tensor:
"""Compute the Local Composition Perplexity metric.
Args:
S (torch.Tensor): Sequence tensor with shape `(num_batch, num_residues)`
(index tensor) or `(num_batch, num_residues, num_alphabet)`.
C (torch.LongTensor): Chain map with shape `(num_batch, num_residues)`.
w (int): Window size.
grad_pseudocount (float): Pseudocount for stabilizing entropy gradients
on backwards pass.
eps (float): Small number for numerical stability in division and logarithms.
Returns:
U (torch.Tensor): Complexities with shape `(num_batch)`.
"""
# adjust window size based on sequence length
if S.shape[1] < w:
w = S.shape[1]
P, N, edge_idx, mask_i, mask_ij = compositions(S, C, w)
# Only count windows with `min_coverage`
min_N = int(min_coverage * w)
mask_coverage = N.sum(-1) > int(min_coverage * w)
H = estimate_entropy(N, method=method)
U = mask_coverage * (torch.exp(H) - np.exp(entropy_min)).clamp(max=0).square()
# Compute entropy as a function of perturbed counts
if differentiable and len(S.shape) == 3:
# Compute how a mutation changes entropy for each neighbor
N_neighbors = collect_neighbors(N, edge_idx)
mask_coverage_j = collect_neighbors(mask_coverage[..., None], edge_idx)
N_ij = (N_neighbors - S[:, :, None, :])[..., None, :] + torch.eye(
N.shape[-1], device=N.device
)[None, None, None, ...]
N_ij = N_ij.clamp(min=0)
H_ij = estimate_entropy(N_ij, method=method)
U_ij = (torch.exp(H_ij) - np.exp(entropy_min)).clamp(max=0).square()
U_ij = mask_ij[..., None] * mask_coverage_j * U_ij
U_differentiable = (U_ij.detach() * S[:, :, None, :]).sum([-1, -2])
U = U.detach() + U_differentiable - U_differentiable.detach()
U = (mask_i * U).sum(1)
return U
def complexity_scores_lcp_t(
t,
S: torch.LongTensor,
C: torch.LongTensor,
idx: torch.LongTensor,
edge_idx_t: torch.LongTensor,
mask_ij_t: torch.Tensor,
w: int = 30,
entropy_min: float = 2.515,
eps: float = 1e-5,
method: str = "chao-shen",
) -> torch.Tensor:
"""Compute local LCP scores for autoregressive decoding."""
Q = len(AA20)
O = F.one_hot(S, Q)
O_j = collect_neighbors(O, edge_idx_t)
idx_i = idx[:, t, None]
C_i = C[:, t, None]
idx_j = collect_neighbors(idx[..., None], edge_idx_t)[..., 0]
C_j = collect_neighbors(C[..., None], edge_idx_t)[..., 0]
# Sum valid neighbor counts
is_near = (idx_i - idx_j).abs() <= w / 2
same_chain = C_i == C_j
valid_ij_t = (is_near * same_chain * (mask_ij_t > 0)).float()[..., None]
N_k = (valid_ij_t * O_j).sum(-2)
# Compute counts under all possible extensions
N_k = N_k[:, :, None, :] + torch.eye(Q, device=N_k.device)[None, None, ...]
H = estimate_entropy(N_k, method=method)
U = -(torch.exp(H) - np.exp(entropy_min)).clamp(max=0).square()
return U
def estimate_entropy(
N: torch.Tensor, method: str = "chao-shen", eps: float = 1e-11
) -> torch.Tensor:
"""Estimate entropy from counts.
See Chao, A., & Shen, T. J. (2003) for more details.
Args:
N (torch.Tensor): Tensor of counts with shape `(..., num_bins)`.
Returns:
H (torch.Tensor): Estimated entropy with shape `(...)`.
"""
N = N.float()
N_total = N.sum(-1, keepdims=True)
P = N / (N_total + eps)
if method == "chao-shen":
# Estimate coverage and adjusted frequencies
singletons = N.long().eq(1).sum(-1, keepdims=True).float()
C = 1.0 - singletons / (N_total + eps)
P_adjust = C * P
P_inclusion = (1.0 - (1.0 - P_adjust) ** N_total).clamp(min=eps)
H = -(P_adjust * torch.log(P_adjust.clamp(min=eps)) / P_inclusion).sum(-1)
elif method == "miller-maddow":
bins = (N > 0).float().sum(-1)
bias = (bins - 1) / (2 * N_total[..., 0] + eps)
H = -(P * torch.log(P + eps)).sum(-1) + bias
elif method == "laplace":
N = N.float() + 1 / N.shape[-1]
N_total = N.sum(-1, keepdims=True)
P = N / (N_total + eps)
H = -(P * torch.log(P)).sum(-1)
else:
H = -(P * torch.log(P + eps)).sum(-1)
return H