diff --git a/app.py b/app.py index 7ff4e282f00e340fc00bb404d3a4c305d26e491a..0e081bde24391eae5013fae0aeffaee8a4f67169 100644 --- a/app.py +++ b/app.py @@ -2,10 +2,13 @@ import streamlit as st from Bio import SeqIO import torch import torch.nn as nn -from esm.model.esm2 import ESM2 as ESM2_SISS -from esm import Alphabet, FastaBatchedDataset import pandas as pd -from tqdm import tqdm + +import esm +from esm.data import * +from esm.model.esm2_secondarystructure import ESM2 as ESM2_SISS + +from esm import Alphabet, FastaBatchedDataset from io import StringIO @@ -18,14 +21,14 @@ modelfile = 'model.pkl' layers = 6 heads = 16 embed_dim = 128 -batch_toks = 4096 +batch_toks = 1024 inp_len = 50 device = "cpu" -alphabet = Alphabet(prepend_toks=("", "", ""), standard_toks = 'AGCT', append_toks=("", "", "")) -alphabet.tok_to_idx = {'': 0, '': 1, '': 2, 'A': 3, 'G': 4, 'C': 5, 'T': 6, '': 7, '': 8, '': 9} +alphabet = Alphabet(standard_toks = 'AGCT') +assert alphabet.tok_to_idx == {'': 0, '': 1, '': 2, 'A': 3, 'G': 4, 'C': 5, 'T': 6, '': 7, '': 8, '': 9} class CNN_linear(nn.Module): def __init__(self, @@ -68,8 +71,8 @@ class CNN_linear(nn.Module): def forward(self, tokens, need_head_weights=True, return_contacts=False, return_representation=True): - # x = self.esm2(tokens, [layers], need_head_weights, return_contacts, return_representation) - x = self.esm2(tokens, [layers]) + x = self.esm2(tokens, [layers], need_head_weights, return_contacts, return_representation) + # x = self.esm2(tokens, [layers]) x = x["representations"][layers][:, 0] x_o = x.unsqueeze(2) @@ -81,15 +84,14 @@ class CNN_linear(nn.Module): o = self.output(o_dropout) return o -def eval_step(dataloader, model, threshold = 0.5): +def eval_step(dataloader, model, threshold=0.5): model.eval() logits_list= [] # y_pred_list, y_prob_list = [], [] ids_list, strs_list = [], [] my_bar = st.progress(0, text="Running UTR_LM") with torch.no_grad(): - # for (ids, strs, _, toks, _, _) in tqdm(dataloader): - for i, (ids, strs, toks) in enumerate(dataloader): + for i, (ids, strs, _, toks, _, _) in enumerate(dataloader): ids_list.extend(ids) strs_list.extend(strs) # toks = toks.to(device) @@ -106,6 +108,7 @@ def eval_step(dataloader, model, threshold = 0.5): # y_pred_list.extend(y_pred.tolist()) st.success('Done', icon="✅") + # data_pred = pd.DataFrame({'ID':ids_list, 'Sequence':strs_list, "Translation Efficiency":logits_list, "prob":y_prob_list, "pred":y_pred_list}) data_pred = pd.DataFrame({'ID':ids_list, 'Sequence':strs_list, "Translation Efficiency":logits_list}) return data_pred @@ -129,8 +132,9 @@ def read_raw(raw_input): return ids, sequences def generate_dataset_dataloader(ids, seqs): - # dataset = FastaBatchedDataset(ids, seqs, mask_prob = 0.0) - dataset = FastaBatchedDataset(ids, seqs) + dataset = FastaBatchedDataset(ids, seqs, mask_prob = 0.0) + + # dataset = FastaBatchedDataset(ids, seqs) batches = dataset.get_batch_indices(toks_per_batch=batch_toks, extra_toks_per_seq=2) dataloader = torch.utils.data.DataLoader(dataset, collate_fn=alphabet.get_batch_converter(), @@ -166,13 +170,14 @@ uploaded = st.file_uploader("Sequence file in FASTA format") if st.button("Predict"): if uploaded: result = predict_raw(uploaded.getvalue().decode()) - result_file = result.to_csv(index=False) - st.download_button("Download", result_file, file_name="UTR_LM_prediction.csv") - st.dataframe(result) + # result_file = result.to_csv(index=False) + # st.download_button("Download", result_file, file_name="UTR_LM_prediction.csv") + # st.dataframe(result) else: result = predict_raw(seq) - result_file = result.to_csv(index=False) - st.download_button("Download", result_file, file_name="UTR_LM_prediction.csv") - st.dataframe(result) + + result_file = result.to_csv(index=False) + st.download_button("Download", result_file, file_name="UTR_LM_prediction.csv") + st.dataframe(result) diff --git a/esm/.DS_Store b/esm/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..8543000e002a12004c962c66a73fd62f526d435a Binary files /dev/null and b/esm/.DS_Store differ diff --git a/esm/._.DS_Store b/esm/._.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..8e82ed96c0d694f6a640da6c163a3ef7e4194513 Binary files /dev/null and b/esm/._.DS_Store differ diff --git a/esm/._multihead_attention.py b/esm/._multihead_attention.py new file mode 100644 index 0000000000000000000000000000000000000000..c9959ba1b4b55ed6fdaca9e450120196a9d3a4dd Binary files /dev/null and b/esm/._multihead_attention.py differ diff --git a/esm/._pretrained.py b/esm/._pretrained.py new file mode 100644 index 0000000000000000000000000000000000000000..76ae07cba88a1b44cee6f93c9ed74bbb5163c084 Binary files /dev/null and b/esm/._pretrained.py differ diff --git a/esm/__init__.py b/esm/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6813f920797c4777cb2f368e6ae88abc86831e09 --- /dev/null +++ b/esm/__init__.py @@ -0,0 +1,17 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +from .version import version as __version__ # noqa + +from .data import Alphabet, BatchConverter, FastaBatchedDataset # noqa +from .model.esm1 import ProteinBertModel # noqa +from .model.esm2 import ESM2 # noqa +from .model.msa_transformer import MSATransformer #noqa +from . import pretrained # noqa + +# from .version import version as __version__ # noqa + +# from .data import Alphabet, BatchConverter, FastaBatchedDataset # noqa +# from .model import ProteinBertModel, MSATransformer, ESM2 # noqa +# from . import pretrained # noqa diff --git a/esm/__pycache__/__init__.cpython-36.pyc b/esm/__pycache__/__init__.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2d4dd520e6b4b47bad959caad242fe6ce6c07ab0 Binary files /dev/null and b/esm/__pycache__/__init__.cpython-36.pyc differ diff --git a/esm/__pycache__/__init__.cpython-39.pyc b/esm/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..aee86d6847b14a0d93fa7a17790dae598cf07e2d Binary files /dev/null and b/esm/__pycache__/__init__.cpython-39.pyc differ diff --git a/esm/__pycache__/axial_attention.cpython-36.pyc b/esm/__pycache__/axial_attention.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9a5b0968a3b8ede99e0c26f912b347ec66a431fa Binary files /dev/null and b/esm/__pycache__/axial_attention.cpython-36.pyc differ diff --git a/esm/__pycache__/axial_attention.cpython-39.pyc b/esm/__pycache__/axial_attention.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3511771014abdcc9e0156e14e8237c10b8e0a641 Binary files /dev/null and b/esm/__pycache__/axial_attention.cpython-39.pyc differ diff --git a/esm/__pycache__/constants.cpython-36.pyc b/esm/__pycache__/constants.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c954df1a46096d4e92eb77c383099ff3e94dd708 Binary files /dev/null and b/esm/__pycache__/constants.cpython-36.pyc differ diff --git a/esm/__pycache__/constants.cpython-39.pyc b/esm/__pycache__/constants.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a5c96ab2e25547e3c48572614eda43c865c0f349 Binary files /dev/null and b/esm/__pycache__/constants.cpython-39.pyc differ diff --git a/esm/__pycache__/data.cpython-36.pyc b/esm/__pycache__/data.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..380bab366c14492bb8e49ac49002a278b516aceb Binary files /dev/null and b/esm/__pycache__/data.cpython-36.pyc differ diff --git a/esm/__pycache__/data.cpython-39.pyc b/esm/__pycache__/data.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7de7cf743b1bf859954635ac6d33a96de4205fe4 Binary files /dev/null and b/esm/__pycache__/data.cpython-39.pyc differ diff --git a/esm/__pycache__/data_protein.cpython-36.pyc b/esm/__pycache__/data_protein.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0e54142abe4d95cd49aad012740f7f50dec7a31d Binary files /dev/null and b/esm/__pycache__/data_protein.cpython-36.pyc differ diff --git a/esm/__pycache__/model.cpython-36.pyc b/esm/__pycache__/model.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9897f8ebe47e3c2e527c865f686f5682b08d2cb9 Binary files /dev/null and b/esm/__pycache__/model.cpython-36.pyc differ diff --git a/esm/__pycache__/model.cpython-39.pyc b/esm/__pycache__/model.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f69d3521c73c63440a0fd47ef8f5d05d2c8ec18d Binary files /dev/null and b/esm/__pycache__/model.cpython-39.pyc differ diff --git a/esm/__pycache__/modules.cpython-36.pyc b/esm/__pycache__/modules.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8acb182cec1507c2bbc529144b9ef2e6db4f5e5b Binary files /dev/null and b/esm/__pycache__/modules.cpython-36.pyc differ diff --git a/esm/__pycache__/modules.cpython-39.pyc b/esm/__pycache__/modules.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f85a49c62df9d330569385ca9328787c061a6ee2 Binary files /dev/null and b/esm/__pycache__/modules.cpython-39.pyc differ diff --git a/esm/__pycache__/multihead_attention.cpython-36.pyc b/esm/__pycache__/multihead_attention.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e015f92c52ba669840749a235ed59f55e5c375bc Binary files /dev/null and b/esm/__pycache__/multihead_attention.cpython-36.pyc differ diff --git a/esm/__pycache__/multihead_attention.cpython-39.pyc b/esm/__pycache__/multihead_attention.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7de8f0d72d09eb000fedb046ba7e27d8f688f813 Binary files /dev/null and b/esm/__pycache__/multihead_attention.cpython-39.pyc differ diff --git a/esm/__pycache__/pretrained.cpython-36.pyc b/esm/__pycache__/pretrained.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1332d5463cee7fcfc8811d10c7b947fc77a40d9f Binary files /dev/null and b/esm/__pycache__/pretrained.cpython-36.pyc differ diff --git a/esm/__pycache__/pretrained.cpython-39.pyc b/esm/__pycache__/pretrained.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a162b0fd9dd912f6cc886c062f26c673803f36ef Binary files /dev/null and b/esm/__pycache__/pretrained.cpython-39.pyc differ diff --git a/esm/__pycache__/rotary_embedding.cpython-36.pyc b/esm/__pycache__/rotary_embedding.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..39d77ad2bbcc5807df4602f097fd93a9e588d514 Binary files /dev/null and b/esm/__pycache__/rotary_embedding.cpython-36.pyc differ diff --git a/esm/__pycache__/rotary_embedding.cpython-39.pyc b/esm/__pycache__/rotary_embedding.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..afe800f34d5f5eabdb8e9f91744747b411ca8839 Binary files /dev/null and b/esm/__pycache__/rotary_embedding.cpython-39.pyc differ diff --git a/esm/__pycache__/version.cpython-36.pyc b/esm/__pycache__/version.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dff46c70ddbe8103af55d52fea8a8f81c12738e5 Binary files /dev/null and b/esm/__pycache__/version.cpython-36.pyc differ diff --git a/esm/__pycache__/version.cpython-39.pyc b/esm/__pycache__/version.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..944094460afd22f197e7a7ffbd22f6258744ae7e Binary files /dev/null and b/esm/__pycache__/version.cpython-39.pyc differ diff --git a/esm/axial_attention.py b/esm/axial_attention.py new file mode 100644 index 0000000000000000000000000000000000000000..f95f287fccef8fb79e814f1108aef7a95d7b90f1 --- /dev/null +++ b/esm/axial_attention.py @@ -0,0 +1,239 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import math +import torch +import torch.nn as nn + + +class RowSelfAttention(nn.Module): + """Compute self-attention over rows of a 2D input.""" + + def __init__( + self, + embed_dim, + num_heads, + dropout=0.0, + max_tokens_per_msa: int = 2 ** 16, + ): + super().__init__() + self.num_heads = num_heads + self.dropout = dropout + self.head_dim = embed_dim // num_heads + self.scaling = self.head_dim ** -0.5 + self.max_tokens_per_msa = max_tokens_per_msa + self.attn_shape = "hnij" + + self.k_proj = nn.Linear(embed_dim, embed_dim) + self.v_proj = nn.Linear(embed_dim, embed_dim) + self.q_proj = nn.Linear(embed_dim, embed_dim) + + self.out_proj = nn.Linear(embed_dim, embed_dim) + self.dropout_module = nn.Dropout(dropout) + + def align_scaling(self, q): + num_rows = q.size(0) + return self.scaling / math.sqrt(num_rows) + + def _batched_forward( + self, + x, + self_attn_mask=None, + self_attn_padding_mask=None, + ): + num_rows, num_cols, batch_size, embed_dim = x.size() + max_rows = max(1, self.max_tokens_per_msa // num_cols) + attns = 0 + scaling = self.align_scaling(x) + for start in range(0, num_rows, max_rows): + attn_weights = self.compute_attention_weights( + x[start : start + max_rows], + scaling, + self_attn_mask=self_attn_mask, + self_attn_padding_mask=self_attn_padding_mask[:, start : start + max_rows] + if self_attn_padding_mask is not None + else None, + ) + attns += attn_weights + attn_probs = attns.softmax(-1) + attn_probs = self.dropout_module(attn_probs) + + outputs = [] + for start in range(0, num_rows, max_rows): + output = self.compute_attention_update(x[start : start + max_rows], attn_probs) + outputs.append(output) + + output = torch.cat(outputs, 0) + return output, attn_probs + + def compute_attention_weights( + self, + x, + scaling: float, + self_attn_mask=None, + self_attn_padding_mask=None, + ): + num_rows, num_cols, batch_size, embed_dim = x.size() + q = self.q_proj(x).view(num_rows, num_cols, batch_size, self.num_heads, self.head_dim) + k = self.k_proj(x).view(num_rows, num_cols, batch_size, self.num_heads, self.head_dim) + q *= scaling + if self_attn_padding_mask is not None: + # Zero out any padded aligned positions - this is important since + # we take a sum across the alignment axis. + q *= 1 - self_attn_padding_mask.permute(1, 2, 0).unsqueeze(3).unsqueeze(4).to(q) + + attn_weights = torch.einsum(f"rinhd,rjnhd->{self.attn_shape}", q, k) + + if self_attn_mask is not None: + raise NotImplementedError + # Mask Size: [B x R x C], Weights Size: [H x B x C x C] + + if self_attn_padding_mask is not None: + attn_weights = attn_weights.masked_fill( + self_attn_padding_mask[:, 0].unsqueeze(0).unsqueeze(2), + -10000, + ) + + return attn_weights + + def compute_attention_update( + self, + x, + attn_probs, + ): + num_rows, num_cols, batch_size, embed_dim = x.size() + v = self.v_proj(x).view(num_rows, num_cols, batch_size, self.num_heads, self.head_dim) + context = torch.einsum(f"{self.attn_shape},rjnhd->rinhd", attn_probs, v) + context = context.contiguous().view(num_rows, num_cols, batch_size, embed_dim) + output = self.out_proj(context) + return output + + def forward( + self, + x, + self_attn_mask=None, + self_attn_padding_mask=None, + ): + num_rows, num_cols, batch_size, embed_dim = x.size() + if (num_rows * num_cols > self.max_tokens_per_msa) and not torch.is_grad_enabled(): + return self._batched_forward(x, self_attn_mask, self_attn_padding_mask) + else: + scaling = self.align_scaling(x) + attn_weights = self.compute_attention_weights( + x, scaling, self_attn_mask, self_attn_padding_mask + ) + attn_probs = attn_weights.softmax(-1) + attn_probs = self.dropout_module(attn_probs) + output = self.compute_attention_update(x, attn_probs) + return output, attn_probs + + +class ColumnSelfAttention(nn.Module): + """Compute self-attention over columns of a 2D input.""" + + def __init__( + self, + embed_dim, + num_heads, + dropout=0.0, + max_tokens_per_msa: int = 2 ** 16, + ): + super().__init__() + + self.num_heads = num_heads + self.dropout = dropout + self.head_dim = embed_dim // num_heads + self.scaling = self.head_dim ** -0.5 + self.max_tokens_per_msa = max_tokens_per_msa + + self.k_proj = nn.Linear(embed_dim, embed_dim) + self.v_proj = nn.Linear(embed_dim, embed_dim) + self.q_proj = nn.Linear(embed_dim, embed_dim) + + self.out_proj = nn.Linear(embed_dim, embed_dim) + self.dropout_module = nn.Dropout(dropout) + + def _batched_forward( + self, + x, + self_attn_mask=None, + self_attn_padding_mask=None, + ): + num_rows, num_cols, batch_size, embed_dim = x.size() + max_cols = max(1, self.max_tokens_per_msa // num_rows) + outputs = [] + attns = [] + for start in range(0, num_cols, max_cols): + output, attn = self( + x[:, start : start + max_cols], + self_attn_mask=self_attn_mask, + self_attn_padding_mask=self_attn_padding_mask[:, :, start : start + max_cols] + if self_attn_padding_mask is not None + else None, + ) + outputs.append(output) + attns.append(attn) + output = torch.cat(outputs, 1) + attns = torch.cat(attns, 1) + return output, attns + + def compute_attention_update( + self, + x, + self_attn_mask=None, + self_attn_padding_mask=None, + ): + num_rows, num_cols, batch_size, embed_dim = x.size() + if num_rows == 1: + # if there is only 1 position, this is equivalent and doesn't break with padding + attn_probs = torch.ones( + self.num_heads, + num_cols, + batch_size, + num_rows, + num_rows, + device=x.device, + dtype=x.dtype, + ) + output = self.out_proj(self.v_proj(x)) + else: + q = self.q_proj(x).view(num_rows, num_cols, batch_size, self.num_heads, self.head_dim) + k = self.k_proj(x).view(num_rows, num_cols, batch_size, self.num_heads, self.head_dim) + v = self.v_proj(x).view(num_rows, num_cols, batch_size, self.num_heads, self.head_dim) + q *= self.scaling + + attn_weights = torch.einsum("icnhd,jcnhd->hcnij", q, k) + + if self_attn_mask is not None: + raise NotImplementedError + if self_attn_padding_mask is not None: + attn_weights = attn_weights.masked_fill( + self_attn_padding_mask.permute(2, 0, 1).unsqueeze(0).unsqueeze(3), + -10000, + ) + + attn_probs = attn_weights.softmax(-1) + attn_probs = self.dropout_module(attn_probs) + context = torch.einsum("hcnij,jcnhd->icnhd", attn_probs, v) + context = context.contiguous().view(num_rows, num_cols, batch_size, embed_dim) + output = self.out_proj(context) + return output, attn_probs + + def forward( + self, + x, + self_attn_mask=None, + self_attn_padding_mask=None, + ): + num_rows, num_cols, batch_size, embed_dim = x.size() + # if False and num_rows * num_cols > 2 ** 14 and not torch.is_grad_enabled(): + if (num_rows * num_cols) > self.max_tokens_per_msa and not torch.is_grad_enabled(): + return self._batched_forward( + x, + self_attn_mask, + self_attn_padding_mask, + ) + else: + return self.compute_attention_update(x, self_attn_mask, self_attn_padding_mask) diff --git a/esm/constants.py b/esm/constants.py new file mode 100644 index 0000000000000000000000000000000000000000..e7ef02cc924356acfcf8d4c9897d1bf09da6f58d --- /dev/null +++ b/esm/constants.py @@ -0,0 +1,14 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# fmt: off +proteinseq_toks = { + 'toks': ['L', 'A', 'G', 'V', 'S', 'E', 'R', 'T', 'I', 'D', 'P', 'K', 'Q', 'N', 'F', 'Y', 'M', 'H', 'W', 'C', 'X', 'B', 'U', 'Z', 'O', '.', '-'] +} + +rnaseq_toks = { + 'toks': ['A', 'G', 'T', 'C'] +} +# fmt: on diff --git a/esm/data.py b/esm/data.py new file mode 100644 index 0000000000000000000000000000000000000000..5a43c1f605f49bd35e614c3eca56a4caf33c714a --- /dev/null +++ b/esm/data.py @@ -0,0 +1,524 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import itertools +import os +from typing import Sequence, Tuple, List, Union +import pickle +import re +import shutil +import torch +from pathlib import Path +from .constants import proteinseq_toks, rnaseq_toks +import math +import random +from copy import deepcopy + +RawMSA = Sequence[Tuple[str, str]] + + +class Alphabet(object): + def __init__( + self, + standard_toks: Sequence[str], + prepend_toks: Sequence[str] = ("", "", ""), # "", + append_toks: Sequence[str] = ("", "", ""), # + prepend_bos: bool = True, + append_eos: bool = True, + use_msa: bool = False, + mask_prob: float = 0.15, ###--- + ): + self.mask_prob = mask_prob ###--- + self.standard_toks = list(standard_toks) + self.prepend_toks = list(prepend_toks) + self.append_toks = list(append_toks) + self.prepend_bos = prepend_bos + self.append_eos = append_eos + self.use_msa = use_msa + + self.all_toks = list(self.prepend_toks) + self.all_toks.extend(self.standard_toks) +# for i in range((8 - (len(self.all_toks) % 8)) % 8): +# self.all_toks.append(f"") + self.all_toks.extend(self.append_toks) + + self.tok_to_idx = {tok: i for i, tok in enumerate(self.all_toks)} +# print(self.tok_to_idx) + self.unk_idx = self.tok_to_idx[""] + self.padding_idx = self.get_idx("") + self.cls_idx = self.get_idx("") + self.mask_idx = self.get_idx("") + self.eos_idx = self.get_idx("") + self.all_special_tokens = ['', '', ''] # , '', '' + self.unique_no_split_tokens = self.all_toks + + def __len__(self): + return len(self.all_toks) + + def get_idx(self, tok): + return self.tok_to_idx.get(tok, self.unk_idx) + + def get_tok(self, ind): + return self.all_toks[ind] + + def to_dict(self): + return self.tok_to_idx.copy() + + def get_batch_converter(self): + if self.use_msa: + return MSABatchConverter(self) + else: + return BatchConverter(self) + + @classmethod + def from_architecture(cls, name: str) -> "Alphabet": + if name in ("ESM-1", "protein_bert_base"): + standard_toks = proteinseq_toks["toks"] + prepend_toks: Tuple[str, ...] = ("", "", "", "") + append_toks: Tuple[str, ...] = ("", "", "") + prepend_bos = True + append_eos = False + use_msa = False + elif name in ("ESM-1b", "roberta_large"): + standard_toks = proteinseq_toks["toks"] ###---rnaseq + prepend_toks = ("", "", "", "") + append_toks = ("",) + prepend_bos = True + append_eos = True + use_msa = False + elif name in ("MSA Transformer", "msa_transformer"): + standard_toks = proteinseq_toks["toks"] + prepend_toks = ("", "", "", "") + append_toks = ("",) + prepend_bos = True + append_eos = False + use_msa = True + else: + raise ValueError("Unknown architecture selected") + return cls(standard_toks, prepend_toks, append_toks, prepend_bos, append_eos, use_msa) + + def _tokenize(self, text) -> str: + return text.split() + + def tokenize(self, text, **kwargs) -> List[str]: + """ + Inspired by https://github.com/huggingface/transformers/blob/master/src/transformers/tokenization_utils.py + Converts a string in a sequence of tokens, using the tokenizer. + + Args: + text (:obj:`str`): + The sequence to be encoded. + + Returns: + :obj:`List[str]`: The list of tokens. + """ + + def split_on_token(tok, text): + result = [] + split_text = text.split(tok) + for i, sub_text in enumerate(split_text): + # AddedToken can control whitespace stripping around them. + # We use them for GPT2 and Roberta to have different behavior depending on the special token + # Cf. https://github.com/huggingface/transformers/pull/2778 + # and https://github.com/huggingface/transformers/issues/3788 + # We strip left and right by default + if i < len(split_text) - 1: + sub_text = sub_text.rstrip() + if i > 0: + sub_text = sub_text.lstrip() + + if i == 0 and not sub_text: + result.append(tok) + elif i == len(split_text) - 1: + if sub_text: + result.append(sub_text) + else: + pass + else: + if sub_text: + result.append(sub_text) + result.append(tok) + return result + + def split_on_tokens(tok_list, text): + if not text.strip(): + return [] + + tokenized_text = [] + text_list = [text] + for tok in tok_list: + tokenized_text = [] + for sub_text in text_list: + if sub_text not in self.unique_no_split_tokens: + tokenized_text.extend(split_on_token(tok, sub_text)) + else: + tokenized_text.append(sub_text) + text_list = tokenized_text + + return list( + itertools.chain.from_iterable( + ( + self._tokenize(token) + if token not in self.unique_no_split_tokens + else [token] + for token in tokenized_text + ) + ) + ) + + no_split_token = self.unique_no_split_tokens + tokenized_text = split_on_tokens(no_split_token, text) + return tokenized_text + + def encode(self, text): + return [self.tok_to_idx[tok] for tok in self.tokenize(text)] + +class FastaBatchedDataset(object): + def __init__(self, sequence_labels, sequence_strs, mask_prob = 0.15): + self.sequence_labels = list(sequence_labels) + self.sequence_strs = list(sequence_strs) + self.mask_prob = mask_prob + + @classmethod + def from_file(cls, fasta_file, mask_prob = 0.15): + sequence_labels, sequence_strs = [], [] + cur_seq_label = None + buf = [] + + def _flush_current_seq(): + nonlocal cur_seq_label, buf + if cur_seq_label is None: + return + sequence_labels.append(cur_seq_label) + sequence_strs.append("".join(buf)) + cur_seq_label = None + buf = [] + + with open(fasta_file, "r") as infile: + for line_idx, line in enumerate(infile): + if line.startswith(">"): # label line + _flush_current_seq() + line = line[1:].strip() + if len(line) > 0: + cur_seq_label = line + else: + cur_seq_label = f"seqnum{line_idx:09d}" + else: # sequence line + buf.append(line.strip()) + + _flush_current_seq() + + assert len(set(sequence_strs)) == len( + sequence_strs + ), "Found duplicate sequence labels" + + return cls(sequence_labels, sequence_strs, mask_prob) + + def __len__(self): + return len(self.sequence_labels) + + def mask_sequence(self, seq): ###--- + length = len(seq) +# print(self.mask_prob) + max_length = math.ceil(length * self.mask_prob) + rand = random.sample(range(0, length), max_length) + res = ''.join(['' if idx in rand else ele for idx, ele in enumerate(seq)]) + #print(seq, rand, res) + return rand, res + + def __getitem__(self, idx): + sequence_str = self.sequence_strs[idx] + sequence_label = self.sequence_labels[idx] + masked_indices, masked_sequence_str = self.mask_sequence(sequence_str) + return sequence_label, sequence_str, masked_sequence_str, masked_indices + + def get_batch_indices(self, toks_per_batch, extra_toks_per_seq=0): + sizes = [(len(s), i) for i, s in enumerate(self.sequence_strs)] + sizes.sort() + batches = [] + buf = [] + max_len = 0 + + def _flush_current_buf(): + nonlocal max_len, buf + if len(buf) == 0: + return + batches.append(buf) + buf = [] + max_len = 0 + + for sz, i in sizes: + sz += extra_toks_per_seq + if max(sz, max_len) * (len(buf) + 1) > toks_per_batch: + _flush_current_buf() + max_len = max(max_len, sz) + buf.append(i) + + _flush_current_buf() + return batches + +class BatchConverter(object): + """Callable to convert an unprocessed (labels + strings) batch to a + processed (labels + tensor) batch. + """ + + def __init__(self, alphabet): + self.alphabet = alphabet + + def __call__(self, raw_batch: Sequence[Tuple[str, str]]): + # RoBERTa uses an eos token, while ESM-1 does not. + batch_size = len(raw_batch) + batch_labels, seq_str_list, masked_seq_str_list, masked_indices_list = zip(*raw_batch) + + masked_seq_encoded_list = [self.alphabet.encode(seq_str) for seq_str in masked_seq_str_list] ###--- + seq_encoded_list = [self.alphabet.encode(seq_str) for seq_str in seq_str_list] ###--- +# print('====', seq_str_list) +# print('----', masked_seq_str_list) +# print('++++', masked_seq_encoded_list) +# print('****', seq_encoded_list) + + max_len = max(len(seq_encoded) for seq_encoded in masked_seq_encoded_list) + tokens = torch.empty( + ( + batch_size, + max_len + int(self.alphabet.prepend_bos) + int(self.alphabet.append_eos), + ), + dtype=torch.int64, + ) + tokens.fill_(self.alphabet.padding_idx) + masked_tokens = deepcopy(tokens) + + labels = [] + strs, masked_strs = [], [] + masked_indices = [] +# print('=================') + for i, (label, seq_str, masked_seq_str, seq_encoded, masked_seq_encoded, indices_mask) in enumerate( + zip(batch_labels, seq_str_list, masked_seq_str_list, seq_encoded_list, masked_seq_encoded_list, masked_indices_list) ###--- + ): + labels.append(label) + strs.append(seq_str) + masked_strs.append(masked_seq_str) + masked_indices.append(indices_mask) + + if self.alphabet.prepend_bos: + tokens[i, 0] = self.alphabet.cls_idx + masked_tokens[i, 0] = self.alphabet.cls_idx + + seq = torch.tensor(seq_encoded, dtype=torch.int64) + masked_seq = torch.tensor(masked_seq_encoded, dtype=torch.int64) +# print(tokens, masked_tokens) + tokens[ + i, + int(self.alphabet.prepend_bos) : len(seq_encoded) + + int(self.alphabet.prepend_bos), + ] = seq + + masked_tokens[ + i, + int(self.alphabet.prepend_bos) : len(masked_seq_encoded) + + int(self.alphabet.prepend_bos), + ] = masked_seq +# print(tokens, masked_tokens) + if self.alphabet.append_eos: + tokens[i, len(seq_encoded) + int(self.alphabet.prepend_bos)] = self.alphabet.eos_idx + masked_tokens[i, len(masked_seq_encoded) + int(self.alphabet.prepend_bos)] = self.alphabet.eos_idx +# print(tokens, masked_tokens) + return labels, strs, masked_strs, tokens, masked_tokens, masked_indices + + +class MSABatchConverter(BatchConverter): + def __call__(self, inputs: Union[Sequence[RawMSA], RawMSA]): + if isinstance(inputs[0][0], str): + # Input is a single MSA + raw_batch: Sequence[RawMSA] = [inputs] # type: ignore + else: + raw_batch = inputs # type: ignore + + batch_size = len(raw_batch) + max_alignments = max(len(msa) for msa in raw_batch) + max_seqlen = max(len(msa[0][1]) for msa in raw_batch) + + tokens = torch.empty( + ( + batch_size, + max_alignments, + max_seqlen + int(self.alphabet.prepend_bos) + int(self.alphabet.append_eos), + ), + dtype=torch.int64, + ) + tokens.fill_(self.alphabet.padding_idx) + labels = [] + strs = [] + + for i, msa in enumerate(raw_batch): + msa_seqlens = set(len(seq) for _, seq in msa) + if not len(msa_seqlens) == 1: + raise RuntimeError( + "Received unaligned sequences for input to MSA, all sequence " + "lengths must be equal." + ) + msa_labels, msa_strs, msa_tokens = super().__call__(msa) + labels.append(msa_labels) + strs.append(msa_strs) + tokens[i, : msa_tokens.size(0), : msa_tokens.size(1)] = msa_tokens + + return labels, strs, tokens + + +def read_fasta( + path, + keep_gaps=True, + keep_insertions=True, + to_upper=False, +): + with open(path, "r") as f: + for result in read_alignment_lines( + f, keep_gaps=keep_gaps, keep_insertions=keep_insertions, to_upper=to_upper + ): + yield result + + +def read_alignment_lines( + lines, + keep_gaps=True, + keep_insertions=True, + to_upper=False, +): + seq = desc = None + + def parse(s): + if not keep_gaps: + s = re.sub("-", "", s) + if not keep_insertions: + s = re.sub("[a-z]", "", s) + return s.upper() if to_upper else s + + for line in lines: + # Line may be empty if seq % file_line_width == 0 + if len(line) > 0 and line[0] == ">": + if seq is not None: + yield desc, parse(seq) + desc = line.strip() + seq = "" + else: + assert isinstance(seq, str) + seq += line.strip() + assert isinstance(seq, str) and isinstance(desc, str) + yield desc, parse(seq) + + +class ESMStructuralSplitDataset(torch.utils.data.Dataset): + """ + Structural Split Dataset as described in section A.10 of the supplement of our paper. + https://doi.org/10.1101/622803 + + We use the full version of SCOPe 2.07, clustered at 90% sequence identity, + generated on January 23, 2020. + + For each SCOPe domain: + - We extract the sequence from the corresponding PDB file + - We extract the 3D coordinates of the Carbon beta atoms, aligning them + to the sequence. We put NaN where Cb atoms are missing. + - From the 3D coordinates, we calculate a pairwise distance map, based + on L2 distance + - We use DSSP to generate secondary structure labels for the corresponding + PDB file. This is also aligned to the sequence. We put - where SSP + labels are missing. + + For each SCOPe classification level of family/superfamily/fold (in order of difficulty), + we have split the data into 5 partitions for cross validation. These are provided + in a downloaded splits folder, in the format: + splits/{split_level}/{cv_partition}/{train|valid}.txt + where train is the partition and valid is the concatentation of the remaining 4. + + For each SCOPe domain, we provide a pkl dump that contains: + - seq : The domain sequence, stored as an L-length string + - ssp : The secondary structure labels, stored as an L-length string + - dist : The distance map, stored as an LxL numpy array + - coords : The 3D coordinates, stored as an Lx3 numpy array + + """ + + base_folder = "structural-data" + file_list = [ + # url tar filename filename MD5 Hash + ( + "https://dl.fbaipublicfiles.com/fair-esm/structural-data/splits.tar.gz", + "splits.tar.gz", + "splits", + "456fe1c7f22c9d3d8dfe9735da52411d", + ), + ( + "https://dl.fbaipublicfiles.com/fair-esm/structural-data/pkl.tar.gz", + "pkl.tar.gz", + "pkl", + "644ea91e56066c750cd50101d390f5db", + ), + ] + + def __init__( + self, + split_level, + cv_partition, + split, + root_path=os.path.expanduser("~/.cache/torch/data/esm"), + download=False, + ): + super().__init__() + assert split in [ + "train", + "valid", + ], "train_valid must be 'train' or 'valid'" + self.root_path = root_path + self.base_path = os.path.join(self.root_path, self.base_folder) + + # check if root path has what you need or else download it + if download: + self.download() + + self.split_file = os.path.join( + self.base_path, "splits", split_level, cv_partition, f"{split}.txt" + ) + self.pkl_dir = os.path.join(self.base_path, "pkl") + self.names = [] + with open(self.split_file) as f: + self.names = f.read().splitlines() + + def __len__(self): + return len(self.names) + + def _check_exists(self) -> bool: + for (_, _, filename, _) in self.file_list: + fpath = os.path.join(self.base_path, filename) + if not os.path.exists(fpath) or not os.path.isdir(fpath): + return False + return True + + def download(self): + + if self._check_exists(): + print("Files already downloaded and verified") + return + + from torchvision.datasets.utils import download_url + + for url, tar_filename, filename, md5_hash in self.file_list: + download_path = os.path.join(self.base_path, tar_filename) + download_url(url=url, root=self.base_path, filename=tar_filename, md5=md5_hash) + shutil.unpack_archive(download_path, self.base_path) + + def __getitem__(self, idx): + """ + Returns a dict with the following entires + - seq : Str (domain sequence) + - ssp : Str (SSP labels) + - dist : np.array (distance map) + - coords : np.array (3D coordinates) + """ + name = self.names[idx] + pkl_fname = os.path.join(self.pkl_dir, name[1:3], f"{name}.pkl") + with open(pkl_fname, "rb") as f: + obj = pickle.load(f) + return obj diff --git a/esm/data_supervised.py b/esm/data_supervised.py new file mode 100644 index 0000000000000000000000000000000000000000..62789c0dc7ea5bfd52130eb724c090cc9837df93 --- /dev/null +++ b/esm/data_supervised.py @@ -0,0 +1,524 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import itertools +import os +from typing import Sequence, Tuple, List, Union +import pickle +import re +import shutil +import torch +from pathlib import Path +from .constants import proteinseq_toks, rnaseq_toks +import math +import random +from copy import deepcopy + +RawMSA = Sequence[Tuple[str, str]] + + +class Alphabet(object): + def __init__( + self, + standard_toks: Sequence[str], + prepend_toks: Sequence[str] = ("", "", ""), # "", + append_toks: Sequence[str] = ("", "", ""), # + prepend_bos: bool = True, + append_eos: bool = True, + use_msa: bool = False, + mask_prob: float = 0.15, ###--- + ): + self.mask_prob = mask_prob ###--- + self.standard_toks = list(standard_toks) + self.prepend_toks = list(prepend_toks) + self.append_toks = list(append_toks) + self.prepend_bos = prepend_bos + self.append_eos = append_eos + self.use_msa = use_msa + + self.all_toks = list(self.prepend_toks) + self.all_toks.extend(self.standard_toks) +# for i in range((8 - (len(self.all_toks) % 8)) % 8): +# self.all_toks.append(f"") + self.all_toks.extend(self.append_toks) + + self.tok_to_idx = {tok: i for i, tok in enumerate(self.all_toks)} +# print(self.tok_to_idx) + self.unk_idx = self.tok_to_idx[""] + self.padding_idx = self.get_idx("") + self.cls_idx = self.get_idx("") + self.mask_idx = self.get_idx("") + self.eos_idx = self.get_idx("") + self.all_special_tokens = ['', '', ''] # , '', '' + self.unique_no_split_tokens = self.all_toks + + def __len__(self): + return len(self.all_toks) + + def get_idx(self, tok): + return self.tok_to_idx.get(tok, self.unk_idx) + + def get_tok(self, ind): + return self.all_toks[ind] + + def to_dict(self): + return self.tok_to_idx.copy() + + def get_batch_converter(self): + if self.use_msa: + return MSABatchConverter(self) + else: + return BatchConverter(self) + + @classmethod + def from_architecture(cls, name: str) -> "Alphabet": + if name in ("ESM-1", "protein_bert_base"): + standard_toks = proteinseq_toks["toks"] + prepend_toks: Tuple[str, ...] = ("", "", "", "") + append_toks: Tuple[str, ...] = ("", "", "") + prepend_bos = True + append_eos = False + use_msa = False + elif name in ("ESM-1b", "roberta_large"): + standard_toks = proteinseq_toks["toks"] ###---rnaseq + prepend_toks = ("", "", "", "") + append_toks = ("",) + prepend_bos = True + append_eos = True + use_msa = False + elif name in ("MSA Transformer", "msa_transformer"): + standard_toks = proteinseq_toks["toks"] + prepend_toks = ("", "", "", "") + append_toks = ("",) + prepend_bos = True + append_eos = False + use_msa = True + else: + raise ValueError("Unknown architecture selected") + return cls(standard_toks, prepend_toks, append_toks, prepend_bos, append_eos, use_msa) + + def _tokenize(self, text) -> str: + return text.split() + + def tokenize(self, text, **kwargs) -> List[str]: + """ + Inspired by https://github.com/huggingface/transformers/blob/master/src/transformers/tokenization_utils.py + Converts a string in a sequence of tokens, using the tokenizer. + + Args: + text (:obj:`str`): + The sequence to be encoded. + + Returns: + :obj:`List[str]`: The list of tokens. + """ + + def split_on_token(tok, text): + result = [] + split_text = text.split(tok) + for i, sub_text in enumerate(split_text): + # AddedToken can control whitespace stripping around them. + # We use them for GPT2 and Roberta to have different behavior depending on the special token + # Cf. https://github.com/huggingface/transformers/pull/2778 + # and https://github.com/huggingface/transformers/issues/3788 + # We strip left and right by default + if i < len(split_text) - 1: + sub_text = sub_text.rstrip() + if i > 0: + sub_text = sub_text.lstrip() + + if i == 0 and not sub_text: + result.append(tok) + elif i == len(split_text) - 1: + if sub_text: + result.append(sub_text) + else: + pass + else: + if sub_text: + result.append(sub_text) + result.append(tok) + return result + + def split_on_tokens(tok_list, text): + if not text.strip(): + return [] + + tokenized_text = [] + text_list = [text] + for tok in tok_list: + tokenized_text = [] + for sub_text in text_list: + if sub_text not in self.unique_no_split_tokens: + tokenized_text.extend(split_on_token(tok, sub_text)) + else: + tokenized_text.append(sub_text) + text_list = tokenized_text + + return list( + itertools.chain.from_iterable( + ( + self._tokenize(token) + if token not in self.unique_no_split_tokens + else [token] + for token in tokenized_text + ) + ) + ) + + no_split_token = self.unique_no_split_tokens + tokenized_text = split_on_tokens(no_split_token, text) + return tokenized_text + + def encode(self, text): + return [self.tok_to_idx[tok] for tok in self.tokenize(text)] + +class FastaBatchedDataset(object): + def __init__(self, sequence_labels, sequence_strs, mask_prob = 0.15): + self.sequence_labels = list(sequence_labels) + self.sequence_strs = list(sequence_strs) + self.mask_prob = mask_prob + + @classmethod + def from_file(cls, fasta_file, mask_prob = 0.15): + sequence_labels, sequence_strs = [], [] + cur_seq_label = None + buf = [] + + def _flush_current_seq(): + nonlocal cur_seq_label, buf + if cur_seq_label is None: + return + sequence_labels.append(cur_seq_label) + sequence_strs.append("".join(buf)) + cur_seq_label = None + buf = [] + + with open(fasta_file, "r") as infile: + for line_idx, line in enumerate(infile): + if line.startswith(">"): # label line + _flush_current_seq() + line = line[1:].strip() + if len(line) > 0: + cur_seq_label = line + else: + cur_seq_label = f"seqnum{line_idx:09d}" + else: # sequence line + buf.append(line.strip()) + + _flush_current_seq() + + assert len(set(sequence_labels)) == len( + sequence_labels + ), "Found duplicate sequence labels" + + return cls(sequence_labels, sequence_strs, mask_prob) + + def __len__(self): + return len(self.sequence_labels) + + def mask_sequence(self, seq): ###--- + length = len(seq) +# print(self.mask_prob) + max_length = math.ceil(length * self.mask_prob) + rand = random.sample(range(0, length), max_length) + res = ''.join(['' if idx in rand else ele for idx, ele in enumerate(seq)]) + #print(seq, rand, res) + return rand, res + + def __getitem__(self, idx): + sequence_str = self.sequence_strs[idx] + sequence_label = self.sequence_labels[idx] + masked_indices, masked_sequence_str = self.mask_sequence(sequence_str) + return sequence_label, sequence_str, masked_sequence_str, masked_indices + + def get_batch_indices(self, toks_per_batch, extra_toks_per_seq=0): + sizes = [(len(s), i) for i, s in enumerate(self.sequence_strs)] + sizes.sort() + batches = [] + buf = [] + max_len = 0 + + def _flush_current_buf(): + nonlocal max_len, buf + if len(buf) == 0: + return + batches.append(buf) + buf = [] + max_len = 0 + + for sz, i in sizes: + sz += extra_toks_per_seq + if max(sz, max_len) * (len(buf) + 1) > toks_per_batch: + _flush_current_buf() + max_len = max(max_len, sz) + buf.append(i) + + _flush_current_buf() + return batches + +class BatchConverter(object): + """Callable to convert an unprocessed (labels + strings) batch to a + processed (labels + tensor) batch. + """ + + def __init__(self, alphabet): + self.alphabet = alphabet + + def __call__(self, raw_batch: Sequence[Tuple[str, str]]): + # RoBERTa uses an eos token, while ESM-1 does not. + batch_size = len(raw_batch) + batch_labels, seq_str_list, masked_seq_str_list, masked_indices_list = zip(*raw_batch) + + masked_seq_encoded_list = [self.alphabet.encode(seq_str) for seq_str in masked_seq_str_list] ###--- + seq_encoded_list = [self.alphabet.encode(seq_str) for seq_str in seq_str_list] ###--- +# print('====', seq_str_list) +# print('----', masked_seq_str_list) +# print('++++', masked_seq_encoded_list) +# print('****', seq_encoded_list) + + max_len = max(len(seq_encoded) for seq_encoded in masked_seq_encoded_list) + tokens = torch.empty( + ( + batch_size, + max_len + int(self.alphabet.prepend_bos) + int(self.alphabet.append_eos), + ), + dtype=torch.int64, + ) + tokens.fill_(self.alphabet.padding_idx) + masked_tokens = deepcopy(tokens) + + labels = [] + strs, masked_strs = [], [] + masked_indices = [] +# print('=================') + for i, (label, seq_str, masked_seq_str, seq_encoded, masked_seq_encoded, indices_mask) in enumerate( + zip(batch_labels, seq_str_list, masked_seq_str_list, seq_encoded_list, masked_seq_encoded_list, masked_indices_list) ###--- + ): + labels.append(label) + strs.append(seq_str) + masked_strs.append(masked_seq_str) + masked_indices.append(indices_mask) + + if self.alphabet.prepend_bos: + tokens[i, 0] = self.alphabet.cls_idx + masked_tokens[i, 0] = self.alphabet.cls_idx + + seq = torch.tensor(seq_encoded, dtype=torch.int64) + masked_seq = torch.tensor(masked_seq_encoded, dtype=torch.int64) +# print(tokens, masked_tokens) + tokens[ + i, + int(self.alphabet.prepend_bos) : len(seq_encoded) + + int(self.alphabet.prepend_bos), + ] = seq + + masked_tokens[ + i, + int(self.alphabet.prepend_bos) : len(masked_seq_encoded) + + int(self.alphabet.prepend_bos), + ] = masked_seq +# print(tokens, masked_tokens) + if self.alphabet.append_eos: + tokens[i, len(seq_encoded) + int(self.alphabet.prepend_bos)] = self.alphabet.eos_idx + masked_tokens[i, len(masked_seq_encoded) + int(self.alphabet.prepend_bos)] = self.alphabet.eos_idx +# print(tokens, masked_tokens) + return labels, strs, masked_strs, tokens, masked_tokens, masked_indices + + +class MSABatchConverter(BatchConverter): + def __call__(self, inputs: Union[Sequence[RawMSA], RawMSA]): + if isinstance(inputs[0][0], str): + # Input is a single MSA + raw_batch: Sequence[RawMSA] = [inputs] # type: ignore + else: + raw_batch = inputs # type: ignore + + batch_size = len(raw_batch) + max_alignments = max(len(msa) for msa in raw_batch) + max_seqlen = max(len(msa[0][1]) for msa in raw_batch) + + tokens = torch.empty( + ( + batch_size, + max_alignments, + max_seqlen + int(self.alphabet.prepend_bos) + int(self.alphabet.append_eos), + ), + dtype=torch.int64, + ) + tokens.fill_(self.alphabet.padding_idx) + labels = [] + strs = [] + + for i, msa in enumerate(raw_batch): + msa_seqlens = set(len(seq) for _, seq in msa) + if not len(msa_seqlens) == 1: + raise RuntimeError( + "Received unaligned sequences for input to MSA, all sequence " + "lengths must be equal." + ) + msa_labels, msa_strs, msa_tokens = super().__call__(msa) + labels.append(msa_labels) + strs.append(msa_strs) + tokens[i, : msa_tokens.size(0), : msa_tokens.size(1)] = msa_tokens + + return labels, strs, tokens + + +def read_fasta( + path, + keep_gaps=True, + keep_insertions=True, + to_upper=False, +): + with open(path, "r") as f: + for result in read_alignment_lines( + f, keep_gaps=keep_gaps, keep_insertions=keep_insertions, to_upper=to_upper + ): + yield result + + +def read_alignment_lines( + lines, + keep_gaps=True, + keep_insertions=True, + to_upper=False, +): + seq = desc = None + + def parse(s): + if not keep_gaps: + s = re.sub("-", "", s) + if not keep_insertions: + s = re.sub("[a-z]", "", s) + return s.upper() if to_upper else s + + for line in lines: + # Line may be empty if seq % file_line_width == 0 + if len(line) > 0 and line[0] == ">": + if seq is not None: + yield desc, parse(seq) + desc = line.strip() + seq = "" + else: + assert isinstance(seq, str) + seq += line.strip() + assert isinstance(seq, str) and isinstance(desc, str) + yield desc, parse(seq) + + +class ESMStructuralSplitDataset(torch.utils.data.Dataset): + """ + Structural Split Dataset as described in section A.10 of the supplement of our paper. + https://doi.org/10.1101/622803 + + We use the full version of SCOPe 2.07, clustered at 90% sequence identity, + generated on January 23, 2020. + + For each SCOPe domain: + - We extract the sequence from the corresponding PDB file + - We extract the 3D coordinates of the Carbon beta atoms, aligning them + to the sequence. We put NaN where Cb atoms are missing. + - From the 3D coordinates, we calculate a pairwise distance map, based + on L2 distance + - We use DSSP to generate secondary structure labels for the corresponding + PDB file. This is also aligned to the sequence. We put - where SSP + labels are missing. + + For each SCOPe classification level of family/superfamily/fold (in order of difficulty), + we have split the data into 5 partitions for cross validation. These are provided + in a downloaded splits folder, in the format: + splits/{split_level}/{cv_partition}/{train|valid}.txt + where train is the partition and valid is the concatentation of the remaining 4. + + For each SCOPe domain, we provide a pkl dump that contains: + - seq : The domain sequence, stored as an L-length string + - ssp : The secondary structure labels, stored as an L-length string + - dist : The distance map, stored as an LxL numpy array + - coords : The 3D coordinates, stored as an Lx3 numpy array + + """ + + base_folder = "structural-data" + file_list = [ + # url tar filename filename MD5 Hash + ( + "https://dl.fbaipublicfiles.com/fair-esm/structural-data/splits.tar.gz", + "splits.tar.gz", + "splits", + "456fe1c7f22c9d3d8dfe9735da52411d", + ), + ( + "https://dl.fbaipublicfiles.com/fair-esm/structural-data/pkl.tar.gz", + "pkl.tar.gz", + "pkl", + "644ea91e56066c750cd50101d390f5db", + ), + ] + + def __init__( + self, + split_level, + cv_partition, + split, + root_path=os.path.expanduser("~/.cache/torch/data/esm"), + download=False, + ): + super().__init__() + assert split in [ + "train", + "valid", + ], "train_valid must be 'train' or 'valid'" + self.root_path = root_path + self.base_path = os.path.join(self.root_path, self.base_folder) + + # check if root path has what you need or else download it + if download: + self.download() + + self.split_file = os.path.join( + self.base_path, "splits", split_level, cv_partition, f"{split}.txt" + ) + self.pkl_dir = os.path.join(self.base_path, "pkl") + self.names = [] + with open(self.split_file) as f: + self.names = f.read().splitlines() + + def __len__(self): + return len(self.names) + + def _check_exists(self) -> bool: + for (_, _, filename, _) in self.file_list: + fpath = os.path.join(self.base_path, filename) + if not os.path.exists(fpath) or not os.path.isdir(fpath): + return False + return True + + def download(self): + + if self._check_exists(): + print("Files already downloaded and verified") + return + + from torchvision.datasets.utils import download_url + + for url, tar_filename, filename, md5_hash in self.file_list: + download_path = os.path.join(self.base_path, tar_filename) + download_url(url=url, root=self.base_path, filename=tar_filename, md5=md5_hash) + shutil.unpack_archive(download_path, self.base_path) + + def __getitem__(self, idx): + """ + Returns a dict with the following entires + - seq : Str (domain sequence) + - ssp : Str (SSP labels) + - dist : np.array (distance map) + - coords : np.array (3D coordinates) + """ + name = self.names[idx] + pkl_fname = os.path.join(self.pkl_dir, name[1:3], f"{name}.pkl") + with open(pkl_fname, "rb") as f: + obj = pickle.load(f) + return obj diff --git a/esm/model/._esm2_secondarystructure.py b/esm/model/._esm2_secondarystructure.py new file mode 100644 index 0000000000000000000000000000000000000000..1ac2328d14be5db179803252d104e28dc46cdd8d Binary files /dev/null and b/esm/model/._esm2_secondarystructure.py differ diff --git a/esm/model/__pycache__/esm1.cpython-36.pyc b/esm/model/__pycache__/esm1.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..22466f8ab2dbf380a96aff39fc4650e6b9715834 Binary files /dev/null and b/esm/model/__pycache__/esm1.cpython-36.pyc differ diff --git a/esm/model/__pycache__/esm1.cpython-39.pyc b/esm/model/__pycache__/esm1.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..37cd9ecd46791231d8553b82ccd57ba56fa94f89 Binary files /dev/null and b/esm/model/__pycache__/esm1.cpython-39.pyc differ diff --git a/esm/model/__pycache__/esm2.cpython-36.pyc b/esm/model/__pycache__/esm2.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d5672f4c6642622e8ed36922402cd699d095ebf6 Binary files /dev/null and b/esm/model/__pycache__/esm2.cpython-36.pyc differ diff --git a/esm/model/__pycache__/esm2.cpython-39.pyc b/esm/model/__pycache__/esm2.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ee7242a86abf4949dbd21e546aab2c35ef103a27 Binary files /dev/null and b/esm/model/__pycache__/esm2.cpython-39.pyc differ diff --git a/esm/model/__pycache__/esm2_only_secondarystructure.cpython-39.pyc b/esm/model/__pycache__/esm2_only_secondarystructure.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4d8f8b63e8eb839ea1d6e4273117ae8c410f11df Binary files /dev/null and b/esm/model/__pycache__/esm2_only_secondarystructure.cpython-39.pyc differ diff --git a/esm/model/__pycache__/esm2_secondarystructure.cpython-39.pyc b/esm/model/__pycache__/esm2_secondarystructure.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..997de2043bcc9be3966f400e26bbe09f2a3aa8f9 Binary files /dev/null and b/esm/model/__pycache__/esm2_secondarystructure.cpython-39.pyc differ diff --git a/esm/model/__pycache__/esm2_supervised.cpython-39.pyc b/esm/model/__pycache__/esm2_supervised.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ace8626659123f71caf84f8533cc53dcc90422d8 Binary files /dev/null and b/esm/model/__pycache__/esm2_supervised.cpython-39.pyc differ diff --git a/esm/model/__pycache__/msa_transformer.cpython-36.pyc b/esm/model/__pycache__/msa_transformer.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..17b24e3a60ba04a83983eb2d539768eff46e1879 Binary files /dev/null and b/esm/model/__pycache__/msa_transformer.cpython-36.pyc differ diff --git a/esm/model/__pycache__/msa_transformer.cpython-39.pyc b/esm/model/__pycache__/msa_transformer.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..38736311a55b5ecb90edf996c003cd0e4f5b4f5e Binary files /dev/null and b/esm/model/__pycache__/msa_transformer.cpython-39.pyc differ diff --git a/esm/model/esm1.py b/esm/model/esm1.py new file mode 100644 index 0000000000000000000000000000000000000000..5933fd64d84e2a72aaf46d8e2bacad67174d278e --- /dev/null +++ b/esm/model/esm1.py @@ -0,0 +1,203 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ..modules import ( + TransformerLayer, + LearnedPositionalEmbedding, + SinusoidalPositionalEmbedding, + RobertaLMHead, + ESM1bLayerNorm, + ContactPredictionHead, +) + + +class ProteinBertModel(nn.Module): + @classmethod + def add_args(cls, parser): + parser.add_argument( + "--num_layers", default=36, type=int, metavar="N", help="number of layers" + ) + parser.add_argument( + "--embed_dim", default=1280, type=int, metavar="N", help="embedding dimension" + ) + parser.add_argument( + "--logit_bias", action="store_true", help="whether to apply bias to logits" + ) + parser.add_argument( + "--ffn_embed_dim", + default=5120, + type=int, + metavar="N", + help="embedding dimension for FFN", + ) + parser.add_argument( + "--attention_heads", + default=20, + type=int, + metavar="N", + help="number of attention heads", + ) + + def __init__(self, args, alphabet): + super().__init__() + self.args = args + self.alphabet_size = len(alphabet) + self.padding_idx = alphabet.padding_idx + self.mask_idx = alphabet.mask_idx + self.cls_idx = alphabet.cls_idx + self.eos_idx = alphabet.eos_idx + self.prepend_bos = alphabet.prepend_bos + self.append_eos = alphabet.append_eos + self.emb_layer_norm_before = getattr(self.args, "emb_layer_norm_before", False) + if self.args.arch == "roberta_large": + self.model_version = "ESM-1b" + self._init_submodules_esm1b() + else: + self.model_version = "ESM-1" + self._init_submodules_esm1() + + def _init_submodules_common(self): + self.embed_tokens = nn.Embedding( + self.alphabet_size, self.args.embed_dim, padding_idx=self.padding_idx + ) + self.layers = nn.ModuleList( + [ + TransformerLayer( + self.args.embed_dim, + self.args.ffn_embed_dim, + self.args.attention_heads, + add_bias_kv=(self.model_version != "ESM-1b"), + use_esm1b_layer_norm=(self.model_version == "ESM-1b"), + ) + for _ in range(self.args.layers) + ] + ) + + self.contact_head = ContactPredictionHead( + self.args.layers * self.args.attention_heads, + self.prepend_bos, + self.append_eos, + eos_idx=self.eos_idx, + ) + + def _init_submodules_esm1b(self): + self._init_submodules_common() + self.embed_scale = 1 + self.embed_positions = LearnedPositionalEmbedding( + self.args.max_positions, self.args.embed_dim, self.padding_idx + ) + self.emb_layer_norm_before = ( + ESM1bLayerNorm(self.args.embed_dim) if self.emb_layer_norm_before else None + ) + self.emb_layer_norm_after = ESM1bLayerNorm(self.args.embed_dim) + self.lm_head = RobertaLMHead( + embed_dim=self.args.embed_dim, + output_dim=self.alphabet_size, + weight=self.embed_tokens.weight, + ) + + def _init_submodules_esm1(self): + self._init_submodules_common() + self.embed_scale = math.sqrt(self.args.embed_dim) + self.embed_positions = SinusoidalPositionalEmbedding(self.args.embed_dim, self.padding_idx) + self.embed_out = nn.Parameter(torch.zeros((self.alphabet_size, self.args.embed_dim))) + self.embed_out_bias = None + if self.args.final_bias: + self.embed_out_bias = nn.Parameter(torch.zeros(self.alphabet_size)) + + def forward(self, tokens, repr_layers=[], need_head_weights=False, return_contacts=False, return_representation=False): + if return_contacts: + need_head_weights = True + + assert tokens.ndim == 2 + padding_mask = tokens.eq(self.padding_idx) # B, T + + x = self.embed_scale * self.embed_tokens(tokens) + + if getattr(self.args, "token_dropout", False): + x.masked_fill_((tokens == self.mask_idx).unsqueeze(-1), 0.0) + # x: B x T x C + mask_ratio_train = 0.15 * 0.8 + src_lengths = (~padding_mask).sum(-1) + mask_ratio_observed = (tokens == self.mask_idx).sum(-1).float() / src_lengths + x = x * (1 - mask_ratio_train) / (1 - mask_ratio_observed)[:, None, None] + + x = x + self.embed_positions(tokens) + + if self.model_version == "ESM-1b": + if self.emb_layer_norm_before: + x = self.emb_layer_norm_before(x) + if padding_mask is not None: + x = x * (1 - padding_mask.unsqueeze(-1).type_as(x)) + + repr_layers = set(repr_layers) + hidden_representations = {} + if 0 in repr_layers: + hidden_representations[0] = x + + if need_head_weights: + attn_weights = [] + + # (B, T, E) => (T, B, E) + x = x.transpose(0, 1) + + if not padding_mask.any(): + padding_mask = None + + for layer_idx, layer in enumerate(self.layers): + x, attn = layer( + x, self_attn_padding_mask=padding_mask, need_head_weights=need_head_weights + ) + if (layer_idx + 1) in repr_layers: + hidden_representations[layer_idx + 1] = x.transpose(0, 1) + if need_head_weights: + # (H, B, T, T) => (B, H, T, T) + attn_weights.append(attn.transpose(1, 0)) + + if self.model_version == "ESM-1b": + x = self.emb_layer_norm_after(x) + x = x.transpose(0, 1) # (T, B, E) => (B, T, E) + + # last hidden representation should have layer norm applied + if (layer_idx + 1) in repr_layers: + hidden_representations[layer_idx + 1] = x + x = self.lm_head(x) + else: + x = F.linear(x, self.embed_out, bias=self.embed_out_bias) + x = x.transpose(0, 1) # (T, B, E) => (B, T, E) + + if return_representation: + result = {"logits": x, "representations": hidden_representations} + else: + result = {"logits": x} + if need_head_weights: + # attentions: B x L x H x T x T + attentions = torch.stack(attn_weights, 1) + if self.model_version == "ESM-1": + # ESM-1 models have an additional null-token for attention, which we remove + attentions = attentions[..., :-1] + if padding_mask is not None: + attention_mask = 1 - padding_mask.type_as(attentions) + attention_mask = attention_mask.unsqueeze(1) * attention_mask.unsqueeze(2) + attentions = attentions * attention_mask[:, None, None, :, :] + result["attentions"] = attentions + if return_contacts: + contacts = self.contact_head(tokens, attentions) + result["contacts"] = contacts + + return result + + def predict_contacts(self, tokens): + return self(tokens, return_contacts=True)["contacts"] + + @property + def num_layers(self): + return self.args.layers \ No newline at end of file diff --git a/esm/model/esm2.py b/esm/model/esm2.py new file mode 100644 index 0000000000000000000000000000000000000000..cbb023c53542024d84ddaa1dc0a214ec29e09d8a --- /dev/null +++ b/esm/model/esm2.py @@ -0,0 +1,163 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Union +import torch +import torch.nn as nn + +import esm +from esm.modules import ContactPredictionHead, ESM1bLayerNorm, RobertaLMHead, TransformerLayer + + +class ESM2(nn.Module): + def __init__( + self, + num_layers: int = 33, + embed_dim: int = 1280, + attention_heads: int = 20, + alphabet: Union[esm.data.Alphabet, str] = "ESM-1b", + token_dropout: bool = True, + ): + super().__init__() + self.num_layers = num_layers + self.embed_dim = embed_dim + self.attention_heads = attention_heads + if not isinstance(alphabet, esm.data.Alphabet): + alphabet = esm.data.Alphabet.from_architecture(alphabet) + self.alphabet = alphabet + self.alphabet_size = len(alphabet) + self.padding_idx = alphabet.padding_idx + self.mask_idx = alphabet.mask_idx + self.cls_idx = alphabet.cls_idx + self.eos_idx = alphabet.eos_idx + self.prepend_bos = alphabet.prepend_bos + self.append_eos = alphabet.append_eos + self.token_dropout = token_dropout + + self._init_submodules() + + def _init_submodules(self): + self.embed_scale = 1 + self.embed_tokens = nn.Embedding( + self.alphabet_size, + self.embed_dim, + padding_idx=self.padding_idx, + ) + + self.layers = nn.ModuleList( + [ + TransformerLayer( + self.embed_dim, + 4 * self.embed_dim, + self.attention_heads, + add_bias_kv=False, + use_esm1b_layer_norm=True, + use_rotary_embeddings=True, + ) + for _ in range(self.num_layers) + ] + ) + + self.contact_head = ContactPredictionHead( + self.num_layers * self.attention_heads, + self.prepend_bos, + self.append_eos, + eos_idx=self.eos_idx, + ) + self.emb_layer_norm_after = ESM1bLayerNorm(self.embed_dim) + + self.lm_head = RobertaLMHead( + embed_dim=self.embed_dim, + output_dim=self.alphabet_size, + weight=self.embed_tokens.weight, + ) + + def forward(self, tokens, repr_layers=[], need_head_weights=False, return_contacts=False, return_representation=False): + if return_contacts: + need_head_weights = True + + assert tokens.ndim == 2 + padding_mask = tokens.eq(self.padding_idx) # B, T + + x = self.embed_scale * self.embed_tokens(tokens) + + if self.token_dropout: + x.masked_fill_((tokens == self.mask_idx).unsqueeze(-1), 0.0) + # x: B x T x C + mask_ratio_train = 0.15 * 0.8 + src_lengths = (~padding_mask).sum(-1) + mask_ratio_observed = (tokens == self.mask_idx).sum(-1).to(x.dtype) / src_lengths + x = x * (1 - mask_ratio_train) / (1 - mask_ratio_observed)[:, None, None] + + if padding_mask is not None: + x = x * (1 - padding_mask.unsqueeze(-1).type_as(x)) + + repr_layers = set(repr_layers) + hidden_representations = {} + if 0 in repr_layers: + hidden_representations[0] = x + + if need_head_weights: + attn_weights = [] + + # (B, T, E) => (T, B, E) + x = x.transpose(0, 1) + + if not padding_mask.any(): + padding_mask = None + + for layer_idx, layer in enumerate(self.layers): + x, attn = layer( + x, + self_attn_padding_mask=padding_mask, + need_head_weights=need_head_weights, + ) + if (layer_idx + 1) in repr_layers: + hidden_representations[layer_idx + 1] = x.transpose(0, 1) + if need_head_weights: + # (H, B, T, T) => (B, H, T, T) + attn_weights.append(attn.transpose(1, 0)) +# print(x.shape) # 73, 2, 1280 + x = self.emb_layer_norm_after(x) + x = x.transpose(0, 1) # (T, B, E) => (B, T, E) + + # last hidden representation should have layer norm applied + if (layer_idx + 1) in repr_layers: + hidden_representations[layer_idx + 1] = x + x = self.lm_head(x) + + if return_representation: + result = {"logits": x, "representations": hidden_representations} + else: + result = {"logits": x} + if need_head_weights: + # attentions: B x L x H x T x T + attentions = torch.stack(attn_weights, 1) + if padding_mask is not None: + attention_mask = 1 - padding_mask.type_as(attentions) + attention_mask = attention_mask.unsqueeze(1) * attention_mask.unsqueeze(2) + attentions = attentions * attention_mask[:, None, None, :, :] + result["attentions"] = attentions + if return_contacts: + attentions_symm, contacts = self.contact_head(tokens, attentions) + result["contacts"] = contacts + result["attentions_symm"] = attentions_symm + + return result + + def predict_contacts(self, tokens): + return self(tokens, return_contacts=True)["contacts"] + + def predict_symmetric_attentions(self, tokens): + return self(tokens, return_contacts=True)["attentions_symm"] + + def predict_attentions(self, tokens): + return self(tokens, need_head_weights=True)["attentions"] + + def predict_representations(self, tokens): + return self(tokens, return_representation=True)['representations'] + + def predict_logits(self, tokens): + return self(tokens)['logits'] diff --git a/esm/model/esm2_only_secondarystructure.py b/esm/model/esm2_only_secondarystructure.py new file mode 100644 index 0000000000000000000000000000000000000000..2dae4912a55f6ee6491cbafb0fefbbf6a64f382f --- /dev/null +++ b/esm/model/esm2_only_secondarystructure.py @@ -0,0 +1,179 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Union +import torch +import torch.nn as nn + +import esm +from esm.modules import ContactPredictionHead, ESM1bLayerNorm, RobertaLMHead, TransformerLayer + + +class ESM2(nn.Module): + def __init__( + self, + num_layers: int = 33, + embed_dim: int = 1280, + attention_heads: int = 20, + alphabet: Union[esm.data.Alphabet, str] = "ESM-1b", + token_dropout: bool = True, + ): + super().__init__() + self.num_layers = num_layers + self.embed_dim = embed_dim + self.attention_heads = attention_heads + if not isinstance(alphabet, esm.data.Alphabet): + alphabet = esm.data.Alphabet.from_architecture(alphabet) + self.alphabet = alphabet + self.alphabet_size = len(alphabet) + self.padding_idx = alphabet.padding_idx + self.mask_idx = alphabet.mask_idx + self.cls_idx = alphabet.cls_idx + self.eos_idx = alphabet.eos_idx + self.prepend_bos = alphabet.prepend_bos + self.append_eos = alphabet.append_eos + self.token_dropout = token_dropout + + self._init_submodules() + + def _init_submodules(self): + self.embed_scale = 1 + self.embed_tokens = nn.Embedding( + self.alphabet_size, + self.embed_dim, + padding_idx=self.padding_idx, + ) + + self.layers = nn.ModuleList( + [ + TransformerLayer( + self.embed_dim, + 4 * self.embed_dim, + self.attention_heads, + add_bias_kv=False, + use_esm1b_layer_norm=True, + use_rotary_embeddings=True, + ) + for _ in range(self.num_layers) + ] + ) + + self.contact_head = ContactPredictionHead( + self.num_layers * self.attention_heads, + self.prepend_bos, + self.append_eos, + eos_idx=self.eos_idx, + ) + self.emb_layer_norm_after = ESM1bLayerNorm(self.embed_dim) + + self.lm_head = RobertaLMHead( + embed_dim=self.embed_dim, + output_dim=self.alphabet_size, + weight=self.embed_tokens.weight, + ) +# self.supervised_linear = nn.Linear(self.embed_dim, 1) + self.structure_linear = nn.Linear(self.embed_dim, 3) + def forward(self, tokens, repr_layers=[], need_head_weights=True, return_contacts=True, return_representation=True, return_attentions_symm = False, return_attentions = False): + if return_contacts: + need_head_weights = True + + assert tokens.ndim == 2 + padding_mask = tokens.eq(self.padding_idx) # B, T + + x = self.embed_scale * self.embed_tokens(tokens) + + if self.token_dropout: + x.masked_fill_((tokens == self.mask_idx).unsqueeze(-1), 0.0) + #print(f'tokens = {tokens}') + #print(f'self.mask_idx = {self.mask_idx}') + #print('x.shape = ', x.shape) + # x: B x T x C + mask_ratio_train = 0.15 * 0.8 + src_lengths = (~padding_mask).sum(-1) + #print(f'mask_ratio_train = {mask_ratio_train}') + #print(f'padding_mask = {padding_mask}') + #print(f'src_lengths = {src_lengths}') + mask_ratio_observed = (tokens == self.mask_idx).sum(-1).to(x.dtype) / src_lengths + #print('mask_ratio_observed = ',mask_ratio_observed) + x = x * (1 - mask_ratio_train) / (1 - mask_ratio_observed)[:, None, None] + #print(f'x.shape = {x.shape}:\n', x) + if padding_mask is not None: + x = x * (1 - padding_mask.unsqueeze(-1).type_as(x)) + #print(f'x.shape = {x.shape}:\n', x) + repr_layers = set(repr_layers) + hidden_representations = {} + if 0 in repr_layers: + hidden_representations[0] = x + + if need_head_weights: + attn_weights = [] + + # (B, T, E) => (T, B, E) + x = x.transpose(0, 1) + + if not padding_mask.any(): + padding_mask = None + + for layer_idx, layer in enumerate(self.layers): + x, attn = layer( + x, + self_attn_padding_mask=padding_mask, + need_head_weights=need_head_weights, + ) + if (layer_idx + 1) in repr_layers: + hidden_representations[layer_idx + 1] = x.transpose(0, 1) + if need_head_weights: + # (H, B, T, T) => (B, H, T, T) + attn_weights.append(attn.transpose(1, 0)) +# print(x.shape) # 73, 2, 1280 + x = self.emb_layer_norm_after(x) + x = x.transpose(0, 1) # (T, B, E) => (B, T, E) + + # last hidden representation should have layer norm applied + if (layer_idx + 1) in repr_layers: + hidden_representations[layer_idx + 1] = x +# x_supervised = self.supervised_linear(x[:,0,:]) + x_structure = self.structure_linear(x) + x = self.lm_head(x) + + if return_representation: + result = {"logits": x, "logits_structure": x_structure, "representations": hidden_representations} + else: + result = {"logits": x, "logits_structure": x_structure} + if need_head_weights: + # attentions: B x L x H x T x T + attentions = torch.stack(attn_weights, 1) + if padding_mask is not None: + attention_mask = 1 - padding_mask.type_as(attentions) + attention_mask = attention_mask.unsqueeze(1) * attention_mask.unsqueeze(2) + attentions = attentions * attention_mask[:, None, None, :, :] + if return_attentions: result["attentions"] = attentions + if return_contacts: + attentions_symm, contacts = self.contact_head(tokens, attentions) + result["contacts"] = contacts + if return_attentions_symm: result["attentions_symm"] = attentions_symm + + return result + + def predict_contacts(self, tokens): + return self(tokens, return_contacts=True)["contacts"] + + def predict_symmetric_attentions(self, tokens): + return self(tokens, return_contacts=True)["attentions_symm"] + + def predict_attentions(self, tokens): + return self(tokens, need_head_weights=True)["attentions"] + + def predict_representations(self, tokens): + return self(tokens, return_representation=True)['representations'] + + def predict_logits(self, tokens): + return self(tokens)['logits'] + +# def predict_logits_supervised(self, tokens): +# return self(tokens)['logits_supervised'] + + def predict_logits_structure(self, tokens): + return self(tokens)['logits_structure'] diff --git a/esm/model/esm2_secondarystructure.py b/esm/model/esm2_secondarystructure.py new file mode 100644 index 0000000000000000000000000000000000000000..a5f08e7b98a66213d9097b44002aa0f865feb9d8 --- /dev/null +++ b/esm/model/esm2_secondarystructure.py @@ -0,0 +1,179 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Union +import torch +import torch.nn as nn + +import esm +from esm.modules import ContactPredictionHead, ESM1bLayerNorm, RobertaLMHead, TransformerLayer +# ```该代码定义了一个名为 ESM2 的 PyTorch 模型,继承自 nn.Module。在 __init__ 方法中,定义了一些超参数,例如 num_layers、embed_dim、attention_heads 等等。同时,它还初始化了一些子模块,例如 Embedding 层 embed_tokens、一系列 Transformer 层 layers、预测接触的 ContactPredictionHead 层 contact_head,以及一些线性层 lm_head、supervised_linear、structure_linear 等。该模型的前向传播在 forward 方法中定义,接收一个表示序列的 token 序列 tokens,返回预测的标签和其他附加信息。``` + +class ESM2(nn.Module): + def __init__( + self, + num_layers: int = 33, + embed_dim: int = 1280, + attention_heads: int = 20, + alphabet: Union[esm.data.Alphabet, str] = "ESM-1b", + token_dropout: bool = True, + ): + super().__init__() + self.num_layers = num_layers + self.embed_dim = embed_dim + self.attention_heads = attention_heads + if not isinstance(alphabet, esm.data.Alphabet): + alphabet = esm.data.Alphabet.from_architecture(alphabet) + self.alphabet = alphabet + self.alphabet_size = len(alphabet) + self.padding_idx = alphabet.padding_idx + self.mask_idx = alphabet.mask_idx + self.cls_idx = alphabet.cls_idx + self.eos_idx = alphabet.eos_idx + self.prepend_bos = alphabet.prepend_bos + self.append_eos = alphabet.append_eos + self.token_dropout = token_dropout + + self._init_submodules() + + def _init_submodules(self): + self.embed_scale = 1 + self.embed_tokens = nn.Embedding( + self.alphabet_size, + self.embed_dim, + padding_idx=self.padding_idx, + ) + + self.layers = nn.ModuleList( + [ + TransformerLayer( + self.embed_dim, + 4 * self.embed_dim, + self.attention_heads, + add_bias_kv=False, + use_esm1b_layer_norm=True, + use_rotary_embeddings=True, + ) + for _ in range(self.num_layers) + ] + ) + + self.contact_head = ContactPredictionHead( + self.num_layers * self.attention_heads, + self.prepend_bos, + self.append_eos, + eos_idx=self.eos_idx, + ) + self.emb_layer_norm_after = ESM1bLayerNorm(self.embed_dim) + + self.lm_head = RobertaLMHead( + embed_dim=self.embed_dim, + output_dim=self.alphabet_size, + weight=self.embed_tokens.weight, + ) + self.supervised_linear = nn.Linear(self.embed_dim, 1) + self.structure_linear = nn.Linear(self.embed_dim, 3) + def forward(self, tokens, repr_layers=[], need_head_weights=True, return_contacts=True, return_representation=True, return_attentions_symm = False, return_attentions = False): + if return_contacts: + need_head_weights = True + + assert tokens.ndim == 2 + padding_mask = tokens.eq(self.padding_idx) # B, T + + x = self.embed_scale * self.embed_tokens(tokens) + + if self.token_dropout: + x.masked_fill_((tokens == self.mask_idx).unsqueeze(-1), 0.0) + #print(f'tokens = {tokens}') + #print(f'self.mask_idx = {self.mask_idx}') + #print('x.shape = ', x.shape) + # x: B x T x C + mask_ratio_train = 0.15 * 0.8 + src_lengths = (~padding_mask).sum(-1) + #print(f'mask_ratio_train = {mask_ratio_train}') + #print(f'padding_mask = {padding_mask}') + #print(f'src_lengths = {src_lengths}') + mask_ratio_observed = (tokens == self.mask_idx).sum(-1).to(x.dtype) / src_lengths + #print('mask_ratio_observed = ',mask_ratio_observed) + x = x * (1 - mask_ratio_train) / (1 - mask_ratio_observed)[:, None, None] + #print(f'x.shape = {x.shape}:\n', x) + if padding_mask is not None: + x = x * (1 - padding_mask.unsqueeze(-1).type_as(x)) + #print(f'x.shape = {x.shape}:\n', x) + repr_layers = set(repr_layers) + hidden_representations = {} + if 0 in repr_layers: + hidden_representations[0] = x + + if need_head_weights: + attn_weights = [] + + # (B, T, E) => (T, B, E) + x = x.transpose(0, 1) + + if not padding_mask.any(): + padding_mask = None + + for layer_idx, layer in enumerate(self.layers): + x, attn = layer( + x, + self_attn_padding_mask=padding_mask, + need_head_weights=need_head_weights, + ) + if (layer_idx + 1) in repr_layers: + hidden_representations[layer_idx + 1] = x.transpose(0, 1) + if need_head_weights: + # (H, B, T, T) => (B, H, T, T) + attn_weights.append(attn.transpose(1, 0)) +# print(x.shape) # 73, 2, 1280 + x = self.emb_layer_norm_after(x) + x = x.transpose(0, 1) # (T, B, E) => (B, T, E) + + # last hidden representation should have layer norm applied + if (layer_idx + 1) in repr_layers: + hidden_representations[layer_idx + 1] = x + x_supervised = self.supervised_linear(x[:,0,:]) + x_structure = self.structure_linear(x) + x = self.lm_head(x) + + if return_representation: + result = {"logits": x, "logits_supervised": x_supervised, "logits_structure": x_structure, "representations": hidden_representations} + else: + result = {"logits": x, "logits_supervised": x_supervised, "logits_structure": x_structure} + if need_head_weights: + # attentions: B x L x H x T x T + attentions = torch.stack(attn_weights, 1) + if padding_mask is not None: + attention_mask = 1 - padding_mask.type_as(attentions) + attention_mask = attention_mask.unsqueeze(1) * attention_mask.unsqueeze(2) + attentions = attentions * attention_mask[:, None, None, :, :] + if return_attentions: result["attentions"] = attentions + if return_contacts: + attentions_symm, contacts = self.contact_head(tokens, attentions) + result["contacts"] = contacts + if return_attentions_symm: result["attentions_symm"] = attentions_symm + + return result + + def predict_contacts(self, tokens): + return self(tokens, return_contacts=True)["contacts"] + + def predict_symmetric_attentions(self, tokens): + return self(tokens, return_contacts=True)["attentions_symm"] + + def predict_attentions(self, tokens): + return self(tokens, need_head_weights=True)["attentions"] + + def predict_representations(self, tokens): + return self(tokens, return_representation=True)['representations'] + + def predict_logits(self, tokens): + return self(tokens)['logits'] + + def predict_logits_supervised(self, tokens): + return self(tokens)['logits_supervised'] + + def predict_logits_structure(self, tokens): + return self(tokens)['logits_structure'] diff --git a/esm/model/esm2_supervised.py b/esm/model/esm2_supervised.py new file mode 100644 index 0000000000000000000000000000000000000000..ee4965b883399f9128859e6334cb4742fbc881e4 --- /dev/null +++ b/esm/model/esm2_supervised.py @@ -0,0 +1,174 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Union +import torch +import torch.nn as nn + +import esm +from esm.modules import ContactPredictionHead, ESM1bLayerNorm, RobertaLMHead, TransformerLayer + + +class ESM2(nn.Module): + def __init__( + self, + num_layers: int = 33, + embed_dim: int = 1280, + attention_heads: int = 20, + alphabet: Union[esm.data.Alphabet, str] = "ESM-1b", + token_dropout: bool = True, + ): + super().__init__() + self.num_layers = num_layers + self.embed_dim = embed_dim + self.attention_heads = attention_heads + if not isinstance(alphabet, esm.data.Alphabet): + alphabet = esm.data.Alphabet.from_architecture(alphabet) + self.alphabet = alphabet + self.alphabet_size = len(alphabet) + self.padding_idx = alphabet.padding_idx + self.mask_idx = alphabet.mask_idx + self.cls_idx = alphabet.cls_idx + self.eos_idx = alphabet.eos_idx + self.prepend_bos = alphabet.prepend_bos + self.append_eos = alphabet.append_eos + self.token_dropout = token_dropout + + self._init_submodules() + + def _init_submodules(self): + self.embed_scale = 1 + self.embed_tokens = nn.Embedding( + self.alphabet_size, + self.embed_dim, + padding_idx=self.padding_idx, + ) + + self.layers = nn.ModuleList( + [ + TransformerLayer( + self.embed_dim, + 4 * self.embed_dim, + self.attention_heads, + add_bias_kv=False, + use_esm1b_layer_norm=True, + use_rotary_embeddings=True, + ) + for _ in range(self.num_layers) + ] + ) + + self.contact_head = ContactPredictionHead( + self.num_layers * self.attention_heads, + self.prepend_bos, + self.append_eos, + eos_idx=self.eos_idx, + ) + self.emb_layer_norm_after = ESM1bLayerNorm(self.embed_dim) + + self.lm_head = RobertaLMHead( + embed_dim=self.embed_dim, + output_dim=self.alphabet_size, + weight=self.embed_tokens.weight, + ) + self.supervised_linear = nn.Linear(self.embed_dim, 1) + def forward(self, tokens, repr_layers=[], need_head_weights=True, return_contacts=True, return_representation=True, return_attentions_symm = False, return_attentions = False): + if return_contacts: + need_head_weights = True + + assert tokens.ndim == 2 + padding_mask = tokens.eq(self.padding_idx) # B, T + + x = self.embed_scale * self.embed_tokens(tokens) + + if self.token_dropout: + x.masked_fill_((tokens == self.mask_idx).unsqueeze(-1), 0.0) + #print(f'tokens = {tokens}') + #print(f'self.mask_idx = {self.mask_idx}') + #print('x.shape = ', x.shape) + # x: B x T x C + mask_ratio_train = 0.15 * 0.8 + src_lengths = (~padding_mask).sum(-1) + #print(f'mask_ratio_train = {mask_ratio_train}') + #print(f'padding_mask = {padding_mask}') + #print(f'src_lengths = {src_lengths}') + mask_ratio_observed = (tokens == self.mask_idx).sum(-1).to(x.dtype) / src_lengths + #print('mask_ratio_observed = ',mask_ratio_observed) + x = x * (1 - mask_ratio_train) / (1 - mask_ratio_observed)[:, None, None] + #print(f'x.shape = {x.shape}:\n', x) + if padding_mask is not None: + x = x * (1 - padding_mask.unsqueeze(-1).type_as(x)) + #print(f'x.shape = {x.shape}:\n', x) + repr_layers = set(repr_layers) + hidden_representations = {} + if 0 in repr_layers: + hidden_representations[0] = x + + if need_head_weights: + attn_weights = [] + + # (B, T, E) => (T, B, E) + x = x.transpose(0, 1) + + if not padding_mask.any(): + padding_mask = None + + for layer_idx, layer in enumerate(self.layers): + x, attn = layer( + x, + self_attn_padding_mask=padding_mask, + need_head_weights=need_head_weights, + ) + if (layer_idx + 1) in repr_layers: + hidden_representations[layer_idx + 1] = x.transpose(0, 1) + if need_head_weights: + # (H, B, T, T) => (B, H, T, T) + attn_weights.append(attn.transpose(1, 0)) +# print(x.shape) # 73, 2, 1280 + x = self.emb_layer_norm_after(x) + x = x.transpose(0, 1) # (T, B, E) => (B, T, E) + + # last hidden representation should have layer norm applied + if (layer_idx + 1) in repr_layers: + hidden_representations[layer_idx + 1] = x + x_supervised = self.supervised_linear(x[:,0,:]) + x = self.lm_head(x) + + if return_representation: + result = {"logits": x, "logits_supervised": x_supervised, "representations": hidden_representations} + else: + result = {"logits": x, "logits_supervised": x_supervised} + if need_head_weights: + # attentions: B x L x H x T x T + attentions = torch.stack(attn_weights, 1) + if padding_mask is not None: + attention_mask = 1 - padding_mask.type_as(attentions) + attention_mask = attention_mask.unsqueeze(1) * attention_mask.unsqueeze(2) + attentions = attentions * attention_mask[:, None, None, :, :] + if return_attentions: result["attentions"] = attentions + if return_contacts: + attentions_symm, contacts = self.contact_head(tokens, attentions) + result["contacts"] = contacts + if return_attentions_symm: result["attentions_symm"] = attentions_symm + + return result + + def predict_contacts(self, tokens): + return self(tokens, return_contacts=True)["contacts"] + + def predict_symmetric_attentions(self, tokens): + return self(tokens, return_contacts=True)["attentions_symm"] + + def predict_attentions(self, tokens): + return self(tokens, need_head_weights=True)["attentions"] + + def predict_representations(self, tokens): + return self(tokens, return_representation=True)['representations'] + + def predict_logits(self, tokens): + return self(tokens)['logits'] + + def predict_logits_supervised(self, tokens): + return self(tokens)['logits_supervised'] diff --git a/esm/model/msa_transformer.py b/esm/model/msa_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..08c99cd24e1e13a80d59928b573ea892cf27bef0 --- /dev/null +++ b/esm/model/msa_transformer.py @@ -0,0 +1,238 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn + +from ..modules import ( + AxialTransformerLayer, + LearnedPositionalEmbedding, + RobertaLMHead, + ESM1bLayerNorm, + ContactPredictionHead, +) + +from ..axial_attention import RowSelfAttention, ColumnSelfAttention + + + +class MSATransformer(nn.Module): + @classmethod + def add_args(cls, parser): + # fmt: off + parser.add_argument( + "--num_layers", + default=12, + type=int, + metavar="N", + help="number of layers" + ) + parser.add_argument( + "--embed_dim", + default=768, + type=int, + metavar="N", + help="embedding dimension" + ) + parser.add_argument( + "--logit_bias", + action="store_true", + help="whether to apply bias to logits" + ) + parser.add_argument( + "--ffn_embed_dim", + default=3072, + type=int, + metavar="N", + help="embedding dimension for FFN", + ) + parser.add_argument( + "--attention_heads", + default=12, + type=int, + metavar="N", + help="number of attention heads", + ) + parser.add_argument( + "--dropout", + default=0.1, + type=float, + help="Dropout to apply." + ) + parser.add_argument( + "--attention_dropout", + default=0.1, + type=float, + help="Dropout to apply." + ) + parser.add_argument( + "--activation_dropout", + default=0.1, + type=float, + help="Dropout to apply." + ) + parser.add_argument( + "--max_tokens_per_msa", + default=2 ** 14, + type=int, + help=( + "Used during inference to batch attention computations in a single " + "forward pass. This allows increased input sizes with less memory." + ), + ) + # fmt: on + + def __init__(self, args, alphabet): + super().__init__() + self.args = args + self.alphabet_size = len(alphabet) + self.padding_idx = alphabet.padding_idx + self.mask_idx = alphabet.mask_idx + self.cls_idx = alphabet.cls_idx + self.eos_idx = alphabet.eos_idx + self.prepend_bos = alphabet.prepend_bos + self.append_eos = alphabet.append_eos + + self.embed_tokens = nn.Embedding( + self.alphabet_size, self.args.embed_dim, padding_idx=self.padding_idx + ) + + if getattr(self.args, "embed_positions_msa", False): + emb_dim = getattr(self.args, "embed_positions_msa_dim", self.args.embed_dim) + self.msa_position_embedding = nn.Parameter( + 0.01 * torch.randn(1, 1024, 1, emb_dim), + requires_grad=True, + ) + else: + self.register_parameter("msa_position_embedding", None) + + self.dropout_module = nn.Dropout(self.args.dropout) + self.layers = nn.ModuleList( + [ + AxialTransformerLayer( + self.args.embed_dim, + self.args.ffn_embed_dim, + self.args.attention_heads, + self.args.dropout, + self.args.attention_dropout, + self.args.activation_dropout, + getattr(self.args, "max_tokens_per_msa", self.args.max_tokens), + ) + for _ in range(self.args.layers) + ] + ) + + self.contact_head = ContactPredictionHead( + self.args.layers * self.args.attention_heads, + self.prepend_bos, + self.append_eos, + eos_idx=self.eos_idx, + ) + self.embed_positions = LearnedPositionalEmbedding( + self.args.max_positions, + self.args.embed_dim, + self.padding_idx, + ) + self.emb_layer_norm_before = ESM1bLayerNorm(self.args.embed_dim) + self.emb_layer_norm_after = ESM1bLayerNorm(self.args.embed_dim) + self.lm_head = RobertaLMHead( + embed_dim=self.args.embed_dim, + output_dim=self.alphabet_size, + weight=self.embed_tokens.weight, + ) + + def forward(self, tokens, repr_layers=[], need_head_weights=False, return_contacts=False): + if return_contacts: + need_head_weights = True + + assert tokens.ndim == 3 + batch_size, num_alignments, seqlen = tokens.size() + padding_mask = tokens.eq(self.padding_idx) # B, R, C + if not padding_mask.any(): + padding_mask = None + + x = self.embed_tokens(tokens) + x += self.embed_positions(tokens.view(batch_size * num_alignments, seqlen)).view(x.size()) + if self.msa_position_embedding is not None: + if x.size(1) > 1024: + raise RuntimeError( + "Using model with MSA position embedding trained on maximum MSA " + f"depth of 1024, but received {x.size(1)} alignments." + ) + x += self.msa_position_embedding[:, :num_alignments] + + x = self.emb_layer_norm_before(x) + + x = self.dropout_module(x) + + if padding_mask is not None: + x = x * (1 - padding_mask.unsqueeze(-1).type_as(x)) + + repr_layers = set(repr_layers) + hidden_representations = {} + if 0 in repr_layers: + hidden_representations[0] = x + + if need_head_weights: + row_attn_weights = [] + col_attn_weights = [] + + # B x R x C x D -> R x C x B x D + x = x.permute(1, 2, 0, 3) + + for layer_idx, layer in enumerate(self.layers): + x = layer( + x, + self_attn_padding_mask=padding_mask, + need_head_weights=need_head_weights, + ) + if need_head_weights: + x, col_attn, row_attn = x + # H x C x B x R x R -> B x H x C x R x R + col_attn_weights.append(col_attn.permute(2, 0, 1, 3, 4)) + # H x B x C x C -> B x H x C x C + row_attn_weights.append(row_attn.permute(1, 0, 2, 3)) + if (layer_idx + 1) in repr_layers: + hidden_representations[layer_idx + 1] = x.permute(2, 0, 1, 3) + + x = self.emb_layer_norm_after(x) + x = x.permute(2, 0, 1, 3) # R x C x B x D -> B x R x C x D + + # last hidden representation should have layer norm applied + if (layer_idx + 1) in repr_layers: + hidden_representations[layer_idx + 1] = x + x = self.lm_head(x) + + result = {"logits": x, "representations": hidden_representations} + if need_head_weights: + # col_attentions: B x L x H x C x R x R + col_attentions = torch.stack(col_attn_weights, 1) + # row_attentions: B x L x H x C x C + row_attentions = torch.stack(row_attn_weights, 1) + result["col_attentions"] = col_attentions + result["row_attentions"] = row_attentions + if return_contacts: + contacts = self.contact_head(tokens, row_attentions) + result["contacts"] = contacts + + return result + + def predict_contacts(self, tokens): + return self(tokens, return_contacts=True)["contacts"] + + @property + def num_layers(self): + return self.args.layers + + def max_tokens_per_msa_(self, value: int) -> None: + """The MSA Transformer automatically batches attention computations when + gradients are disabled to allow you to pass in larger MSAs at test time than + you can fit in GPU memory. By default this occurs when more than 2^14 tokens + are passed in the input MSA. You can set this value to infinity to disable + this behavior. + """ + for module in self.modules(): + if isinstance(module, (RowSelfAttention, ColumnSelfAttention)): + module.max_tokens_per_msa = value \ No newline at end of file diff --git a/esm/modules.py b/esm/modules.py new file mode 100644 index 0000000000000000000000000000000000000000..5d7ee4741b23b6b88ead11c56ff20e978db678d3 --- /dev/null +++ b/esm/modules.py @@ -0,0 +1,419 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import math +from typing import Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .multihead_attention import MultiheadAttention # noqa +from .axial_attention import ColumnSelfAttention, RowSelfAttention + + +def gelu(x): + """Implementation of the gelu activation function. + For information: OpenAI GPT's gelu is slightly different + (and gives slightly different results): + 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) + """ + return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0))) + + +def symmetrize(x): + "Make layer symmetric in final two dimensions, used for contact prediction." + return x + x.transpose(-1, -2) + + +def apc(x): + "Perform average product correct, used for contact prediction." + a1 = x.sum(-1, keepdims=True) + a2 = x.sum(-2, keepdims=True) + a12 = x.sum((-1, -2), keepdims=True) + + avg = a1 * a2 + avg.div_(a12) # in-place to reduce memory + normalized = x - avg + return normalized + + +class ESM1LayerNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-12, affine=True): + """Construct a layernorm layer in the TF style (eps inside the sqrt).""" + super().__init__() + self.hidden_size = (hidden_size,) if isinstance(hidden_size, int) else tuple(hidden_size) + self.eps = eps + self.affine = bool(affine) + if self.affine: + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.bias = nn.Parameter(torch.zeros(hidden_size)) + else: + self.weight, self.bias = None, None + + def forward(self, x): + dims = tuple(-(i + 1) for i in range(len(self.hidden_size))) + means = x.mean(dims, keepdim=True) + x_zeromean = x - means + variances = x_zeromean.pow(2).mean(dims, keepdim=True) + x = x_zeromean / torch.sqrt(variances + self.eps) + if self.affine: + x = (self.weight * x) + self.bias + return x + + +try: + from apex.normalization import FusedLayerNorm as _FusedLayerNorm + + class ESM1bLayerNorm(_FusedLayerNorm): + @torch.jit.unused + def forward(self, x): + if not x.is_cuda: + return super().forward(x) + else: + with torch.cuda.device(x.device): + return super().forward(x) + +except ImportError: + from torch.nn import LayerNorm as ESM1bLayerNorm + + +class TransformerLayer(nn.Module): + """Transformer layer block.""" + + def __init__( + self, + embed_dim, + ffn_embed_dim, + attention_heads, + add_bias_kv=True, + use_esm1b_layer_norm=False, + use_rotary_embeddings: bool = False, + ): + super().__init__() + self.embed_dim = embed_dim + self.ffn_embed_dim = ffn_embed_dim + self.attention_heads = attention_heads + self.use_rotary_embeddings = use_rotary_embeddings + self._init_submodules(add_bias_kv, use_esm1b_layer_norm) + + def _init_submodules(self, add_bias_kv, use_esm1b_layer_norm): + BertLayerNorm = ESM1bLayerNorm if use_esm1b_layer_norm else ESM1LayerNorm + + self.self_attn = MultiheadAttention( + self.embed_dim, + self.attention_heads, + add_bias_kv=add_bias_kv, + add_zero_attn=False, + use_rotary_embeddings=self.use_rotary_embeddings, + ) + self.self_attn_layer_norm = BertLayerNorm(self.embed_dim) + + self.fc1 = nn.Linear(self.embed_dim, self.ffn_embed_dim) + self.fc2 = nn.Linear(self.ffn_embed_dim, self.embed_dim) + + self.final_layer_norm = BertLayerNorm(self.embed_dim) + + def forward( + self, x, self_attn_mask=None, self_attn_padding_mask=None, need_head_weights=False + ): + residual = x + x = self.self_attn_layer_norm(x) + x, attn = self.self_attn( + query=x, + key=x, + value=x, + key_padding_mask=self_attn_padding_mask, + need_weights=True, + need_head_weights=need_head_weights, + attn_mask=self_attn_mask, + ) + x = residual + x + + residual = x + x = self.final_layer_norm(x) + x = gelu(self.fc1(x)) + x = self.fc2(x) + x = residual + x + #print(f'------{attn.half().dtype}-----') + + return x, attn#.half() ### + + +class AxialTransformerLayer(nn.Module): + """Implements an Axial MSA Transformer block.""" + + def __init__( + self, + embedding_dim: int = 768, + ffn_embedding_dim: int = 3072, + num_attention_heads: int = 8, + dropout: float = 0.1, + attention_dropout: float = 0.1, + activation_dropout: float = 0.1, + max_tokens_per_msa: int = 2**14, + ) -> None: + super().__init__() + + # Initialize parameters + self.embedding_dim = embedding_dim + self.dropout_prob = dropout + + row_self_attention = RowSelfAttention( + embedding_dim, + num_attention_heads, + dropout=dropout, + max_tokens_per_msa=max_tokens_per_msa, + ) + + column_self_attention = ColumnSelfAttention( + embedding_dim, + num_attention_heads, + dropout=dropout, + max_tokens_per_msa=max_tokens_per_msa, + ) + + feed_forward_layer = FeedForwardNetwork( + embedding_dim, + ffn_embedding_dim, + activation_dropout=activation_dropout, + max_tokens_per_msa=max_tokens_per_msa, + ) + + self.row_self_attention = self.build_residual(row_self_attention) + self.column_self_attention = self.build_residual(column_self_attention) + self.feed_forward_layer = self.build_residual(feed_forward_layer) + + def build_residual(self, layer: nn.Module): + return NormalizedResidualBlock( + layer, + self.embedding_dim, + self.dropout_prob, + ) + + def forward( + self, + x: torch.Tensor, + self_attn_mask: Optional[torch.Tensor] = None, + self_attn_padding_mask: Optional[torch.Tensor] = None, + need_head_weights: bool = False, + ): + """ + LayerNorm is applied either before or after the self-attention/ffn + modules similar to the original Transformer implementation. + """ + x, row_attn = self.row_self_attention( + x, + self_attn_mask=self_attn_mask, + self_attn_padding_mask=self_attn_padding_mask, + ) + x, column_attn = self.column_self_attention( + x, + self_attn_mask=self_attn_mask, + self_attn_padding_mask=self_attn_padding_mask, + ) + x = self.feed_forward_layer(x) + if need_head_weights: + return x, column_attn, row_attn + else: + return x + + +class LearnedPositionalEmbedding(nn.Embedding): + """ + This module learns positional embeddings up to a fixed maximum size. + Padding ids are ignored by either offsetting based on padding_idx + or by setting padding_idx to None and ensuring that the appropriate + position ids are passed to the forward function. + """ + + def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int): + if padding_idx is not None: + num_embeddings_ = num_embeddings + padding_idx + 1 + else: + num_embeddings_ = num_embeddings + super().__init__(num_embeddings_, embedding_dim, padding_idx) + self.max_positions = num_embeddings + + def forward(self, input: torch.Tensor): + """Input is expected to be of size [bsz x seqlen].""" + if input.size(1) > self.max_positions: + raise ValueError( + f"Sequence length {input.size(1)} above maximum " + f" sequence length of {self.max_positions}" + ) + mask = input.ne(self.padding_idx).int() + positions = (torch.cumsum(mask, dim=1).type_as(mask) * mask).long() + self.padding_idx + return F.embedding( + positions, + self.weight, + self.padding_idx, + self.max_norm, + self.norm_type, + self.scale_grad_by_freq, + self.sparse, + ) + + +class SinusoidalPositionalEmbedding(nn.Module): + def __init__(self, embed_dim, padding_idx, learned=False): + super().__init__() + self.embed_dim = embed_dim + self.padding_idx = padding_idx + self.register_buffer("_float_tensor", torch.FloatTensor(1)) + self.weights = None + + def forward(self, x): + bsz, seq_len = x.shape + max_pos = self.padding_idx + 1 + seq_len + if self.weights is None or max_pos > self.weights.size(0): + self.weights = self.get_embedding(max_pos) + self.weights = self.weights.type_as(self._float_tensor) + + positions = self.make_positions(x) + return self.weights.index_select(0, positions.view(-1)).view(bsz, seq_len, -1).detach() + + def make_positions(self, x): + mask = x.ne(self.padding_idx) + range_buf = torch.arange(x.size(1), device=x.device).expand_as(x) + self.padding_idx + 1 + positions = range_buf.expand_as(x) + return positions * mask.long() + self.padding_idx * (1 - mask.long()) + + def get_embedding(self, num_embeddings): + half_dim = self.embed_dim // 2 + emb = math.log(10000) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=torch.float) * -emb) + emb = torch.arange(num_embeddings, dtype=torch.float).unsqueeze(1) * emb.unsqueeze(0) + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view(num_embeddings, -1) + if self.embed_dim % 2 == 1: + # zero pad + emb = torch.cat([emb, torch.zeros(num_embeddings, 1)], dim=1) + if self.padding_idx is not None: + emb[self.padding_idx, :] = 0 + return emb + + +class RobertaLMHead(nn.Module): + """Head for masked language modeling.""" + + def __init__(self, embed_dim, output_dim, weight): + super().__init__() + self.dense = nn.Linear(embed_dim, embed_dim) + self.layer_norm = ESM1bLayerNorm(embed_dim) + self.weight = weight + self.bias = nn.Parameter(torch.zeros(output_dim)) + + def forward(self, features): + x = self.dense(features) + x = gelu(x) + x = self.layer_norm(x) + # project back to size of vocabulary with bias + x = F.linear(x, self.weight) + self.bias + return x + + +class ContactPredictionHead(nn.Module): + """Performs symmetrization, apc, and computes a logistic regression on the output features""" + + def __init__( + self, + in_features: int, + prepend_bos: bool, + append_eos: bool, + bias=True, + eos_idx: Optional[int] = None, + ): + super().__init__() + self.in_features = in_features + self.prepend_bos = prepend_bos + self.append_eos = append_eos + if append_eos and eos_idx is None: + raise ValueError("Using an alphabet with eos token, but no eos token was passed in.") + self.eos_idx = eos_idx + self.regression = nn.Linear(in_features, 1, bias) + self.activation = nn.Sigmoid() + + def forward(self, tokens, attentions): + # remove eos token attentions + if self.append_eos: + eos_mask = tokens.ne(self.eos_idx).to(attentions) + eos_mask = eos_mask.unsqueeze(1) * eos_mask.unsqueeze(2) + attentions = attentions * eos_mask[:, None, None, :, :] + attentions = attentions[..., :-1, :-1] + # remove cls token attentions + if self.prepend_bos: + attentions = attentions[..., 1:, 1:] + batch_size, layers, heads, seqlen, _ = attentions.size() + attentions = attentions.view(batch_size, layers * heads, seqlen, seqlen) + + # features: B x C x T x T + attentions = attentions.to( + self.regression.weight.device + ) # attentions always float32, may need to convert to float16 + attentions = apc(symmetrize(attentions)) + attentions = attentions.permute(0, 2, 3, 1) + #print(f'----------{attentions.dtype, attentions.float().dtype}----') + return attentions.sum(dim=-1), self.activation(self.regression(attentions).squeeze(3))#float().to(self.regression.weight.device)).squeeze(3)) + + +class NormalizedResidualBlock(nn.Module): + def __init__( + self, + layer: nn.Module, + embedding_dim: int, + dropout: float = 0.1, + ): + super().__init__() + self.embedding_dim = embedding_dim + + self.layer = layer + self.dropout_module = nn.Dropout( + dropout, + ) + self.layer_norm = ESM1bLayerNorm(self.embedding_dim) + + def forward(self, x, *args, **kwargs): + residual = x + x = self.layer_norm(x) + outputs = self.layer(x, *args, **kwargs) + if isinstance(outputs, tuple): + x, *out = outputs + else: + x = outputs + out = None + + x = self.dropout_module(x) + x = residual + x + + if out is not None: + return (x,) + tuple(out) + else: + return x + + +class FeedForwardNetwork(nn.Module): + def __init__( + self, + embedding_dim: int, + ffn_embedding_dim: int, + activation_dropout: float = 0.1, + max_tokens_per_msa: int = 2**14, + ): + super().__init__() + self.embedding_dim = embedding_dim + self.ffn_embedding_dim = ffn_embedding_dim + self.max_tokens_per_msa = max_tokens_per_msa + self.activation_fn = nn.GELU() + self.activation_dropout_module = nn.Dropout( + activation_dropout, + ) + self.fc1 = nn.Linear(embedding_dim, ffn_embedding_dim) + self.fc2 = nn.Linear(ffn_embedding_dim, embedding_dim) + + def forward(self, x): + x = self.activation_fn(self.fc1(x)) + x = self.activation_dropout_module(x) + x = self.fc2(x) + return x diff --git a/esm/multihead_attention.py b/esm/multihead_attention.py new file mode 100644 index 0000000000000000000000000000000000000000..9b0e156dd0d83420f891b7e83b0e6467955bf043 --- /dev/null +++ b/esm/multihead_attention.py @@ -0,0 +1,506 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import math +from typing import Dict, Optional, Tuple + +import torch +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn import Parameter +from esm.rotary_embedding import RotaryEmbedding + +import uuid + + +def utils_softmax(x, dim: int, onnx_trace: bool = False): + if onnx_trace: + return F.softmax(x.float(), dim=dim) + else: + return F.softmax(x, dim=dim, dtype=torch.float32) + + +class FairseqIncrementalState(object): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.init_incremental_state() + + def init_incremental_state(self): + self._incremental_state_id = str(uuid.uuid4()) + + def _get_full_incremental_state_key(self, key: str) -> str: + return "{}.{}".format(self._incremental_state_id, key) + + def get_incremental_state( + self, + incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]], + key: str, + ) -> Optional[Dict[str, Optional[Tensor]]]: + """Helper for getting incremental state for an nn.Module.""" + full_key = self._get_full_incremental_state_key(key) + if incremental_state is None or full_key not in incremental_state: + return None + return incremental_state[full_key] + + def set_incremental_state( + self, + incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]], + key: str, + value: Dict[str, Optional[Tensor]], + ) -> Optional[Dict[str, Dict[str, Optional[Tensor]]]]: + """Helper for setting incremental state for an nn.Module.""" + if incremental_state is not None: + full_key = self._get_full_incremental_state_key(key) + incremental_state[full_key] = value + return incremental_state + + +def with_incremental_state(cls): + cls.__bases__ = (FairseqIncrementalState,) + tuple( + b for b in cls.__bases__ if b != FairseqIncrementalState + ) + return cls + + +@with_incremental_state +class MultiheadAttention(nn.Module): + """Multi-headed attention. + See "Attention Is All You Need" for more details. + """ + + def __init__( + self, + embed_dim, + num_heads, + kdim=None, + vdim=None, + dropout=0.0, + bias=True, + add_bias_kv: bool = False, + add_zero_attn: bool = False, + self_attention: bool = False, + encoder_decoder_attention: bool = False, + use_rotary_embeddings: bool = False, + ): + super().__init__() + self.embed_dim = embed_dim + self.kdim = kdim if kdim is not None else embed_dim + self.vdim = vdim if vdim is not None else embed_dim + self.qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim + + self.num_heads = num_heads + self.dropout = dropout + self.head_dim = embed_dim // num_heads + assert ( + self.head_dim * num_heads == self.embed_dim + ), "embed_dim must be divisible by num_heads" + self.scaling = self.head_dim**-0.5 + + self.self_attention = self_attention + self.encoder_decoder_attention = encoder_decoder_attention + + assert not self.self_attention or self.qkv_same_dim, ( + "Self-attention requires query, key and " "value to be of the same size" + ) + + self.k_proj = nn.Linear(self.kdim, embed_dim, bias=bias) + self.v_proj = nn.Linear(self.vdim, embed_dim, bias=bias) + self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + + self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + + if add_bias_kv: + self.bias_k = Parameter(torch.Tensor(1, 1, embed_dim)) + self.bias_v = Parameter(torch.Tensor(1, 1, embed_dim)) + else: + self.bias_k = self.bias_v = None + + self.add_zero_attn = add_zero_attn + + self.reset_parameters() + + self.onnx_trace = False + self.rot_emb = None + if use_rotary_embeddings: + self.rot_emb = RotaryEmbedding(dim=self.head_dim) + + self.enable_torch_version = False + if hasattr(F, "multi_head_attention_forward"): + self.enable_torch_version = True + else: + self.enable_torch_version = False + + def prepare_for_onnx_export_(self): + self.onnx_trace = True + + def reset_parameters(self): + if self.qkv_same_dim: + # Empirically observed the convergence to be much better with + # the scaled initialization + nn.init.xavier_uniform_(self.k_proj.weight, gain=1 / math.sqrt(2)) + nn.init.xavier_uniform_(self.v_proj.weight, gain=1 / math.sqrt(2)) + nn.init.xavier_uniform_(self.q_proj.weight, gain=1 / math.sqrt(2)) + else: + nn.init.xavier_uniform_(self.k_proj.weight) + nn.init.xavier_uniform_(self.v_proj.weight) + nn.init.xavier_uniform_(self.q_proj.weight) + + nn.init.xavier_uniform_(self.out_proj.weight) + if self.out_proj.bias is not None: + nn.init.constant_(self.out_proj.bias, 0.0) + if self.bias_k is not None: + nn.init.xavier_normal_(self.bias_k) + if self.bias_v is not None: + nn.init.xavier_normal_(self.bias_v) + + def forward( + self, + query, + key: Optional[Tensor], + value: Optional[Tensor], + key_padding_mask: Optional[Tensor] = None, + incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, + need_weights: bool = True, + static_kv: bool = False, + attn_mask: Optional[Tensor] = None, + before_softmax: bool = False, + need_head_weights: bool = False, + ) -> Tuple[Tensor, Optional[Tensor]]: + """Input shape: Time x Batch x Channel + Args: + key_padding_mask (ByteTensor, optional): mask to exclude + keys that are pads, of shape `(batch, src_len)`, where + padding elements are indicated by 1s. + need_weights (bool, optional): return the attention weights, + averaged over heads (default: False). + attn_mask (ByteTensor, optional): typically used to + implement causal attention, where the mask prevents the + attention from looking forward in time (default: None). + before_softmax (bool, optional): return the raw attention + weights and values before the attention softmax. + need_head_weights (bool, optional): return the attention + weights for each head. Implies *need_weights*. Default: + return the average attention weights over all heads. + """ + if need_head_weights: + need_weights = True + + tgt_len, bsz, embed_dim = query.size() + assert embed_dim == self.embed_dim + assert list(query.size()) == [tgt_len, bsz, embed_dim] + + if ( + not self.rot_emb + and self.enable_torch_version + and not self.onnx_trace + and incremental_state is None + and not static_kv + # A workaround for quantization to work. Otherwise JIT compilation + # treats bias in linear module as method. + and not torch.jit.is_scripting() + and not need_head_weights + ): + assert key is not None and value is not None + return F.multi_head_attention_forward( + query, + key, + value, + self.embed_dim, + self.num_heads, + torch.empty([0]), + torch.cat((self.q_proj.bias, self.k_proj.bias, self.v_proj.bias)), + self.bias_k, + self.bias_v, + self.add_zero_attn, + self.dropout, + self.out_proj.weight, + self.out_proj.bias, + self.training, + key_padding_mask, + need_weights, + attn_mask, + use_separate_proj_weight=True, + q_proj_weight=self.q_proj.weight, + k_proj_weight=self.k_proj.weight, + v_proj_weight=self.v_proj.weight, + ) + if incremental_state is not None: + saved_state = self._get_input_buffer(incremental_state) + if saved_state is not None and "prev_key" in saved_state: + # previous time steps are cached - no need to recompute + # key and value if they are static + if static_kv: + assert self.encoder_decoder_attention and not self.self_attention + key = value = None + else: + saved_state = None + + if self.self_attention: + q = self.q_proj(query) + k = self.k_proj(query) + v = self.v_proj(query) + elif self.encoder_decoder_attention: + # encoder-decoder attention + q = self.q_proj(query) + if key is None: + assert value is None + k = v = None + else: + k = self.k_proj(key) + v = self.v_proj(key) + + else: + assert key is not None and value is not None + q = self.q_proj(query) + k = self.k_proj(key) + v = self.v_proj(value) + q *= self.scaling + + if self.bias_k is not None: + assert self.bias_v is not None + k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)]) + v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)]) + if attn_mask is not None: + attn_mask = torch.cat( + [attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1 + ) + if key_padding_mask is not None: + key_padding_mask = torch.cat( + [ + key_padding_mask, + key_padding_mask.new_zeros(key_padding_mask.size(0), 1), + ], + dim=1, + ) + + q = q.contiguous().view(tgt_len, bsz * self.num_heads, self.head_dim).transpose(0, 1) + if k is not None: + k = k.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1) + if v is not None: + v = v.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1) + + if saved_state is not None: + # saved states are stored with shape (bsz, num_heads, seq_len, head_dim) + if "prev_key" in saved_state: + _prev_key = saved_state["prev_key"] + assert _prev_key is not None + prev_key = _prev_key.view(bsz * self.num_heads, -1, self.head_dim) + if static_kv: + k = prev_key + else: + assert k is not None + k = torch.cat([prev_key, k], dim=1) + if "prev_value" in saved_state: + _prev_value = saved_state["prev_value"] + assert _prev_value is not None + prev_value = _prev_value.view(bsz * self.num_heads, -1, self.head_dim) + if static_kv: + v = prev_value + else: + assert v is not None + v = torch.cat([prev_value, v], dim=1) + prev_key_padding_mask: Optional[Tensor] = None + if "prev_key_padding_mask" in saved_state: + prev_key_padding_mask = saved_state["prev_key_padding_mask"] + assert k is not None and v is not None + key_padding_mask = MultiheadAttention._append_prev_key_padding_mask( + key_padding_mask=key_padding_mask, + prev_key_padding_mask=prev_key_padding_mask, + batch_size=bsz, + src_len=k.size(1), + static_kv=static_kv, + ) + + saved_state["prev_key"] = k.view(bsz, self.num_heads, -1, self.head_dim) + saved_state["prev_value"] = v.view(bsz, self.num_heads, -1, self.head_dim) + saved_state["prev_key_padding_mask"] = key_padding_mask + # In this branch incremental_state is never None + assert incremental_state is not None + incremental_state = self._set_input_buffer(incremental_state, saved_state) + assert k is not None + src_len = k.size(1) + + # This is part of a workaround to get around fork/join parallelism + # not supporting Optional types. + if key_padding_mask is not None and key_padding_mask.dim() == 0: + key_padding_mask = None + + if key_padding_mask is not None: + assert key_padding_mask.size(0) == bsz + assert key_padding_mask.size(1) == src_len + + if self.add_zero_attn: + assert v is not None + src_len += 1 + k = torch.cat([k, k.new_zeros((k.size(0), 1) + k.size()[2:])], dim=1) + v = torch.cat([v, v.new_zeros((v.size(0), 1) + v.size()[2:])], dim=1) + if attn_mask is not None: + attn_mask = torch.cat( + [attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1 + ) + if key_padding_mask is not None: + key_padding_mask = torch.cat( + [ + key_padding_mask, + torch.zeros(key_padding_mask.size(0), 1).type_as(key_padding_mask), + ], + dim=1, + ) + + if self.rot_emb: + q, k = self.rot_emb(q, k) + + attn_weights = torch.bmm(q, k.transpose(1, 2)) + attn_weights = MultiheadAttention.apply_sparse_mask(attn_weights, tgt_len, src_len, bsz) + + assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len] + + if attn_mask is not None: + attn_mask = attn_mask.unsqueeze(0) + if self.onnx_trace: + attn_mask = attn_mask.repeat(attn_weights.size(0), 1, 1) + attn_weights += attn_mask + + if key_padding_mask is not None: + # don't attend to padding symbols + attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attn_weights = attn_weights.masked_fill( + key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool), float("-inf") + ) + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + + if before_softmax: + return attn_weights, v + + attn_weights_float = utils_softmax(attn_weights, dim=-1, onnx_trace=self.onnx_trace) + attn_weights = attn_weights_float.type_as(attn_weights) + attn_probs = F.dropout( + attn_weights_float.type_as(attn_weights), + p=self.dropout, + training=self.training, + ) + assert v is not None + attn = torch.bmm(attn_probs, v) + assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim] + if self.onnx_trace and attn.size(1) == 1: + # when ONNX tracing a single decoder step (sequence length == 1) + # the transpose is a no-op copy before view, thus unnecessary + attn = attn.contiguous().view(tgt_len, bsz, embed_dim) + else: + attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) + attn = self.out_proj(attn) + attn_weights: Optional[Tensor] = None + if need_weights: + attn_weights = attn_weights_float.view( + bsz, self.num_heads, tgt_len, src_len + ).type_as(attn).transpose(1, 0) + if not need_head_weights: + # average attention weights over heads + attn_weights = attn_weights.mean(dim=0) + + return attn, attn_weights + + @staticmethod + def _append_prev_key_padding_mask( + key_padding_mask: Optional[Tensor], + prev_key_padding_mask: Optional[Tensor], + batch_size: int, + src_len: int, + static_kv: bool, + ) -> Optional[Tensor]: + # saved key padding masks have shape (bsz, seq_len) + if prev_key_padding_mask is not None and static_kv: + new_key_padding_mask = prev_key_padding_mask + elif prev_key_padding_mask is not None and key_padding_mask is not None: + new_key_padding_mask = torch.cat( + [prev_key_padding_mask.float(), key_padding_mask.float()], dim=1 + ) + # During incremental decoding, as the padding token enters and + # leaves the frame, there will be a time when prev or current + # is None + elif prev_key_padding_mask is not None: + filler = torch.zeros( + (batch_size, src_len - prev_key_padding_mask.size(1)), + device=prev_key_padding_mask.device, + ) + new_key_padding_mask = torch.cat( + [prev_key_padding_mask.float(), filler.float()], dim=1 + ) + elif key_padding_mask is not None: + filler = torch.zeros( + (batch_size, src_len - key_padding_mask.size(1)), + device=key_padding_mask.device, + ) + new_key_padding_mask = torch.cat([filler.float(), key_padding_mask.float()], dim=1) + else: + new_key_padding_mask = prev_key_padding_mask + return new_key_padding_mask + + @torch.jit.export + def reorder_incremental_state( + self, incremental_state: Dict[str, Dict[str, Optional[Tensor]]], new_order: Tensor + ): + """Reorder buffered internal state (for incremental generation).""" + input_buffer = self._get_input_buffer(incremental_state) + if input_buffer is not None: + for k in input_buffer.keys(): + input_buffer_k = input_buffer[k] + if input_buffer_k is not None: + if self.encoder_decoder_attention and input_buffer_k.size(0) == new_order.size( + 0 + ): + break + input_buffer[k] = input_buffer_k.index_select(0, new_order) + incremental_state = self._set_input_buffer(incremental_state, input_buffer) + return incremental_state + + def _get_input_buffer( + self, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] + ) -> Dict[str, Optional[Tensor]]: + result = self.get_incremental_state(incremental_state, "attn_state") + if result is not None: + return result + else: + empty_result: Dict[str, Optional[Tensor]] = {} + return empty_result + + def _set_input_buffer( + self, + incremental_state: Dict[str, Dict[str, Optional[Tensor]]], + buffer: Dict[str, Optional[Tensor]], + ): + return self.set_incremental_state(incremental_state, "attn_state", buffer) + + def apply_sparse_mask(attn_weights, tgt_len: int, src_len: int, bsz: int): + return attn_weights + + def upgrade_state_dict_named(self, state_dict, name): + prefix = name + "." if name != "" else "" + items_to_add = {} + keys_to_remove = [] + for k in state_dict.keys(): + if k.endswith(prefix + "in_proj_weight"): + # in_proj_weight used to be q + k + v with same dimensions + dim = int(state_dict[k].shape[0] / 3) + items_to_add[prefix + "q_proj.weight"] = state_dict[k][:dim] + items_to_add[prefix + "k_proj.weight"] = state_dict[k][dim : 2 * dim] + items_to_add[prefix + "v_proj.weight"] = state_dict[k][2 * dim :] + + keys_to_remove.append(k) + + k_bias = prefix + "in_proj_bias" + if k_bias in state_dict.keys(): + dim = int(state_dict[k].shape[0] / 3) + items_to_add[prefix + "q_proj.bias"] = state_dict[k_bias][:dim] + items_to_add[prefix + "k_proj.bias"] = state_dict[k_bias][dim : 2 * dim] + items_to_add[prefix + "v_proj.bias"] = state_dict[k_bias][2 * dim :] + + keys_to_remove.append(prefix + "in_proj_bias") + + for k in keys_to_remove: + del state_dict[k] + + for key, value in items_to_add.items(): + state_dict[key] = value \ No newline at end of file diff --git a/esm/pretrained.py b/esm/pretrained.py new file mode 100644 index 0000000000000000000000000000000000000000..3375f049e82087bd58319b0b27551ec33942e7b4 --- /dev/null +++ b/esm/pretrained.py @@ -0,0 +1,378 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import re +import urllib +import warnings +from argparse import Namespace +from pathlib import Path + +import torch + +import esm +from esm.model.esm2 import ESM2 + + +def _has_regression_weights(model_name): + """Return whether we expect / require regression weights; + Right now that is all models except ESM-1v and ESM-IF""" + return not ("esm1v" in model_name or "esm_if" in model_name) + + +def load_model_and_alphabet(model_name): + if model_name.endswith(".pt"): # treat as filepath + return load_model_and_alphabet_local(model_name) + else: + return load_model_and_alphabet_hub(model_name) + + +def load_hub_workaround(url): + try: + data = torch.hub.load_state_dict_from_url(url, progress=False, map_location="cpu") + except RuntimeError: + # Pytorch version issue - see https://github.com/pytorch/pytorch/issues/43106 + fn = Path(url).name + data = torch.load( + f"{torch.hub.get_dir()}/checkpoints/{fn}", + map_location="cpu", + ) + except urllib.error.HTTPError as e: + raise Exception(f"Could not load {url}, check if you specified a correct model name?") + return data + + +def load_regression_hub(model_name): + url = f"https://dl.fbaipublicfiles.com/fair-esm/regression/{model_name}-contact-regression.pt" + regression_data = load_hub_workaround(url) + return regression_data + + +def _download_model_and_regression_data(model_name): + url = f"https://dl.fbaipublicfiles.com/fair-esm/models/{model_name}.pt" + model_data = load_hub_workaround(url) + if _has_regression_weights(model_name): + regression_data = load_regression_hub(model_name) + else: + regression_data = None + return model_data, regression_data + + +def load_model_and_alphabet_hub(model_name): + model_data, regression_data = _download_model_and_regression_data(model_name) + return load_model_and_alphabet_core(model_name, model_data, regression_data) + + +def load_model_and_alphabet_local(model_location): + """Load from local path. The regression weights need to be co-located""" + model_location = Path(model_location) + model_data = torch.load(str(model_location), map_location="cpu") + model_name = model_location.stem + if _has_regression_weights(model_name): + regression_location = str(model_location.with_suffix("")) + "-contact-regression.pt" + regression_data = torch.load(regression_location, map_location="cpu") + else: + regression_data = None + return load_model_and_alphabet_core(model_name, model_data, regression_data) + + +def has_emb_layer_norm_before(model_state): + """Determine whether layer norm needs to be applied before the encoder""" + return any(k.startswith("emb_layer_norm_before") for k, param in model_state.items()) + + +def _load_model_and_alphabet_core_v1(model_data): + import esm # since esm.inverse_folding is imported below, you actually have to re-import esm here + + alphabet = esm.Alphabet.from_architecture(model_data["args"].arch) + + if model_data["args"].arch == "roberta_large": + # upgrade state dict + pra = lambda s: "".join(s.split("encoder_")[1:] if "encoder" in s else s) + prs1 = lambda s: "".join(s.split("encoder.")[1:] if "encoder" in s else s) + prs2 = lambda s: "".join( + s.split("sentence_encoder.")[1:] if "sentence_encoder" in s else s + ) + model_args = {pra(arg[0]): arg[1] for arg in vars(model_data["args"]).items()} + model_state = {prs1(prs2(arg[0])): arg[1] for arg in model_data["model"].items()} + model_state["embed_tokens.weight"][alphabet.mask_idx].zero_() # For token drop + model_args["emb_layer_norm_before"] = has_emb_layer_norm_before(model_state) + model_type = esm.ProteinBertModel + + elif model_data["args"].arch == "protein_bert_base": + + # upgrade state dict + pra = lambda s: "".join(s.split("decoder_")[1:] if "decoder" in s else s) + prs = lambda s: "".join(s.split("decoder.")[1:] if "decoder" in s else s) + model_args = {pra(arg[0]): arg[1] for arg in vars(model_data["args"]).items()} + model_state = {prs(arg[0]): arg[1] for arg in model_data["model"].items()} + model_type = esm.ProteinBertModel + elif model_data["args"].arch == "msa_transformer": + + # upgrade state dict + pra = lambda s: "".join(s.split("encoder_")[1:] if "encoder" in s else s) + prs1 = lambda s: "".join(s.split("encoder.")[1:] if "encoder" in s else s) + prs2 = lambda s: "".join( + s.split("sentence_encoder.")[1:] if "sentence_encoder" in s else s + ) + prs3 = lambda s: s.replace("row", "column") if "row" in s else s.replace("column", "row") + model_args = {pra(arg[0]): arg[1] for arg in vars(model_data["args"]).items()} + model_state = {prs1(prs2(prs3(arg[0]))): arg[1] for arg in model_data["model"].items()} + if model_args.get("embed_positions_msa", False): + emb_dim = model_state["msa_position_embedding"].size(-1) + model_args["embed_positions_msa_dim"] = emb_dim # initial release, bug: emb_dim==1 + + model_type = esm.MSATransformer + + elif "invariant_gvp" in model_data["args"].arch: + import esm.inverse_folding + + model_type = esm.inverse_folding.gvp_transformer.GVPTransformerModel + model_args = vars(model_data["args"]) # convert Namespace -> dict + + def update_name(s): + # Map the module names in checkpoints trained with internal code to + # the updated module names in open source code + s = s.replace("W_v", "embed_graph.embed_node") + s = s.replace("W_e", "embed_graph.embed_edge") + s = s.replace("embed_scores.0", "embed_confidence") + s = s.replace("embed_score.", "embed_graph.embed_confidence.") + s = s.replace("seq_logits_projection.", "") + s = s.replace("embed_ingraham_features", "embed_dihedrals") + s = s.replace("embed_gvp_in_local_frame.0", "embed_gvp_output") + s = s.replace("embed_features_in_local_frame.0", "embed_gvp_input_features") + return s + + model_state = { + update_name(sname): svalue + for sname, svalue in model_data["model"].items() + if "version" not in sname + } + + else: + raise ValueError("Unknown architecture selected") + + model = model_type( + Namespace(**model_args), + alphabet, + ) + + return model, alphabet, model_state + + +def _load_model_and_alphabet_core_v2(model_data): + def upgrade_state_dict(state_dict): + """Removes prefixes 'model.encoder.sentence_encoder.' and 'model.encoder.'.""" + prefixes = ["encoder.sentence_encoder.", "encoder."] + pattern = re.compile("^" + "|".join(prefixes)) + state_dict = {pattern.sub("", name): param for name, param in state_dict.items()} + return state_dict + + cfg = model_data["cfg"]["model"] + state_dict = model_data["model"] + state_dict = upgrade_state_dict(state_dict) + alphabet = esm.data.Alphabet.from_architecture("ESM-1b") + model = ESM2( + num_layers=cfg.encoder_layers, + embed_dim=cfg.encoder_embed_dim, + attention_heads=cfg.encoder_attention_heads, + alphabet=alphabet, + token_dropout=cfg.token_dropout, + ) + return model, alphabet, state_dict + + +def load_model_and_alphabet_core(model_name, model_data, regression_data=None): + if regression_data is not None: + model_data["model"].update(regression_data["model"]) + + if model_name.startswith("esm2"): + model, alphabet, model_state = _load_model_and_alphabet_core_v2(model_data) + else: + model, alphabet, model_state = _load_model_and_alphabet_core_v1(model_data) + + expected_keys = set(model.state_dict().keys()) + found_keys = set(model_state.keys()) + + if regression_data is None: + expected_missing = {"contact_head.regression.weight", "contact_head.regression.bias"} + error_msgs = [] + missing = (expected_keys - found_keys) - expected_missing + if missing: + error_msgs.append(f"Missing key(s) in state_dict: {missing}.") + unexpected = found_keys - expected_keys + if unexpected: + error_msgs.append(f"Unexpected key(s) in state_dict: {unexpected}.") + + if error_msgs: + raise RuntimeError( + "Error(s) in loading state_dict for {}:\n\t{}".format( + model.__class__.__name__, "\n\t".join(error_msgs) + ) + ) + if expected_missing - found_keys: + warnings.warn( + "Regression weights not found, predicting contacts will not produce correct results." + ) + + model.load_state_dict(model_state, strict=regression_data is not None) + + return model, alphabet + + +def esm1_t34_670M_UR50S(): + """34 layer transformer model with 670M params, trained on Uniref50 Sparse. + Returns a tuple of (Model, Alphabet). + """ + return load_model_and_alphabet_hub("esm1_t34_670M_UR50S") + + +def esm1_t34_670M_UR50D(): + """34 layer transformer model with 670M params, trained on Uniref50 Dense. + Returns a tuple of (Model, Alphabet). + """ + return load_model_and_alphabet_hub("esm1_t34_670M_UR50D") + + +def esm1_t34_670M_UR100(): + """34 layer transformer model with 670M params, trained on Uniref100. + Returns a tuple of (Model, Alphabet). + """ + return load_model_and_alphabet_hub("esm1_t34_670M_UR100") + + +def esm1_t12_85M_UR50S(): + """12 layer transformer model with 85M params, trained on Uniref50 Sparse. + Returns a tuple of (Model, Alphabet). + """ + return load_model_and_alphabet_hub("esm1_t12_85M_UR50S") + + +def esm1_t6_43M_UR50S(): + """6 layer transformer model with 43M params, trained on Uniref50 Sparse. + Returns a tuple of (Model, Alphabet). + """ + return load_model_and_alphabet_hub("esm1_t6_43M_UR50S") + + +def esm1b_t33_650M_UR50S(): + """33 layer transformer model with 650M params, trained on Uniref50 Sparse. + This is our best performing model, which will be described in a future publication. + Returns a tuple of (Model, Alphabet). + """ + return load_model_and_alphabet_hub("esm1b_t33_650M_UR50S") + + +def esm_msa1_t12_100M_UR50S(): + warnings.warn( + "This model had a minor bug in the positional embeddings, " + "please use ESM-MSA-1b: esm.pretrained.esm_msa1b_t12_100M_UR50S()", + ) + return load_model_and_alphabet_hub("esm_msa1_t12_100M_UR50S") + + +def esm_msa1b_t12_100M_UR50S(): + return load_model_and_alphabet_hub("esm_msa1b_t12_100M_UR50S") + + +def esm1v_t33_650M_UR90S(): + """33 layer transformer model with 650M params, trained on Uniref90. + This is model 1 of a 5 model ensemble. + Returns a tuple of (Model, Alphabet). + """ + return load_model_and_alphabet_hub("esm1v_t33_650M_UR90S_1") + + +def esm1v_t33_650M_UR90S_1(): + """33 layer transformer model with 650M params, trained on Uniref90. + This is model 1 of a 5 model ensemble. + Returns a tuple of (Model, Alphabet). + """ + return load_model_and_alphabet_hub("esm1v_t33_650M_UR90S_1") + + +def esm1v_t33_650M_UR90S_2(): + """33 layer transformer model with 650M params, trained on Uniref90. + This is model 2 of a 5 model ensemble. + Returns a tuple of (Model, Alphabet). + """ + return load_model_and_alphabet_hub("esm1v_t33_650M_UR90S_2") + + +def esm1v_t33_650M_UR90S_3(): + """33 layer transformer model with 650M params, trained on Uniref90. + This is model 3 of a 5 model ensemble. + Returns a tuple of (Model, Alphabet). + """ + return load_model_and_alphabet_hub("esm1v_t33_650M_UR90S_3") + + +def esm1v_t33_650M_UR90S_4(): + """33 layer transformer model with 650M params, trained on Uniref90. + This is model 4 of a 5 model ensemble. + Returns a tuple of (Model, Alphabet). + """ + return load_model_and_alphabet_hub("esm1v_t33_650M_UR90S_4") + + +def esm1v_t33_650M_UR90S_5(): + """33 layer transformer model with 650M params, trained on Uniref90. + This is model 5 of a 5 model ensemble. + Returns a tuple of (Model, Alphabet). + """ + return load_model_and_alphabet_hub("esm1v_t33_650M_UR90S_5") + + +def esm_if1_gvp4_t16_142M_UR50(): + """Inverse folding model with 142M params, with 4 GVP-GNN layers, 8 + Transformer encoder layers, and 8 Transformer decoder layers, trained on + CATH structures and 12 million alphafold2 predicted structures from UniRef50 + sequences. + Returns a tuple of (Model, Alphabet). + """ + return load_model_and_alphabet_hub("esm_if1_gvp4_t16_142M_UR50") + + +def esm2_t6_8M_UR50D(): + """6 layer ESM-2 model with 8M params, trained on UniRef50. + Returns a tuple of (Model, Alphabet). + """ + return load_model_and_alphabet_hub("esm2_t6_8M_UR50D") + + +def esm2_t12_35M_UR50D(): + """12 layer ESM-2 model with 35M params, trained on UniRef50. + Returns a tuple of (Model, Alphabet). + """ + return load_model_and_alphabet_hub("esm2_t12_35M_UR50D") + + +def esm2_t30_150M_UR50D(): + """30 layer ESM-2 model with 150M params, trained on UniRef50. + Returns a tuple of (Model, Alphabet). + """ + return load_model_and_alphabet_hub("esm2_t30_150M_UR50D") + + +def esm2_t33_650M_UR50D(): + """33 layer ESM-2 model with 650M params, trained on UniRef50. + Returns a tuple of (Model, Alphabet). + """ + return load_model_and_alphabet_hub("esm2_t33_650M_UR50D") + + +def esm2_t36_3B_UR50D(): + """36 layer ESM-2 model with 3B params, trained on UniRef50. + Returns a tuple of (Model, Alphabet). + """ + return load_model_and_alphabet_hub("esm2_t36_3B_UR50D") + + +def esm2_t48_15B_UR50D(): + """48 layer ESM-2 model with 15B params, trained on UniRef50. + If you have OOM while loading this model, please refer to README + on how to employ FSDP and ZeRO CPU offloading + Returns a tuple of (Model, Alphabet). + """ + return load_model_and_alphabet_hub("esm2_t48_15B_UR50D") \ No newline at end of file diff --git a/esm/rotary_embedding.py b/esm/rotary_embedding.py new file mode 100644 index 0000000000000000000000000000000000000000..496eda0e756edb9d8a2605dc388a2ad78c97011d --- /dev/null +++ b/esm/rotary_embedding.py @@ -0,0 +1,69 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Tuple + +import torch + + +def rotate_half(x): + x1, x2 = x.chunk(2, dim=-1) + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(x, cos, sin): + cos = cos[:, : x.shape[-2], :] + sin = sin[:, : x.shape[-2], :] + + return (x * cos) + (rotate_half(x) * sin) + + +class RotaryEmbedding(torch.nn.Module): + """ + The rotary position embeddings from RoFormer_ (Su et. al). + A crucial insight from the method is that the query and keys are + transformed by rotation matrices which depend on the relative positions. + Other implementations are available in the Rotary Transformer repo_ and in + GPT-NeoX_, GPT-NeoX was an inspiration + .. _RoFormer: https://arxiv.org/abs/2104.09864 + .. _repo: https://github.com/ZhuiyiTechnology/roformer + .. _GPT-NeoX: https://github.com/EleutherAI/gpt-neox + .. warning: Please note that this embedding is not registered on purpose, as it is transformative + (it does not create the embedding dimension) and will likely be picked up (imported) on a ad-hoc basis + """ + + def __init__(self, dim: int, *_, **__): + super().__init__() + # Generate and save the inverse frequency buffer (non trainable) + inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim)) + self.register_buffer("inv_freq", inv_freq) + + self._seq_len_cached = None + self._cos_cached = None + self._sin_cached = None + + def _update_cos_sin_tables(self, x, seq_dimension=1): + seq_len = x.shape[seq_dimension] + + # Reset the tables if the sequence length has changed, + # or if we're on a new device (possibly due to tracing for instance) + if seq_len != self._seq_len_cached or self._cos_cached.device != x.device: + self._seq_len_cached = seq_len + t = torch.arange(x.shape[seq_dimension], device=x.device).type_as(self.inv_freq) + freqs = torch.einsum("i,j->ij", t, self.inv_freq) + emb = torch.cat((freqs, freqs), dim=-1).to(x.device) + + self._cos_cached = emb.cos()[None, :, :] + self._sin_cached = emb.sin()[None, :, :] + + return self._cos_cached, self._sin_cached + + def forward(self, q: torch.Tensor, k: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + self._cos_cached, self._sin_cached = self._update_cos_sin_tables(k, seq_dimension=-2) + + return ( + apply_rotary_pos_emb(q, self._cos_cached, self._sin_cached), + apply_rotary_pos_emb(k, self._cos_cached, self._sin_cached), + ) \ No newline at end of file diff --git a/esm/version.py b/esm/version.py new file mode 100644 index 0000000000000000000000000000000000000000..201c6a6b47375952380733ed928c09b83647043a --- /dev/null +++ b/esm/version.py @@ -0,0 +1,6 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +version = "1.0.2" diff --git a/requirements.txt b/requirements.txt index ed7d77aba9db69b749b936adc77d6085e6479651..c583cb20cf817619627081a5647a3e8c50d2d517 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,4 @@ biopython==1.81 torch==2.0.1 numpy -pandas -tqdm -fair-esm @ git+https://github.com/facebookresearch/esm.git@900251ba3e2b7cdc06b44b10dfa3a0c1dd49752b \ No newline at end of file +pandas \ No newline at end of file