PreMode / esm.inference.py
gzhong's picture
Upload folder using huggingface_hub
8d4f72a verified
import pandas as pd
import numpy as np
import os
import esm
import torch
import argparse
os.environ['CUDA_LAUNCH_BLOCKING'] = '-1'
def precompute_sequence(transcript_id, sequence, esm_model, batch_converter, out_dir, device_id=0):
if os.path.exists(os.path.join(out_dir, transcript_id + '.contacts.npy')):
return
else:
print('begin precompute sequence for {}'.format(transcript_id))
try:
data = [(transcript_id, sequence)]
_, _, toks = batch_converter(data)
except:
print(transcript_id)
return
toks = toks.to(f'cuda:{device_id}')
aa = toks.shape[1]
if aa <= 2250:
print(f"{transcript_id} has {toks.shape[1]} amino acids")
return
with torch.no_grad():
out = esm_model(toks, repr_layers=[33], return_contacts=True, need_head_weights=False)
representations = out["representations"][33][0].to(device='cpu').detach().numpy()
# output is batch x layers x heads x seqlen x seqlen
# attentions = out["attentions"][0].to(device="cpu").detach().numpy()
contacts = out['contacts'][0].to(device="cpu").detach().numpy()
logits = out['logits'][0].to(device="cpu").detach().numpy()
np.save(
f"{out_dir}/{transcript_id}.representations.layer.48.npy",
representations,
)
np.save(
f"{out_dir}/{transcript_id}.contacts.npy",
contacts,
)
np.save(
f"{out_dir}/{transcript_id}.logits.npy",
logits,
)
return
def precompute_sequence_multiple_gpus(transcript_id, sequence, esm_model, batch_converter, out_dir):
if os.path.exists(os.path.join(out_dir, transcript_id + '.contacts.npy')):
return
else:
print('begin precompute sequence for {}'.format(transcript_id))
try:
data = [(transcript_id, sequence)]
_, _, toks = batch_converter(data)
except:
print(transcript_id)
return
toks = toks.to('cuda:0')
if toks.shape[1] > 30000:
print(f"{transcript_id} has {toks.shape[1]} amino acids, don't proceed")
return
print(f"{transcript_id} has {toks.shape[1]} amino acids")
if toks.shape[1] > 5500:
need_head_weights = False
return_contacts = False
else:
need_head_weights = True
return_contacts = True
with torch.no_grad():
assert toks.ndim == 2
padding_mask = toks.eq(esm_model.padding_idx) # B, T
x = esm_model.embed_scale * esm_model.embed_tokens(toks)
if esm_model.token_dropout:
x.masked_fill_((toks == esm_model.mask_idx).unsqueeze(-1), 0.0)
# x: B x T x C
mask_ratio_train = 0.15 * 0.8
src_lengths = (~padding_mask).sum(-1)
mask_ratio_observed = (toks == esm_model.mask_idx).sum(-1).to(x.dtype) / src_lengths
x = x * (1 - mask_ratio_train) / (1 - mask_ratio_observed)[:, None, None]
if padding_mask is not None:
x = x * (1 - padding_mask.unsqueeze(-1).type_as(x))
repr_layers = {33}
hidden_representations = {}
if 0 in repr_layers:
hidden_representations[0] = x
if need_head_weights:
attn_weights = []
# (B, T, E) => (T, B, E)
x = x.transpose(0, 1)
if not padding_mask.any():
padding_mask = None
for layer_idx, layer in enumerate(esm_model.layers):
x = x.to(f'cuda:{layer_idx // 9}')
x, attn = layer(
x,
self_attn_padding_mask=padding_mask,
need_head_weights=need_head_weights,
)
if (layer_idx + 1) in repr_layers:
hidden_representations[layer_idx + 1] = x.transpose(0, 1)
if need_head_weights:
# (H, B, T, T) => (B, H, T, T)
attn_weights.append(attn.transpose(1, 0).cpu())
x = esm_model.emb_layer_norm_after(x)
x = x.transpose(0, 1) # (T, B, E) => (B, T, E)
# last hidden representation should have layer norm applied
if (layer_idx + 1) in repr_layers:
hidden_representations[layer_idx + 1] = x
# lm head is on cuda:0, x is on cuda:3
x = esm_model.lm_head(x.to('cuda:0'))
out = {"logits": x, "representations": hidden_representations}
if need_head_weights:
# attentions: B x L x H x T x T
attentions = torch.stack(attn_weights, 1)
if padding_mask is not None:
attention_mask = 1 - padding_mask.type_as(attentions)
attention_mask = attention_mask.unsqueeze(1) * attention_mask.unsqueeze(2)
attentions = attentions * attention_mask[:, None, None, :, :]
out["attentions"] = attentions
if return_contacts:
contacts = esm_model.contact_head(toks, attentions)
out["contacts"] = contacts
representations = out["representations"][33][0].to(device='cpu').detach().numpy()
# output is batch x layers x heads x seqlen x seqlen
logits = out['logits'][0].to(device="cpu").detach().numpy()
np.save(
f"{out_dir}/{transcript_id}.representations.layer.48.npy",
representations,
)
np.save(
f"{out_dir}/{transcript_id}.logits.npy",
logits,
)
if return_contacts:
contacts = out['contacts'][0].to(device="cpu").detach().numpy()
np.save(
f"{out_dir}/{transcript_id}.contacts.npy",
contacts,
)
return
def main(file=None, outdir=None):
model, alphabet = esm.pretrained.esm2_t33_650M_UR50D()
if torch.cuda.is_available():
# manually split the model into 4 GPUs
model.embed_tokens.to('cuda:0')
for layer_idx, layer in enumerate(model.layers):
layer.to(f'cuda:{layer_idx // 9}')
model.emb_layer_norm_after.to('cuda:3')
model.lm_head.to('cuda:0')
model.contact_head.to('cpu')
print("Transferred model to GPUs")
# model = model.to(f'cuda:{rank}')
if file is None:
return
files = pd.read_csv(file, index_col=0)
os.makedirs(outdir, exist_ok=True)
for transcript_id, sequence in zip(files['uniprotID'], files['sequence']):
precompute_sequence_multiple_gpus(transcript_id, sequence, model,
alphabet.get_batch_converter(),
outdir)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--file', type=str, default=None)
parser.add_argument('--outdir', type=str, default=None)
args = parser.parse_args()
main(args.file, args.outdir)