# Copyright 2024 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import Optional, Union import numpy as np import torch import torch.nn as nn from scipy.spatial.transform import Rotation from protenix.utils.scatter_utils import scatter def random_sample_watermark(pred_x, gt_augmented_x, N_sample): watermark_label = torch.randint(0, 2, (N_sample,)).float().to(pred_x.device) replace_mask_expanded = (watermark_label==0.).unsqueeze(-1).unsqueeze(-1) # (batch, 1, 1) sampled = torch.where(replace_mask_expanded, gt_augmented_x, pred_x) return sampled, watermark_label def centre_random_augmentation( x_input_coords: torch.Tensor, N_sample: int = 1, s_trans: float = 1.0, centre_only: bool = False, mask: torch.Tensor = None, eps: float = 1e-12, ) -> torch.Tensor: """Implements Algorithm 19 in AF3 Args: x_input_coords (torch.Tensor): input coords [..., N_atom, 3] N_sample (int, optional): the total number of augmentation. Defaults to 1. s_trans (float, optional): scale factor of trans. Defaults to 1.0. centre_only (bool, optional): if set true, will only perform centering without applying random translation and rotation. mask (torch.Tensor, optional): masking for the coords [..., N_atom] eps (float, optional): small number used for masked mean Returns: torch.Tensor: the Augmentation version of input coords [..., N_sample, N_atom, 3] """ N_atom = x_input_coords.size(-2) device = x_input_coords.device # Move to origin [..., N_atom, 3] if mask is None: x_input_coords = x_input_coords - torch.mean( input=x_input_coords, dim=-2, keepdim=True ) else: center = (x_input_coords * mask.unsqueeze(dim=-1)).sum(dim=-2) / ( mask.sum(dim=-1) + eps ) x_input_coords = x_input_coords - center.unsqueeze(dim=-2) # Expand to [..., N_sample, N_atom, 3] x_input_coords = expand_at_dim(x_input_coords, dim=-3, n=N_sample) if centre_only: return x_input_coords # N_augment = batch_size * N_sample N_augment = torch.numel(x_input_coords[..., 0, 0]) # Generate N_augment (rot, trans) pairs batch_size_shape = x_input_coords.shape[:-3] rot_matrix_random = ( uniform_random_rotation(N_sample=N_augment) .to(device) .reshape(*batch_size_shape, N_sample, 3, 3) ).detach() # [..., N_sample, 3, 3] trans_random = s_trans * torch.randn(size=(*batch_size_shape, N_sample, 3)).to( device ) # [..., N_sample, 3] x_augment_coords = ( rot_vec_mul( r=expand_at_dim(rot_matrix_random, dim=-3, n=N_atom), t=x_input_coords ) + trans_random[..., None, :] ) # [..., N_sample, N_atom, 3] return x_augment_coords # Comment: Rotation.random is not supported by torch.compile() def uniform_random_rotation(N_sample: int = 1) -> torch.Tensor: """Generate random rotation matrices with scipy.spatial.transform.Rotation Args: N_sample (int, optional): the total number of augmentation. Defaults to 1. Returns: torch.Tensor: N_sample rot matrics [N_sample, 3, 3] """ rotation = Rotation.random(num=N_sample) rot_matrix = torch.from_numpy(rotation.as_matrix()).float() # [N_sample, 3, 3] return rot_matrix # this is from openfold.utils.rigid_utils import rot_vec_mul def rot_vec_mul(r: torch.Tensor, t: torch.Tensor) -> torch.Tensor: """Apply rot matrix to vector Applies a rotation to a vector. Written out by hand to avoid transfer to avoid AMP downcasting. Args: r (torch.Tensor): the rotation matrices [..., 3, 3] t (torch.Tensor): the coordinate tensors [..., 3] Returns: torch.Tensor: the rotated coordinates """ x, y, z = torch.unbind(input=t, dim=-1) return torch.stack( tensors=[ r[..., 0, 0] * x + r[..., 0, 1] * y + r[..., 0, 2] * z, r[..., 1, 0] * x + r[..., 1, 1] * y + r[..., 1, 2] * z, r[..., 2, 0] * x + r[..., 2, 1] * y + r[..., 2, 2] * z, ], dim=-1, ) # from openfold.utils.tensor_utils.permute_final_dims # from openfold.utils.tensor_utils.flatten_final_dims def permute_final_dims(tensor: torch.Tensor, inds: list[int]) -> torch.Tensor: """Permute final dims of tensor Args: tensor (torch.Tensor): the input tensor [...] inds (List[int]): the dim to permute Returns: torch.Tensor: the permuted tensor """ zero_index = -1 * len(inds) first_inds = list(range(len(tensor.shape[:zero_index]))) return tensor.permute(first_inds + [zero_index + i for i in inds]) def flatten_final_dims(t: torch.Tensor, num_dims: int) -> torch.Tensor: """Flatten final dims of tensor Args: t (torch.Tensor): the input tensor [...] num_dims (int): the number of final dims to flatten Returns: torch.Tensor: the flattened tensor """ return t.reshape(shape=t.shape[:-num_dims] + (-1,)) def one_hot( x: torch.Tensor, lower_bins: torch.Tensor, upper_bins: torch.Tensor ) -> torch.Tensor: """Get one hot embedding of x from lower_bins and upper_bins Args: x (torch.Tensor): the input x [...] lower_bins (torch.Tensor): the lower bounds of bins [bins] upper_bins (torch.Tensor): the upper bounds of bins [bins] Returns: torch.Tensor: the one hot embedding of x from v_bins [..., bins] """ dgram = (x[..., None] > lower_bins) * (x[..., None] < upper_bins).float() return dgram # this is mostly from openfold.utils.torch_utils import batched_gather def batched_gather( data: torch.Tensor, inds: torch.Tensor, dim: int = 0, no_batch_dims: int = 0 ) -> torch.Tensor: """Gather data according to indices specify by inds Args: data (torch.Tensor): the input data [..., K, ...] inds (torch.Tensor): the indices for gathering data [..., N] dim (int, optional): along which dimension to gather data by inds (the dim of "K" "N"). Defaults to 0. no_batch_dims (int, optional): length of dimensions before the "dim" dimension. Defaults to 0. Returns: torch.Tensor: gathered data [..., N, ...] """ # for the naive case if len(inds.shape) == 1 and no_batch_dims == 0 and dim == 0: return data[inds] ranges = [] for i, s in enumerate(data.shape[:no_batch_dims]): r = torch.arange(s) r = r.view(*(*((1,) * i), -1, *((1,) * (len(inds.shape) - i - 1)))) ranges.append(r) remaining_dims = [slice(None) for _ in range(len(data.shape) - no_batch_dims)] remaining_dims[dim - no_batch_dims if dim >= 0 else dim] = inds ranges.extend(remaining_dims) return data[ranges] def broadcast_token_to_atom( x_token: torch.Tensor, atom_to_token_idx: torch.Tensor ) -> torch.Tensor: """Broadcast token-level embeddings to atom-level embeddings Args: x_token (torch.Tensor): token embedding [..., N_token, d] atom_to_token_idx (torch.Tensor): map atom idx to token idx [..., N_atom] or [N_atom] Returns: torch.Tensor: atom embedding [..., N_atom, d] """ if len(atom_to_token_idx.shape) == 1: # shape = [N_atom], easy index return x_token[..., atom_to_token_idx, :] else: assert atom_to_token_idx.shape[:-1] == x_token.shape[:-2] return batched_gather( data=x_token, inds=atom_to_token_idx, dim=-2, no_batch_dims=len(x_token.shape[:-2]), ) def aggregate_atom_to_token( x_atom: torch.Tensor, atom_to_token_idx: torch.Tensor, n_token: Optional[int] = None, reduce: str = "mean", ) -> torch.Tensor: """Aggregate atom embedding to obtain token embedding Args: x_atom (torch.Tensor): atom-level embedding [..., N_atom, d] atom_to_token_idx (torch.Tensor): map atom to token idx [..., N_atom] or [N_atom] n_token (int, optional): number of tokens in total. Defaults to None. reduce (str, optional): aggregation method. Defaults to "mean". Returns: torch.Tensor: token-level embedding [..., N_token, d] """ # Broadcasting in the given dim. out = scatter( src=x_atom, index=atom_to_token_idx, dim=-2, dim_size=n_token, reduce=reduce ) return out def sample_indices( n: int, device: torch.device = torch.device("cpu"), lower_bound=1, strategy: str = "random", ) -> torch.Tensor: """Sample msa indices k from uniform[1,n] Args: n (int): the msa num strategy (str): the strategy to sample msa index, random or topk Returns: torch.Tensor: the sampled indices k """ assert strategy in ["random", "topk"] sample_size = torch.randint(low=min(lower_bound, n), high=n + 1, size=(1,)).item() if strategy == "random": indices = torch.randperm(n=n, device=device)[:sample_size] if strategy == "topk": indices = torch.arange(sample_size, device=device) return indices def sample_msa_feature_dict_random_without_replacement( feat_dict: dict[str, torch.Tensor], dim_dict: dict[str, int], cutoff: int = 512, lower_bound: int = 1, strategy: str = "random", ) -> dict[str, torch.Tensor]: """Sample a dict of MSA features randomly without replacement. Args: feat_dict (dict[str, torch.Tensor]): A dict containing the MSA features. dim_dict (dict[str, int]): A dict containing the dimensions of the MSA features. cutoff (int): The maximum number of features to sample. lower_bound (int): The minimum number of features to sample. strategy (str): The sampling strategy to use. Can be either "random" or "sequential". Returns: dict[str, torch.Tensor]: A dict containing the sampled MSA features. """ msa_len = feat_dict["msa"].size(dim=dim_dict["msa"]) indices = sample_indices( n=msa_len, device=feat_dict["msa"].device, lower_bound=lower_bound, strategy=strategy, ) if cutoff > 0: indices = indices[:cutoff] msa_feat_dict = { feat_name: torch.index_select( input=feat_dict[feat_name], dim=dim, index=indices ) for feat_name, dim in dim_dict.items() } return msa_feat_dict def expand_at_dim(x: torch.Tensor, dim: int, n: int) -> torch.Tensor: """expand a tensor at specific dim by n times Args: x (torch.Tensor): input dim (int): dimension to expand n (int): expand size Returns: torch.Tensor: expanded tensor of shape [..., n, ...] """ x = x.unsqueeze(dim=dim) if dim < 0: dim = x.dim() + dim before_shape = x.shape[:dim] after_shape = x.shape[dim + 1 :] return x.expand(*before_shape, n, *after_shape) def pad_at_dim( x: torch.Tensor, dim: int, pad_length: Union[tuple[int], list[int]], value: float = 0, ) -> torch.Tensor: """pad to input x at dimension dim with length pad_length[0] to the left and and pad_length[1] to the right. Args: x (torch.Tensor): input dim (int): padding dimension pad_length (Union[Tuple[int], List[int]]): length to pad to the beginning and end. Returns: torch.Tensor: padded tensor """ n_dim = len(x.shape) if dim < 0: dim = n_dim + dim pad = (pad_length[0], pad_length[1]) if pad == (0, 0): return x k = n_dim - (dim + 1) if k > 0: pad_skip = (0, 0) * k pad = (*pad_skip, *pad) return nn.functional.pad(x, pad=pad, value=value) def reshape_at_dim( x: torch.Tensor, dim: int, target_shape: Union[tuple[int], list[int]] ) -> torch.Tensor: """reshape dimension dim of x to target_shape Args: x (torch.Tensor): input dim (int): dimension to reshape target_shape (Union[Tuple[int], List[int]]): target_shape of dim Returns: torch.Tensor: reshaped tensor """ n_dim = len(x.shape) if dim < 0: dim = n_dim + dim target_shape = tuple(target_shape) target_shape = (*x.shape[:dim], *target_shape) if dim + 1 < n_dim: target_shape = (*target_shape, *x.shape[dim + 1 :]) return x.reshape(target_shape) def move_final_dim_to_dim(x: torch.Tensor, dim: int) -> torch.Tensor: """ Move the final dimension of a tensor to a specified dimension. Args: x (torch.Tensor): Input tensor. dim (int): Target dimension to move the final dimension to. Returns: torch.Tensor: Tensor with the final dimension moved to the specified dimension. """ # permute_final_dims n_dim = len(x.shape) if dim < 0: dim = n_dim + dim if dim >= n_dim - 1: return x new_order = (n_dim - 1,) if dim > 0: new_order = tuple(range(dim)) + new_order if dim < n_dim - 1: new_order = new_order + tuple(range(dim, n_dim - 1)) return x.permute(new_order) def simple_merge_dict_list(dict_list: list[dict]) -> dict: """ Merge a list of dictionaries into a single dictionary. Args: dict_list (list[dict]): List of dictionaries to merge. Returns: dict: Merged dictionary where values are concatenated arrays. """ merged_dict = {} def add(key, value): merged_dict.setdefault(key, []) if isinstance(value, (float, int)): value = np.array([value]) elif isinstance(value, torch.Tensor): if value.dim() == 0: value = np.array([value.item()]) else: value = value.detach().cpu().numpy() elif isinstance(value, np.ndarray): pass else: raise ValueError(f"Unsupported type for metric data: {type(value)}") merged_dict[key].append(value) for x in dict_list: for k, v in x.items(): add(k, v) for k, v in merged_dict.items(): merged_dict[k] = np.concatenate(v) return merged_dict