|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from functools import partial |
|
from typing import Optional, Union |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
from protenix.model.modules.primitives import ( |
|
AdaptiveLayerNorm, |
|
Attention, |
|
BiasInitLinear, |
|
LinearNoBias, |
|
broadcast_token_to_local_atom_pair, |
|
rearrange_qk_to_dense_trunk, |
|
) |
|
from protenix.model.utils import ( |
|
aggregate_atom_to_token, |
|
broadcast_token_to_atom, |
|
permute_final_dims, |
|
) |
|
from protenix.openfold_local.model.primitives import LayerNorm |
|
from protenix.openfold_local.utils.checkpointing import checkpoint_blocks |
|
|
|
|
|
class AttentionPairBias(nn.Module): |
|
""" |
|
Implements Algorithm 24 in AF3 |
|
""" |
|
|
|
def __init__( |
|
self, |
|
has_s: bool = True, |
|
n_heads: int = 16, |
|
c_a: int = 768, |
|
c_s: int = 384, |
|
c_z: int = 128, |
|
biasinit: float = -2.0, |
|
) -> None: |
|
""" |
|
Args: |
|
has_s (bool, optional): whether s is None as stated in Algorithm 24 Line1. Defaults to True. |
|
n_heads (int, optional): number of attention-like head in AttentionPairBias. Defaults to 16. |
|
c_a (int, optional): the embedding dim of a(single feature aggregated atom info). Defaults to 768. |
|
c_s (int, optional): hidden dim [for single embedding]. Defaults to 384. |
|
c_z (int, optional): hidden dim [for pair embedding]. Defaults to 128. |
|
biasinit (float, optional): biasinit for BiasInitLinear. Defaults to -2.0. |
|
""" |
|
super(AttentionPairBias, self).__init__() |
|
assert c_a % n_heads == 0 |
|
self.n_heads = n_heads |
|
self.has_s = has_s |
|
if has_s: |
|
|
|
self.layernorm_a = AdaptiveLayerNorm(c_a=c_a, c_s=c_s) |
|
|
|
self.linear_a_last = BiasInitLinear( |
|
in_features=c_s, out_features=c_a, bias=True, biasinit=biasinit |
|
) |
|
else: |
|
self.layernorm_a = LayerNorm(c_a) |
|
|
|
self.local_attention_method = "local_cross_attention" |
|
self.attention = Attention( |
|
c_q=c_a, |
|
c_k=c_a, |
|
c_v=c_a, |
|
c_hidden=c_a // n_heads, |
|
num_heads=n_heads, |
|
gating=True, |
|
q_linear_bias=True, |
|
local_attention_method=self.local_attention_method, |
|
) |
|
self.layernorm_z = LayerNorm(c_z) |
|
|
|
self.linear_nobias_z = LinearNoBias(in_features=c_z, out_features=n_heads) |
|
|
|
def glorot_init(self): |
|
nn.init.xavier_uniform_(self.attention.linear_q.weight) |
|
nn.init.xavier_uniform_(self.attention.linear_k.weight) |
|
nn.init.xavier_uniform_(self.attention.linear_v.weight) |
|
nn.init.zeros_(self.attention.linear_q.bias) |
|
|
|
def local_multihead_attention( |
|
self, |
|
a: torch.Tensor, |
|
s: torch.Tensor, |
|
z: torch.Tensor, |
|
n_queries: int = 32, |
|
n_keys: int = 128, |
|
inplace_safe: bool = False, |
|
chunk_size: Optional[int] = None, |
|
) -> torch.Tensor: |
|
"""Used by Algorithm 24, with beta_ij being the local mask. Used in AtomTransformer. |
|
|
|
Args: |
|
a (torch.Tensor): atom embedding |
|
[..., N_atom, c_a] |
|
s (torch.Tensor): atom embedding |
|
[..., N_atom, c_s] |
|
z (torch.Tensor): atom-atom pair embedding, in trunked dense shape. Used for computing pair bias. |
|
[..., n_blocks, n_queries, n_keys, c_z] |
|
n_queries (int, optional): local window size of query tensor. Defaults to 32. |
|
n_keys (int, optional): local window size of key tensor. Defaults to 128. |
|
inplace_safe (bool): Whether it is safe to use inplace operations. Defaults to False. |
|
chunk_size (Optional[int]): Chunk size for memory-efficient operations. Defaults to None. |
|
|
|
Returns: |
|
torch.Tensor: the updated a from AttentionPairBias |
|
[..., N_atom, c_a] |
|
""" |
|
|
|
assert n_queries == z.size(-3) |
|
assert n_keys == z.size(-2) |
|
assert len(z.shape) == len(a.shape) + 2 |
|
|
|
|
|
bias = self.linear_nobias_z( |
|
self.layernorm_z(z) |
|
) |
|
bias = permute_final_dims( |
|
bias, [3, 0, 1, 2] |
|
) |
|
|
|
|
|
a = self.attention( |
|
q_x=a, |
|
kv_x=a, |
|
trunked_attn_bias=bias, |
|
n_queries=n_queries, |
|
n_keys=n_keys, |
|
inplace_safe=inplace_safe, |
|
chunk_size=chunk_size, |
|
) |
|
return a |
|
|
|
def standard_multihead_attention( |
|
self, |
|
a: torch.Tensor, |
|
s: torch.Tensor, |
|
z: torch.Tensor, |
|
inplace_safe: bool = False, |
|
) -> torch.Tensor: |
|
"""Used by Algorithm 7/20 |
|
|
|
Args: |
|
a (torch.Tensor): the single feature aggregate per-atom representation |
|
[..., N_token, c_a] |
|
s (torch.Tensor): single embedding |
|
[..., N_token, c_s] |
|
z (torch.Tensor): pair embedding, used for computing pair bias. |
|
[..., N_token, N_token, c_z] |
|
inplace_safe (bool): Whether it is safe to use inplace operations. Defaults to False. |
|
|
|
Returns: |
|
torch.Tensor: the updated a from AttentionPairBias |
|
[..., N_token, c_a] |
|
""" |
|
|
|
|
|
bias = self.linear_nobias_z(self.layernorm_z(z)) |
|
bias = permute_final_dims(bias, [2, 0, 1]) |
|
|
|
|
|
a = self.attention(q_x=a, kv_x=a, attn_bias=bias, inplace_safe=inplace_safe) |
|
|
|
return a |
|
|
|
def forward( |
|
self, |
|
a: torch.Tensor, |
|
s: torch.Tensor, |
|
z: torch.Tensor, |
|
n_queries: Optional[int] = None, |
|
n_keys: Optional[int] = None, |
|
inplace_safe: bool = False, |
|
chunk_size: Optional[int] = None, |
|
) -> torch.Tensor: |
|
"""Details are given in local_forward and standard_forward""" |
|
|
|
if self.has_s: |
|
a = self.layernorm_a(a=a, s=s) |
|
else: |
|
a = self.layernorm_a(a) |
|
|
|
|
|
if n_queries and n_keys: |
|
a = self.local_multihead_attention( |
|
a, |
|
s, |
|
z, |
|
n_queries, |
|
n_keys, |
|
inplace_safe=inplace_safe, |
|
chunk_size=chunk_size, |
|
) |
|
else: |
|
a = self.standard_multihead_attention(a, s, z, inplace_safe=inplace_safe) |
|
|
|
|
|
if self.has_s: |
|
if inplace_safe: |
|
a *= torch.sigmoid(self.linear_a_last(s)) |
|
else: |
|
a = torch.sigmoid(self.linear_a_last(s)) * a |
|
|
|
return a |
|
|
|
|
|
class DiffusionTransformerBlock(nn.Module): |
|
""" |
|
Implements Algorithm 23[Line2-Line3] in AF3 |
|
""" |
|
|
|
def __init__( |
|
self, |
|
c_a: int, |
|
c_s: int, |
|
c_z: int, |
|
n_heads: int, |
|
biasinit: float = -2.0, |
|
) -> None: |
|
""" |
|
Args: |
|
c_a (int, optional): single embedding dimension. |
|
c_s (int, optional): single embedding dimension. |
|
c_z (int, optional): pair embedding dimension. |
|
n_heads (int, optional): number of heads for DiffusionTransformerBlock. |
|
""" |
|
super(DiffusionTransformerBlock, self).__init__() |
|
self.n_heads = n_heads |
|
self.c_a = c_a |
|
self.c_s = c_s |
|
self.c_z = c_z |
|
self.attention_pair_bias = AttentionPairBias( |
|
has_s=True, n_heads=n_heads, c_a=c_a, c_s=c_s, c_z=c_z, biasinit=biasinit |
|
) |
|
self.conditioned_transition_block = ConditionedTransitionBlock( |
|
n=2, c_a=c_a, c_s=c_s, biasinit=biasinit |
|
) |
|
|
|
def forward( |
|
self, |
|
a: torch.Tensor, |
|
s: torch.Tensor, |
|
z: torch.Tensor, |
|
n_queries: Optional[int] = None, |
|
n_keys: Optional[int] = None, |
|
inplace_safe: bool = False, |
|
chunk_size: Optional[int] = None, |
|
) -> torch.Tensor: |
|
""" |
|
Args: |
|
a (torch.Tensor): the single feature aggregate per-atom representation |
|
[..., N, c_a] |
|
s (torch.Tensor): single embedding |
|
[..., N, c_s] |
|
z (torch.Tensor): pair embedding |
|
[..., N, N, c_z] or [..., n_block, n_queries, n_keys, c_z] |
|
n_queries (int, optional): local window size of query tensor. If not None, will perform local attention. Defaults to None. |
|
n_keys (int, optional): local window size of key tensor. Defaults to None. |
|
inplace_safe (bool): Whether it is safe to use inplace operations. Defaults to False. |
|
chunk_size (Optional[int]): Chunk size for memory-efficient operations. Defaults to None. |
|
|
|
Returns: |
|
torch.Tensor: the output of DiffusionTransformerBlock |
|
[..., N, c_a] |
|
""" |
|
attn_out = self.attention_pair_bias( |
|
a=a, |
|
s=s, |
|
z=z, |
|
n_queries=n_queries, |
|
n_keys=n_keys, |
|
inplace_safe=inplace_safe, |
|
chunk_size=chunk_size, |
|
) |
|
if inplace_safe: |
|
attn_out += a |
|
else: |
|
attn_out = attn_out + a |
|
ff_out = self.conditioned_transition_block(a=attn_out, s=s) |
|
out_a = ff_out + attn_out |
|
|
|
return out_a, s, z |
|
|
|
|
|
class DiffusionTransformer(nn.Module): |
|
""" |
|
Implements Algorithm 23 in AF3 |
|
""" |
|
|
|
def __init__( |
|
self, |
|
c_a: int, |
|
c_s: int, |
|
c_z: int, |
|
n_blocks: int, |
|
n_heads: int, |
|
blocks_per_ckpt: Optional[int] = None, |
|
) -> None: |
|
""" |
|
Args: |
|
c_a (int): single embedding dimension. |
|
c_s (int): single embedding dimension. |
|
c_z (int): pair embedding dimension. |
|
n_blocks (int): number of blocks in DiffusionTransformer. |
|
n_heads (int): number of heads in attention. |
|
blocks_per_ckpt: number of DiffusionTransformer blocks in each activation checkpoint |
|
""" |
|
super(DiffusionTransformer, self).__init__() |
|
self.n_blocks = n_blocks |
|
self.n_heads = n_heads |
|
self.c_a = c_a |
|
self.c_s = c_s |
|
self.c_z = c_z |
|
self.blocks_per_ckpt = blocks_per_ckpt |
|
|
|
self.blocks = nn.ModuleList() |
|
for _ in range(n_blocks): |
|
block = DiffusionTransformerBlock( |
|
n_heads=n_heads, c_a=c_a, c_s=c_s, c_z=c_z |
|
) |
|
self.blocks.append(block) |
|
|
|
def _prep_blocks( |
|
self, |
|
n_queries: Optional[int] = None, |
|
n_keys: Optional[int] = None, |
|
inplace_safe: bool = False, |
|
chunk_size: Optional[int] = None, |
|
clear_cache_between_blocks: bool = False, |
|
): |
|
blocks = [ |
|
partial( |
|
b, |
|
n_queries=n_queries, |
|
n_keys=n_keys, |
|
inplace_safe=inplace_safe, |
|
chunk_size=chunk_size, |
|
) |
|
for b in self.blocks |
|
] |
|
|
|
def clear_cache(b, *args, **kwargs): |
|
|
|
return b(*args, **kwargs) |
|
|
|
if clear_cache_between_blocks: |
|
blocks = [partial(clear_cache, b) for b in blocks] |
|
return blocks |
|
|
|
def forward( |
|
self, |
|
a: torch.Tensor, |
|
s: torch.Tensor, |
|
z: torch.Tensor, |
|
n_queries: Optional[int] = None, |
|
n_keys: Optional[int] = None, |
|
inplace_safe: bool = False, |
|
chunk_size: Optional[int] = None, |
|
) -> torch.Tensor: |
|
""" |
|
Args: |
|
a (torch.Tensor): the single feature aggregate per-atom representation |
|
[..., N, c_a] |
|
s (torch.Tensor): single embedding |
|
[..., N, c_s] |
|
z (torch.Tensor): pair embedding |
|
[..., N, N, c_z] |
|
n_queries (int, optional): local window size of query tensor. If not None, will perform local attention. Defaults to None. |
|
n_keys (int, optional): local window size of key tensor. Defaults to None. |
|
|
|
Returns: |
|
torch.Tensor: the output of DiffusionTransformer |
|
[..., N, c_a] |
|
""" |
|
if z.shape[-2] > 2000 and (not self.training): |
|
clear_cache_between_blocks = True |
|
else: |
|
clear_cache_between_blocks = False |
|
blocks = self._prep_blocks( |
|
n_queries=n_queries, |
|
n_keys=n_keys, |
|
inplace_safe=inplace_safe, |
|
chunk_size=chunk_size, |
|
clear_cache_between_blocks=clear_cache_between_blocks, |
|
) |
|
blocks_per_ckpt = self.blocks_per_ckpt |
|
if not torch.is_grad_enabled(): |
|
blocks_per_ckpt = None |
|
a, s, z = checkpoint_blocks( |
|
blocks, args=(a, s, z), blocks_per_ckpt=blocks_per_ckpt |
|
) |
|
del s, z |
|
return a |
|
|
|
|
|
class AtomTransformer(nn.Module): |
|
""" |
|
Implements Algorithm 7 in AF3 |
|
""" |
|
|
|
def __init__( |
|
self, |
|
c_atom: int = 128, |
|
c_atompair: int = 16, |
|
n_blocks: int = 3, |
|
n_heads: int = 4, |
|
n_queries: int = 32, |
|
n_keys: int = 128, |
|
blocks_per_ckpt: Optional[int] = None, |
|
) -> None: |
|
"""Performs local transformer among atom embeddings, with bias predicted from atom pair embeddings |
|
|
|
Args: |
|
c_atom int: embedding dim for atom feature. Defaults to 128. |
|
c_atompair int: embedding dim for atompair feature. Defaults to 16. |
|
n_blocks (int, optional): number of block in AtomTransformer. Defaults to 3. |
|
n_heads (int, optional): nubmer of heads in attention. Defaults to 4. |
|
n_queries (int, optional): local window size of query tensor. If not None, will perform local attention. Defaults to 32. |
|
n_keys (int, optional): local window size of key tensor. Defaults to 128. |
|
blocks_per_ckpt: number of AtomTransformer/DiffusionTransformer blocks in each activation checkpoint |
|
Size of each chunk. A higher value corresponds to fewer |
|
checkpoints, and trades memory for speed. If None, no checkpointing |
|
is performed. |
|
""" |
|
super(AtomTransformer, self).__init__() |
|
self.n_blocks = n_blocks |
|
self.n_heads = n_heads |
|
self.n_queries = n_queries |
|
self.n_keys = n_keys |
|
self.c_atom = c_atom |
|
self.c_atompair = c_atompair |
|
self.diffusion_transformer = DiffusionTransformer( |
|
n_blocks=n_blocks, |
|
n_heads=n_heads, |
|
c_a=c_atom, |
|
c_s=c_atom, |
|
c_z=c_atompair, |
|
blocks_per_ckpt=blocks_per_ckpt, |
|
) |
|
|
|
def forward( |
|
self, |
|
q: torch.Tensor, |
|
c: torch.Tensor, |
|
p: torch.Tensor, |
|
inplace_safe: bool = False, |
|
chunk_size: Optional[int] = None, |
|
) -> torch.Tensor: |
|
""" |
|
Args: |
|
q (torch.Tensor): atom single embedding |
|
[..., N_atom, c_atom] |
|
c (torch.Tensor): atom single embedding |
|
[..., N_atom, c_atom] |
|
p (torch.Tensor): atompair embedding in dense block shape. |
|
[..., n_blocks, n_queries, n_keys, c_atompair] |
|
|
|
Returns: |
|
torch.Tensor: the output of AtomTransformer |
|
[..., N_atom, c_atom] |
|
""" |
|
n_blocks, n_queries, n_keys = p.shape[-4:-1] |
|
|
|
assert n_queries == self.n_queries |
|
assert n_keys == self.n_keys |
|
return self.diffusion_transformer( |
|
a=q, |
|
s=c, |
|
z=p, |
|
n_queries=self.n_queries, |
|
n_keys=self.n_keys, |
|
inplace_safe=inplace_safe, |
|
chunk_size=chunk_size, |
|
) |
|
|
|
|
|
class ConditionedTransitionBlock(nn.Module): |
|
""" |
|
Implements Algorithm 25 in AF3 |
|
""" |
|
|
|
def __init__(self, c_a: int, c_s: int, n: int = 2, biasinit: float = -2.0) -> None: |
|
""" |
|
Args: |
|
c_a (int, optional): single embedding dim (single feature aggregated atom info). |
|
c_s (int, optional): single embedding dim. |
|
n (int, optional): channel scale factor. Defaults to 2. |
|
""" |
|
super(ConditionedTransitionBlock, self).__init__() |
|
self.c_a = c_a |
|
self.c_s = c_s |
|
self.n = n |
|
self.adaln = AdaptiveLayerNorm(c_a=c_a, c_s=c_s) |
|
self.linear_nobias_a1 = LinearNoBias(in_features=c_a, out_features=n * c_a) |
|
self.linear_nobias_a2 = LinearNoBias(in_features=c_a, out_features=n * c_a) |
|
self.linear_nobias_b = LinearNoBias(in_features=n * c_a, out_features=c_a) |
|
self.linear_s = BiasInitLinear( |
|
in_features=c_s, out_features=c_a, bias=True, biasinit=biasinit |
|
) |
|
|
|
def forward(self, a: torch.Tensor, s: torch.Tensor) -> torch.Tensor: |
|
""" |
|
Args: |
|
a (torch.Tensor): the single feature aggregate per-atom representation |
|
[..., N, c_a] |
|
s (torch.Tensor): single embedding |
|
[..., N, c_s] |
|
|
|
Returns: |
|
torch.Tensor: the updated a from ConditionedTransitionBlock |
|
[..., N, c_a] |
|
""" |
|
a = self.adaln(a, s) |
|
b = F.silu((self.linear_nobias_a1(a))) * self.linear_nobias_a2(a) |
|
|
|
a = torch.sigmoid(self.linear_s(s)) * self.linear_nobias_b(b) |
|
return a |
|
|
|
|
|
class AtomAttentionEncoder(nn.Module): |
|
""" |
|
Implements Algorithm 5 in AF3 |
|
""" |
|
|
|
def __init__( |
|
self, |
|
has_coords: bool, |
|
c_token: int, |
|
c_atom: int = 128, |
|
c_atompair: int = 16, |
|
c_s: int = 384, |
|
c_z: int = 128, |
|
n_blocks: int = 3, |
|
n_heads: int = 4, |
|
n_queries: int = 32, |
|
n_keys: int = 128, |
|
blocks_per_ckpt: Optional[int] = None, |
|
) -> None: |
|
""" |
|
Args: |
|
has_coords (bool): whether the module input will contains coordinates (r_l). |
|
c_token (int): token embedding dim. |
|
c_atom (int, optional): atom embedding dim. Defaults to 128. |
|
c_atompair (int, optional): atompair embedding dim. Defaults to 16. |
|
c_s (int, optional): single embedding dim. Defaults to 384. |
|
c_z (int, optional): pair embedding dim. Defaults to 128. |
|
n_blocks (int, optional): number of blocks in AtomTransformer. Defaults to 3. |
|
n_heads (int, optionall): number of heads in AtomTransformer. Defaults to 4. |
|
blocks_per_ckpt: number of AtomAttentionEncoder/AtomTransformer blocks in each activation checkpoint |
|
Size of each chunk. A higher value corresponds to fewer |
|
checkpoints, and trades memory for speed. If None, no checkpointing |
|
is performed. |
|
""" |
|
super(AtomAttentionEncoder, self).__init__() |
|
self.has_coords = has_coords |
|
self.c_atom = c_atom |
|
self.c_atompair = c_atompair |
|
self.c_token = c_token |
|
self.c_s = c_s |
|
self.c_z = c_z |
|
self.n_queries = n_queries |
|
self.n_keys = n_keys |
|
self.local_attention_method = "local_cross_attention" |
|
|
|
self.input_feature = { |
|
"ref_pos": 3, |
|
"ref_charge": 1, |
|
"ref_mask": 1, |
|
"ref_element": 128, |
|
"ref_atom_name_chars": 4 * 64, |
|
} |
|
self.linear_no_bias_f = LinearNoBias( |
|
in_features=sum(self.input_feature.values()), out_features=self.c_atom |
|
) |
|
self.linear_no_bias_d = LinearNoBias( |
|
in_features=3, out_features=self.c_atompair |
|
) |
|
self.linear_no_bias_invd = LinearNoBias( |
|
in_features=1, out_features=self.c_atompair |
|
) |
|
self.linear_no_bias_v = LinearNoBias( |
|
in_features=1, out_features=self.c_atompair |
|
) |
|
|
|
if self.has_coords: |
|
|
|
self.layernorm_s = LayerNorm(self.c_s) |
|
self.linear_no_bias_s = LinearNoBias( |
|
in_features=self.c_s, out_features=self.c_atom |
|
) |
|
|
|
self.layernorm_z = LayerNorm(self.c_z) |
|
self.linear_no_bias_z = LinearNoBias( |
|
in_features=self.c_z, out_features=self.c_atompair |
|
) |
|
|
|
self.linear_no_bias_r = LinearNoBias( |
|
in_features=3, out_features=self.c_atom |
|
) |
|
self.linear_no_bias_cl = LinearNoBias( |
|
in_features=self.c_atom, out_features=self.c_atompair |
|
) |
|
self.linear_no_bias_cm = LinearNoBias( |
|
in_features=self.c_atom, out_features=self.c_atompair |
|
) |
|
self.small_mlp = nn.Sequential( |
|
nn.ReLU(), |
|
LinearNoBias(in_features=self.c_atompair, out_features=self.c_atompair), |
|
nn.ReLU(), |
|
LinearNoBias(in_features=self.c_atompair, out_features=self.c_atompair), |
|
nn.ReLU(), |
|
LinearNoBias(in_features=self.c_atompair, out_features=self.c_atompair), |
|
) |
|
self.atom_transformer = AtomTransformer( |
|
n_blocks=n_blocks, |
|
n_heads=n_heads, |
|
c_atom=c_atom, |
|
c_atompair=c_atompair, |
|
n_queries=n_queries, |
|
n_keys=n_keys, |
|
blocks_per_ckpt=blocks_per_ckpt, |
|
) |
|
self.linear_no_bias_q = LinearNoBias( |
|
in_features=self.c_atom, out_features=self.c_token |
|
) |
|
|
|
def linear_init( |
|
self, |
|
zero_init_atom_encoder_residual_linear: bool = False, |
|
he_normal_init_atom_encoder_small_mlp: bool = False, |
|
he_normal_init_atom_encoder_output: bool = False, |
|
): |
|
""" |
|
Initializes the parameters of the diffusion module according to the provided initialization configuration. |
|
|
|
Args: |
|
zero_init_atom_encoder_residual_linear (bool): Whether to zero-initialize the residual linear layers. |
|
he_normal_init_atom_encoder_small_mlp (bool): Whether to initialize the small MLP layers with He normal initialization. |
|
he_normal_init_atom_encoder_output (bool): Whether to initialize the output layer with He normal initialization. |
|
""" |
|
|
|
if zero_init_atom_encoder_residual_linear: |
|
nn.init.zeros_(self.linear_no_bias_invd.weight) |
|
nn.init.zeros_(self.linear_no_bias_v.weight) |
|
nn.init.zeros_(self.linear_no_bias_s.weight) |
|
nn.init.zeros_(self.linear_no_bias_z.weight) |
|
nn.init.zeros_(self.linear_no_bias_r.weight) |
|
nn.init.zeros_(self.linear_no_bias_cl.weight) |
|
nn.init.zeros_(self.linear_no_bias_cm.weight) |
|
if he_normal_init_atom_encoder_small_mlp: |
|
for layer in self.small_mlp: |
|
if not isinstance(layer, torch.nn.modules.activation.ReLU): |
|
nn.init.kaiming_normal_( |
|
layer.weight, |
|
a=0, |
|
mode="fan_in", |
|
nonlinearity="relu", |
|
) |
|
if he_normal_init_atom_encoder_output: |
|
nn.init.kaiming_normal_( |
|
self.linear_no_bias_q.weight, a=0, mode="fan_in", nonlinearity="relu" |
|
) |
|
|
|
def forward( |
|
self, |
|
input_feature_dict: dict[str, Union[torch.Tensor, int, float, dict]], |
|
r_l: torch.Tensor = None, |
|
s: torch.Tensor = None, |
|
z: torch.Tensor = None, |
|
inplace_safe: bool = False, |
|
chunk_size: Optional[int] = None, |
|
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: |
|
""" |
|
Args: |
|
input_feature_dict (dict[str, Union[torch.Tensor, int, float, dict]]): input meta feature dict |
|
r_l (torch.Tensor, optional): noisy position. |
|
[..., N_sample, N_atom, 3] if has_coords else None. |
|
s (torch.Tensor, optional): single embedding. |
|
[..., N_sample, N_token, c_s] if has_coords else None. |
|
z (torch.Tensor, optional): pair embedding |
|
[..., N_sample, N_token, N_token, c_z] if has_coords else None. |
|
|
|
Returns: |
|
tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: the output of AtomAttentionEncoder |
|
a: |
|
[..., (N_sample), N_token, c_token] |
|
q_l: |
|
[..., (N_sample), N_atom, c_atom] |
|
c_l: |
|
[..., (N_sample), N_atom, c_atom] |
|
p_lm: |
|
[..., (N_sample), N_atom, N_atom, c_atompair] |
|
|
|
""" |
|
|
|
if self.has_coords: |
|
assert r_l is not None |
|
assert s is not None |
|
assert z is not None |
|
|
|
atom_to_token_idx = input_feature_dict["atom_to_token_idx"] |
|
|
|
|
|
batch_shape = input_feature_dict["ref_pos"].shape[:-2] |
|
N_atom = input_feature_dict["ref_pos"].shape[-2] |
|
c_l = self.linear_no_bias_f( |
|
torch.cat( |
|
[ |
|
input_feature_dict[name].reshape( |
|
*batch_shape, N_atom, self.input_feature[name] |
|
) |
|
for name in self.input_feature |
|
], |
|
dim=-1, |
|
) |
|
) |
|
|
|
|
|
|
|
|
|
q_trunked_list, k_trunked_list, pad_info = rearrange_qk_to_dense_trunk( |
|
q=[input_feature_dict["ref_pos"], input_feature_dict["ref_space_uid"]], |
|
k=[input_feature_dict["ref_pos"], input_feature_dict["ref_space_uid"]], |
|
dim_q=[-2, -1], |
|
dim_k=[-2, -1], |
|
n_queries=self.n_queries, |
|
n_keys=self.n_keys, |
|
compute_mask=True, |
|
) |
|
|
|
|
|
d_lm = ( |
|
q_trunked_list[0][..., None, :] - k_trunked_list[0][..., None, :, :] |
|
) |
|
v_lm = ( |
|
q_trunked_list[1][..., None].int() == k_trunked_list[1][..., None, :].int() |
|
).unsqueeze( |
|
dim=-1 |
|
) |
|
p_lm = (self.linear_no_bias_d(d_lm) * v_lm) * pad_info[ |
|
"mask_trunked" |
|
].unsqueeze( |
|
dim=-1 |
|
) |
|
|
|
|
|
if inplace_safe: |
|
p_lm += ( |
|
self.linear_no_bias_invd(1 / (1 + (d_lm**2).sum(dim=-1, keepdim=True))) |
|
* v_lm |
|
) |
|
p_lm += self.linear_no_bias_v(v_lm.to(dtype=p_lm.dtype)) * v_lm |
|
else: |
|
p_lm = ( |
|
p_lm |
|
+ self.linear_no_bias_invd( |
|
1 / (1 + (d_lm**2).sum(dim=-1, keepdim=True)) |
|
) |
|
* v_lm |
|
) |
|
p_lm = p_lm + self.linear_no_bias_v(v_lm.to(dtype=p_lm.dtype)) * v_lm |
|
|
|
|
|
q_l = c_l.clone() |
|
|
|
|
|
n_token = None |
|
if r_l is not None: |
|
N_sample = r_l.size(-3) |
|
|
|
|
|
n_token = s.size(-2) |
|
c_l = c_l.unsqueeze(dim=-3) + self.linear_no_bias_s( |
|
self.layernorm_s( |
|
broadcast_token_to_atom( |
|
x_token=s, atom_to_token_idx=atom_to_token_idx |
|
) |
|
) |
|
) |
|
z_local_pairs, _ = broadcast_token_to_local_atom_pair( |
|
z_token=z, |
|
atom_to_token_idx=atom_to_token_idx, |
|
n_queries=self.n_queries, |
|
n_keys=self.n_keys, |
|
compute_mask=False, |
|
) |
|
p_lm = p_lm.unsqueeze(dim=-5) + self.linear_no_bias_z( |
|
self.layernorm_z(z_local_pairs) |
|
) |
|
|
|
|
|
q_l = q_l.unsqueeze(dim=-3) + self.linear_no_bias_r( |
|
r_l |
|
) |
|
|
|
|
|
c_l_q, c_l_k, _ = rearrange_qk_to_dense_trunk( |
|
q=c_l, |
|
k=c_l, |
|
dim_q=-2, |
|
dim_k=-2, |
|
n_queries=self.n_queries, |
|
n_keys=self.n_keys, |
|
compute_mask=False, |
|
) |
|
if inplace_safe: |
|
p_lm += self.linear_no_bias_cl(F.relu(c_l_q[..., None, :])) |
|
p_lm += self.linear_no_bias_cm(F.relu(c_l_k[..., None, :, :])) |
|
p_lm += self.small_mlp(p_lm) |
|
else: |
|
p_lm = ( |
|
p_lm |
|
+ self.linear_no_bias_cl(F.relu(c_l_q[..., None, :])) |
|
+ self.linear_no_bias_cm(F.relu(c_l_k[..., None, :, :])) |
|
) |
|
|
|
|
|
p_lm = p_lm + self.small_mlp(p_lm) |
|
|
|
|
|
q_l = self.atom_transformer( |
|
q_l, c_l, p_lm, chunk_size=chunk_size |
|
) |
|
|
|
|
|
a = aggregate_atom_to_token( |
|
x_atom=F.relu(self.linear_no_bias_q(q_l)), |
|
atom_to_token_idx=atom_to_token_idx, |
|
n_token=n_token, |
|
reduce="mean", |
|
) |
|
if (not self.training) and (a.shape[-2] > 2000 or q_l.shape[-2] > 20000): |
|
torch.cuda.empty_cache() |
|
return a, q_l, c_l, p_lm |
|
|
|
|
|
class AtomAttentionDecoder(nn.Module): |
|
""" |
|
Implements Algorithm 6 in AF3 |
|
""" |
|
|
|
def __init__( |
|
self, |
|
n_blocks: int = 3, |
|
n_heads: int = 4, |
|
c_token: int = 384, |
|
c_atom: int = 128, |
|
c_atompair: int = 16, |
|
n_queries: int = 32, |
|
n_keys: int = 128, |
|
blocks_per_ckpt: Optional[int] = None, |
|
) -> None: |
|
""" |
|
Args: |
|
n_blocks (int, optional): number of blocks for AtomTransformer. Defaults to 3. |
|
n_heads (int, optional): number of heads for AtomTransformer. Defaults to 4. |
|
c_token (int, optional): feature channel of token (single a). Defaults to 384. |
|
c_atom (int, optional): embedding dim for atom embedding. Defaults to 128. |
|
c_atompair (int, optional): embedding dim for atom pair embedding. |
|
blocks_per_ckpt: number of AtomAttentionDecoder/AtomTransformer blocks in each activation checkpoint |
|
Size of each chunk. A higher value corresponds to fewer |
|
checkpoints, and trades memory for speed. If None, no checkpointing |
|
is performed. |
|
""" |
|
super(AtomAttentionDecoder, self).__init__() |
|
self.n_blocks = n_blocks |
|
self.n_heads = n_heads |
|
self.c_token = c_token |
|
self.c_atom = c_atom |
|
self.c_atompair = c_atompair |
|
self.n_queries = n_queries |
|
self.n_keys = n_keys |
|
self.linear_no_bias_a = LinearNoBias(in_features=c_token, out_features=c_atom) |
|
self.layernorm_q = LayerNorm(c_atom) |
|
self.linear_no_bias_out = LinearNoBias(in_features=c_atom, out_features=3) |
|
self.atom_transformer = AtomTransformer( |
|
n_blocks=n_blocks, |
|
n_heads=n_heads, |
|
c_atom=c_atom, |
|
c_atompair=c_atompair, |
|
n_queries=n_queries, |
|
n_keys=n_keys, |
|
blocks_per_ckpt=blocks_per_ckpt, |
|
) |
|
|
|
def forward( |
|
self, |
|
input_feature_dict: dict[str, Union[torch.Tensor, int, float, dict]], |
|
a: torch.Tensor, |
|
q_skip: torch.Tensor, |
|
c_skip: torch.Tensor, |
|
p_skip: torch.Tensor, |
|
inplace_safe: bool = False, |
|
chunk_size: Optional[int] = None, |
|
) -> torch.Tensor: |
|
""" |
|
Args: |
|
input_feature_dict (dict[str, Union[torch.Tensor, int, float, dict]]): input meta feature dict |
|
a (torch.Tensor): the single feature aggregate per-atom representation |
|
[..., N_token, c_token] |
|
q_skip (torch.Tensor): atom single embedding |
|
[..., N_atom, c_atom] |
|
c_skip (torch.Tensor): atom single embedding |
|
[..., N_atom, c_atom] |
|
p_skip (torch.Tensor): atompair single embedding |
|
[..., n_blocks, n_queries, n_keys, c_atompair] |
|
|
|
Returns: |
|
torch.Tensor: the updated nosiy coordinates |
|
[..., N_atom, 3] |
|
""" |
|
|
|
q = ( |
|
self.linear_no_bias_a( |
|
broadcast_token_to_atom( |
|
x_token=a, atom_to_token_idx=input_feature_dict["atom_to_token_idx"] |
|
) |
|
) |
|
+ q_skip |
|
) |
|
|
|
|
|
q = self.atom_transformer( |
|
q, c_skip, p_skip, inplace_safe=inplace_safe, chunk_size=chunk_size |
|
) |
|
|
|
|
|
r = self.linear_no_bias_out(self.layernorm_q(q)) |
|
|
|
return r |
|
|