''' Fine-tunes ESM Transformer models with labelled data. ''' import argparse import functools from itertools import chain import os import pathlib import random import numpy as np import pandas as pd from sklearn.utils import shuffle import torch from esm import Alphabet, FastaBatchedDataset, ProteinBertModel from esm import BatchConverter, pretrained from utils import read_fasta from utils.metric_utils import spearman, topk_mean, r2, hit_rate, aucroc, ndcg from utils.esm_utils import RandomCropBatchConverter, CSVBatchedDataset mse_criterion = torch.nn.MSELoss(reduction='mean') ce_criterion = torch.nn.CrossEntropyLoss(reduction='none') def create_parser(): parser = argparse.ArgumentParser( description="Supervised finetuning for sequences in a FASTA file" # noqa ) parser.add_argument( "csv_file", type=pathlib.Path, help="csv file for labeled data", ) parser.add_argument( "wt_fasta_file", type=pathlib.Path, help="fasta file for WT sequence", ) parser.add_argument( "output_dir", type=pathlib.Path, help="output directory", ) parser.add_argument( "--model_location", type=str, help="initial model location", default='/mnt/esm_weights/esm1b/esm1b_t33_650M_UR50S.pt', ) parser.add_argument( "--epochs", type=int, default=50, help="number of epochs" ) parser.add_argument( "--seed", type=int, default=0, help="random seed" ) parser.add_argument( "--n_train", type=int, default=-1, help="training data size" ) parser.add_argument( "--val_split", type=float, default=0.2, help="validation split" ) parser.add_argument( "--n_test", type=int, default=-1, help="test data size" ) parser.add_argument( "--toks_per_batch", type=int, default=512, help="maximum batch size" ) parser.add_argument( "--max_len", type=int, default=500, help="maximum seq len" ) parser.add_argument( "--learning_rate", type=float, default=3e-5, help="lr" ) parser.add_argument( "--save_model", dest="save_model", action="store_true" ) parser.add_argument( "--train_on_single", dest="train_on_single", action="store_true" ) parser.add_argument( "--train_on_all", dest="train_on_single", action="store_false" ) parser.set_defaults(save_model=False, train_on_single=True) return parser def step(model, labels, toks, wt_toks, mask_idx): labels = torch.tensor(labels) if torch.cuda.is_available(): labels = labels.to(device="cuda", non_blocking=True) predictions = predict(model, toks, wt_toks, mask_idx) loss = mse_criterion(predictions, labels) return loss, predictions def predict(model, toks, wt_toks, mask_idx): if torch.cuda.is_available(): toks = toks.to(device="cuda", non_blocking=True) wt_toks_rep = wt_toks.repeat(toks.shape[0], 1) mask = (toks != wt_toks) masked_toks = torch.where(mask, mask_idx, toks) out = model(masked_toks, return_contacts=False) logits = out["logits"] logits_tr = logits.transpose(1, 2) # [B, E, T] ce_loss_mut = ce_criterion(logits_tr, toks) # [B, E] ce_loss_wt = ce_criterion(logits_tr, wt_toks_rep) ll_diff_sum = torch.sum( (ce_loss_wt - ce_loss_mut) * mask, dim=1, keepdim=True) # [B, 1] return ll_diff_sum[:, 0] def main(args): model_data = torch.load(args.model_location, map_location='cpu') model, alphabet = pretrained.load_model_and_alphabet(args.model_location) #repr_layers = [model.num_layers] # extract last layer #toplinear = torch.nn.Linear(model.args.embed_dim, 1) #toplinear = torch.nn.Linear(1, 1) batch_converter = BatchConverter(alphabet) mask_idx = torch.tensor(alphabet.mask_idx) args.output_dir.mkdir(parents=True, exist_ok=True) wt_seq = read_fasta(args.wt_fasta_file)[0] _, _, wt_toks = batch_converter([('WT', wt_seq)]) if torch.cuda.is_available(): model = model.cuda() #toplinear = toplinear.cuda() mask_idx = mask_idx.cuda() wt_toks = wt_toks.cuda() print("Transferred model to GPU") df_full = shuffle(pd.read_csv(args.csv_file), random_state=args.seed) print(f"Read {args.csv_file} with {len(df_full)} sequences") if args.n_test == -1: args.n_test = int(len(df_full) * 0.2) df_test = df_full[-args.n_test:] df_trainval = df_full.drop(df_test.index) if args.train_on_single and 'n_mut' in df_full.columns: df_trainval = df_trainval[df_trainval.n_mut <= 1] if args.n_train == -1: args.n_train = int(len(df_trainval)) if args.n_train > len(df_trainval): print(f"Insufficient data") return n_val = int(args.n_train * args.val_split) df_train = df_trainval[:args.n_train-n_val] df_val = df_trainval[args.n_train-n_val:args.n_train] train_dataset = CSVBatchedDataset.from_dataframe(df_train) val_dataset = CSVBatchedDataset.from_dataframe(df_val) test_dataset = CSVBatchedDataset.from_dataframe(df_test) train_batches = train_dataset.get_batch_indices( args.toks_per_batch, extra_toks_per_seq=1) val_batches = val_dataset.get_batch_indices( args.toks_per_batch, extra_toks_per_seq=1) test_batches = test_dataset.get_batch_indices( args.toks_per_batch, extra_toks_per_seq=1) train_data_loader = torch.utils.data.DataLoader(train_dataset, collate_fn=batch_converter, batch_sampler=train_batches) val_data_loader = torch.utils.data.DataLoader(val_dataset, collate_fn=batch_converter, batch_sampler=val_batches) test_data_loader = torch.utils.data.DataLoader(test_dataset, collate_fn=batch_converter, batch_sampler=test_batches) #optimizer = torch.optim.Adam( # chain(model.parameters(), toplinear.parameters()), lr=args.learning_rate) optimizer = torch.optim.Adam( model.parameters(), lr=args.learning_rate) train_loss = np.zeros(args.epochs+1) val_loss = np.zeros(args.epochs+1) val_spearman = np.zeros(args.epochs+1) best_val_spearman = None for epoch in range(args.epochs+1): # Train if epoch > 0: for batch_idx, (labels, strs, toks) in enumerate(train_data_loader): if batch_idx % 100 == 0: print( f"Processing {batch_idx + 1} of {len(train_batches)} " f"batches ({toks.size(0)} sequences)" ) optimizer.zero_grad() loss, _ = step(model, labels, toks, wt_toks, mask_idx) loss.backward() optimizer.step() train_loss[epoch] += loss.to('cpu').item() train_loss[epoch] /= float(len(train_data_loader)) # Validation model_eval = model.eval() y_pred = [] y_true = [] with torch.no_grad(): for batch_idx, (labels, strs, toks) in enumerate(val_data_loader): loss, predictions = step( model, labels, toks, wt_toks, mask_idx) y_pred.append(predictions.to('cpu').numpy()) y_true.append(labels) val_loss[epoch] += loss.to('cpu').item() val_loss[epoch] /= float(len(val_data_loader)) print('epoch %d, train loss: %.3f, val loss: %.3f' % ( epoch + 1, train_loss[epoch], val_loss[epoch])) y_pred = np.concatenate(y_pred) y_true = np.concatenate(y_true) val_spearman[epoch] = spearman(y_pred, y_true) print(f'Val Spearman correlation {val_spearman[epoch]}') if best_val_spearman is None or val_spearman[epoch] > best_val_spearman: best_val_spearman = val_spearman[epoch] model_data["model"] = model.state_dict() #model_data["toplinear"] = toplinear.state_dict() torch.save(model_data, os.path.join(args.output_dir, 'model_data.pt')) np.savetxt(os.path.join(args.output_dir, 'loss_trajectory_train.npy'), train_loss) np.savetxt(os.path.join(args.output_dir, 'loss_trajectory_val.npy'), val_loss) np.savetxt(os.path.join(args.output_dir, 'spearman_trajectory_val.npy'), val_spearman) # Load best saved model model, alphabet = pretrained.load_model_and_alphabet( os.path.join(args.output_dir, 'model_data.pt')) if torch.cuda.is_available(): model = model.cuda() model_eval = model.eval() y_pred = [] y_true = [] with torch.no_grad(): for batch_idx, (labels, strs, toks) in enumerate(test_data_loader): predictions = predict(model, toks, wt_toks, mask_idx) y_pred.append(predictions.to('cpu').numpy()) y_true.append(labels) y_pred = np.concatenate(y_pred) y_true = np.concatenate(y_true) print(f'Test Spearman correlation {spearman(y_pred, y_true)}') df_test = df_test.copy() df_test['pred'] = y_pred metric_fns = { 'spearman': spearman, 'ndcg': ndcg, } results_dict = {k: mf(df_test.pred.values, df_test.log_fitness.values) for k, mf in metric_fns.items()} results_dict.update({ 'predictor': 'esm_finetune', 'n_train': args.n_train, 'seed': args.seed, 'epochs': args.epochs, }) if 'n_mut' in df_test.columns: max_n_mut = min(df_test.n_mut.max(), 5) for j in range(1, max_n_mut+1): y_pred = df_test[df_test.n_mut == j].pred.values y_true = df_test[df_test.n_mut == j].log_fitness.values results_dict.update({ f'{k}_{j}mut': mf(y_pred, y_true) for k, mf in metric_fns.items()}) results = pd.DataFrame(columns=sorted(results_dict.keys())) results = results.append(results_dict, ignore_index=True) results.to_csv(os.path.join(args.output_dir, 'metrics.csv'), mode='w', index=False, columns=sorted(results.columns.values)) if not args.save_model: os.remove(os.path.join(args.output_dir, 'model_data.pt')) if __name__ == "__main__": parser = create_parser() args = parser.parse_args() main(args)