|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from typing import Optional, Union |
|
|
|
import torch |
|
import torch.nn as nn |
|
|
|
from protenix.model.modules.pairformer import PairformerStack |
|
from protenix.model.modules.primitives import LinearNoBias |
|
from protenix.model.utils import broadcast_token_to_atom, one_hot |
|
from protenix.openfold_local.model.primitives import LayerNorm |
|
from protenix.utils.torch_utils import cdist |
|
|
|
|
|
class ConfidenceHead(nn.Module): |
|
""" |
|
Implements Algorithm 31 in AF3 |
|
""" |
|
|
|
def __init__( |
|
self, |
|
n_blocks: int = 4, |
|
c_s: int = 384, |
|
c_z: int = 128, |
|
c_s_inputs: int = 449, |
|
b_pae: int = 64, |
|
b_pde: int = 64, |
|
b_plddt: int = 50, |
|
b_resolved: int = 2, |
|
max_atoms_per_token: int = 20, |
|
pairformer_dropout: float = 0.0, |
|
blocks_per_ckpt: Optional[int] = None, |
|
distance_bin_start: float = 3.25, |
|
distance_bin_end: float = 52.0, |
|
distance_bin_step: float = 1.25, |
|
stop_gradient: bool = True, |
|
) -> None: |
|
""" |
|
Args: |
|
n_blocks (int, optional): number of blocks for ConfidenceHead. Defaults to 4. |
|
c_s (int, optional): hidden dim [for single embedding]. Defaults to 384. |
|
c_z (int, optional): hidden dim [for pair embedding]. Defaults to 128. |
|
c_s_inputs (int, optional): hidden dim [for single embedding from InputFeatureEmbedder]. Defaults to 449. |
|
b_pae (int, optional): the bin number for pae. Defaults to 64. |
|
b_pde (int, optional): the bin numer for pde. Defaults to 64. |
|
b_plddt (int, optional): the bin number for plddt. Defaults to 50. |
|
b_resolved (int, optional): the bin number for resolved. Defaults to 2. |
|
max_atoms_per_token (int, optional): max atoms in a token. Defaults to 20. |
|
pairformer_dropout (float, optional): dropout ratio for Pairformer. Defaults to 0.0. |
|
blocks_per_ckpt: number of Pairformer blocks in each activation checkpoint |
|
distance_bin_start (float, optional): Start of the distance bin range. Defaults to 3.375. |
|
distance_bin_end (float, optional): End of the distance bin range. Defaults to 21.375. |
|
distance_bin_step (float, optional): Step size for the distance bins. Defaults to 1.25. |
|
stop_gradient (bool, optional): Whether to stop gradient propagation. Defaults to True. |
|
""" |
|
super(ConfidenceHead, self).__init__() |
|
self.n_blocks = n_blocks |
|
self.c_s = c_s |
|
self.c_z = c_z |
|
self.c_s_inputs = c_s_inputs |
|
self.b_pae = b_pae |
|
self.b_pde = b_pde |
|
self.b_plddt = b_plddt |
|
self.b_resolved = b_resolved |
|
self.max_atoms_per_token = max_atoms_per_token |
|
self.stop_gradient = stop_gradient |
|
self.linear_no_bias_s1 = LinearNoBias( |
|
in_features=self.c_s_inputs, out_features=self.c_z |
|
) |
|
self.linear_no_bias_s2 = LinearNoBias( |
|
in_features=self.c_s_inputs, out_features=self.c_z |
|
) |
|
lower_bins = torch.arange( |
|
distance_bin_start, distance_bin_end, distance_bin_step |
|
) |
|
upper_bins = torch.cat([lower_bins[1:], torch.tensor([1e6])]) |
|
|
|
self.lower_bins = nn.Parameter(lower_bins, requires_grad=False) |
|
self.upper_bins = nn.Parameter(upper_bins, requires_grad=False) |
|
self.num_bins = len(lower_bins) |
|
|
|
self.linear_no_bias_d = LinearNoBias( |
|
in_features=self.num_bins, out_features=self.c_z |
|
) |
|
|
|
self.pairformer_stack = PairformerStack( |
|
c_z=self.c_z, |
|
c_s=self.c_s, |
|
n_blocks=n_blocks, |
|
dropout=pairformer_dropout, |
|
blocks_per_ckpt=blocks_per_ckpt, |
|
) |
|
self.linear_no_bias_pae = LinearNoBias( |
|
in_features=self.c_z, out_features=self.b_pae |
|
) |
|
self.linear_no_bias_pde = LinearNoBias( |
|
in_features=self.c_z, out_features=self.b_pde |
|
) |
|
self.plddt_weight = nn.Parameter( |
|
data=torch.empty(size=(self.max_atoms_per_token, self.c_s, self.b_plddt)) |
|
) |
|
self.resolved_weight = nn.Parameter( |
|
data=torch.empty(size=(self.max_atoms_per_token, self.c_s, self.b_resolved)) |
|
) |
|
|
|
self.linear_no_bias_s_inputs = LinearNoBias(self.c_s_inputs, self.c_s) |
|
self.linear_no_bias_s_trunk = LinearNoBias(self.c_s, self.c_s) |
|
self.layernorm_s_trunk = LayerNorm(self.c_s) |
|
self.linear_no_bias_z_trunk = LinearNoBias(self.c_z, self.c_z) |
|
self.layernorm_z_trunk = LayerNorm(self.c_z) |
|
|
|
self.layernorm_no_bias_z_cat = nn.LayerNorm(self.c_z * 2, bias=False) |
|
self.layernorm_no_bias_s_cat = nn.LayerNorm(self.c_s * 2, bias=False) |
|
self.linear_no_bias_z_cat = LinearNoBias(self.c_z * 2, self.c_z) |
|
self.linear_no_bias_s_cat = LinearNoBias(self.c_s * 2, self.c_s) |
|
|
|
|
|
self.pae_ln = LayerNorm(self.c_z) |
|
self.pde_ln = LayerNorm(self.c_z) |
|
self.plddt_ln = LayerNorm(self.c_s) |
|
self.resolved_ln = LayerNorm(self.c_s) |
|
|
|
with torch.no_grad(): |
|
|
|
nn.init.zeros_(self.linear_no_bias_pae.weight) |
|
nn.init.zeros_(self.linear_no_bias_pde.weight) |
|
nn.init.zeros_(self.plddt_weight) |
|
nn.init.zeros_(self.resolved_weight) |
|
|
|
|
|
|
|
|
|
|
|
def forward( |
|
self, |
|
input_feature_dict: dict[str, Union[torch.Tensor, int, float, dict]], |
|
s_inputs: torch.Tensor, |
|
s_trunk: torch.Tensor, |
|
z_trunk: torch.Tensor, |
|
pair_mask: torch.Tensor, |
|
x_pred_coords: torch.Tensor, |
|
use_memory_efficient_kernel: bool = False, |
|
use_deepspeed_evo_attention: bool = False, |
|
use_lma: bool = False, |
|
inplace_safe: bool = False, |
|
chunk_size: Optional[int] = None, |
|
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: |
|
""" |
|
Args: |
|
input_feature_dict: Dictionary containing input features. |
|
s_inputs (torch.Tensor): single embedding from InputFeatureEmbedder |
|
[..., N_tokens, c_s_inputs] |
|
s_trunk (torch.Tensor): single feature embedding from PairFormer (Alg17) |
|
[..., N_tokens, c_s] |
|
z_trunk (torch.Tensor): pair feature embedding from PairFormer (Alg17) |
|
[..., N_tokens, N_tokens, c_z] |
|
pair_mask (torch.Tensor): pair mask |
|
[..., N_token, N_token] |
|
x_pred_coords (torch.Tensor): predicted coordinates |
|
[..., N_sample, N_atoms, 3] |
|
use_memory_efficient_kernel (bool, optional): Whether to use memory-efficient kernel. Defaults to False. |
|
use_deepspeed_evo_attention (bool, optional): Whether to use DeepSpeed evolutionary attention. Defaults to False. |
|
use_lma (bool, optional): Whether to use low-memory attention. Defaults to False. |
|
inplace_safe (bool, optional): Whether to use inplace operations. Defaults to False. |
|
chunk_size (Optional[int], optional): Chunk size for memory-efficient operations. Defaults to None. |
|
|
|
Returns: |
|
tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: |
|
- plddt_preds: Predicted pLDDT scores [..., N_sample, N_atom, plddt_bins]. |
|
- pae_preds: Predicted PAE scores [..., N_sample, N_token, N_token, pae_bins]. |
|
- pde_preds: Predicted PDE scores [..., N_sample, N_token, N_token, pde_bins]. |
|
- resolved_preds: Predicted resolved scores [..., N_sample, N_atom, 2]. |
|
""" |
|
|
|
if self.stop_gradient: |
|
s_inputs = s_inputs.detach() |
|
s_trunk = s_trunk.detach() |
|
z_trunk = z_trunk.detach() |
|
|
|
s_trunk = self.linear_no_bias_s_trunk(self.layernorm_s_trunk(s_trunk)) |
|
z_trunk = self.linear_no_bias_z_trunk(self.layernorm_z_trunk(z_trunk)) |
|
|
|
z_init = ( |
|
self.linear_no_bias_s1(s_inputs)[..., None, :, :] |
|
+ self.linear_no_bias_s2(s_inputs)[..., None, :] |
|
) |
|
s_init = self.linear_no_bias_s_inputs(s_inputs) |
|
s_trunk = torch.cat([s_init, s_trunk], dim=-1) |
|
z_trunk = torch.cat([z_init, z_trunk], dim=-1) |
|
|
|
s_trunk = self.linear_no_bias_s_cat(self.layernorm_no_bias_s_cat(s_trunk)) |
|
z_trunk = self.linear_no_bias_z_cat(self.layernorm_no_bias_z_cat(z_trunk)) |
|
|
|
if not self.training: |
|
del z_init |
|
torch.cuda.empty_cache() |
|
|
|
x_rep_atom_mask = input_feature_dict[ |
|
"distogram_rep_atom_mask" |
|
].bool() |
|
x_pred_rep_coords = x_pred_coords[..., x_rep_atom_mask, :] |
|
N_sample = x_pred_rep_coords.size(-3) |
|
|
|
plddt_preds, pae_preds, pde_preds, resolved_preds = [], [], [], [] |
|
for i in range(N_sample): |
|
plddt_pred, pae_pred, pde_pred, resolved_pred = ( |
|
self.memory_efficient_forward( |
|
input_feature_dict=input_feature_dict, |
|
s_trunk=s_trunk.clone() if inplace_safe else s_trunk, |
|
z_pair=z_trunk.clone() if inplace_safe else z_trunk, |
|
pair_mask=pair_mask, |
|
x_pred_rep_coords=x_pred_rep_coords[..., i, :, :], |
|
use_memory_efficient_kernel=use_memory_efficient_kernel, |
|
use_deepspeed_evo_attention=use_deepspeed_evo_attention, |
|
use_lma=use_lma, |
|
inplace_safe=inplace_safe, |
|
chunk_size=chunk_size, |
|
) |
|
) |
|
if z_trunk.shape[-2] > 2000 and (not self.training): |
|
|
|
pae_pred = pae_pred.cpu() |
|
pde_pred = pde_pred.cpu() |
|
torch.cuda.empty_cache() |
|
plddt_preds.append(plddt_pred) |
|
pae_preds.append(pae_pred) |
|
pde_preds.append(pde_pred) |
|
resolved_preds.append(resolved_pred) |
|
plddt_preds = torch.stack( |
|
plddt_preds, dim=-3 |
|
) |
|
|
|
pae_preds = torch.stack( |
|
pae_preds, dim=-4 |
|
) |
|
pde_preds = torch.stack( |
|
pde_preds, dim=-4 |
|
) |
|
resolved_preds = torch.stack( |
|
resolved_preds, dim=-3 |
|
) |
|
return plddt_preds, pae_preds, pde_preds, resolved_preds |
|
|
|
def memory_efficient_forward( |
|
self, |
|
input_feature_dict: dict[str, Union[torch.Tensor, int, float, dict]], |
|
s_trunk: torch.Tensor, |
|
z_pair: torch.Tensor, |
|
pair_mask: torch.Tensor, |
|
x_pred_rep_coords: torch.Tensor, |
|
use_memory_efficient_kernel: bool = False, |
|
use_deepspeed_evo_attention: bool = False, |
|
use_lma: bool = False, |
|
inplace_safe: bool = False, |
|
chunk_size: Optional[int] = None, |
|
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: |
|
""" |
|
Args: |
|
... |
|
x_pred_coords (torch.Tensor): predicted coordinates |
|
[..., N_atoms, 3] # Note: N_sample = 1 for avoiding CUDA OOM |
|
""" |
|
|
|
distance_pred = cdist( |
|
x_pred_rep_coords, x_pred_rep_coords |
|
) |
|
if inplace_safe: |
|
z_pair += self.linear_no_bias_d( |
|
one_hot( |
|
x=distance_pred, |
|
lower_bins=self.lower_bins, |
|
upper_bins=self.upper_bins, |
|
) |
|
) |
|
else: |
|
z_pair = z_pair + self.linear_no_bias_d( |
|
one_hot( |
|
x=distance_pred, |
|
lower_bins=self.lower_bins, |
|
upper_bins=self.upper_bins, |
|
) |
|
) |
|
|
|
s_single, z_pair = self.pairformer_stack( |
|
s_trunk, |
|
z_pair, |
|
pair_mask, |
|
use_memory_efficient_kernel=use_memory_efficient_kernel, |
|
use_deepspeed_evo_attention=use_deepspeed_evo_attention, |
|
use_lma=use_lma, |
|
inplace_safe=inplace_safe, |
|
chunk_size=chunk_size, |
|
) |
|
|
|
pae_pred = self.linear_no_bias_pae(self.pae_ln(z_pair)) |
|
pde_pred = self.linear_no_bias_pde( |
|
self.pde_ln(z_pair + z_pair.transpose(-2, -3)) |
|
) |
|
|
|
atom_to_token_idx = input_feature_dict[ |
|
"atom_to_token_idx" |
|
] |
|
atom_to_tokatom_idx = input_feature_dict[ |
|
"atom_to_tokatom_idx" |
|
] |
|
|
|
a = broadcast_token_to_atom( |
|
x_token=s_single, atom_to_token_idx=atom_to_token_idx |
|
) |
|
plddt_pred = torch.einsum( |
|
"...nc,ncb->...nb", self.plddt_ln(a), self.plddt_weight[atom_to_tokatom_idx] |
|
) |
|
resolved_pred = torch.einsum( |
|
"...nc,ncb->...nb", |
|
self.resolved_ln(a), |
|
self.resolved_weight[atom_to_tokatom_idx], |
|
) |
|
if not self.training and z_pair.shape[-2] > 2000: |
|
torch.cuda.empty_cache() |
|
return plddt_pred, pae_pred, pde_pred, resolved_pred |
|
|