|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from functools import partial |
|
from typing import Any, Optional, Union |
|
|
|
import torch |
|
import torch.nn as nn |
|
|
|
from protenix.model.modules.primitives import LinearNoBias, Transition |
|
from protenix.model.modules.transformer import AttentionPairBias |
|
from protenix.model.utils import sample_msa_feature_dict_random_without_replacement |
|
from protenix.openfold_local.model.dropout import DropoutRowwise |
|
from protenix.openfold_local.model.outer_product_mean import ( |
|
OuterProductMean, |
|
) |
|
from protenix.openfold_local.model.primitives import LayerNorm |
|
from protenix.openfold_local.model.triangular_attention import TriangleAttention |
|
from protenix.openfold_local.model.triangular_multiplicative_update import ( |
|
TriangleMultiplicationIncoming, |
|
) |
|
from protenix.openfold_local.model.triangular_multiplicative_update import ( |
|
TriangleMultiplicationOutgoing, |
|
) |
|
from protenix.openfold_local.utils.checkpointing import checkpoint_blocks |
|
|
|
|
|
class PairformerBlock(nn.Module): |
|
"""Implements Algorithm 17 [Line2-Line8] in AF3 |
|
c_hidden_mul is set as openfold |
|
Ref to: |
|
https://github.com/aqlaboratory/openfold/blob/feb45a521e11af1db241a33d58fb175e207f8ce0/openfold/model/evoformer.py#L123 |
|
""" |
|
|
|
def __init__( |
|
self, |
|
n_heads: int = 16, |
|
c_z: int = 128, |
|
c_s: int = 384, |
|
c_hidden_mul: int = 128, |
|
c_hidden_pair_att: int = 32, |
|
no_heads_pair: int = 4, |
|
dropout: float = 0.25, |
|
) -> None: |
|
""" |
|
Args: |
|
n_heads (int, optional): number of head [for AttentionPairBias]. Defaults to 16. |
|
c_z (int, optional): hidden dim [for pair embedding]. Defaults to 128. |
|
c_s (int, optional): hidden dim [for single embedding]. Defaults to 384. |
|
c_hidden_mul (int, optional): hidden dim [for TriangleMultiplicationOutgoing]. |
|
Defaults to 128. |
|
c_hidden_pair_att (int, optional): hidden dim [for TriangleAttention]. Defaults to 32. |
|
no_heads_pair (int, optional): number of head [for TriangleAttention]. Defaults to 4. |
|
dropout (float, optional): dropout ratio [for TriangleUpdate]. Defaults to 0.25. |
|
""" |
|
super(PairformerBlock, self).__init__() |
|
self.n_heads = n_heads |
|
self.tri_mul_out = TriangleMultiplicationOutgoing( |
|
c_z=c_z, c_hidden=c_hidden_mul |
|
) |
|
self.tri_mul_in = TriangleMultiplicationIncoming(c_z=c_z, c_hidden=c_hidden_mul) |
|
self.tri_att_start = TriangleAttention( |
|
c_in=c_z, |
|
c_hidden=c_hidden_pair_att, |
|
no_heads=no_heads_pair, |
|
) |
|
self.tri_att_end = TriangleAttention( |
|
c_in=c_z, |
|
c_hidden=c_hidden_pair_att, |
|
no_heads=no_heads_pair, |
|
) |
|
self.dropout_row = DropoutRowwise(dropout) |
|
self.pair_transition = Transition(c_in=c_z, n=4) |
|
self.c_s = c_s |
|
if self.c_s > 0: |
|
self.attention_pair_bias = AttentionPairBias( |
|
has_s=False, n_heads=n_heads, c_a=c_s, c_z=c_z |
|
) |
|
self.single_transition = Transition(c_in=c_s, n=4) |
|
|
|
def forward( |
|
self, |
|
s: Optional[torch.Tensor], |
|
z: torch.Tensor, |
|
pair_mask: torch.Tensor, |
|
use_memory_efficient_kernel: bool = False, |
|
use_deepspeed_evo_attention: bool = False, |
|
use_lma: bool = False, |
|
inplace_safe: bool = False, |
|
chunk_size: Optional[int] = None, |
|
) -> tuple[Optional[torch.Tensor], torch.Tensor]: |
|
""" |
|
Forward pass of the PairformerBlock. |
|
|
|
Args: |
|
s (Optional[torch.Tensor]): single feature |
|
[..., N_token, c_s] |
|
z (torch.Tensor): pair embedding |
|
[..., N_token, N_token, c_z] |
|
pair_mask (torch.Tensor): pair mask |
|
[..., N_token, N_token] |
|
use_memory_efficient_kernel (bool): Whether to use memory-efficient kernel. Defaults to False. |
|
use_deepspeed_evo_attention (bool): Whether to use DeepSpeed evolutionary attention. Defaults to False. |
|
use_lma (bool): Whether to use low-memory attention. Defaults to False. |
|
inplace_safe (bool): Whether it is safe to use inplace operations. Defaults to False. |
|
chunk_size (Optional[int]): Chunk size for memory-efficient operations. Defaults to None. |
|
|
|
Returns: |
|
tuple[Optional[torch.Tensor], torch.Tensor]: the update of s[Optional] and z |
|
[..., N_token, c_s] | None |
|
[..., N_token, N_token, c_z] |
|
""" |
|
if inplace_safe: |
|
z = self.tri_mul_out( |
|
z, mask=pair_mask, inplace_safe=inplace_safe, _add_with_inplace=True |
|
) |
|
z = self.tri_mul_in( |
|
z, mask=pair_mask, inplace_safe=inplace_safe, _add_with_inplace=True |
|
) |
|
z += self.tri_att_start( |
|
z, |
|
mask=pair_mask, |
|
use_memory_efficient_kernel=use_memory_efficient_kernel, |
|
use_deepspeed_evo_attention=use_deepspeed_evo_attention, |
|
use_lma=use_lma, |
|
inplace_safe=inplace_safe, |
|
chunk_size=chunk_size, |
|
) |
|
z = z.transpose(-2, -3).contiguous() |
|
z += self.tri_att_end( |
|
z, |
|
mask=pair_mask.tranpose(-1, -2) if pair_mask is not None else None, |
|
use_memory_efficient_kernel=use_memory_efficient_kernel, |
|
use_deepspeed_evo_attention=use_deepspeed_evo_attention, |
|
use_lma=use_lma, |
|
inplace_safe=inplace_safe, |
|
chunk_size=chunk_size, |
|
) |
|
z = z.transpose(-2, -3).contiguous() |
|
z += self.pair_transition(z) |
|
if self.c_s > 0: |
|
s += self.attention_pair_bias( |
|
a=s, |
|
s=None, |
|
z=z, |
|
) |
|
s += self.single_transition(s) |
|
return s, z |
|
else: |
|
tmu_update = self.tri_mul_out( |
|
z, mask=pair_mask, inplace_safe=inplace_safe, _add_with_inplace=False |
|
) |
|
z = z + self.dropout_row(tmu_update) |
|
del tmu_update |
|
tmu_update = self.tri_mul_in( |
|
z, mask=pair_mask, inplace_safe=inplace_safe, _add_with_inplace=False |
|
) |
|
z = z + self.dropout_row(tmu_update) |
|
del tmu_update |
|
z = z + self.dropout_row( |
|
self.tri_att_start( |
|
z, |
|
mask=pair_mask, |
|
use_memory_efficient_kernel=use_memory_efficient_kernel, |
|
use_deepspeed_evo_attention=use_deepspeed_evo_attention, |
|
use_lma=use_lma, |
|
inplace_safe=inplace_safe, |
|
chunk_size=chunk_size, |
|
) |
|
) |
|
z = z.transpose(-2, -3) |
|
z = z + self.dropout_row( |
|
self.tri_att_end( |
|
z, |
|
mask=pair_mask.tranpose(-1, -2) if pair_mask is not None else None, |
|
use_memory_efficient_kernel=use_memory_efficient_kernel, |
|
use_deepspeed_evo_attention=use_deepspeed_evo_attention, |
|
use_lma=use_lma, |
|
inplace_safe=inplace_safe, |
|
chunk_size=chunk_size, |
|
) |
|
) |
|
z = z.transpose(-2, -3) |
|
|
|
z = z + self.pair_transition(z) |
|
if self.c_s > 0: |
|
s = s + self.attention_pair_bias( |
|
a=s, |
|
s=None, |
|
z=z, |
|
) |
|
s = s + self.single_transition(s) |
|
return s, z |
|
|
|
|
|
class PairformerStack(nn.Module): |
|
""" |
|
Implements Algorithm 17 [PairformerStack] in AF3 |
|
""" |
|
|
|
def __init__( |
|
self, |
|
n_blocks: int = 48, |
|
n_heads: int = 16, |
|
c_z: int = 128, |
|
c_s: int = 384, |
|
dropout: float = 0.25, |
|
blocks_per_ckpt: Optional[int] = None, |
|
) -> None: |
|
""" |
|
Args: |
|
n_blocks (int, optional): number of blocks [for PairformerStack]. Defaults to 48. |
|
n_heads (int, optional): number of head [for AttentionPairBias]. Defaults to 16. |
|
c_z (int, optional): hidden dim [for pair embedding]. Defaults to 128. |
|
c_s (int, optional): hidden dim [for single embedding]. Defaults to 384. |
|
dropout (float, optional): dropout ratio. Defaults to 0.25. |
|
blocks_per_ckpt: number of Pairformer blocks in each activation checkpoint |
|
Size of each chunk. A higher value corresponds to fewer |
|
checkpoints, and trades memory for speed. If None, no checkpointing |
|
is performed. |
|
""" |
|
super(PairformerStack, self).__init__() |
|
self.n_blocks = n_blocks |
|
self.n_heads = n_heads |
|
self.blocks_per_ckpt = blocks_per_ckpt |
|
self.blocks = nn.ModuleList() |
|
|
|
for _ in range(n_blocks): |
|
block = PairformerBlock(n_heads=n_heads, c_z=c_z, c_s=c_s, dropout=dropout) |
|
self.blocks.append(block) |
|
|
|
def _prep_blocks( |
|
self, |
|
pair_mask: Optional[torch.Tensor], |
|
use_memory_efficient_kernel: bool = False, |
|
use_deepspeed_evo_attention: bool = False, |
|
use_lma: bool = False, |
|
inplace_safe: bool = False, |
|
chunk_size: Optional[int] = None, |
|
clear_cache_between_blocks: bool = False, |
|
): |
|
blocks = [ |
|
partial( |
|
b, |
|
pair_mask=pair_mask, |
|
use_memory_efficient_kernel=use_memory_efficient_kernel, |
|
use_deepspeed_evo_attention=use_deepspeed_evo_attention, |
|
use_lma=use_lma, |
|
inplace_safe=inplace_safe, |
|
chunk_size=chunk_size, |
|
) |
|
for b in self.blocks |
|
] |
|
|
|
def clear_cache(b, *args, **kwargs): |
|
torch.cuda.empty_cache() |
|
return b(*args, **kwargs) |
|
|
|
if clear_cache_between_blocks: |
|
blocks = [partial(clear_cache, b) for b in blocks] |
|
return blocks |
|
|
|
def forward( |
|
self, |
|
s: torch.Tensor, |
|
z: torch.Tensor, |
|
pair_mask: torch.Tensor, |
|
use_memory_efficient_kernel: bool = False, |
|
use_deepspeed_evo_attention: bool = False, |
|
use_lma: bool = False, |
|
inplace_safe: bool = False, |
|
chunk_size: Optional[int] = None, |
|
) -> tuple[torch.Tensor, torch.Tensor]: |
|
""" |
|
Args: |
|
s (Optional[torch.Tensor]): single feature |
|
[..., N_token, c_s] |
|
z (torch.Tensor): pair embedding |
|
[..., N_token, N_token, c_z] |
|
pair_mask (torch.Tensor): pair mask |
|
[..., N_token, N_token] |
|
use_memory_efficient_kernel (bool): Whether to use memory-efficient kernel. Defaults to False. |
|
use_deepspeed_evo_attention (bool): Whether to use DeepSpeed evolutionary attention. Defaults to False. |
|
use_lma (bool): Whether to use low-memory attention. Defaults to False. |
|
inplace_safe (bool): Whether it is safe to use inplace operations. Defaults to False. |
|
chunk_size (Optional[int]): Chunk size for memory-efficient operations. Defaults to None. |
|
|
|
Returns: |
|
tuple[torch.Tensor, torch.Tensor]: the update of s and z |
|
[..., N_token, c_s] |
|
[..., N_token, N_token, c_z] |
|
""" |
|
if z.shape[-2] > 2000 and (not self.training): |
|
clear_cache_between_blocks = True |
|
else: |
|
clear_cache_between_blocks = False |
|
blocks = self._prep_blocks( |
|
pair_mask=pair_mask, |
|
use_memory_efficient_kernel=use_memory_efficient_kernel, |
|
use_deepspeed_evo_attention=use_deepspeed_evo_attention, |
|
use_lma=use_lma, |
|
inplace_safe=inplace_safe, |
|
chunk_size=chunk_size, |
|
clear_cache_between_blocks=clear_cache_between_blocks, |
|
) |
|
|
|
blocks_per_ckpt = self.blocks_per_ckpt |
|
if not torch.is_grad_enabled(): |
|
blocks_per_ckpt = None |
|
s, z = checkpoint_blocks( |
|
blocks, |
|
args=(s, z), |
|
blocks_per_ckpt=blocks_per_ckpt, |
|
) |
|
return s, z |
|
|
|
|
|
class MSAPairWeightedAveraging(nn.Module): |
|
""" |
|
Implements Algorithm 10 [MSAPairWeightedAveraging] in AF3 |
|
""" |
|
|
|
def __init__(self, c_m: int = 64, c: int = 32, c_z: int = 128, n_heads=8) -> None: |
|
""" |
|
|
|
Args: |
|
c_m (int, optional): hidden dim [for msa embedding]. Defaults to 64. |
|
c (int, optional): hidden [for MSAPairWeightedAveraging] dim. Defaults to 32. |
|
c_z (int, optional): hidden dim [for pair embedding]. Defaults to 128. |
|
n_heads (int, optional): number of heads [for MSAPairWeightedAveraging]. Defaults to 8. |
|
""" |
|
super(MSAPairWeightedAveraging, self).__init__() |
|
self.c_m = c_m |
|
self.c = c |
|
self.n_heads = n_heads |
|
self.c_z = c_z |
|
|
|
self.layernorm_m = LayerNorm(self.c_m) |
|
self.linear_no_bias_mv = LinearNoBias( |
|
in_features=self.c_m, out_features=self.c * self.n_heads |
|
) |
|
self.layernorm_z = LayerNorm(self.c_z) |
|
self.linear_no_bias_z = LinearNoBias( |
|
in_features=self.c_z, out_features=self.n_heads |
|
) |
|
self.linear_no_bias_mg = LinearNoBias( |
|
in_features=self.c_m, out_features=self.c * self.n_heads |
|
) |
|
|
|
self.softmax_w = nn.Softmax(dim=-2) |
|
|
|
self.linear_no_bias_out = LinearNoBias( |
|
in_features=self.c * self.n_heads, out_features=self.c_m |
|
) |
|
|
|
def forward(self, m: torch.Tensor, z: torch.Tensor) -> torch.Tensor: |
|
""" |
|
Args: |
|
m (torch.Tensor): msa embedding |
|
[...,n_msa_sampled, n_token, c_m] |
|
z (torch.Tensor): pair embedding |
|
[...,n_token, n_token, c_z] |
|
Returns: |
|
torch.Tensor: updated msa embedding |
|
[...,n_msa_sampled, n_token, c_m] |
|
""" |
|
|
|
m = self.layernorm_m(m) |
|
v = self.linear_no_bias_mv(m) |
|
v = v.reshape( |
|
*v.shape[:-1], self.n_heads, self.c |
|
) |
|
b = self.linear_no_bias_z( |
|
self.layernorm_z(z) |
|
) |
|
g = torch.sigmoid( |
|
self.linear_no_bias_mg(m) |
|
) |
|
g = g.reshape( |
|
*g.shape[:-1], self.n_heads, self.c |
|
) |
|
w = self.softmax_w(b) |
|
wv = torch.einsum( |
|
"...ijh,...mjhc->...mihc", w, v |
|
) |
|
o = g * wv |
|
o = o.reshape( |
|
*o.shape[:-2], self.n_heads * self.c |
|
) |
|
m = self.linear_no_bias_out(o) |
|
return m |
|
|
|
|
|
class MSAStack(nn.Module): |
|
""" |
|
Implements MSAStack Line7-Line8 in Algorithm 8 |
|
""" |
|
|
|
def __init__(self, c_m: int = 64, c: int = 8, dropout: float = 0.15) -> None: |
|
""" |
|
Args: |
|
c_m (int, optional): hidden dim [for msa embedding]. Defaults to 64. |
|
c (int, optional): hidden [for MSAStack] dim. Defaults to 8. |
|
dropout (float, optional): dropout ratio. Defaults to 0.15. |
|
""" |
|
super(MSAStack, self).__init__() |
|
self.c = c |
|
self.msa_pair_weighted_averaging = MSAPairWeightedAveraging(c=self.c) |
|
self.dropout_row = DropoutRowwise(dropout) |
|
self.transition_m = Transition(c_in=c_m, n=4) |
|
|
|
def forward(self, m: torch.Tensor, z: torch.Tensor) -> torch.Tensor: |
|
""" |
|
Args: |
|
m (torch.Tensor): msa embedding |
|
[...,n_msa_sampled, n_token, c_m] |
|
z (torch.Tensor): pair embedding |
|
[...,n_token, n_token, c_z] |
|
|
|
Returns: |
|
torch.Tensor: updated msa embedding |
|
[...,n_msa_sampled, n_token, c_m] |
|
""" |
|
m = m + self.dropout_row(self.msa_pair_weighted_averaging(m, z)) |
|
m = m + self.transition_m(m) |
|
return m |
|
|
|
|
|
class MSABlock(nn.Module): |
|
""" |
|
Base MSA Block, Line6-Line13 in Algorithm 8 |
|
""" |
|
|
|
def __init__( |
|
self, |
|
c_m: int = 64, |
|
c_z: int = 128, |
|
c_hidden: int = 32, |
|
is_last_block: bool = False, |
|
msa_dropout: float = 0.15, |
|
pair_dropout: float = 0.25, |
|
) -> None: |
|
""" |
|
Args: |
|
c_m (int, optional): hidden dim [for msa embedding]. Defaults to 64. |
|
c_z (int, optional): hidden dim [for pair embedding]. Defaults to 128. |
|
c_hidden (int, optional): hidden dim [for MSABlock]. Defaults to 32. |
|
is_last_block (int): if this is the last block of MSAModule. Defaults to False. |
|
msa_dropout (float, optional): dropout ratio for msa block. Defaults to 0.15. |
|
pair_dropout (float, optional): dropout ratio for pair stack. Defaults to 0.25. |
|
""" |
|
super(MSABlock, self).__init__() |
|
self.c_m = c_m |
|
self.c_z = c_z |
|
self.c_hidden = c_hidden |
|
self.is_last_block = is_last_block |
|
|
|
self.outer_product_mean_msa = OuterProductMean( |
|
c_m=self.c_m, c_z=self.c_z, c_hidden=self.c_hidden |
|
) |
|
if not self.is_last_block: |
|
|
|
self.msa_stack = MSAStack(c_m=self.c_m, dropout=msa_dropout) |
|
|
|
self.pair_stack = PairformerBlock(c_z=c_z, c_s=0, dropout=pair_dropout) |
|
|
|
def forward( |
|
self, |
|
m: torch.Tensor, |
|
z: torch.Tensor, |
|
pair_mask, |
|
use_memory_efficient_kernel: bool = False, |
|
use_deepspeed_evo_attention: bool = False, |
|
use_lma: bool = False, |
|
inplace_safe: bool = False, |
|
chunk_size: Optional[int] = None, |
|
) -> tuple[torch.Tensor, torch.Tensor]: |
|
""" |
|
Args: |
|
m (torch.Tensor): msa embedding |
|
[...,n_msa_sampled, n_token, c_m] |
|
z (torch.Tensor): pair embedding |
|
[...,n_token, n_token, c_z] |
|
pair_mask (torch.Tensor): pair mask |
|
[..., N_token, N_token] |
|
use_memory_efficient_kernel (bool): Whether to use memory-efficient kernel. Defaults to False. |
|
use_deepspeed_evo_attention (bool): Whether to use DeepSpeed evolutionary attention. Defaults to False. |
|
use_lma (bool): Whether to use low-memory attention. Defaults to False. |
|
inplace_safe (bool): Whether it is safe to use inplace operations. Defaults to False. |
|
chunk_size (Optional[int]): Chunk size for memory-efficient operations. Defaults to None. |
|
|
|
Returns: |
|
tuple[torch.Tensor, torch.Tensor]: updated m z of MSABlock |
|
[...,n_msa_sampled, n_token, c_m] |
|
[...,n_token, n_token, c_z] |
|
""" |
|
|
|
z = z + self.outer_product_mean_msa( |
|
m, inplace_safe=inplace_safe, chunk_size=chunk_size |
|
) |
|
if not self.is_last_block: |
|
|
|
m = self.msa_stack(m, z) |
|
|
|
_, z = self.pair_stack( |
|
s=None, |
|
z=z, |
|
pair_mask=pair_mask, |
|
use_memory_efficient_kernel=use_memory_efficient_kernel, |
|
use_deepspeed_evo_attention=use_deepspeed_evo_attention, |
|
use_lma=use_lma, |
|
inplace_safe=inplace_safe, |
|
chunk_size=chunk_size, |
|
) |
|
|
|
if not self.is_last_block: |
|
return m, z |
|
else: |
|
return None, z |
|
|
|
|
|
class MSAModule(nn.Module): |
|
""" |
|
Implements Algorithm 8 [MSAModule] in AF3 |
|
""" |
|
|
|
def __init__( |
|
self, |
|
n_blocks: int = 4, |
|
c_m: int = 64, |
|
c_z: int = 128, |
|
c_s_inputs: int = 449, |
|
msa_dropout: float = 0.15, |
|
pair_dropout: float = 0.25, |
|
blocks_per_ckpt: Optional[int] = 1, |
|
msa_configs: dict = None, |
|
) -> None: |
|
"""Main Entry of MSAModule |
|
|
|
Args: |
|
n_blocks (int, optional): number of blocks [for MSAModule]. Defaults to 4. |
|
c_m (int, optional): hidden dim [for msa embedding]. Defaults to 64. |
|
c_z (int, optional): hidden dim [for pair embedding]. Defaults to 128. |
|
c_s_inputs (int, optional): |
|
hidden dim for single embedding from InputFeatureEmbedder. Defaults to 449. |
|
msa_dropout (float, optional): dropout ratio for msa block. Defaults to 0.15. |
|
pair_dropout (float, optional): dropout ratio for pair stack. Defaults to 0.25. |
|
blocks_per_ckpt: number of MSAModule blocks in each activation checkpoint |
|
Size of each chunk. A higher value corresponds to fewer |
|
checkpoints, and trades memory for speed. If None, no checkpointing |
|
is performed. |
|
msa_configs (dict, optional): a dictionary containing keys: |
|
"enable": whether using msa embedding. |
|
]""" |
|
super(MSAModule, self).__init__() |
|
self.n_blocks = n_blocks |
|
self.c_m = c_m |
|
self.c_s_inputs = c_s_inputs |
|
self.blocks_per_ckpt = blocks_per_ckpt |
|
self.input_feature = { |
|
"msa": 32, |
|
"has_deletion": 1, |
|
"deletion_value": 1, |
|
} |
|
|
|
self.msa_configs = { |
|
"enable": msa_configs.get("enable", False), |
|
"strategy": msa_configs.get("strategy", "random"), |
|
} |
|
if "sample_cutoff" in msa_configs: |
|
self.msa_configs["train_cutoff"] = msa_configs["sample_cutoff"].get( |
|
"train", 512 |
|
) |
|
self.msa_configs["test_cutoff"] = msa_configs["sample_cutoff"].get( |
|
"test", 16384 |
|
) |
|
if "min_size" in msa_configs: |
|
self.msa_configs["train_lowerb"] = msa_configs["min_size"].get("train", 1) |
|
self.msa_configs["test_lowerb"] = msa_configs["min_size"].get("test", 1) |
|
|
|
self.linear_no_bias_m = LinearNoBias( |
|
in_features=32 + 1 + 1, out_features=self.c_m |
|
) |
|
|
|
self.linear_no_bias_s = LinearNoBias( |
|
in_features=self.c_s_inputs, out_features=self.c_m |
|
) |
|
self.blocks = nn.ModuleList() |
|
|
|
for i in range(n_blocks): |
|
block = MSABlock( |
|
c_m=self.c_m, |
|
c_z=c_z, |
|
is_last_block=(i + 1 == n_blocks), |
|
msa_dropout=msa_dropout, |
|
pair_dropout=pair_dropout, |
|
) |
|
self.blocks.append(block) |
|
|
|
def _prep_blocks( |
|
self, |
|
pair_mask: Optional[torch.Tensor], |
|
use_memory_efficient_kernel: bool = False, |
|
use_deepspeed_evo_attention: bool = False, |
|
use_lma: bool = False, |
|
inplace_safe: bool = False, |
|
chunk_size: Optional[int] = None, |
|
clear_cache_between_blocks: bool = False, |
|
): |
|
blocks = [ |
|
partial( |
|
b, |
|
pair_mask=pair_mask, |
|
use_memory_efficient_kernel=use_memory_efficient_kernel, |
|
use_deepspeed_evo_attention=use_deepspeed_evo_attention, |
|
use_lma=use_lma, |
|
inplace_safe=inplace_safe, |
|
chunk_size=chunk_size, |
|
) |
|
for b in self.blocks |
|
] |
|
|
|
def clear_cache(b, *args, **kwargs): |
|
torch.cuda.empty_cache() |
|
return b(*args, **kwargs) |
|
|
|
if clear_cache_between_blocks: |
|
blocks = [partial(clear_cache, b) for b in blocks] |
|
return blocks |
|
|
|
def forward( |
|
self, |
|
input_feature_dict: dict[str, Any], |
|
z: torch.Tensor, |
|
s_inputs: torch.Tensor, |
|
pair_mask: torch.Tensor, |
|
use_memory_efficient_kernel: bool = False, |
|
use_deepspeed_evo_attention: bool = False, |
|
use_lma: bool = False, |
|
inplace_safe: bool = False, |
|
chunk_size: Optional[int] = None, |
|
) -> torch.Tensor: |
|
""" |
|
Args: |
|
input_feature_dict (dict[str, Any]): |
|
input meta feature dict |
|
z (torch.Tensor): pair embedding |
|
[..., N_token, N_token, c_z] |
|
s_inputs (torch.Tensor): single embedding from InputFeatureEmbedder |
|
[..., N_token, c_s_inputs] |
|
pair_mask (torch.Tensor): pair mask |
|
[..., N_token, N_token] |
|
use_memory_efficient_kernel (bool): Whether to use memory-efficient kernel. Defaults to False. |
|
use_deepspeed_evo_attention (bool): Whether to use DeepSpeed evolutionary attention. Defaults to False. |
|
use_lma (bool): Whether to use low-memory attention. Defaults to False. |
|
inplace_safe (bool): Whether it is safe to use inplace operations. Defaults to False. |
|
chunk_size (Optional[int]): Chunk size for memory-efficient operations. Defaults to None. |
|
|
|
Returns: |
|
torch.Tensor: the updated z |
|
[..., N_token, N_token, c_z] |
|
""" |
|
|
|
if self.n_blocks < 1: |
|
return z |
|
|
|
if "msa" not in input_feature_dict: |
|
return z |
|
|
|
msa_feat = sample_msa_feature_dict_random_without_replacement( |
|
feat_dict=input_feature_dict, |
|
dim_dict={feat_name: -2 for feat_name in self.input_feature}, |
|
cutoff=( |
|
self.msa_configs["train_cutoff"] |
|
if self.training |
|
else self.msa_configs["test_cutoff"] |
|
), |
|
lower_bound=( |
|
self.msa_configs["train_lowerb"] |
|
if self.training |
|
else self.msa_configs["test_lowerb"] |
|
), |
|
strategy=self.msa_configs["strategy"], |
|
) |
|
|
|
msa_feat["msa"] = torch.nn.functional.one_hot( |
|
msa_feat["msa"], |
|
num_classes=self.input_feature["msa"], |
|
) |
|
|
|
target_shape = msa_feat["msa"].shape[:-1] |
|
msa_sample = torch.cat( |
|
[ |
|
msa_feat[name].reshape(*target_shape, d) |
|
for name, d in self.input_feature.items() |
|
], |
|
dim=-1, |
|
) |
|
|
|
msa_sample = self.linear_no_bias_m(msa_sample) |
|
|
|
|
|
msa_sample = msa_sample + self.linear_no_bias_s(s_inputs) |
|
if z.shape[-2] > 2000 and (not self.training): |
|
clear_cache_between_blocks = True |
|
else: |
|
clear_cache_between_blocks = False |
|
blocks = self._prep_blocks( |
|
pair_mask=pair_mask, |
|
use_memory_efficient_kernel=use_memory_efficient_kernel, |
|
use_deepspeed_evo_attention=use_deepspeed_evo_attention, |
|
use_lma=use_lma, |
|
inplace_safe=inplace_safe, |
|
chunk_size=chunk_size, |
|
clear_cache_between_blocks=clear_cache_between_blocks, |
|
) |
|
blocks_per_ckpt = self.blocks_per_ckpt |
|
if not torch.is_grad_enabled(): |
|
blocks_per_ckpt = None |
|
msa_sample, z = checkpoint_blocks( |
|
blocks, |
|
args=(msa_sample, z), |
|
blocks_per_ckpt=blocks_per_ckpt, |
|
) |
|
if z.shape[-2] > 2000: |
|
torch.cuda.empty_cache() |
|
return z |
|
|
|
|
|
class TemplateEmbedder(nn.Module): |
|
""" |
|
Implements Algorithm 16 in AF3 |
|
""" |
|
|
|
def __init__( |
|
self, |
|
n_blocks: int = 2, |
|
c: int = 64, |
|
c_z: int = 128, |
|
dropout: float = 0.25, |
|
blocks_per_ckpt: Optional[int] = None, |
|
) -> None: |
|
""" |
|
Args: |
|
n_blocks (int, optional): number of blocks for TemplateEmbedder. Defaults to 2. |
|
c (int, optional): hidden dim of TemplateEmbedder. Defaults to 64. |
|
c_z (int, optional): hidden dim [for pair embedding]. Defaults to 128. |
|
dropout (float, optional): dropout ratio for PairformerStack. Defaults to 0.25. |
|
Note this value is missed in Algorithm 16, so we use default ratio for Pairformer |
|
blocks_per_ckpt: number of TemplateEmbedder/Pairformer blocks in each activation |
|
checkpoint Size of each chunk. A higher value corresponds to fewer |
|
checkpoints, and trades memory for speed. If None, no checkpointing |
|
is performed. |
|
""" |
|
super(TemplateEmbedder, self).__init__() |
|
self.n_blocks = n_blocks |
|
self.c = c |
|
self.c_z = c_z |
|
self.input_feature1 = { |
|
"template_distogram": 39, |
|
"b_template_backbone_frame_mask": 1, |
|
"template_unit_vector": 3, |
|
"b_template_pseudo_beta_mask": 1, |
|
} |
|
self.input_feature2 = { |
|
"template_restype_i": 32, |
|
"template_restype_j": 32, |
|
} |
|
self.distogram = {"max_bin": 50.75, "min_bin": 3.25, "no_bins": 39} |
|
self.inf = 100000.0 |
|
|
|
self.linear_no_bias_z = LinearNoBias(in_features=self.c_z, out_features=self.c) |
|
self.layernorm_z = LayerNorm(self.c_z) |
|
self.linear_no_bias_a = LinearNoBias( |
|
in_features=sum(self.input_feature1.values()) |
|
+ sum(self.input_feature2.values()), |
|
out_features=self.c, |
|
) |
|
self.pairformer_stack = PairformerStack( |
|
c_s=0, |
|
c_z=c, |
|
n_blocks=self.n_blocks, |
|
dropout=dropout, |
|
blocks_per_ckpt=blocks_per_ckpt, |
|
) |
|
self.layernorm_v = LayerNorm(self.c) |
|
self.linear_no_bias_u = LinearNoBias(in_features=self.c, out_features=self.c_z) |
|
|
|
def forward( |
|
self, |
|
input_feature_dict: dict[str, Any], |
|
z: torch.Tensor, |
|
pair_mask: torch.Tensor = None, |
|
use_memory_efficient_kernel: bool = False, |
|
use_deepspeed_evo_attention: bool = False, |
|
use_lma: bool = False, |
|
inplace_safe: bool = False, |
|
chunk_size: Optional[int] = None, |
|
) -> torch.Tensor: |
|
""" |
|
Args: |
|
input_feature_dict (dict[str, Any]): input feature dict |
|
z (torch.Tensor): pair embedding |
|
[..., N_token, N_token, c_z] |
|
pair_mask (torch.Tensor, optional): pair masking. Default to None. |
|
[..., N_token, N_token] |
|
|
|
Returns: |
|
torch.Tensor: the template feature |
|
[..., N_token, N_token, c_z] |
|
""" |
|
|
|
if "template_restype" not in input_feature_dict or self.n_blocks < 1: |
|
return 0 |
|
return 0 |
|
|