Llama-3.1-8B-DALv0.1
/
venv
/lib
/python3.12
/site-packages
/transformers
/models
/esm
/modeling_esmfold.py
# coding=utf-8 | |
# Copyright 2022 Meta and The HuggingFace Inc. team. All rights reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
import math | |
import sys | |
from dataclasses import dataclass | |
from functools import partial | |
from typing import Callable, Dict, List, Optional, Sequence, Tuple, Union | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
from torch.nn import LayerNorm | |
from ...integrations.deepspeed import is_deepspeed_available | |
from ...modeling_outputs import ModelOutput | |
from ...utils import ( | |
ContextManagers, | |
add_start_docstrings, | |
add_start_docstrings_to_model_forward, | |
is_scipy_available, | |
logging, | |
replace_return_docstrings, | |
) | |
from .configuration_esm import EsmConfig | |
from .modeling_esm import ESM_START_DOCSTRING, EsmModel, EsmPreTrainedModel | |
from .openfold_utils import ( | |
OFProtein, | |
Rigid, | |
Rotation, | |
atom14_to_atom37, | |
chunk_layer, | |
compute_predicted_aligned_error, | |
compute_tm, | |
frames_and_literature_positions_to_atom14_pos, | |
make_atom14_masks, | |
residue_constants, | |
to_pdb, | |
torsion_angles_to_frames, | |
) | |
logger = logging.get_logger(__name__) | |
_CHECKPOINT_FOR_DOC = "facebook/esmfold_v1" | |
_CONFIG_FOR_DOC = "EsmConfig" | |
class EsmForProteinFoldingOutput(ModelOutput): | |
""" | |
Output type of [`EsmForProteinFoldingOutput`]. | |
Args: | |
frames (`torch.FloatTensor`): | |
Output frames. | |
sidechain_frames (`torch.FloatTensor`): | |
Output sidechain frames. | |
unnormalized_angles (`torch.FloatTensor`): | |
Predicted unnormalized backbone and side chain torsion angles. | |
angles (`torch.FloatTensor`): | |
Predicted backbone and side chain torsion angles. | |
positions (`torch.FloatTensor`): | |
Predicted positions of the backbone and side chain atoms. | |
states (`torch.FloatTensor`): | |
Hidden states from the protein folding trunk. | |
s_s (`torch.FloatTensor`): | |
Per-residue embeddings derived by concatenating the hidden states of each layer of the ESM-2 LM stem. | |
s_z (`torch.FloatTensor`): | |
Pairwise residue embeddings. | |
distogram_logits (`torch.FloatTensor`): | |
Input logits to the distogram used to compute residue distances. | |
lm_logits (`torch.FloatTensor`): | |
Logits output by the ESM-2 protein language model stem. | |
aatype (`torch.FloatTensor`): | |
Input amino acids (AlphaFold2 indices). | |
atom14_atom_exists (`torch.FloatTensor`): | |
Whether each atom exists in the atom14 representation. | |
residx_atom14_to_atom37 (`torch.FloatTensor`): | |
Mapping between atoms in the atom14 and atom37 representations. | |
residx_atom37_to_atom14 (`torch.FloatTensor`): | |
Mapping between atoms in the atom37 and atom14 representations. | |
atom37_atom_exists (`torch.FloatTensor`): | |
Whether each atom exists in the atom37 representation. | |
residue_index (`torch.FloatTensor`): | |
The index of each residue in the protein chain. Unless internal padding tokens are used, this will just be | |
a sequence of integers from 0 to `sequence_length`. | |
lddt_head (`torch.FloatTensor`): | |
Raw outputs from the lddt head used to compute plddt. | |
plddt (`torch.FloatTensor`): | |
Per-residue confidence scores. Regions of low confidence may indicate areas where the model's prediction is | |
uncertain, or where the protein structure is disordered. | |
ptm_logits (`torch.FloatTensor`): | |
Raw logits used for computing ptm. | |
ptm (`torch.FloatTensor`): | |
TM-score output representing the model's high-level confidence in the overall structure. | |
aligned_confidence_probs (`torch.FloatTensor`): | |
Per-residue confidence scores for the aligned structure. | |
predicted_aligned_error (`torch.FloatTensor`): | |
Predicted error between the model's prediction and the ground truth. | |
max_predicted_aligned_error (`torch.FloatTensor`): | |
Per-sample maximum predicted error. | |
""" | |
frames: torch.FloatTensor = None | |
sidechain_frames: torch.FloatTensor = None | |
unnormalized_angles: torch.FloatTensor = None | |
angles: torch.FloatTensor = None | |
positions: torch.FloatTensor = None | |
states: torch.FloatTensor = None | |
s_s: torch.FloatTensor = None | |
s_z: torch.FloatTensor = None | |
distogram_logits: torch.FloatTensor = None | |
lm_logits: torch.FloatTensor = None | |
aatype: torch.FloatTensor = None | |
atom14_atom_exists: torch.FloatTensor = None | |
residx_atom14_to_atom37: torch.FloatTensor = None | |
residx_atom37_to_atom14: torch.FloatTensor = None | |
atom37_atom_exists: torch.FloatTensor = None | |
residue_index: torch.FloatTensor = None | |
lddt_head: torch.FloatTensor = None | |
plddt: torch.FloatTensor = None | |
ptm_logits: torch.FloatTensor = None | |
ptm: torch.FloatTensor = None | |
aligned_confidence_probs: torch.FloatTensor = None | |
predicted_aligned_error: torch.FloatTensor = None | |
max_predicted_aligned_error: torch.FloatTensor = None | |
ESMFOLD_INPUTS_DOCSTRING = r""" | |
Args: | |
input_ids (`torch.LongTensor` of shape `({0})`): | |
Indices of input sequence tokens in the vocabulary. | |
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and | |
[`PreTrainedTokenizer.__call__`] for details. | |
[What are input IDs?](../glossary#input-ids) | |
attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*): | |
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: | |
- 1 for tokens that are **not masked**, | |
- 0 for tokens that are **masked**. | |
[What are attention masks?](../glossary#attention-mask) | |
position_ids (`torch.LongTensor` of shape `({0})`, *optional*): | |
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, | |
config.max_position_embeddings - 1]`. | |
[What are position IDs?](../glossary#position-ids) | |
masking_pattern (`torch.LongTensor` of shape `({0})`, *optional*): | |
Locations of tokens to mask during training as a form of regularization. Mask values selected in `[0, 1]`. | |
num_recycles (`int`, *optional*, defaults to `None`): | |
Number of times to recycle the input sequence. If `None`, defaults to `config.num_recycles`. "Recycling" | |
consists of passing the output of the folding trunk back in as input to the trunk. During training, the | |
number of recycles should vary with each batch, to ensure that the model learns to output valid predictions | |
after each recycle. During inference, num_recycles should be set to the highest value that the model was | |
trained with for maximum accuracy. Accordingly, when this value is set to `None`, config.max_recycles is | |
used. | |
""" | |
def is_fp16_enabled(): | |
# Autocast world | |
fp16_enabled = torch.get_autocast_gpu_dtype() == torch.float16 | |
fp16_enabled = fp16_enabled and torch.is_autocast_enabled() | |
return fp16_enabled | |
def is_deepspeed_initialized(): | |
if is_deepspeed_available(): | |
return False | |
else: | |
try: | |
import deepspeed | |
# This is not available in all DeepSpeed versions. | |
return deepspeed.utils.is_initialized() | |
except Exception: | |
return False | |
def collate_dense_tensors(samples: List[torch.Tensor], pad_v: float = 0) -> torch.Tensor: | |
""" | |
Takes a list of tensors with the following dimensions: | |
[(d_11, ..., d_1K), | |
(d_21, ..., d_2K), ..., (d_N1, ..., d_NK)] | |
and stack + pads them into a single tensor of: | |
(N, max_i=1,N { d_i1 }, ..., max_i=1,N {diK}) | |
""" | |
if len(samples) == 0: | |
return torch.Tensor() | |
if len({x.dim() for x in samples}) != 1: | |
raise RuntimeError(f"Samples has varying dimensions: {[x.dim() for x in samples]}") | |
(device,) = tuple({x.device for x in samples}) # assumes all on same device | |
max_shape = [max(lst) for lst in zip(*[x.shape for x in samples])] | |
result = torch.empty(len(samples), *max_shape, dtype=samples[0].dtype, device=device) | |
result.fill_(pad_v) | |
for i in range(len(samples)): | |
result_i = result[i] | |
t = samples[i] | |
result_i[tuple(slice(0, k) for k in t.shape)] = t | |
return result | |
def flatten_final_dims(t: torch.Tensor, no_dims: int): | |
return t.reshape(t.shape[:-no_dims] + (-1,)) | |
def permute_final_dims(tensor: torch.Tensor, inds: List[int]): | |
zero_index = -1 * len(inds) | |
first_inds = list(range(len(tensor.shape[:zero_index]))) | |
return tensor.permute(first_inds + [zero_index + i for i in inds]) | |
def dict_multimap(fn, dicts): | |
first = dicts[0] | |
new_dict = {} | |
for k, v in first.items(): | |
all_v = [d[k] for d in dicts] | |
if isinstance(v, dict): | |
new_dict[k] = dict_multimap(fn, all_v) | |
else: | |
new_dict[k] = fn(all_v) | |
return new_dict | |
def trunc_normal_init_(weights, scale=1.0, fan="fan_in"): | |
shape = weights.shape | |
scale = scale / max(1, shape[1]) | |
if not is_scipy_available(): | |
logger.warning( | |
"This init requires scipy, but scipy was not found, default to an approximation that might not be" | |
" equivalent." | |
) | |
std = math.sqrt(scale) | |
torch.nn.init.normal_(weights, std=std).clamp(min=0.0, max=2.0 * std) | |
else: | |
from scipy.stats import truncnorm | |
std = math.sqrt(scale) / truncnorm.std(a=-2, b=2, loc=0, scale=1) | |
samples = truncnorm.rvs(a=-2, b=2, loc=0, scale=std, size=weights.numel()) | |
samples = np.reshape(samples, shape) | |
weights.copy_(torch.tensor(samples, device=weights.device)) | |
def ipa_point_weights_init_(weights): | |
with torch.no_grad(): | |
softplus_inverse_1 = 0.541324854612918 | |
weights.fill_(softplus_inverse_1) | |
class EsmFoldLinear(nn.Linear): | |
""" | |
A Linear layer with built-in nonstandard initializations. Called just like torch.nn.Linear. | |
Implements the initializers in 1.11.4, plus some additional ones found in the code. | |
""" | |
def __init__( | |
self, | |
in_dim: int, | |
out_dim: int, | |
bias: bool = True, | |
init: str = "default", | |
init_fn: Optional[Callable[[torch.Tensor, torch.Tensor], None]] = None, | |
): | |
""" | |
Args: | |
in_dim: | |
The final dimension of inputs to the layer | |
out_dim: | |
The final dimension of layer outputs | |
bias: | |
Whether to learn an additive bias. True by default | |
init: | |
The initializer to use. Choose from: | |
"default": LeCun fan-in truncated normal initialization "relu": He initialization w/ truncated normal | |
distribution "glorot": Fan-average Glorot uniform initialization "gating": Weights=0, Bias=1 "normal": | |
Normal initialization with std=1/sqrt(fan_in) "final": Weights=0, Bias=0 | |
Overridden by init_fn if the latter is not None. | |
init_fn: | |
A custom initializer taking weight and bias as inputs. Overrides init if not None. | |
""" | |
super().__init__(in_dim, out_dim, bias=bias) | |
if bias: | |
with torch.no_grad(): | |
self.bias.fill_(0) | |
self.init = init | |
self.init_fn = init_fn | |
if init not in ["default", "relu", "glorot", "gating", "normal", "final"]: | |
raise ValueError("Invalid init string.") | |
class EsmFoldLayerNorm(nn.Module): | |
def __init__(self, c_in, eps=1e-5): | |
super().__init__() | |
self.c_in = (c_in,) | |
self.eps = eps | |
self.weight = nn.Parameter(torch.ones(c_in)) | |
self.bias = nn.Parameter(torch.zeros(c_in)) | |
def forward(self, x): | |
d = x.dtype | |
if d is torch.bfloat16 and not is_deepspeed_initialized(): | |
with torch.cuda.amp.autocast(enabled=False): | |
out = nn.functional.layer_norm(x, self.c_in, self.weight.to(dtype=d), self.bias.to(dtype=d), self.eps) | |
else: | |
out = nn.functional.layer_norm(x, self.c_in, self.weight, self.bias, self.eps) | |
return out | |
def softmax_no_cast(t: torch.Tensor, dim: int = -1) -> torch.Tensor: | |
""" | |
Softmax, but without automatic casting to fp32 when the input is of type bfloat16 | |
""" | |
d = t.dtype | |
if d is torch.bfloat16 and not is_deepspeed_initialized(): | |
with torch.cuda.amp.autocast(enabled=False): | |
s = torch.nn.functional.softmax(t, dim=dim) | |
else: | |
s = torch.nn.functional.softmax(t, dim=dim) | |
return s | |
class EsmFoldAttention(nn.Module): | |
""" | |
Standard multi-head attention using AlphaFold's default layer initialization. Allows multiple bias vectors. | |
""" | |
def __init__( | |
self, | |
c_q: int, | |
c_k: int, | |
c_v: int, | |
c_hidden: int, | |
no_heads: int, | |
gating: bool = True, | |
): | |
""" | |
Args: | |
c_q: | |
Input dimension of query data | |
c_k: | |
Input dimension of key data | |
c_v: | |
Input dimension of value data | |
c_hidden: | |
Per-head hidden dimension | |
no_heads: | |
Number of attention heads | |
gating: | |
Whether the output should be gated using query data | |
""" | |
super().__init__() | |
self.c_q = c_q | |
self.c_k = c_k | |
self.c_v = c_v | |
self.c_hidden = c_hidden | |
self.no_heads = no_heads | |
self.gating = gating | |
# DISCREPANCY: c_hidden is not the per-head channel dimension, as | |
# stated in the supplement, but the overall channel dimension. | |
self.linear_q = EsmFoldLinear(self.c_q, self.c_hidden * self.no_heads, bias=False, init="glorot") | |
self.linear_k = EsmFoldLinear(self.c_k, self.c_hidden * self.no_heads, bias=False, init="glorot") | |
self.linear_v = EsmFoldLinear(self.c_v, self.c_hidden * self.no_heads, bias=False, init="glorot") | |
self.linear_o = EsmFoldLinear(self.c_hidden * self.no_heads, self.c_q, init="final") | |
self.linear_g = None | |
if self.gating: | |
self.linear_g = EsmFoldLinear(self.c_q, self.c_hidden * self.no_heads, init="gating") | |
self.sigmoid = nn.Sigmoid() | |
def _prep_qkv(self, q_x: torch.Tensor, kv_x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | |
# [*, Q/K/V, H * C_hidden] | |
q = self.linear_q(q_x) | |
k = self.linear_k(kv_x) | |
v = self.linear_v(kv_x) | |
# [*, Q/K, H, C_hidden] | |
q = q.view(q.shape[:-1] + (self.no_heads, -1)) | |
k = k.view(k.shape[:-1] + (self.no_heads, -1)) | |
v = v.view(v.shape[:-1] + (self.no_heads, -1)) | |
# [*, H, Q/K, C_hidden] | |
q = q.transpose(-2, -3) | |
k = k.transpose(-2, -3) | |
v = v.transpose(-2, -3) | |
q /= math.sqrt(self.c_hidden) | |
return q, k, v | |
def _wrap_up(self, o: torch.Tensor, q_x: torch.Tensor) -> torch.Tensor: | |
if self.linear_g is not None: | |
g = self.sigmoid(self.linear_g(q_x)) | |
# [*, Q, H, C_hidden] | |
g = g.view(g.shape[:-1] + (self.no_heads, -1)) | |
o = o * g | |
# [*, Q, H * C_hidden] | |
o = flatten_final_dims(o, 2) | |
# [*, Q, C_q] | |
o = self.linear_o(o) | |
return o | |
def forward( | |
self, | |
q_x: torch.Tensor, | |
kv_x: torch.Tensor, | |
biases: Optional[List[torch.Tensor]] = None, | |
use_memory_efficient_kernel: bool = False, | |
use_lma: bool = False, | |
lma_q_chunk_size: int = 1024, | |
lma_kv_chunk_size: int = 4096, | |
use_flash: bool = False, | |
flash_mask: Optional[torch.Tensor] = None, | |
) -> torch.Tensor: | |
""" | |
Args: | |
q_x: | |
[*, Q, C_q] query data | |
kv_x: | |
[*, K, C_k] key data | |
biases: | |
List of biases that broadcast to [*, H, Q, K] | |
use_memory_efficient_kernel: | |
Whether to use a custom memory-efficient attention kernel. This should be the default choice for most. | |
If none of the "use_<...>" flags are True, a stock PyTorch implementation is used instead | |
use_lma: | |
Whether to use low-memory attention (Staats & Rabe 2021). If none of the "use_<...>" flags are True, a | |
stock PyTorch implementation is used instead | |
lma_q_chunk_size: | |
Query chunk size (for LMA) | |
lma_kv_chunk_size: | |
Key/Value chunk size (for LMA) | |
Returns | |
[*, Q, C_q] attention update | |
""" | |
if use_lma and (lma_q_chunk_size is None or lma_kv_chunk_size is None): | |
raise ValueError("If use_lma is specified, lma_q_chunk_size and lma_kv_chunk_size must be provided") | |
if use_flash and biases is not None: | |
raise ValueError("use_flash is incompatible with the bias option. For masking, use flash_mask instead") | |
attn_options = [use_memory_efficient_kernel, use_lma, use_flash] | |
if sum(attn_options) > 1: | |
raise ValueError("Choose at most one alternative attention algorithm") | |
if biases is None: | |
biases = [] | |
# [*, H, Q/K, C_hidden] | |
query, key, value = self._prep_qkv(q_x, kv_x) | |
key = permute_final_dims(key, (1, 0)) | |
# [*, H, Q, K] | |
output = torch.matmul(query, key) | |
for b in biases: | |
output += b | |
output = softmax_no_cast(output, -1) | |
# [*, H, Q, C_hidden] | |
output = torch.matmul(output, value) | |
output = output.transpose(-2, -3) | |
output = self._wrap_up(output, q_x) | |
return output | |
class EsmFoldTriangleAttention(nn.Module): | |
def __init__(self, c_in, c_hidden, no_heads, starting=True, inf=1e9): | |
""" | |
Args: | |
c_in: | |
Input channel dimension | |
c_hidden: | |
Overall hidden channel dimension (not per-head) | |
no_heads: | |
Number of attention heads | |
""" | |
super().__init__() | |
self.c_in = c_in | |
self.c_hidden = c_hidden | |
self.no_heads = no_heads | |
self.starting = starting | |
self.inf = inf | |
self.layer_norm = LayerNorm(self.c_in) | |
self.linear = EsmFoldLinear(c_in, self.no_heads, bias=False, init="normal") | |
self.mha = EsmFoldAttention(self.c_in, self.c_in, self.c_in, self.c_hidden, self.no_heads) | |
def _chunk( | |
self, | |
x: torch.Tensor, | |
biases: List[torch.Tensor], | |
chunk_size: int, | |
use_memory_efficient_kernel: bool = False, | |
use_lma: bool = False, | |
inplace_safe: bool = False, | |
) -> torch.Tensor: | |
"triangle! triangle!" | |
mha_inputs = { | |
"q_x": x, | |
"kv_x": x, | |
"biases": biases, | |
} | |
return chunk_layer( | |
partial(self.mha, use_memory_efficient_kernel=use_memory_efficient_kernel, use_lma=use_lma), | |
mha_inputs, | |
chunk_size=chunk_size, | |
no_batch_dims=len(x.shape[:-2]), | |
_out=x if inplace_safe else None, | |
) | |
def forward( | |
self, | |
x: torch.Tensor, | |
mask: Optional[torch.Tensor] = None, | |
chunk_size: Optional[int] = None, | |
use_memory_efficient_kernel: bool = False, | |
use_lma: bool = False, | |
inplace_safe: bool = False, | |
) -> torch.Tensor: | |
""" | |
Args: | |
x: | |
[*, I, J, C_in] input tensor (e.g. the pair representation) | |
Returns: | |
[*, I, J, C_in] output tensor | |
""" | |
if mask is None: | |
# [*, I, J] | |
mask = x.new_ones( | |
x.shape[:-1], | |
) | |
if not self.starting: | |
x = x.transpose(-2, -3) | |
mask = mask.transpose(-1, -2) | |
# [*, I, J, C_in] | |
x = self.layer_norm(x) | |
# [*, I, 1, 1, J] | |
mask_bias = (self.inf * (mask - 1))[..., :, None, None, :] | |
# [*, H, I, J] | |
triangle_bias = permute_final_dims(self.linear(x), (2, 0, 1)) | |
# [*, 1, H, I, J] | |
triangle_bias = triangle_bias.unsqueeze(-4) | |
biases = [mask_bias, triangle_bias] | |
if chunk_size is not None: | |
x = self._chunk( | |
x, | |
biases, | |
chunk_size, | |
use_memory_efficient_kernel=use_memory_efficient_kernel, | |
use_lma=use_lma, | |
inplace_safe=inplace_safe, | |
) | |
else: | |
x = self.mha( | |
q_x=x, kv_x=x, biases=biases, use_memory_efficient_kernel=use_memory_efficient_kernel, use_lma=use_lma | |
) | |
if not self.starting: | |
x = x.transpose(-2, -3) | |
return x | |
class EsmFoldTriangleMultiplicativeUpdate(nn.Module): | |
""" | |
Implements Algorithms 11 and 12. | |
""" | |
def __init__(self, config, _outgoing=True): | |
super().__init__() | |
c_hidden = config.pairwise_state_dim | |
self._outgoing = _outgoing | |
self.linear_a_p = EsmFoldLinear(c_hidden, c_hidden) | |
self.linear_a_g = EsmFoldLinear(c_hidden, c_hidden, init="gating") | |
self.linear_b_p = EsmFoldLinear(c_hidden, c_hidden) | |
self.linear_b_g = EsmFoldLinear(c_hidden, c_hidden, init="gating") | |
self.linear_g = EsmFoldLinear(c_hidden, c_hidden, init="gating") | |
self.linear_z = EsmFoldLinear(c_hidden, c_hidden, init="final") | |
self.layer_norm_in = LayerNorm(c_hidden) | |
self.layer_norm_out = LayerNorm(c_hidden) | |
self.sigmoid = nn.Sigmoid() | |
def _combine_projections( | |
self, a: torch.Tensor, b: torch.Tensor, _inplace_chunk_size: Optional[int] = None | |
) -> torch.Tensor: | |
if self._outgoing: | |
a = permute_final_dims(a, (2, 0, 1)) | |
b = permute_final_dims(b, (2, 1, 0)) | |
else: | |
a = permute_final_dims(a, (2, 1, 0)) | |
b = permute_final_dims(b, (2, 0, 1)) | |
if _inplace_chunk_size is not None: | |
# To be replaced by torch vmap | |
for i in range(0, a.shape[-3], _inplace_chunk_size): | |
a_chunk = a[..., i : i + _inplace_chunk_size, :, :] | |
b_chunk = b[..., i : i + _inplace_chunk_size, :, :] | |
a[..., i : i + _inplace_chunk_size, :, :] = torch.matmul( | |
a_chunk, | |
b_chunk, | |
) | |
p = a | |
else: | |
p = torch.matmul(a, b) | |
return permute_final_dims(p, (1, 2, 0)) | |
def _inference_forward( | |
self, | |
z: torch.Tensor, | |
mask: Optional[torch.Tensor] = None, | |
inplace_chunk_size: Optional[int] = None, | |
with_add: bool = True, | |
): | |
""" | |
Args: | |
z: | |
A [*, N, N, C_z] pair representation | |
mask: | |
A [*, N, N] pair mask | |
inplace_chunk_size: | |
Size of chunks used in the main computation. Increase to trade memory for speed. | |
with_add: | |
If True, z is overwritten with (z + update). Otherwise, it is overwritten with (update). | |
Returns: | |
A reference to the overwritten z | |
More memory-efficient, inference-only version of the forward function. Uses in-place operations, fusion of the | |
addition that happens after this module in the Evoformer, a smidge of recomputation, and a cache of overwritten | |
values to lower peak memory consumption of this module from 5x the size of the input tensor z to 2.5x its size. | |
Useful for inference on extremely long sequences. | |
It works as follows. We will make reference to variables used in the default forward implementation below. | |
Naively, triangle multiplication attention requires the manifestation of 5 tensors the size of z: 1) z, the | |
"square" input tensor, 2) a, the first projection of z, 3) b, the second projection of b, 4) g, a z-sized mask, | |
and 5) a z-sized tensor for intermediate computations. For large N, this is prohibitively expensive; for | |
N=4000, for example, z is more than 8GB alone. To avoid this problem, we compute b, g, and all intermediate | |
tensors in small chunks, noting that the chunks required to compute a chunk of the output depend only on the | |
tensor a and corresponding vertical and horizontal chunks of z. This suggests an algorithm that loops over | |
pairs of chunks of z: hereafter "columns" and "rows" of z, even though each "column" and "row" in fact contains | |
inplace_chunk_size contiguous true columns and rows of z. Writing output chunks to a new tensor would bring | |
total memory consumption down to 3x the size of z. However, more memory can be saved by writing output chunks | |
directly to z in-place. WLOG, we choose to write output chunks vertically, overwriting the ith "column" of z at | |
the end of the ith iteration of the main loop. Despite this overwriting, the ith column is always one column | |
ahead of previously overwritten columns and can be recovered directly from z. After the first iteration, | |
however, the ith row of z is always at least partially overwritten. For this reason, we introduce the z-cache, | |
a tensor one-half the size of z. The z-cache initially contains the left half (2nd and 3rd quadrants) of z. For | |
0 < i < N/2, the missing left part of the ith row of z is recovered from this cache at the beginning of the ith | |
iteration. Once i exceeds n/2, the cache is "reoriented" to encompass the 3rd and 4th quadrants of z instead. | |
Though the 3rd quadrant of the original z is entirely overwritten at this point, it can be recovered from the | |
z-cache itself. Thereafter, the ith row of z can be recovered in its entirety from the reoriented z-cache. | |
After the final iteration, z has been completely overwritten and contains the triangular multiplicative update. | |
If with_add is True, it instead contains the sum of z and the triangular multiplicative update. In either case, | |
peak memory consumption is just 2.5x the size of z, disregarding memory used for chunks and other small | |
variables. | |
""" | |
if mask is None: | |
mask = z.new_ones(z.shape[:-1]) | |
mask = mask.unsqueeze(-1) | |
def compute_projection_helper(pair, mask, a=True): | |
if a: | |
linear_g = self.linear_a_g | |
linear_p = self.linear_a_p | |
else: | |
linear_g = self.linear_b_g | |
linear_p = self.linear_b_p | |
pair = self.layer_norm_in(pair) | |
p = linear_g(pair) | |
p.sigmoid_() | |
p *= linear_p(pair) | |
p *= mask | |
p = permute_final_dims(p, (2, 0, 1)) | |
return p | |
def compute_projection(pair, mask, a=True, chunked=True): | |
need_transpose = self._outgoing ^ a | |
if not chunked: | |
p = compute_projection_helper(pair, mask, a) | |
if need_transpose: | |
p = p.transpose(-1, -2) | |
else: | |
# This computation is chunked so as not to exceed our 2.5x | |
# budget with a large intermediate tensor | |
linear_g = self.linear_a_g if a else self.linear_b_g | |
c = linear_g.bias.shape[-1] | |
out_shape = pair.shape[:-3] + (c,) + pair.shape[-3:-1] | |
p = pair.new_zeros(out_shape) | |
for i in range(0, pair.shape[-3], inplace_chunk_size): | |
pair_chunk = pair[..., i : i + inplace_chunk_size, :, :] | |
pair_chunk = compute_projection_helper( | |
pair[..., i : i + inplace_chunk_size, :, :], | |
mask[..., i : i + inplace_chunk_size, :, :], | |
a, | |
) | |
if need_transpose: | |
pair_chunk = pair_chunk.transpose(-1, -2) | |
p[..., i : i + inplace_chunk_size] = pair_chunk | |
else: | |
p[..., i : i + inplace_chunk_size, :] = pair_chunk | |
del pair_chunk | |
return p | |
# We start by fully manifesting a. In addition to the input, this | |
# brings total memory consumption to 2x z (disregarding size of chunks) | |
# [*, N, N, c] | |
a = compute_projection(z, mask, True, chunked=True) | |
if inplace_chunk_size is not None: | |
n = a.shape[-1] | |
half_n = n // 2 + n % 2 | |
row_dim = -3 | |
col_dim = -2 | |
b_chunk_dim = row_dim if self._outgoing else col_dim | |
def empty_slicer(t): | |
return [slice(None) for _ in t.shape] | |
def slice_tensor(t, start, end, dim): | |
# Slices start:end from the dim dimension of t | |
s = empty_slicer(t) | |
s[dim] = slice(start, end) | |
return t[s] | |
def flip_z_cache_(z_cache, z): | |
# "Reorient" the z_cache (see below), filling it with quadrants | |
# 3---recovered from the z_cache---and 4---recovered from z--- | |
# of the input tensor z. | |
quadrant_3 = slice_tensor(z_cache, half_n, None, row_dim) | |
z_cache = z_cache.transpose(row_dim, col_dim) | |
# If n is odd, we need to shrink the z_cache by one row | |
z_cache = z_cache[..., : (n // 2), :, :] | |
# Move the 3rd quadrant of z into the | |
first_half_slicer = empty_slicer(z_cache) | |
first_half_slicer[col_dim] = slice(0, half_n) | |
z_cache[first_half_slicer] = quadrant_3 | |
# Get the fourth quadrant of z | |
quadrant_4 = slice_tensor(z, half_n, None, row_dim) | |
quadrant_4 = slice_tensor(quadrant_4, half_n, None, col_dim) | |
# Insert said quadrant into the rotated z-cache | |
quadrant_3_slicer = empty_slicer(z_cache) | |
quadrant_3_slicer[col_dim] = slice(half_n, None) | |
z_cache[quadrant_3_slicer] = quadrant_4 | |
return z_cache | |
# Initialize the z cache to the left half of z. | |
z_cache_shape = list(z.shape) | |
z_cache_shape[col_dim] = half_n | |
z_cache = z.new_zeros(z_cache_shape) | |
z_cache_slicer = empty_slicer(z_cache) | |
z_cache_slicer[col_dim] = slice(0, half_n) | |
z_cache.copy_(z[z_cache_slicer]) | |
z_cache_rotated = False | |
# We need to reorient the z-cache at the halfway point, and we | |
# don't want a single chunk to straddle that point. We contract one | |
# of the chunks in the middle to address that problem. | |
i_range = list(range(0, half_n, inplace_chunk_size)) | |
initial_offsets = [i_2 - i_1 for i_1, i_2 in zip(i_range, i_range[1:] + [half_n])] | |
after_half = list(range(half_n, n, inplace_chunk_size)) | |
after_half_offsets = [inplace_chunk_size for _ in after_half] | |
combined_range_with_offsets = zip(i_range + after_half, initial_offsets + after_half_offsets) | |
for i, offset in combined_range_with_offsets: | |
if not z_cache_rotated and i >= half_n: | |
z_cache = flip_z_cache_(z_cache, z) | |
z_cache_rotated = True | |
z_chunk_b = slice_tensor(z, i, i + offset, b_chunk_dim) | |
mask_chunk = slice_tensor(mask, i, i + offset, b_chunk_dim) | |
z_chunk_b = z_chunk_b.clone() | |
if b_chunk_dim == col_dim: | |
z_chunk_b = slice_tensor(z, i, i + offset, col_dim) | |
else: # b_chunk_dim == row_dim | |
# In this case, the b-dimension (b_chunk_dim) is partially | |
# overwritten at the end of each iteration. We need to | |
# restore the missing component from the z-cache. | |
if not z_cache_rotated: | |
z_chunk_slicer = empty_slicer(z_chunk_b) | |
z_chunk_slicer[col_dim] = slice(0, half_n) | |
z_chunk_b[z_chunk_slicer] = slice_tensor(z_cache, i, i + offset, row_dim) | |
else: | |
z_cache_offset = i - half_n | |
z_chunk_b = slice_tensor(z_cache, z_cache_offset, z_cache_offset + offset, row_dim) | |
b_chunk = compute_projection(z_chunk_b, mask_chunk, a=False, chunked=False) | |
del z_chunk_b | |
x_chunk = torch.matmul(a, b_chunk) | |
x_chunk = permute_final_dims(x_chunk, (1, 2, 0)) | |
x_chunk = self.layer_norm_out(x_chunk) | |
x_chunk = self.linear_z(x_chunk) | |
# The g dimension (col_dim) is parallel to and ahead of the | |
# overwrites in z. We can extract the g chunk normally. | |
z_chunk_g = slice_tensor(z, i, i + offset, col_dim) | |
g_chunk = self.linear_g(self.layer_norm_in(z_chunk_g)) | |
g_chunk.sigmoid_() | |
del z_chunk_g | |
x_chunk *= g_chunk | |
# Write the columns into z in-place | |
z_slicer = empty_slicer(z) | |
z_slicer[col_dim] = slice(i, i + offset) | |
if with_add: | |
z[z_slicer] += x_chunk | |
else: | |
z[z_slicer] = x_chunk | |
else: | |
b = compute_projection(z, mask, False, False) | |
x = torch.matmul(a, b) | |
x = self.layer_norm_out(x) | |
x = self.linear_z(x) | |
g = self.linear_g(z) | |
g.sigmoid_() | |
x *= g | |
if with_add: | |
z += x | |
else: | |
z = x | |
return z | |
def forward( | |
self, | |
z: torch.Tensor, | |
mask: Optional[torch.Tensor] = None, | |
inplace_safe: bool = False, | |
_add_with_inplace: bool = False, | |
_inplace_chunk_size: Optional[int] = 256, | |
) -> torch.Tensor: | |
""" | |
Args: | |
x: | |
[*, N_res, N_res, C_z] input tensor | |
mask: | |
[*, N_res, N_res] input mask | |
Returns: | |
[*, N_res, N_res, C_z] output tensor | |
""" | |
if inplace_safe: | |
x = self._inference_forward( | |
z, | |
mask, | |
inplace_chunk_size=_inplace_chunk_size, | |
with_add=_add_with_inplace, | |
) | |
return x | |
if mask is None: | |
mask = z.new_ones(z.shape[:-1]) | |
mask = mask.unsqueeze(-1) | |
z = self.layer_norm_in(z) | |
a = mask | |
a = a * self.sigmoid(self.linear_a_g(z)) | |
a = a * self.linear_a_p(z) | |
b = mask | |
b = b * self.sigmoid(self.linear_b_g(z)) | |
b = b * self.linear_b_p(z) | |
if is_fp16_enabled(): | |
with torch.cuda.amp.autocast(enabled=False): | |
x = self._combine_projections(a.float(), b.float()) | |
else: | |
x = self._combine_projections(a, b) | |
del a, b | |
x = self.layer_norm_out(x) | |
x = self.linear_z(x) | |
g = self.sigmoid(self.linear_g(z)) | |
x = x * g | |
return x | |
class EsmFoldPreTrainedModel(EsmPreTrainedModel): | |
""" | |
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained | |
models. | |
""" | |
# Subclass `EsMPreTrainedModel` to deal with special init | |
def _init_weights(self, module): | |
"""Initialize the weights""" | |
if isinstance(module, EsmFoldLinear): | |
with torch.no_grad(): | |
if module.init_fn is not None: | |
module.init_fn(module.weight, module.bias) | |
elif module.init == "default": | |
trunc_normal_init_(module.weight, scale=1.0) | |
elif module.init == "relu": | |
trunc_normal_init_(module.weight, scale=2.0) | |
elif module.init == "glorot": | |
nn.init.xavier_uniform_(module.weight, gain=1) | |
elif module.init == "gating": | |
module.weight.fill_(0.0) | |
if module.bias: | |
module.bias.fill_(1.0) | |
elif module.init == "normal": | |
torch.nn.init.kaiming_normal_(module.weight, nonlinearity="linear") | |
elif module.init == "final": | |
module.weight.fill_(0.0) | |
elif isinstance(module, EsmFoldInvariantPointAttention): | |
ipa_point_weights_init_(module.head_weights) | |
elif isinstance(module, EsmFoldTriangularSelfAttentionBlock): | |
torch.nn.init.zeros_(module.tri_mul_in.linear_z.weight) | |
torch.nn.init.zeros_(module.tri_mul_in.linear_z.bias) | |
torch.nn.init.zeros_(module.tri_mul_out.linear_z.weight) | |
torch.nn.init.zeros_(module.tri_mul_out.linear_z.bias) | |
torch.nn.init.zeros_(module.tri_att_start.mha.linear_o.weight) | |
torch.nn.init.zeros_(module.tri_att_start.mha.linear_o.bias) | |
torch.nn.init.zeros_(module.tri_att_end.mha.linear_o.weight) | |
torch.nn.init.zeros_(module.tri_att_end.mha.linear_o.bias) | |
torch.nn.init.zeros_(module.sequence_to_pair.o_proj.weight) | |
torch.nn.init.zeros_(module.sequence_to_pair.o_proj.bias) | |
torch.nn.init.zeros_(module.pair_to_sequence.linear.weight) | |
torch.nn.init.zeros_(module.seq_attention.o_proj.weight) | |
torch.nn.init.zeros_(module.seq_attention.o_proj.bias) | |
torch.nn.init.zeros_(module.mlp_seq.mlp[-2].weight) | |
torch.nn.init.zeros_(module.mlp_seq.mlp[-2].bias) | |
torch.nn.init.zeros_(module.mlp_pair.mlp[-2].weight) | |
torch.nn.init.zeros_(module.mlp_pair.mlp[-2].bias) | |
else: | |
super()._init_weights(module) | |
class EsmFoldSelfAttention(nn.Module): | |
def __init__(self, embed_dim, num_heads, head_width, gated=False): | |
super().__init__() | |
assert embed_dim == num_heads * head_width | |
self.embed_dim = embed_dim | |
self.num_heads = num_heads | |
self.head_width = head_width | |
self.proj = nn.Linear(embed_dim, embed_dim * 3, bias=False) | |
self.o_proj = nn.Linear(embed_dim, embed_dim, bias=True) | |
self.gated = gated | |
if gated: | |
self.g_proj = nn.Linear(embed_dim, embed_dim) | |
torch.nn.init.zeros_(self.g_proj.weight) | |
torch.nn.init.ones_(self.g_proj.bias) | |
self.rescale_factor = self.head_width**-0.5 | |
torch.nn.init.zeros_(self.o_proj.bias) | |
def forward(self, x, mask=None, bias=None, indices=None): | |
""" | |
Basic self attention with optional mask and external pairwise bias. To handle sequences of different lengths, | |
use mask. | |
Inputs: | |
x: batch of input sequneces (.. x L x C) mask: batch of boolean masks where 1=valid, 0=padding position (.. | |
x L_k) bias: batch of scalar pairwise attention biases (.. x Lq x Lk x num_heads) | |
Outputs: | |
sequence projection (B x L x embed_dim), attention maps (B x L x L x num_heads) | |
""" | |
t = self.proj(x).view(*x.shape[:2], self.num_heads, -1) | |
t = t.permute(0, 2, 1, 3) | |
q, k, v = t.chunk(3, dim=-1) | |
q = self.rescale_factor * q | |
a = torch.einsum("...qc,...kc->...qk", q, k) | |
# Add external attention bias. | |
if bias is not None: | |
a = a + bias.permute(0, 3, 1, 2) | |
# Do not attend to padding tokens. | |
if mask is not None: | |
mask = mask[:, None, None] | |
a = a.masked_fill(mask == False, -np.inf) # noqa: E712 | |
a = nn.functional.softmax(a, dim=-1) | |
y = torch.einsum("...hqk,...hkc->...qhc", a, v) | |
y = y.reshape(*y.shape[:2], -1) | |
if self.gated: | |
y = self.g_proj(x).sigmoid() * y | |
y = self.o_proj(y) | |
return y, a.permute(0, 3, 1, 2) | |
class EsmFoldDropout(nn.Module): | |
""" | |
Implementation of dropout with the ability to share the dropout mask along a particular dimension. | |
""" | |
def __init__(self, r: float, batch_dim: Union[int, List[int]]): | |
super().__init__() | |
self.r = r | |
if isinstance(batch_dim, int): | |
batch_dim = [batch_dim] | |
self.batch_dim = batch_dim | |
self.dropout = nn.Dropout(self.r) | |
def forward(self, x: torch.Tensor) -> torch.Tensor: | |
shape = list(x.shape) | |
if self.batch_dim is not None: | |
for bd in self.batch_dim: | |
shape[bd] = 1 | |
return x * self.dropout(x.new_ones(shape)) | |
class EsmFoldSequenceToPair(nn.Module): | |
def __init__(self, sequence_state_dim, inner_dim, pairwise_state_dim): | |
super().__init__() | |
self.layernorm = nn.LayerNorm(sequence_state_dim) | |
self.proj = nn.Linear(sequence_state_dim, inner_dim * 2, bias=True) | |
self.o_proj = nn.Linear(2 * inner_dim, pairwise_state_dim, bias=True) | |
torch.nn.init.zeros_(self.proj.bias) | |
torch.nn.init.zeros_(self.o_proj.bias) | |
def forward(self, sequence_state): | |
""" | |
Inputs: | |
sequence_state: B x L x sequence_state_dim | |
Output: | |
pairwise_state: B x L x L x pairwise_state_dim | |
Intermediate state: | |
B x L x L x 2*inner_dim | |
""" | |
assert len(sequence_state.shape) == 3 | |
s = self.layernorm(sequence_state) | |
s = self.proj(s) | |
q, k = s.chunk(2, dim=-1) | |
prod = q[:, None, :, :] * k[:, :, None, :] | |
diff = q[:, None, :, :] - k[:, :, None, :] | |
x = torch.cat([prod, diff], dim=-1) | |
x = self.o_proj(x) | |
return x | |
class EsmFoldPairToSequence(nn.Module): | |
def __init__(self, pairwise_state_dim, num_heads): | |
super().__init__() | |
self.layernorm = nn.LayerNorm(pairwise_state_dim) | |
self.linear = nn.Linear(pairwise_state_dim, num_heads, bias=False) | |
def forward(self, pairwise_state): | |
""" | |
Inputs: | |
pairwise_state: B x L x L x pairwise_state_dim | |
Output: | |
pairwise_bias: B x L x L x num_heads | |
""" | |
assert len(pairwise_state.shape) == 4 | |
z = self.layernorm(pairwise_state) | |
pairwise_bias = self.linear(z) | |
return pairwise_bias | |
class EsmFoldResidueMLP(nn.Module): | |
def __init__(self, embed_dim, inner_dim, dropout=0): | |
super().__init__() | |
self.mlp = nn.Sequential( | |
nn.LayerNorm(embed_dim), | |
nn.Linear(embed_dim, inner_dim), | |
nn.ReLU(), | |
nn.Linear(inner_dim, embed_dim), | |
nn.Dropout(dropout), | |
) | |
def forward(self, x): | |
return x + self.mlp(x) | |
class EsmFoldTriangularSelfAttentionBlock(nn.Module): | |
def __init__(self, config): | |
super().__init__() | |
self.config = config | |
sequence_state_dim = config.sequence_state_dim | |
pairwise_state_dim = config.pairwise_state_dim | |
sequence_num_heads = sequence_state_dim // config.sequence_head_width | |
pairwise_num_heads = pairwise_state_dim // config.pairwise_head_width | |
self.layernorm_1 = nn.LayerNorm(sequence_state_dim) | |
self.sequence_to_pair = EsmFoldSequenceToPair(sequence_state_dim, pairwise_state_dim // 2, pairwise_state_dim) | |
self.pair_to_sequence = EsmFoldPairToSequence(pairwise_state_dim, sequence_num_heads) | |
self.seq_attention = EsmFoldSelfAttention( | |
sequence_state_dim, sequence_num_heads, config.sequence_head_width, gated=True | |
) | |
self.tri_mul_out = EsmFoldTriangleMultiplicativeUpdate(config, _outgoing=True) | |
self.tri_mul_in = EsmFoldTriangleMultiplicativeUpdate(config, _outgoing=False) | |
self.tri_att_start = EsmFoldTriangleAttention( | |
pairwise_state_dim, config.pairwise_head_width, pairwise_num_heads, inf=1e9, starting=True | |
) | |
self.tri_att_end = EsmFoldTriangleAttention( | |
pairwise_state_dim, config.pairwise_head_width, pairwise_num_heads, inf=1e9, starting=False | |
) | |
self.mlp_seq = EsmFoldResidueMLP(sequence_state_dim, 4 * sequence_state_dim, dropout=config.dropout) | |
self.mlp_pair = EsmFoldResidueMLP(pairwise_state_dim, 4 * pairwise_state_dim, dropout=config.dropout) | |
self.drop = nn.Dropout(config.dropout) | |
self.row_drop = EsmFoldDropout(config.dropout * 2, 2) | |
self.col_drop = EsmFoldDropout(config.dropout * 2, 1) | |
def forward(self, sequence_state, pairwise_state, mask=None, chunk_size=None, **__kwargs): | |
""" | |
Inputs: | |
sequence_state: B x L x sequence_state_dim pairwise_state: B x L x L x pairwise_state_dim mask: B x L boolean | |
tensor of valid positions | |
Output: | |
sequence_state: B x L x sequence_state_dim pairwise_state: B x L x L x pairwise_state_dim | |
""" | |
if len(sequence_state.shape) != 3: | |
raise ValueError(f"`sequence_state` should be a 3d-tensor, got {len(sequence_state.shape)} dims.") | |
if len(pairwise_state.shape) != 4: | |
raise ValueError(f"`pairwise_state` should be a 4d-tensor, got {len(pairwise_state.shape)} dims.") | |
if mask is not None and len(mask.shape) != 2: | |
raise ValueError(f"`mask` should be a 2d-tensor, got {len(mask.shape)} dims.") | |
batch_dim, seq_dim, sequence_state_dim = sequence_state.shape | |
pairwise_state_dim = pairwise_state.shape[3] | |
if sequence_state_dim != self.config.sequence_state_dim: | |
raise ValueError( | |
"`sequence_state` last dimension should be equal to `self.sequence_state_dim`. Got " | |
f"{sequence_state_dim} != {self.config.sequence_state_dim}." | |
) | |
if pairwise_state_dim != self.config.pairwise_state_dim: | |
raise ValueError( | |
"`pairwise_state` last dimension should be equal to `self.pairwise_state_dim`. Got " | |
f"{pairwise_state_dim} != {self.config.pairwise_state_dim}." | |
) | |
if batch_dim != pairwise_state.shape[0]: | |
raise ValueError( | |
f"`sequence_state` and `pairwise_state` have inconsistent batch size: {batch_dim} != " | |
f"{pairwise_state.shape[0]}." | |
) | |
if seq_dim != pairwise_state.shape[1] or seq_dim != pairwise_state.shape[2]: | |
raise ValueError( | |
f"`sequence_state` and `pairwise_state` have inconsistent sequence length: {seq_dim} != " | |
f"{pairwise_state.shape[1]} or {pairwise_state.shape[2]}." | |
) | |
# Update sequence state | |
bias = self.pair_to_sequence(pairwise_state) | |
# Self attention with bias + mlp. | |
y = self.layernorm_1(sequence_state) | |
y, _ = self.seq_attention(y, mask=mask, bias=bias) | |
sequence_state = sequence_state + self.drop(y) | |
sequence_state = self.mlp_seq(sequence_state) | |
# Update pairwise state | |
pairwise_state = pairwise_state + self.sequence_to_pair(sequence_state) | |
# Axial attention with triangular bias. | |
tri_mask = mask.unsqueeze(2) * mask.unsqueeze(1) if mask is not None else None | |
pairwise_state = pairwise_state + self.row_drop(self.tri_mul_out(pairwise_state, mask=tri_mask)) | |
pairwise_state = pairwise_state + self.col_drop(self.tri_mul_in(pairwise_state, mask=tri_mask)) | |
pairwise_state = pairwise_state + self.row_drop( | |
self.tri_att_start(pairwise_state, mask=tri_mask, chunk_size=chunk_size) | |
) | |
pairwise_state = pairwise_state + self.col_drop( | |
self.tri_att_end(pairwise_state, mask=tri_mask, chunk_size=chunk_size) | |
) | |
# MLP over pairs. | |
pairwise_state = self.mlp_pair(pairwise_state) | |
return sequence_state, pairwise_state | |
class EsmCategoricalMixture: | |
def __init__(self, param, bins=50, start=0, end=1): | |
# All tensors are of shape ..., bins. | |
self.logits = param | |
bins = torch.linspace(start, end, bins + 1, device=self.logits.device, dtype=self.logits.dtype) | |
self.v_bins = (bins[:-1] + bins[1:]) / 2 | |
def log_prob(self, true): | |
# Shapes are: | |
# self.probs: ... x bins | |
# true : ... | |
true_index = (true.unsqueeze(-1) - self.v_bins[[None] * true.ndim]).abs().argmin(-1) | |
nll = self.logits.log_softmax(-1) | |
return torch.take_along_dim(nll, true_index.unsqueeze(-1), dim=-1).squeeze(-1) | |
def mean(self): | |
return (self.logits.softmax(-1) @ self.v_bins.unsqueeze(1)).squeeze(-1) | |
def categorical_lddt(logits, bins=50): | |
# Logits are ..., 37, bins. | |
return EsmCategoricalMixture(logits, bins=bins).mean() | |
def get_axial_mask(mask): | |
""" | |
Helper to convert B x L mask of valid positions to axial mask used in row column attentions. | |
Input: | |
mask: B x L tensor of booleans | |
Output: | |
mask: B x L x L tensor of booleans | |
""" | |
if mask is None: | |
return None | |
if len(mask.shape) != 2: | |
raise ValueError(f"`mask` should be a 2d-tensor, got {len(mask.shape)} dims.") | |
batch_dim, seq_dim = mask.shape | |
m = mask.unsqueeze(1).expand(batch_dim, seq_dim, seq_dim) | |
m = m.reshape(batch_dim * seq_dim, seq_dim) | |
return m | |
class EsmFoldRelativePosition(nn.Module): | |
def __init__(self, config): | |
super().__init__() | |
self.bins = config.position_bins | |
# Note an additional offset is used so that the 0th position | |
# is reserved for masked pairs. | |
self.embedding = torch.nn.Embedding(2 * self.bins + 2, config.pairwise_state_dim) | |
def forward(self, residue_index, mask=None): | |
""" | |
Input: | |
residue_index: B x L tensor of indices (dytpe=torch.long) mask: B x L tensor of booleans | |
Output: | |
pairwise_state: B x L x L x pairwise_state_dim tensor of embeddings | |
""" | |
if residue_index.dtype != torch.long: | |
raise ValueError(f"`residue_index` has dtype {residue_index.dtype}, it should be `torch.long`.") | |
if mask is not None and residue_index.shape != mask.shape: | |
raise ValueError( | |
f"`residue_index` and `mask` have inconsistent shapes: {residue_index.shape} != {mask.shape}." | |
) | |
diff = residue_index[:, None, :] - residue_index[:, :, None] | |
diff = diff.clamp(-self.bins, self.bins) | |
diff = diff + self.bins + 1 # Add 1 to adjust for padding index. | |
if mask is not None: | |
mask = mask[:, None, :] * mask[:, :, None] | |
diff[mask == False] = 0 # noqa: E712 | |
output = self.embedding(diff) | |
return output | |
class EsmFoldAngleResnetBlock(nn.Module): | |
def __init__(self, config): | |
super().__init__() | |
self.linear_1 = EsmFoldLinear(config.resnet_dim, config.resnet_dim, init="relu") | |
self.linear_2 = EsmFoldLinear(config.resnet_dim, config.resnet_dim, init="final") | |
self.relu = nn.ReLU() | |
def forward(self, a: torch.Tensor) -> torch.Tensor: | |
s_initial = a | |
a = self.relu(a) | |
a = self.linear_1(a) | |
a = self.relu(a) | |
a = self.linear_2(a) | |
return a + s_initial | |
class EsmFoldAngleResnet(nn.Module): | |
""" | |
Implements Algorithm 20, lines 11-14 | |
""" | |
def __init__(self, config): | |
super().__init__() | |
self.config = config | |
self.linear_in = EsmFoldLinear(config.sequence_dim, config.resnet_dim) | |
self.linear_initial = EsmFoldLinear(config.sequence_dim, config.resnet_dim) | |
self.layers = nn.ModuleList() | |
for _ in range(config.num_resnet_blocks): | |
layer = EsmFoldAngleResnetBlock(config) | |
self.layers.append(layer) | |
self.linear_out = EsmFoldLinear(config.resnet_dim, config.num_angles * 2) | |
self.relu = nn.ReLU() | |
def forward(self, s: torch.Tensor, s_initial: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: | |
""" | |
Args: | |
s: | |
[*, C_hidden] single embedding | |
s_initial: | |
[*, C_hidden] single embedding as of the start of the StructureModule | |
Returns: | |
[*, no_angles, 2] predicted angles | |
""" | |
# NOTE: The ReLU's applied to the inputs are absent from the supplement | |
# pseudocode but present in the source. For maximal compatibility with | |
# the pretrained weights, I'm going with the source. | |
# [*, C_hidden] | |
s_initial = self.relu(s_initial) | |
s_initial = self.linear_initial(s_initial) | |
s = self.relu(s) | |
s = self.linear_in(s) | |
s = s + s_initial | |
for l in self.layers: | |
s = l(s) | |
s = self.relu(s) | |
# [*, no_angles * 2] | |
s = self.linear_out(s) | |
# [*, no_angles, 2] | |
s = s.view(s.shape[:-1] + (-1, 2)) | |
unnormalized_s = s | |
norm_denom = torch.sqrt( | |
torch.clamp( | |
torch.sum(s**2, dim=-1, keepdim=True), | |
min=self.config.epsilon, | |
) | |
) | |
s = s / norm_denom | |
return unnormalized_s, s | |
class EsmFoldInvariantPointAttention(nn.Module): | |
""" | |
Implements Algorithm 22. | |
""" | |
def __init__(self, config): | |
super().__init__() | |
self.config = config | |
c_s = config.sequence_dim | |
c_z = config.pairwise_dim | |
self.hidden_dim = config.ipa_dim | |
self.num_heads = config.num_heads_ipa | |
self.num_qk_points = config.num_qk_points | |
self.num_v_points = config.num_v_points | |
# These linear layers differ from their specifications in the | |
# supplement. There, they lack bias and use Glorot initialization. | |
# Here as in the official source, they have bias and use the default | |
# Lecun initialization. | |
hc = config.ipa_dim * config.num_heads_ipa | |
self.linear_q = EsmFoldLinear(c_s, hc) | |
self.linear_kv = EsmFoldLinear(c_s, 2 * hc) | |
hpq = config.num_heads_ipa * config.num_qk_points * 3 | |
self.linear_q_points = EsmFoldLinear(c_s, hpq) | |
hpkv = config.num_heads_ipa * (config.num_qk_points + config.num_v_points) * 3 | |
self.linear_kv_points = EsmFoldLinear(c_s, hpkv) | |
self.linear_b = EsmFoldLinear(c_z, config.num_heads_ipa) | |
self.head_weights = nn.Parameter(torch.zeros((config.num_heads_ipa))) | |
concat_out_dim = config.num_heads_ipa * (c_z + config.ipa_dim + config.num_v_points * 4) | |
self.linear_out = EsmFoldLinear(concat_out_dim, c_s, init="final") | |
self.softmax = nn.Softmax(dim=-1) | |
self.softplus = nn.Softplus() | |
def forward( | |
self, | |
s: torch.Tensor, | |
z: Optional[torch.Tensor], | |
r: Rigid, | |
mask: torch.Tensor, | |
_offload_inference: bool = False, | |
_z_reference_list: Optional[Sequence[torch.Tensor]] = None, | |
) -> torch.Tensor: | |
""" | |
Args: | |
s: | |
[*, N_res, C_s] single representation | |
z: | |
[*, N_res, N_res, C_z] pair representation | |
r: | |
[*, N_res] transformation object | |
mask: | |
[*, N_res] mask | |
Returns: | |
[*, N_res, C_s] single representation update | |
""" | |
z = [z] | |
####################################### | |
# Generate scalar and point activations | |
####################################### | |
# [*, N_res, H * C_hidden] | |
q = self.linear_q(s) | |
kv = self.linear_kv(s) | |
# [*, N_res, H, C_hidden] | |
q = q.view(q.shape[:-1] + (self.num_heads, -1)) | |
# [*, N_res, H, 2 * C_hidden] | |
kv = kv.view(kv.shape[:-1] + (self.num_heads, -1)) | |
# [*, N_res, H, C_hidden] | |
k, v = torch.split(kv, self.hidden_dim, dim=-1) | |
# [*, N_res, H * P_q * 3] | |
q_pts = self.linear_q_points(s) | |
# This is kind of clunky, but it's how the original does it | |
# [*, N_res, H * P_q, 3] | |
q_pts = torch.split(q_pts, q_pts.shape[-1] // 3, dim=-1) | |
q_pts = torch.stack(q_pts, dim=-1) | |
q_pts = r[..., None].apply(q_pts) | |
# [*, N_res, H, P_q, 3] | |
q_pts = q_pts.view(q_pts.shape[:-2] + (self.num_heads, self.num_qk_points, 3)) | |
# [*, N_res, H * (P_q + P_v) * 3] | |
kv_pts = self.linear_kv_points(s) | |
# [*, N_res, H * (P_q + P_v), 3] | |
kv_pts = torch.split(kv_pts, kv_pts.shape[-1] // 3, dim=-1) | |
kv_pts = torch.stack(kv_pts, dim=-1) | |
kv_pts = r[..., None].apply(kv_pts) | |
# [*, N_res, H, (P_q + P_v), 3] | |
kv_pts = kv_pts.view(kv_pts.shape[:-2] + (self.num_heads, -1, 3)) | |
# [*, N_res, H, P_q/P_v, 3] | |
k_pts, v_pts = torch.split(kv_pts, [self.num_qk_points, self.num_v_points], dim=-2) | |
########################## | |
# Compute attention scores | |
########################## | |
# [*, N_res, N_res, H] | |
b = self.linear_b(z[0]) | |
if _offload_inference: | |
assert sys.getrefcount(z[0]) == 2 | |
z[0] = z[0].cpu() | |
# [*, H, N_res, N_res] | |
if is_fp16_enabled(): | |
with torch.cuda.amp.autocast(enabled=False): | |
a = torch.matmul( | |
permute_final_dims(q.float(), (1, 0, 2)), # [*, H, N_res, C_hidden] | |
permute_final_dims(k.float(), (1, 2, 0)), # [*, H, C_hidden, N_res] | |
) | |
else: | |
a = torch.matmul( | |
permute_final_dims(q, (1, 0, 2)), # [*, H, N_res, C_hidden] | |
permute_final_dims(k, (1, 2, 0)), # [*, H, C_hidden, N_res] | |
) | |
a *= math.sqrt(1.0 / (3 * self.hidden_dim)) | |
a += math.sqrt(1.0 / 3) * permute_final_dims(b, (2, 0, 1)) | |
# [*, N_res, N_res, H, P_q, 3] | |
pt_att = q_pts.unsqueeze(-4) - k_pts.unsqueeze(-5) | |
pt_att = pt_att**2 | |
# [*, N_res, N_res, H, P_q] | |
pt_att = sum(torch.unbind(pt_att, dim=-1)) | |
head_weights = self.softplus(self.head_weights).view(*((1,) * len(pt_att.shape[:-2]) + (-1, 1))) | |
head_weights = head_weights * math.sqrt(1.0 / (3 * (self.num_qk_points * 9.0 / 2))) | |
pt_att = pt_att * head_weights | |
# [*, N_res, N_res, H] | |
pt_att = torch.sum(pt_att, dim=-1) * (-0.5) | |
# [*, N_res, N_res] | |
square_mask = mask.unsqueeze(-1) * mask.unsqueeze(-2) | |
square_mask = self.config.inf * (square_mask - 1) | |
# [*, H, N_res, N_res] | |
pt_att = permute_final_dims(pt_att, (2, 0, 1)) | |
a = a + pt_att | |
a = a + square_mask.unsqueeze(-3) | |
a = self.softmax(a) | |
################ | |
# Compute output | |
################ | |
# [*, N_res, H, C_hidden] | |
o = torch.matmul(a, v.transpose(-2, -3).to(dtype=a.dtype)).transpose(-2, -3) | |
# [*, N_res, H * C_hidden] | |
o = flatten_final_dims(o, 2) | |
# [*, H, 3, N_res, P_v] | |
o_pt = torch.sum( | |
(a[..., None, :, :, None] * permute_final_dims(v_pts, (1, 3, 0, 2))[..., None, :, :]), | |
dim=-2, | |
) | |
# [*, N_res, H, P_v, 3] | |
o_pt = permute_final_dims(o_pt, (2, 0, 3, 1)) | |
o_pt = r[..., None, None].invert_apply(o_pt) | |
# [*, N_res, H * P_v] | |
o_pt_norm = flatten_final_dims(torch.sqrt(torch.sum(o_pt**2, dim=-1) + self.config.epsilon), 2) | |
# [*, N_res, H * P_v, 3] | |
o_pt = o_pt.reshape(*o_pt.shape[:-3], -1, 3) | |
if _offload_inference: | |
z[0] = z[0].to(o_pt.device) | |
# [*, N_res, H, C_z] | |
o_pair = torch.matmul(a.transpose(-2, -3), z[0].to(dtype=a.dtype)) | |
# [*, N_res, H * C_z] | |
o_pair = flatten_final_dims(o_pair, 2) | |
# [*, N_res, C_s] | |
s = self.linear_out( | |
torch.cat((o, *torch.unbind(o_pt, dim=-1), o_pt_norm, o_pair), dim=-1).to(dtype=z[0].dtype) | |
) | |
return s | |
class EsmFoldBackboneUpdate(nn.Module): | |
""" | |
Implements part of Algorithm 23. | |
""" | |
def __init__(self, config): | |
super().__init__() | |
self.linear = EsmFoldLinear(config.sequence_dim, 6, init="final") | |
def forward(self, s: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: | |
""" | |
Args: | |
[*, N_res, C_s] single representation | |
Returns: | |
[*, N_res, 6] update vector | |
""" | |
# [*, 6] | |
update = self.linear(s) | |
return update | |
class EsmFoldStructureModuleTransitionLayer(nn.Module): | |
def __init__(self, config): | |
super().__init__() | |
self.linear_1 = EsmFoldLinear(config.sequence_dim, config.sequence_dim, init="relu") | |
self.linear_2 = EsmFoldLinear(config.sequence_dim, config.sequence_dim, init="relu") | |
self.linear_3 = EsmFoldLinear(config.sequence_dim, config.sequence_dim, init="final") | |
self.relu = nn.ReLU() | |
def forward(self, s): | |
s_initial = s | |
s = self.linear_1(s) | |
s = self.relu(s) | |
s = self.linear_2(s) | |
s = self.relu(s) | |
s = self.linear_3(s) | |
s = s + s_initial | |
return s | |
class EsmFoldStructureModuleTransition(nn.Module): | |
def __init__(self, config): | |
super().__init__() | |
self.config = config | |
self.layers = nn.ModuleList() | |
for _ in range(config.num_transition_layers): | |
l = EsmFoldStructureModuleTransitionLayer(config) | |
self.layers.append(l) | |
self.dropout = nn.Dropout(config.dropout_rate) | |
self.layer_norm = LayerNorm(config.sequence_dim) | |
def forward(self, s): | |
for l in self.layers: | |
s = l(s) | |
s = self.dropout(s) | |
s = self.layer_norm(s) | |
return s | |
class EsmFoldStructureModule(nn.Module): | |
def __init__(self, config): | |
super().__init__() | |
self.config = config | |
# Buffers to be lazily initialized later | |
# self.default_frames | |
# self.group_idx | |
# self.atom_mask | |
# self.lit_positions | |
self.layer_norm_s = LayerNorm(config.sequence_dim) | |
self.layer_norm_z = LayerNorm(config.pairwise_dim) | |
self.linear_in = EsmFoldLinear(config.sequence_dim, config.sequence_dim) | |
self.ipa = EsmFoldInvariantPointAttention(config) | |
self.ipa_dropout = nn.Dropout(config.dropout_rate) | |
self.layer_norm_ipa = LayerNorm(config.sequence_dim) | |
self.transition = EsmFoldStructureModuleTransition(config) | |
self.bb_update = EsmFoldBackboneUpdate(config) | |
self.angle_resnet = EsmFoldAngleResnet(config) | |
def forward( | |
self, | |
evoformer_output_dict, | |
aatype, | |
mask=None, | |
_offload_inference=False, | |
): | |
""" | |
Args: | |
evoformer_output_dict: | |
Dictionary containing: | |
"single": | |
[*, N_res, C_s] single representation | |
"pair": | |
[*, N_res, N_res, C_z] pair representation | |
aatype: | |
[*, N_res] amino acid indices | |
mask: | |
Optional [*, N_res] sequence mask | |
Returns: | |
A dictionary of outputs | |
""" | |
s = evoformer_output_dict["single"] | |
if mask is None: | |
# [*, N] | |
mask = s.new_ones(s.shape[:-1]) | |
# [*, N, C_s] | |
s = self.layer_norm_s(s) | |
# [*, N, N, C_z] | |
z = self.layer_norm_z(evoformer_output_dict["pair"]) | |
z_reference_list = None | |
if _offload_inference: | |
assert sys.getrefcount(evoformer_output_dict["pair"]) == 2 | |
evoformer_output_dict["pair"] = evoformer_output_dict["pair"].cpu() | |
z_reference_list = [z] | |
z = None | |
# [*, N, C_s] | |
s_initial = s | |
s = self.linear_in(s) | |
# [*, N] | |
rigids = Rigid.identity( | |
s.shape[:-1], | |
s.dtype, | |
s.device, | |
self.training, | |
fmt="quat", | |
) | |
outputs = [] | |
for i in range(self.config.num_blocks): | |
# [*, N, C_s] | |
s = s + self.ipa( | |
s, | |
z, | |
rigids, | |
mask, | |
_offload_inference=_offload_inference, | |
_z_reference_list=z_reference_list, | |
) | |
s = self.ipa_dropout(s) | |
s = self.layer_norm_ipa(s) | |
s = self.transition(s) | |
# [*, N] | |
rigids = rigids.compose_q_update_vec(self.bb_update(s)) | |
# To hew as closely as possible to AlphaFold, we convert our | |
# quaternion-based transformations to rotation-matrix ones | |
# here | |
backb_to_global = Rigid( | |
Rotation(rot_mats=rigids.get_rots().get_rot_mats(), quats=None), | |
rigids.get_trans(), | |
) | |
backb_to_global = backb_to_global.scale_translation(self.config.trans_scale_factor) | |
# [*, N, 7, 2] | |
unnormalized_angles, angles = self.angle_resnet(s, s_initial) | |
all_frames_to_global = self.torsion_angles_to_frames(backb_to_global, angles, aatype) | |
pred_xyz = self.frames_and_literature_positions_to_atom14_pos(all_frames_to_global, aatype) | |
scaled_rigids = rigids.scale_translation(self.config.trans_scale_factor) | |
preds = { | |
"frames": scaled_rigids.to_tensor_7(), | |
"sidechain_frames": all_frames_to_global.to_tensor_4x4(), | |
"unnormalized_angles": unnormalized_angles, | |
"angles": angles, | |
"positions": pred_xyz, | |
"states": s, | |
} | |
outputs.append(preds) | |
rigids = rigids.stop_rot_gradient() | |
del z, z_reference_list | |
if _offload_inference: | |
evoformer_output_dict["pair"] = evoformer_output_dict["pair"].to(s.device) | |
outputs = dict_multimap(torch.stack, outputs) | |
outputs["single"] = s | |
return outputs | |
def _init_residue_constants(self, float_dtype, device): | |
if not hasattr(self, "default_frames"): | |
self.register_buffer( | |
"default_frames", | |
torch.tensor( | |
residue_constants.restype_rigid_group_default_frame, | |
dtype=float_dtype, | |
device=device, | |
requires_grad=False, | |
), | |
persistent=False, | |
) | |
if not hasattr(self, "group_idx"): | |
self.register_buffer( | |
"group_idx", | |
torch.tensor( | |
residue_constants.restype_atom14_to_rigid_group, | |
device=device, | |
requires_grad=False, | |
), | |
persistent=False, | |
) | |
if not hasattr(self, "atom_mask"): | |
self.register_buffer( | |
"atom_mask", | |
torch.tensor( | |
residue_constants.restype_atom14_mask, | |
dtype=float_dtype, | |
device=device, | |
requires_grad=False, | |
), | |
persistent=False, | |
) | |
if not hasattr(self, "lit_positions"): | |
self.register_buffer( | |
"lit_positions", | |
torch.tensor( | |
residue_constants.restype_atom14_rigid_group_positions, | |
dtype=float_dtype, | |
device=device, | |
requires_grad=False, | |
), | |
persistent=False, | |
) | |
def torsion_angles_to_frames(self, r, alpha, f): | |
# Lazily initialize the residue constants on the correct device | |
self._init_residue_constants(alpha.dtype, alpha.device) | |
# Separated purely to make testing less annoying | |
return torsion_angles_to_frames(r, alpha, f, self.default_frames) | |
def frames_and_literature_positions_to_atom14_pos(self, r, f): # [*, N, 8] # [*, N] | |
# Lazily initialize the residue constants on the correct device | |
self._init_residue_constants(r.get_rots().dtype, r.get_rots().device) | |
return frames_and_literature_positions_to_atom14_pos( | |
r, | |
f, | |
self.default_frames, | |
self.group_idx, | |
self.atom_mask, | |
self.lit_positions, | |
) | |
class EsmFoldingTrunk(nn.Module): | |
def __init__(self, config): | |
super().__init__() | |
self.config = config | |
c_s = config.sequence_state_dim | |
c_z = config.pairwise_state_dim | |
self.pairwise_positional_embedding = EsmFoldRelativePosition(config) | |
self.blocks = nn.ModuleList([EsmFoldTriangularSelfAttentionBlock(config) for _ in range(config.num_blocks)]) | |
self.recycle_bins = 15 | |
self.recycle_s_norm = nn.LayerNorm(c_s) | |
self.recycle_z_norm = nn.LayerNorm(c_z) | |
self.recycle_disto = nn.Embedding(self.recycle_bins, c_z) | |
self.recycle_disto.weight[0].detach().zero_() | |
self.structure_module = EsmFoldStructureModule(config.structure_module) | |
self.trunk2sm_s = nn.Linear(c_s, config.structure_module.sequence_dim) | |
self.trunk2sm_z = nn.Linear(c_z, config.structure_module.pairwise_dim) | |
self.chunk_size = config.chunk_size | |
def set_chunk_size(self, chunk_size): | |
# This parameter means the axial attention will be computed | |
# in a chunked manner. This should make the memory used more or less O(L) instead of O(L^2). | |
# It's equivalent to running a for loop over chunks of the dimension we're iterative over, | |
# where the chunk_size is the size of the chunks, so 128 would mean to parse 128-length chunks. | |
self.chunk_size = chunk_size | |
def forward(self, seq_feats, pair_feats, true_aa, residx, mask, no_recycles): | |
""" | |
Inputs: | |
seq_feats: B x L x C tensor of sequence features pair_feats: B x L x L x C tensor of pair features residx: B | |
x L long tensor giving the position in the sequence mask: B x L boolean tensor indicating valid residues | |
Output: | |
predicted_structure: B x L x (num_atoms_per_residue * 3) tensor wrapped in a Coordinates object | |
""" | |
device = seq_feats.device | |
s_s_0 = seq_feats | |
s_z_0 = pair_feats | |
if no_recycles is None: | |
no_recycles = self.config.max_recycles | |
else: | |
if no_recycles < 0: | |
raise ValueError("Number of recycles must not be negative.") | |
no_recycles += 1 # First 'recycle' is just the standard forward pass through the model. | |
def trunk_iter(s, z, residx, mask): | |
z = z + self.pairwise_positional_embedding(residx, mask=mask) | |
for block in self.blocks: | |
s, z = block(s, z, mask=mask, residue_index=residx, chunk_size=self.chunk_size) | |
return s, z | |
s_s = s_s_0 | |
s_z = s_z_0 | |
recycle_s = torch.zeros_like(s_s) | |
recycle_z = torch.zeros_like(s_z) | |
recycle_bins = torch.zeros(*s_z.shape[:-1], device=device, dtype=torch.int64) | |
for recycle_idx in range(no_recycles): | |
with ContextManagers([] if recycle_idx == no_recycles - 1 else [torch.no_grad()]): | |
# === Recycling === | |
recycle_s = self.recycle_s_norm(recycle_s.detach()).to(device) | |
recycle_z = self.recycle_z_norm(recycle_z.detach()).to(device) | |
recycle_z += self.recycle_disto(recycle_bins.detach()).to(device) | |
s_s, s_z = trunk_iter(s_s_0 + recycle_s, s_z_0 + recycle_z, residx, mask) | |
# === Structure module === | |
structure = self.structure_module( | |
{"single": self.trunk2sm_s(s_s), "pair": self.trunk2sm_z(s_z)}, | |
true_aa, | |
mask.float(), | |
) | |
recycle_s = s_s | |
recycle_z = s_z | |
# Distogram needs the N, CA, C coordinates, and bin constants same as alphafold. | |
recycle_bins = EsmFoldingTrunk.distogram( | |
structure["positions"][-1][:, :, :3], | |
3.375, | |
21.375, | |
self.recycle_bins, | |
) | |
structure["s_s"] = s_s | |
structure["s_z"] = s_z | |
return structure | |
def distogram(coords, min_bin, max_bin, num_bins): | |
# Coords are [... L x 3 x 3], where it's [N, CA, C] x 3 coordinates. | |
boundaries = torch.linspace( | |
min_bin, | |
max_bin, | |
num_bins - 1, | |
device=coords.device, | |
) | |
boundaries = boundaries**2 | |
N, CA, C = [x.squeeze(-2) for x in coords.chunk(3, dim=-2)] | |
# Infer CB coordinates. | |
b = CA - N | |
c = C - CA | |
a = b.cross(c, dim=-1) | |
CB = -0.58273431 * a + 0.56802827 * b - 0.54067466 * c + CA | |
dists = (CB[..., None, :, :] - CB[..., :, None, :]).pow(2).sum(dim=-1, keepdims=True) | |
bins = torch.sum(dists > boundaries, dim=-1) # [..., L, L] | |
return bins | |
# TODO Add information to the docstring about any methods that convert to PDB format, or otherwise prepare | |
# the outputs for downstream use. | |
class EsmForProteinFolding(EsmPreTrainedModel): | |
_no_split_modules = ["EsmFoldStructureModule", "EsmFoldTriangularSelfAttentionBlock"] | |
def __init__(self, config): | |
super().__init__(config) | |
self.config = config | |
self.distogram_bins = 64 | |
self.esm = EsmModel(config, add_pooling_layer=False) | |
self.esm.requires_grad_(False) | |
if self.config.esmfold_config.fp16_esm: | |
self.esm.half() | |
self.esm_feats = self.config.hidden_size | |
self.esm_attns = self.config.num_hidden_layers * self.config.num_attention_heads | |
self.esm_layers = self.config.num_hidden_layers | |
self.register_buffer("af2_to_esm", self._af2_to_esm_from_vocab_list(config.vocab_list)) | |
self.esm_s_combine = nn.Parameter(torch.zeros(self.esm_layers + 1)) | |
trunk_config = self.config.esmfold_config.trunk | |
c_s = trunk_config.sequence_state_dim | |
c_z = trunk_config.pairwise_state_dim | |
self.esm_s_mlp = nn.Sequential( | |
LayerNorm(self.esm_feats), | |
nn.Linear(self.esm_feats, c_s), | |
nn.ReLU(), | |
nn.Linear(c_s, c_s), | |
) | |
# 0 is padding, N is unknown residues, N + 1 is mask. | |
self.n_tokens_embed = residue_constants.restype_num + 3 | |
self.pad_idx = 0 | |
self.unk_idx = self.n_tokens_embed - 2 | |
self.mask_idx = self.n_tokens_embed - 1 | |
self.esm_dict_cls_idx = self.config.vocab_list.index("<cls>") | |
self.esm_dict_mask_idx = self.config.vocab_list.index("<mask>") | |
self.esm_dict_eos_idx = self.config.vocab_list.index("<eos>") | |
self.esm_dict_padding_idx = self.config.vocab_list.index("<pad>") | |
if self.config.esmfold_config.embed_aa: | |
self.embedding = nn.Embedding(self.n_tokens_embed, c_s, padding_idx=0) | |
self.trunk = EsmFoldingTrunk(trunk_config) | |
self.distogram_head = nn.Linear(c_z, self.distogram_bins) | |
self.ptm_head = nn.Linear(c_z, self.distogram_bins) | |
self.lm_head = nn.Linear(c_s, self.n_tokens_embed) | |
self.lddt_bins = 50 | |
structure_module_config = trunk_config.structure_module | |
self.lddt_head = nn.Sequential( | |
nn.LayerNorm(structure_module_config.sequence_dim), | |
nn.Linear(structure_module_config.sequence_dim, self.config.esmfold_config.lddt_head_hid_dim), | |
nn.Linear(self.config.esmfold_config.lddt_head_hid_dim, self.config.esmfold_config.lddt_head_hid_dim), | |
nn.Linear(self.config.esmfold_config.lddt_head_hid_dim, 37 * self.lddt_bins), | |
) | |
def _af2_to_esm_from_vocab_list(vocab_list: List[str]) -> torch.Tensor: | |
# Remember that t is shifted from residue_constants by 1 (0 is padding). | |
esm_reorder = [vocab_list.index("<pad>")] + [vocab_list.index(v) for v in residue_constants.restypes_with_x] | |
return torch.tensor(esm_reorder) | |
def forward( | |
self, | |
input_ids: torch.Tensor, | |
attention_mask: Optional[torch.Tensor] = None, | |
position_ids: Optional[torch.Tensor] = None, | |
masking_pattern: Optional[torch.Tensor] = None, | |
num_recycles: Optional[int] = None, | |
) -> EsmForProteinFoldingOutput: | |
r""" | |
Returns: | |
Example: | |
```python | |
>>> from transformers import AutoTokenizer, EsmForProteinFolding | |
>>> model = EsmForProteinFolding.from_pretrained("facebook/esmfold_v1") | |
>>> tokenizer = AutoTokenizer.from_pretrained("facebook/esmfold_v1") | |
>>> inputs = tokenizer(["MLKNVQVQLV"], return_tensors="pt", add_special_tokens=False) # A tiny random peptide | |
>>> outputs = model(**inputs) | |
>>> folded_positions = outputs.positions | |
``` | |
""" | |
cfg = self.config.esmfold_config | |
aa = input_ids # B x L | |
B = aa.shape[0] | |
L = aa.shape[1] | |
device = input_ids.device | |
if attention_mask is None: | |
attention_mask = torch.ones_like(aa, device=device) | |
if position_ids is None: | |
position_ids = torch.arange(L, device=device).expand_as(input_ids) | |
# === ESM === | |
esmaa = self.af2_idx_to_esm_idx(aa, attention_mask) | |
if masking_pattern is not None: | |
masked_aa, esmaa, mlm_targets = self.bert_mask(aa, esmaa, attention_mask, masking_pattern) | |
else: | |
masked_aa = aa | |
mlm_targets = None | |
# We get sequence and pair representations from whatever version of ESM / | |
# configuration we are using. The sequence representation esm_s is always | |
# present. The pair embedding esm_z may be present depending on the | |
# configuration of the model. If esm_z is not used by the model then it | |
# is returned as None here. | |
esm_s = self.compute_language_model_representations(esmaa) | |
# Convert esm_s and esm_z, if present, to the precision used by the trunk and | |
# the structure module. These tensors may be a lower precision if, for example, | |
# we're running the language model in fp16 precision. | |
esm_s = esm_s.to(self.esm_s_combine.dtype) | |
if cfg.esm_ablate_sequence: | |
esm_s = esm_s * 0 | |
esm_s = esm_s.detach() | |
# === preprocessing === | |
esm_s = (self.esm_s_combine.softmax(0).unsqueeze(0) @ esm_s).squeeze(2) | |
s_s_0 = self.esm_s_mlp(esm_s) | |
s_z_0 = s_s_0.new_zeros(B, L, L, cfg.trunk.pairwise_state_dim) | |
if self.config.esmfold_config.embed_aa: | |
s_s_0 += self.embedding(masked_aa) | |
structure: dict = self.trunk(s_s_0, s_z_0, aa, position_ids, attention_mask, no_recycles=num_recycles) | |
# Documenting what we expect: | |
structure = { | |
k: v | |
for k, v in structure.items() | |
if k | |
in [ | |
"s_z", | |
"s_s", | |
"frames", | |
"sidechain_frames", | |
"unnormalized_angles", | |
"angles", | |
"positions", | |
"states", | |
] | |
} | |
# Add BERT mask for the loss to use, if available. | |
if mlm_targets: | |
structure["mlm_targets"] = mlm_targets | |
disto_logits = self.distogram_head(structure["s_z"]) | |
disto_logits = (disto_logits + disto_logits.transpose(1, 2)) / 2 | |
structure["distogram_logits"] = disto_logits | |
lm_logits = self.lm_head(structure["s_s"]) | |
structure["lm_logits"] = lm_logits | |
structure["aatype"] = aa | |
make_atom14_masks(structure) | |
# Of course, this doesn't respect the true mask because it doesn't know about it... | |
# We're not going to properly mask change of index tensors: | |
# "residx_atom14_to_atom37", | |
# "residx_atom37_to_atom14", | |
for k in [ | |
"atom14_atom_exists", | |
"atom37_atom_exists", | |
]: | |
structure[k] *= attention_mask.unsqueeze(-1) | |
structure["residue_index"] = position_ids | |
lddt_head = self.lddt_head(structure["states"]).reshape(structure["states"].shape[0], B, L, -1, self.lddt_bins) | |
structure["lddt_head"] = lddt_head | |
plddt = categorical_lddt(lddt_head[-1], bins=self.lddt_bins) | |
structure["plddt"] = plddt | |
ptm_logits = self.ptm_head(structure["s_z"]) | |
structure["ptm_logits"] = ptm_logits | |
structure["ptm"] = compute_tm(ptm_logits, max_bin=31, no_bins=self.distogram_bins) | |
structure.update(compute_predicted_aligned_error(ptm_logits, max_bin=31, no_bins=self.distogram_bins)) | |
return EsmForProteinFoldingOutput(**structure) | |
def af2_idx_to_esm_idx(self, aa, mask): | |
# avoid indexing on different devices | |
if self.af2_to_esm.device != aa.device: | |
self.af2_to_esm = self.af2_to_esm.to(aa.device) | |
aa = (aa + 1).masked_fill(mask != 1, 0) | |
return self.af2_to_esm[aa] | |
def compute_language_model_representations(self, esmaa: torch.Tensor) -> torch.Tensor: | |
device = next(self.parameters()).device | |
B, L = esmaa.shape # B = batch size, L = sequence length. | |
if self.config.esmfold_config.bypass_lm: | |
esm_s = torch.zeros(B, L, self.esm_s_combine.size[0], -1, self.esm_feats, device=device) | |
return esm_s | |
bosi, eosi = self.esm_dict_cls_idx, self.esm_dict_eos_idx | |
bos = esmaa.new_full((B, 1), bosi) | |
eos = esmaa.new_full((B, 1), self.esm_dict_padding_idx) | |
esmaa = torch.cat([bos, esmaa, eos], dim=1) | |
# Use the first padding index as eos during inference. | |
esmaa[range(B), (esmaa != 1).sum(1)] = eosi | |
# _, esm_z, esm_s = self.esm(esmaa, return_pairs=self.config.esmfold_config.use_esm_attn_map) | |
# Because we do not support use_esm_attn_map in the HF port as it is not used in any public models, | |
# esm_z is always None | |
esm_hidden_states = self.esm(esmaa, attention_mask=esmaa != 1, output_hidden_states=True)["hidden_states"] | |
esm_s = torch.stack(esm_hidden_states, dim=2) | |
esm_s = esm_s[:, 1:-1] # B, L, nLayers, C | |
return esm_s | |
def bert_mask(self, aa, esmaa, mask, pattern): | |
new_aa = aa.clone() | |
target = aa.clone() | |
new_esmaa = esmaa.clone() | |
new_aa[pattern == 1] = self.mask_idx | |
target[pattern != 1] = 0 | |
new_esmaa[pattern == 1] = self.esm_dict_mask_idx | |
return new_aa, new_esmaa, target | |
def infer( | |
self, | |
seqs: Union[str, List[str]], | |
position_ids=None, | |
): | |
if isinstance(seqs, str): | |
lst = [seqs] | |
else: | |
lst = seqs | |
# Returns the raw outputs of the model given an input sequence. | |
device = next(self.parameters()).device | |
aatype = collate_dense_tensors( | |
[ | |
torch.from_numpy( | |
residue_constants.sequence_to_onehot( | |
sequence=seq, | |
mapping=residue_constants.restype_order_with_x, | |
map_unknown_to_x=True, | |
) | |
) | |
.to(device) | |
.argmax(dim=1) | |
for seq in lst | |
] | |
) # B=1 x L | |
mask = collate_dense_tensors([aatype.new_ones(len(seq)) for seq in lst]) | |
position_ids = ( | |
torch.arange(aatype.shape[1], device=device).expand(len(lst), -1) | |
if position_ids is None | |
else position_ids.to(device) | |
) | |
if position_ids.ndim == 1: | |
position_ids = position_ids.unsqueeze(0) | |
return self.forward( | |
aatype, | |
mask, | |
position_ids=position_ids, | |
) | |
def output_to_pdb(output: Dict) -> List[str]: | |
"""Returns the pbd (file) string from the model given the model output.""" | |
output = {k: v.to("cpu").numpy() for k, v in output.items()} | |
pdbs = [] | |
final_atom_positions = atom14_to_atom37(output["positions"][-1], output) | |
final_atom_mask = output["atom37_atom_exists"] | |
for i in range(output["aatype"].shape[0]): | |
aa = output["aatype"][i] | |
pred_pos = final_atom_positions[i] | |
mask = final_atom_mask[i] | |
resid = output["residue_index"][i] + 1 | |
pred = OFProtein( | |
aatype=aa, | |
atom_positions=pred_pos, | |
atom_mask=mask, | |
residue_index=resid, | |
b_factors=output["plddt"][i], | |
) | |
pdbs.append(to_pdb(pred)) | |
return pdbs | |
def infer_pdb(self, seqs, *args, **kwargs) -> str: | |
"""Returns the pdb (file) string from the model given an input sequence.""" | |
assert isinstance(seqs, str) | |
output = self.infer(seqs, *args, **kwargs) | |
return self.output_to_pdb(output)[0] | |
def infer_pdbs(self, seqs: List[str], *args, **kwargs) -> List[str]: | |
"""Returns the pdb (file) string from the model given an input sequence.""" | |
output = self.infer(seqs, *args, **kwargs) | |
return self.output_to_pdb(output) | |