Zaixi's picture
Add large file
89c0b51
# Copyright 2024 ByteDance and/or its affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Optional, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from protenix.model.modules.primitives import LinearNoBias
from protenix.model.modules.transformer import AtomAttentionEncoder
class InputFeatureEmbedder(nn.Module):
"""
Implements Algorithm 2 in AF3
"""
def __init__(
self,
c_atom: int = 128,
c_atompair: int = 16,
c_token: int = 384,
) -> None:
"""
Args:
c_atom (int, optional): atom embedding dim. Defaults to 128.
c_atompair (int, optional): atom pair embedding dim. Defaults to 16.
c_token (int, optional): token embedding dim. Defaults to 384.
"""
super(InputFeatureEmbedder, self).__init__()
self.c_atom = c_atom
self.c_atompair = c_atompair
self.c_token = c_token
self.atom_attention_encoder = AtomAttentionEncoder(
c_atom=c_atom,
c_atompair=c_atompair,
c_token=c_token,
has_coords=False,
)
# Line2
self.input_feature = {"restype": 32, "profile": 32, "deletion_mean": 1}
def forward(
self,
input_feature_dict: dict[str, Any],
inplace_safe: bool = False,
chunk_size: Optional[int] = None,
) -> torch.Tensor:
"""
Args:
input_feature_dict (Dict[str, Any]): dict of input features
inplace_safe (bool): Whether it is safe to use inplace operations. Defaults to False.
chunk_size (Optional[int]): Chunk size for memory-efficient operations. Defaults to None.
Returns:
torch.Tensor: token embedding
[..., N_token, 384 (c_token) + 32 + 32 + 1 :=449]
"""
# Embed per-atom features.
a, _, _, _ = self.atom_attention_encoder(
input_feature_dict=input_feature_dict,
inplace_safe=inplace_safe,
chunk_size=chunk_size,
) # [..., N_token, c_token]
# Concatenate the per-token features.
batch_shape = input_feature_dict["restype"].shape[:-1]
s_inputs = torch.cat(
[a]
+ [
input_feature_dict[name].reshape(*batch_shape, d)
for name, d in self.input_feature.items()
],
dim=-1,
)
return s_inputs
class RelativePositionEncoding(nn.Module):
"""
Implements Algorithm 3 in AF3
"""
def __init__(self, r_max: int = 32, s_max: int = 2, c_z: int = 128) -> None:
"""
Args:
r_max (int, optional): Relative position indices clip value. Defaults to 32.
s_max (int, optional): Relative chain indices clip value. Defaults to 2.
c_z (int, optional): hidden dim [for pair embedding]. Defaults to 128.
"""
super(RelativePositionEncoding, self).__init__()
self.r_max = r_max
self.s_max = s_max
self.c_z = c_z
self.linear_no_bias = LinearNoBias(
in_features=(4 * self.r_max + 2 * self.s_max + 7), out_features=self.c_z
)
self.input_feature = {
"asym_id": 1,
"residue_index": 1,
"entity_id": 1,
"sym_id": 1,
"token_index": 1,
}
def forward(self, input_feature_dict: dict[str, Any]) -> torch.Tensor:
"""
Args:
input_feature_dict (Dict[str, Any]): input meta feature dict.
asym_id / residue_index / entity_id / sym_id / token_index
[..., N_tokens]
Returns:
torch.Tensor: relative position encoding
[..., N_token, N_token, c_z]
"""
b_same_chain = (
input_feature_dict["asym_id"][..., :, None]
== input_feature_dict["asym_id"][..., None, :]
).long() # [..., N_token, N_token]
b_same_residue = (
input_feature_dict["residue_index"][..., :, None]
== input_feature_dict["residue_index"][..., None, :]
).long() # [..., N_token, N_token]
b_same_entity = (
input_feature_dict["entity_id"][..., :, None]
== input_feature_dict["entity_id"][..., None, :]
).long() # [..., N_token, N_token]
d_residue = torch.clip(
input=input_feature_dict["residue_index"][..., :, None]
- input_feature_dict["residue_index"][..., None, :]
+ self.r_max,
min=0,
max=2 * self.r_max,
) * b_same_chain + (1 - b_same_chain) * (
2 * self.r_max + 1
) # [..., N_token, N_token]
a_rel_pos = F.one_hot(d_residue, 2 * (self.r_max + 1))
d_token = torch.clip(
input=input_feature_dict["token_index"][..., :, None]
- input_feature_dict["token_index"][..., None, :]
+ self.r_max,
min=0,
max=2 * self.r_max,
) * b_same_chain * b_same_residue + (1 - b_same_chain * b_same_residue) * (
2 * self.r_max + 1
) # [..., N_token, N_token]
a_rel_token = F.one_hot(d_token, 2 * (self.r_max + 1))
d_chain = torch.clip(
input=input_feature_dict["sym_id"][..., :, None]
- input_feature_dict["sym_id"][..., None, :]
+ self.s_max,
min=0,
max=2 * self.s_max,
) * b_same_entity + (1 - b_same_entity) * (
2 * self.s_max + 1
) # [..., N_token, N_token]
a_rel_chain = F.one_hot(d_chain, 2 * (self.s_max + 1))
if self.training:
p = self.linear_no_bias(
torch.cat(
[a_rel_pos, a_rel_token, b_same_entity[..., None], a_rel_chain],
dim=-1,
).float()
) # [..., N_token, N_token, 2 * (self.r_max + 1)+ 2 * (self.r_max + 1)+ 1 + 2 * (self.s_max + 1)] -> [..., N_token, N_token, c_z]
return p
else:
del d_chain, d_token, d_residue, b_same_chain, b_same_residue
origin_shape = a_rel_pos.shape[:-1]
Ntoken = a_rel_pos.shape[-2]
a_rel_pos = a_rel_pos.reshape(-1, a_rel_pos.shape[-1])
chunk_num = 1 if Ntoken < 3200 else 8
a_rel_pos_chunks = torch.chunk(
a_rel_pos.reshape(-1, a_rel_pos.shape[-1]), chunk_num, dim=-2
)
a_rel_token_chunks = torch.chunk(
a_rel_token.reshape(-1, a_rel_token.shape[-1]), chunk_num, dim=-2
)
b_same_entity_chunks = torch.chunk(
b_same_entity.reshape(-1, 1), chunk_num, dim=-2
)
a_rel_chain_chunks = torch.chunk(
a_rel_chain.reshape(-1, a_rel_chain.shape[-1]), chunk_num, dim=-2
)
start = 0
p = None
for i in range(len(a_rel_pos_chunks)):
data = torch.cat(
[
a_rel_pos_chunks[i],
a_rel_token_chunks[i],
b_same_entity_chunks[i],
a_rel_chain_chunks[i],
],
dim=-1,
).float()
result = self.linear_no_bias(data)
del data
if p is None:
p = torch.empty(
(a_rel_pos.shape[-2], self.c_z),
device=a_rel_pos.device,
dtype=result.dtype,
)
p[start : start + result.shape[0]] = result
start += result.shape[0]
del result
del a_rel_pos, a_rel_token, b_same_entity, a_rel_chain
p = p.reshape(*origin_shape, -1)
return p
class FourierEmbedding(nn.Module):
"""
Implements Algorithm 22 in AF3
"""
def __init__(self, c: int, seed: int = 42) -> None:
"""
Args:
c (int): embedding dim.
"""
super(FourierEmbedding, self).__init__()
self.c = c
self.seed = seed
generator = torch.Generator()
generator.manual_seed(seed)
w_value = torch.randn(size=(c,), generator=generator)
self.w = nn.Parameter(w_value, requires_grad=False)
b_value = torch.randn(size=(c,), generator=generator)
self.b = nn.Parameter(b_value, requires_grad=False)
def forward(self, t_hat_noise_level: torch.Tensor) -> torch.Tensor:
"""
Args:
t_hat_noise_level (torch.Tensor): the noise level
[..., N_sample]
Returns:
torch.Tensor: the output fourier embedding
[..., N_sample, c]
"""
return torch.cos(
input=2 * torch.pi * (t_hat_noise_level.unsqueeze(dim=-1) * self.w + self.b)
)