Spaces:
Sleeping
Sleeping
import argparse | |
import os.path | |
def main(args): | |
import json, time, os, sys, glob | |
import shutil | |
import warnings | |
import numpy as np | |
import torch | |
from torch import optim | |
from torch.utils.data import DataLoader | |
import queue | |
import copy | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import random | |
import os.path | |
import subprocess | |
from concurrent.futures import ProcessPoolExecutor | |
from utils import worker_init_fn, get_pdbs, loader_pdb, build_training_clusters, PDB_dataset, StructureDataset, StructureLoader | |
from model_utils import featurize, loss_smoothed, loss_nll, get_std_opt, ProteinMPNN | |
scaler = torch.cuda.amp.GradScaler() | |
device = torch.device("cuda:0" if (torch.cuda.is_available()) else "cpu") | |
base_folder = time.strftime(args.path_for_outputs, time.localtime()) | |
if base_folder[-1] != '/': | |
base_folder += '/' | |
if not os.path.exists(base_folder): | |
os.makedirs(base_folder) | |
subfolders = ['model_weights'] | |
for subfolder in subfolders: | |
if not os.path.exists(base_folder + subfolder): | |
os.makedirs(base_folder + subfolder) | |
PATH = args.previous_checkpoint | |
logfile = base_folder + 'log.txt' | |
if not PATH: | |
with open(logfile, 'w') as f: | |
f.write('Epoch\tTrain\tValidation\n') | |
data_path = args.path_for_training_data | |
params = { | |
"LIST" : f"{data_path}/list.csv", | |
"VAL" : f"{data_path}/valid_clusters.txt", | |
"TEST" : f"{data_path}/test_clusters.txt", | |
"DIR" : f"{data_path}", | |
"DATCUT" : "2030-Jan-01", | |
"RESCUT" : args.rescut, #resolution cutoff for PDBs | |
"HOMO" : 0.70 #min seq.id. to detect homo chains | |
} | |
LOAD_PARAM = {'batch_size': 1, | |
'shuffle': True, | |
'pin_memory':False, | |
'num_workers': 4} | |
if args.debug: | |
args.num_examples_per_epoch = 50 | |
args.max_protein_length = 1000 | |
args.batch_size = 1000 | |
train, valid, test = build_training_clusters(params, args.debug) | |
train_set = PDB_dataset(list(train.keys()), loader_pdb, train, params) | |
train_loader = torch.utils.data.DataLoader(train_set, worker_init_fn=worker_init_fn, **LOAD_PARAM) | |
valid_set = PDB_dataset(list(valid.keys()), loader_pdb, valid, params) | |
valid_loader = torch.utils.data.DataLoader(valid_set, worker_init_fn=worker_init_fn, **LOAD_PARAM) | |
model = ProteinMPNN(node_features=args.hidden_dim, | |
edge_features=args.hidden_dim, | |
hidden_dim=args.hidden_dim, | |
num_encoder_layers=args.num_encoder_layers, | |
num_decoder_layers=args.num_encoder_layers, | |
k_neighbors=args.num_neighbors, | |
dropout=args.dropout, | |
augment_eps=args.backbone_noise) | |
model.to(device) | |
if PATH: | |
checkpoint = torch.load(PATH) | |
total_step = checkpoint['step'] #write total_step from the checkpoint | |
epoch = checkpoint['epoch'] #write epoch from the checkpoint | |
model.load_state_dict(checkpoint['model_state_dict']) | |
else: | |
total_step = 0 | |
epoch = 0 | |
optimizer = get_std_opt(model.parameters(), args.hidden_dim, total_step) | |
if PATH: | |
optimizer.optimizer.load_state_dict(checkpoint['optimizer_state_dict']) | |
with ProcessPoolExecutor(max_workers=12) as executor: | |
q = queue.Queue(maxsize=3) | |
p = queue.Queue(maxsize=3) | |
for i in range(3): | |
q.put_nowait(executor.submit(get_pdbs, train_loader, 1, args.max_protein_length, args.num_examples_per_epoch)) | |
p.put_nowait(executor.submit(get_pdbs, valid_loader, 1, args.max_protein_length, args.num_examples_per_epoch)) | |
pdb_dict_train = q.get().result() | |
pdb_dict_valid = p.get().result() | |
dataset_train = StructureDataset(pdb_dict_train, truncate=None, max_length=args.max_protein_length) | |
dataset_valid = StructureDataset(pdb_dict_valid, truncate=None, max_length=args.max_protein_length) | |
loader_train = StructureLoader(dataset_train, batch_size=args.batch_size) | |
loader_valid = StructureLoader(dataset_valid, batch_size=args.batch_size) | |
reload_c = 0 | |
for e in range(args.num_epochs): | |
t0 = time.time() | |
e = epoch + e | |
model.train() | |
train_sum, train_weights = 0., 0. | |
train_acc = 0. | |
if e % args.reload_data_every_n_epochs == 0: | |
if reload_c != 0: | |
pdb_dict_train = q.get().result() | |
dataset_train = StructureDataset(pdb_dict_train, truncate=None, max_length=args.max_protein_length) | |
loader_train = StructureLoader(dataset_train, batch_size=args.batch_size) | |
pdb_dict_valid = p.get().result() | |
dataset_valid = StructureDataset(pdb_dict_valid, truncate=None, max_length=args.max_protein_length) | |
loader_valid = StructureLoader(dataset_valid, batch_size=args.batch_size) | |
q.put_nowait(executor.submit(get_pdbs, train_loader, 1, args.max_protein_length, args.num_examples_per_epoch)) | |
p.put_nowait(executor.submit(get_pdbs, valid_loader, 1, args.max_protein_length, args.num_examples_per_epoch)) | |
reload_c += 1 | |
for _, batch in enumerate(loader_train): | |
start_batch = time.time() | |
X, S, mask, lengths, chain_M, residue_idx, mask_self, chain_encoding_all = featurize(batch, device) | |
elapsed_featurize = time.time() - start_batch | |
optimizer.zero_grad() | |
mask_for_loss = mask*chain_M | |
if args.mixed_precision: | |
with torch.cuda.amp.autocast(): | |
log_probs = model(X, S, mask, chain_M, residue_idx, chain_encoding_all) | |
_, loss_av_smoothed = loss_smoothed(S, log_probs, mask_for_loss) | |
scaler.scale(loss_av_smoothed).backward() | |
if args.gradient_norm > 0.0: | |
total_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), args.gradient_norm) | |
scaler.step(optimizer) | |
scaler.update() | |
else: | |
log_probs = model(X, S, mask, chain_M, residue_idx, chain_encoding_all) | |
_, loss_av_smoothed = loss_smoothed(S, log_probs, mask_for_loss) | |
loss_av_smoothed.backward() | |
if args.gradient_norm > 0.0: | |
total_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), args.gradient_norm) | |
optimizer.step() | |
loss, loss_av, true_false = loss_nll(S, log_probs, mask_for_loss) | |
train_sum += torch.sum(loss * mask_for_loss).cpu().data.numpy() | |
train_acc += torch.sum(true_false * mask_for_loss).cpu().data.numpy() | |
train_weights += torch.sum(mask_for_loss).cpu().data.numpy() | |
total_step += 1 | |
model.eval() | |
with torch.no_grad(): | |
validation_sum, validation_weights = 0., 0. | |
validation_acc = 0. | |
for _, batch in enumerate(loader_valid): | |
X, S, mask, lengths, chain_M, residue_idx, mask_self, chain_encoding_all = featurize(batch, device) | |
log_probs = model(X, S, mask, chain_M, residue_idx, chain_encoding_all) | |
mask_for_loss = mask*chain_M | |
loss, loss_av, true_false = loss_nll(S, log_probs, mask_for_loss) | |
validation_sum += torch.sum(loss * mask_for_loss).cpu().data.numpy() | |
validation_acc += torch.sum(true_false * mask_for_loss).cpu().data.numpy() | |
validation_weights += torch.sum(mask_for_loss).cpu().data.numpy() | |
train_loss = train_sum / train_weights | |
train_accuracy = train_acc / train_weights | |
train_perplexity = np.exp(train_loss) | |
validation_loss = validation_sum / validation_weights | |
validation_accuracy = validation_acc / validation_weights | |
validation_perplexity = np.exp(validation_loss) | |
train_perplexity_ = np.format_float_positional(np.float32(train_perplexity), unique=False, precision=3) | |
validation_perplexity_ = np.format_float_positional(np.float32(validation_perplexity), unique=False, precision=3) | |
train_accuracy_ = np.format_float_positional(np.float32(train_accuracy), unique=False, precision=3) | |
validation_accuracy_ = np.format_float_positional(np.float32(validation_accuracy), unique=False, precision=3) | |
t1 = time.time() | |
dt = np.format_float_positional(np.float32(t1-t0), unique=False, precision=1) | |
with open(logfile, 'a') as f: | |
f.write(f'epoch: {e+1}, step: {total_step}, time: {dt}, train: {train_perplexity_}, valid: {validation_perplexity_}, train_acc: {train_accuracy_}, valid_acc: {validation_accuracy_}\n') | |
print(f'epoch: {e+1}, step: {total_step}, time: {dt}, train: {train_perplexity_}, valid: {validation_perplexity_}, train_acc: {train_accuracy_}, valid_acc: {validation_accuracy_}') | |
checkpoint_filename_last = base_folder+'model_weights/epoch_last.pt'.format(e+1, total_step) | |
torch.save({ | |
'epoch': e+1, | |
'step': total_step, | |
'num_edges' : args.num_neighbors, | |
'noise_level': args.backbone_noise, | |
'model_state_dict': model.state_dict(), | |
'optimizer_state_dict': optimizer.optimizer.state_dict(), | |
}, checkpoint_filename_last) | |
if (e+1) % args.save_model_every_n_epochs == 0: | |
checkpoint_filename = base_folder+'model_weights/epoch{}_step{}.pt'.format(e+1, total_step) | |
torch.save({ | |
'epoch': e+1, | |
'step': total_step, | |
'num_edges' : args.num_neighbors, | |
'noise_level': args.backbone_noise, | |
'model_state_dict': model.state_dict(), | |
'optimizer_state_dict': optimizer.optimizer.state_dict(), | |
}, checkpoint_filename) | |
if __name__ == "__main__": | |
argparser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) | |
argparser.add_argument("--path_for_training_data", type=str, default="my_path/pdb_2021aug02", help="path for loading training data") | |
argparser.add_argument("--path_for_outputs", type=str, default="./exp_020", help="path for logs and model weights") | |
argparser.add_argument("--previous_checkpoint", type=str, default="", help="path for previous model weights, e.g. file.pt") | |
argparser.add_argument("--num_epochs", type=int, default=200, help="number of epochs to train for") | |
argparser.add_argument("--save_model_every_n_epochs", type=int, default=10, help="save model weights every n epochs") | |
argparser.add_argument("--reload_data_every_n_epochs", type=int, default=2, help="reload training data every n epochs") | |
argparser.add_argument("--num_examples_per_epoch", type=int, default=1000000, help="number of training example to load for one epoch") | |
argparser.add_argument("--batch_size", type=int, default=10000, help="number of tokens for one batch") | |
argparser.add_argument("--max_protein_length", type=int, default=10000, help="maximum length of the protein complext") | |
argparser.add_argument("--hidden_dim", type=int, default=128, help="hidden model dimension") | |
argparser.add_argument("--num_encoder_layers", type=int, default=3, help="number of encoder layers") | |
argparser.add_argument("--num_decoder_layers", type=int, default=3, help="number of decoder layers") | |
argparser.add_argument("--num_neighbors", type=int, default=48, help="number of neighbors for the sparse graph") | |
argparser.add_argument("--dropout", type=float, default=0.1, help="dropout level; 0.0 means no dropout") | |
argparser.add_argument("--backbone_noise", type=float, default=0.2, help="amount of noise added to backbone during training") | |
argparser.add_argument("--rescut", type=float, default=3.5, help="PDB resolution cutoff") | |
argparser.add_argument("--debug", type=bool, default=False, help="minimal data loading for debugging") | |
argparser.add_argument("--gradient_norm", type=float, default=-1.0, help="clip gradient norm, set to negative to omit clipping") | |
argparser.add_argument("--mixed_precision", type=bool, default=True, help="train with mixed precision") | |
args = argparser.parse_args() | |
main(args) | |