PreMode / analysis /Hsu.et.al.git /src /esm_embeddings.py
gzhong's picture
Upload folder using huggingface_hub
7718235 verified
'''
Extracts embeddings from ESM models.
'''
import argparse
from collections import defaultdict
import os
import pathlib
import numpy as np
import pandas as pd
import torch
from esm import Alphabet, FastaBatchedDataset, ProteinBertModel, pretrained, BatchConverter
from utils import read_fasta, save
criterion = torch.nn.CrossEntropyLoss(reduction='none')
def create_parser():
parser = argparse.ArgumentParser(
description="Extract per-token representations and model outputs for sequences in a FASTA file" # noqa
)
parser.add_argument(
"fasta_file",
type=pathlib.Path,
help="FASTA file on which to extract representations",
)
parser.add_argument(
"wt_fasta_file",
type=pathlib.Path,
help="FASTA file for WT",
)
parser.add_argument(
"output_dir",
type=pathlib.Path,
help="output dir",
)
parser.add_argument(
"--model_location",
type=str,
help="model location",
default="/mnt/esm_weights/esm1b/esm1b_t33_650M_UR50S.pt"
)
parser.add_argument(
"--save_hidden", type=bool, default=False, help="whether to save rep"
)
parser.add_argument(
"--toks_per_batch", type=int, default=4096, help="maximum batch size"
)
parser.add_argument("--nogpu", action="store_true", help="Do not use GPU even if available")
return parser
def main(args):
model, alphabet = pretrained.load_model_and_alphabet(args.model_location)
batch_converter = alphabet.get_batch_converter()
padding_idx = torch.tensor(alphabet.padding_idx)
model.eval()
if torch.cuda.is_available() and not args.nogpu:
model = model.cuda()
print("Transferred model to GPU")
dataset = FastaBatchedDataset.from_file(args.fasta_file)
batches = dataset.get_batch_indices(args.toks_per_batch, extra_toks_per_seq=1)
data_loader = torch.utils.data.DataLoader(
dataset, collate_fn=batch_converter, batch_sampler=batches
)
print(f"Read {args.fasta_file} with {len(dataset)} sequences")
repr_layers = [model.num_layers] # extract last layer
label_vals = []
avg_rep_vals = []
with torch.no_grad():
for batch_idx, (labels, strs, toks) in enumerate(data_loader):
print(
f"Processing {batch_idx + 1} of {len(batches)} batches ({toks.size(0)} sequences)"
)
if torch.cuda.is_available() and not args.nogpu:
toks = toks.to(device="cuda", non_blocking=True)
out = model(toks, repr_layers=repr_layers,
return_contacts=False)
# [B, T, E]
final_layer = out["representations"][model.num_layers]
notpad = torch.unsqueeze(toks != padding_idx, 2)
avg_rep = (final_layer * notpad).mean(dim=1).to(
device="cpu").numpy()
avg_rep_vals.append(avg_rep)
label_vals.append(labels)
args.output_dir.mkdir(parents=True, exist_ok=True)
avg_rep_vals = np.concatenate(avg_rep_vals, axis=0)
label_vals = np.concatenate(label_vals)
np.savetxt(os.path.join(args.output_dir, 'labels.npy'),
label_vals, fmt="%s")
save(os.path.join(args.output_dir, 'rep.npy'), avg_rep_vals)
if __name__ == "__main__":
parser = create_parser()
args = parser.parse_args()
main(args)