import torch import torch.nn as nn from time import time import os from torch.utils.data import DataLoader from UltraFlow import dataset, commons, losses import numpy as np import pandas as pd import torch.distributed as dist import dgl class DefaultRunner(object): def __init__(self,train_set, val_set, test_set, finetune_val_set, model, optimizer, scheduler, config): self.train_set = train_set self.val_set = val_set self.test_set = test_set self.finetune_val_set = finetune_val_set self.config = config self.device = config.train.device self.batch_size = self.config.train.batch_size self._model = model self._optimizer = optimizer self._scheduler = scheduler self.best_matric = 0 self.start_epoch = 0 if self.device.type == 'cuda': self._model = self._model.cuda(self.device) self.get_loss_fn() def save(self, checkpoint, epoch=None, ddp=False, var_list={}): state = { **var_list, "model": self._model.state_dict() if not ddp else self._model.module.state_dict(), "optimizer": self._optimizer.state_dict(), "scheduler": self._scheduler.state_dict(), "config": self.config } epoch = str(epoch) if epoch is not None else '' checkpoint = os.path.join(checkpoint, 'checkpoint%s' % epoch) torch.save(state, checkpoint) def load(self, checkpoint, epoch=None, load_optimizer=False, load_scheduler=False): epoch = str(epoch) if epoch is not None else '' checkpoint = os.path.join(checkpoint, 'checkpoint%s' % epoch) print("Load checkpoint from %s" % checkpoint) state = torch.load(checkpoint, map_location=self.device) self._model.load_state_dict(state["model"]) #self._model.load_state_dict(state["model"], strict=False) self.best_matric = state['best_matric'] self.start_epoch = state['cur_epoch'] + 1 if load_optimizer: self._optimizer.load_state_dict(state["optimizer"]) if self.device.type == 'cuda': for state in self._optimizer.state.values(): for k, v in state.items(): if isinstance(v, torch.Tensor): state[k] = v.cuda(self.device) if load_scheduler: self._scheduler.load_state_dict(state["scheduler"]) def get_loss_fn(self): self.loss_fn = nn.MSELoss() if self.config.train.pretrain_ranking_loss == 'pairwise_v2': self.ranking_fn = losses.pair_wise_ranking_loss_v2(self.config).to(self.device) def trans_device(self,batch): return [x if isinstance(x, list) else x.to(self.device) for x in batch] @torch.no_grad() def evaluate_pairwsie_pdbbind(self, split, verbose=0, logger=None, visualize=True): test_set = getattr(self, "%s_set" % split) dataloader = DataLoader(test_set, batch_size=self.config.train.batch_size, shuffle=False, collate_fn=dataset.collate_pdbbind_affinity_multi_task_v2, num_workers=self.config.train.num_workers) y_preds, y_preds_IC50, y_preds_K = torch.tensor([]).to(self.device), torch.tensor([]).to( self.device), torch.tensor([]).to(self.device) y, y_IC50, y_K = torch.tensor([]).to(self.device), torch.tensor([]).to(self.device), torch.tensor([]).to(self.device) eval_start = time() model = self._model model.eval() for batch in dataloader: if self.device.type == "cuda": batch = self.trans_device(batch) if self.config.train.encoder_ablation != 'interact': (regression_loss_IC50, regression_loss_K), \ (affinity_pred_IC50, affinity_pred_K), \ (affinity_IC50, affinity_K) = model(batch, ASRP=False) else: node_feats_lig, node_feats_pro = model(batch, ASRP=False) bg_lig, bg_prot, bg_inter, labels, _, ass_des, IC50_f, K_f = batch bg_lig.ndata['h'] = node_feats_lig bg_prot.ndata['h'] = node_feats_pro lig_g_feats = dgl.readout_nodes(bg_lig, 'h', op=self.config.train.interact_ablate_op) pro_g_feats = dgl.readout_nodes(bg_prot, 'h', op=self.config.train.interact_ablate_op) complex_feats = torch.cat([lig_g_feats, pro_g_feats], dim=1) (regression_loss_IC50, regression_loss_K), \ (affinity_pred_IC50, affinity_pred_K), \ (affinity_IC50, affinity_K) = self.interact_ablation_model(complex_feats, labels, IC50_f, K_f) affinity_pred = torch.cat([affinity_pred_IC50, affinity_pred_K], dim=0) affinity = torch.cat([affinity_IC50, affinity_K], dim=0) y_preds_IC50 = torch.cat([y_preds_IC50, affinity_pred_IC50]) y_preds_K = torch.cat([y_preds_K, affinity_pred_K]) y_preds = torch.cat([y_preds, affinity_pred]) y_IC50 = torch.cat([y_IC50, affinity_IC50]) y_K = torch.cat([y_K, affinity_K]) y = torch.cat([y, affinity]) metics_dict = commons.get_sbap_regression_metric_dict(np.array(y.cpu()), np.array(y_preds.cpu())) result_str = commons.get_matric_output_str(metics_dict) result_str = f'{split} total ' + result_str if len(y_IC50) > 0: metics_dict_IC50 = commons.get_sbap_regression_metric_dict(np.array(y_IC50.cpu()), np.array(y_preds_IC50.cpu())) result_str_IC50 = commons.get_matric_output_str(metics_dict_IC50) result_str_IC50 = f'| IC50 ' + result_str_IC50 result_str += result_str_IC50 if len(y_K) > 0: metics_dict_K = commons.get_sbap_regression_metric_dict(np.array(y_K.cpu()), np.array(y_preds_K.cpu())) result_str_K = commons.get_matric_output_str(metics_dict_K) result_str_K = f'| K ' + result_str_K result_str += result_str_K result_str += 'Time: %.4f' % (time() - eval_start) if verbose: if logger is not None: logger.info(result_str) else: print(result_str) return metics_dict['RMSE'], metics_dict['MAE'], metics_dict['SD'], metics_dict['Pearson'], metics_dict['Spearman'] @torch.no_grad() def evaluate_pairwsie(self, split, verbose=0, logger=None, visualize=True): """ Evaluate the model. Parameters: split (str): split to evaluate. Can be ``train``, ``val`` or ``test``. """ if split not in ['train', 'val', 'test']: raise ValueError('split should be either train, val, or test.') test_set = getattr(self, "%s_set" % split) relation_preds = torch.tensor([]).to(self.device) relations = torch.tensor([]).to(self.device) y_preds = torch.tensor([]).to(self.device) ys = torch.tensor([]).to(self.device) eval_start = time() model = self._model model.eval() for batch in test_set: if self.device.type == "cuda": batch = self.trans_device(batch) y_pred, x_output, ranking_assay_embedding = model(batch) n = x_output.shape[0] pair_a_index, pair_b_index = [], [] for i in range(n): pair_a_index.extend([i] * (n - 1)) pair_b_index.extend([j for j in range(n) if i != j]) pair_index = pair_a_index + pair_b_index _, relation, relation_pred = self.ranking_fn(x_output[pair_index], batch[-3][pair_index], ranking_assay_embedding[pair_index]) relation_preds = torch.cat([relation_preds, relation_pred]) relations = torch.cat([relations, relation]) y_preds = torch.cat([y_preds, y_pred]) ys = torch.cat([ys, batch[-3]]) acc = (sum(relation_preds == relations) / (len(relation_preds))).cpu().item() result_str = 'valid acc: {:.4f}'.format(acc) np_y = np.array(ys.cpu()) np_f = np.array(y_preds.cpu()) regression_metrics_dict = commons.get_sbap_regression_metric_dict(np_y, np_f) regression_result_str = commons.get_matric_output_str(regression_metrics_dict) result_str += regression_result_str result_str += ' | Time: %.4f'%(time() - eval_start) if verbose: if logger is not None: logger.info(result_str) else: print(result_str) return acc @torch.no_grad() def evaluate_pointwise(self, split, verbose=0, logger=None, visualize=True): """ Evaluate the model. Parameters: split (str): split to evaluate. Can be ``train``, ``val`` or ``test``. """ if split not in ['train', 'val', 'test']: raise ValueError('split should be either train, val, or test.') test_set = getattr(self, "%s_set" % split) relation_preds = torch.tensor([]).to(self.device) relations = torch.tensor([]).to(self.device) y_preds = torch.tensor([]).to(self.device) ys = torch.tensor([]).to(self.device) eval_start = time() model = self._model model.eval() for batch in test_set: if self.device.type == "cuda": batch = self.trans_device(batch) y_pred, x_output, _ = model(batch) n = x_output.shape[0] pair_a_index, pair_b_index = [], [] for i in range(n): pair_a_index.extend([i] * (n - 1)) pair_b_index.extend([j for j in range(n) if i != j]) pair_index = pair_a_index + pair_b_index score_pred = y_pred[pair_index] score_target = batch[-3][pair_index] batch_repeat_num = len(score_pred) batch_size = batch_repeat_num // 2 pred_A, targe_A, pred_B, target_B = score_pred[:batch_size], score_target[:batch_size], score_pred[batch_size:], score_target[batch_size:] relation_pred = torch.zeros(pred_A.size(), dtype=torch.long, device=pred_A.device) relation_pred[(pred_A - pred_B) > 0.0] = 1 relation = torch.zeros(targe_A.size(), dtype=torch.long, device=targe_A.device) relation[(targe_A - target_B) > 0.0] = 1 relation_preds = torch.cat([relation_preds, relation_pred]) relations = torch.cat([relations, relation]) y_preds = torch.cat([y_preds, y_pred]) ys = torch.cat([ys, batch[-3]]) acc = (sum(relation_preds == relations) / (len(relation_preds))).cpu().item() result_str = 'valid acc: {:.4f}'.format(acc) np_y = np.array(ys.cpu()) np_f = np.array(y_preds.cpu()) regression_metrics_dict = commons.get_sbap_regression_metric_dict(np_y, np_f) regression_result_str = commons.get_matric_output_str(regression_metrics_dict) result_str += regression_result_str result_str += ' | Time: %.4f'%(time() - eval_start) if verbose: if logger is not None: logger.info(result_str) else: print(result_str) return acc, regression_metrics_dict['RMSE'] def train(self, ddp=False): if self.config.train.pretrain_sampling_method == 'pairwise_v1': if not self.config.train.multi_task: print('begin pairwise_v1 training') self.train_pairwise_v1(ddp=ddp) elif self.config.train.multi_task == 'IC50KdKi': print('begin pairwise_v1 multi-task training IC50/Kd/Ki') self.train_pairwise_v1_multi_task(ddp=ddp) elif self.config.train.multi_task == 'IC50K': print('begin pairwise_v1 multi-task training IC50/K') self.train_pairwise_v1_multi_task_v2(ddp=ddp) elif self.config.train.pretrain_sampling_method == 'pointwise': if not self.config.train.multi_task: print('begin pointwise training') self.train_pointwise(ddp=ddp) elif self.config.train.multi_task == 'IC50KdKi': print('begin pointwise multi-task training IC50/Kd/Ki') self.train_pointwise_multi_task(ddp=ddp) elif self.config.train.multi_task == 'IC50K': print('begin pointwise multi-task training IC50/K') self.train_pointwise_multi_task_v2(ddp=ddp) def train_pairwise_v1(self, verbose=1, ddp=False): self.logger = self.config.logger train_start = time() num_epochs = self.config.train.pretrain_epochs if ddp and self.config.train.use_memory_efficient_dataset != 'v1': train_sampler = torch.utils.data.distributed.DistributedSampler(self.train_set) dataloader = DataLoader(self.train_set, batch_size=self.config.train.batch_size, drop_last=True, collate_fn=dataset.collate_affinity_pair_wise, num_workers=self.config.train.num_workers, sampler=train_sampler) else: dataloader = DataLoader(self.train_set, batch_size=self.config.train.batch_size, drop_last=True, shuffle=self.config.train.shuffle, collate_fn=dataset.collate_affinity_pair_wise, num_workers=self.config.train.num_workers) model = self._model if self.logger is not None: self.logger.info(self.config) self.logger.info('trainable params in model: {:.2f}M'.format( sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6)) self.logger.info('start training...') train_losses = [] val_matric = [] best_matric = self.best_matric best_loss = 1000000 start_epoch = self.start_epoch early_stop = 0 for epoch in range(num_epochs): # train model.train() epoch_start = time() batch_losses, batch_regression_losses, batch_ranking_losses = [], [], [] batch_cnt = 0 if ddp and self.config.train.use_memory_efficient_dataset != 'v1': dataloader.sampler.set_epoch(epoch) for batch in dataloader: batch_cnt += 1 if self.device.type == "cuda": batch = self.trans_device(batch) y_pred, x_output, ranking_assay_embedding = model(batch) y_pred_num = len(y_pred) assert y_pred_num % 2 == 0 if self.config.train.pairwise_two_tower_regression_loss: regression_loss = self.loss_fn(y_pred, batch[-3]) else: regression_loss = self.loss_fn(y_pred[:y_pred_num // 2], batch[-3][:y_pred_num // 2]) ranking_loss, _, _ = self.ranking_fn(x_output, batch[-3], ranking_assay_embedding) pretrain_loss = self.config.train.pretrain_ranking_loss_lambda * ranking_loss +\ self.config.train.pretrain_regression_loss_lambda * regression_loss if not pretrain_loss.requires_grad: raise RuntimeError("loss doesn't require grad") self._optimizer.zero_grad() pretrain_loss.backward() self._optimizer.step() batch_ranking_losses.append(self.config.train.pretrain_ranking_loss_lambda * ranking_loss.item()) batch_regression_losses.append(self.config.train.pretrain_regression_loss_lambda * regression_loss.item()) batch_losses.append(pretrain_loss.item()) train_losses.append(sum(batch_losses)) if self.logger is not None: self.logger.info('Epoch: %d | Pretrain Loss: %.4f | Regression Loss: %.4f | Ranking Loss: %.4f | Lr: %.4f | Time: %.4f' % ( epoch + start_epoch, sum(batch_losses), sum(batch_regression_losses), sum(batch_ranking_losses), self._optimizer.param_groups[0]['lr'], time() - epoch_start)) if (not ddp) or (ddp and dist.get_rank() == 0): # evaluate if self.config.train.eval: eval_acc = self.evaluate_pairwsie('val', verbose=1, logger=self.logger) val_matric.append(eval_acc) if val_matric[-1] > best_matric: early_stop = 0 best_matric = val_matric[-1] if self.config.train.save: print('saving checkpoint') val_list = { 'cur_epoch': epoch + start_epoch, 'best_matric': best_matric, } self.save(self.config.train.save_path, epoch + start_epoch, ddp, val_list) if self.config.train.scheduler.type == "plateau": self._scheduler.step(train_losses[-1]) else: self._scheduler.step() val_list = { 'cur_epoch': epoch + start_epoch, 'best_matric': best_matric, } self.save(self.config.train.save_path, 'latest', ddp, val_list) if sum(batch_losses) < best_loss: best_loss = sum(batch_losses) self.save(self.config.train.save_path, 'best_loss', ddp, val_list) torch.cuda.empty_cache() if epoch % self.config.train.pretrain_regression_loss_lambda_degrade_epoch == 0: self.config.train.pretrain_regression_loss_lambda *= self.config.train.pretrain_regression_loss_lambda_degrade_ratio self.best_matric = best_matric self.start_epoch = start_epoch + num_epochs print('optimization finished.') print('Total time elapsed: %.5fs' % (time() - train_start)) def train_pairwise_v1_multi_task(self, verbose=1, ddp=False): self.logger = self.config.logger train_start = time() num_epochs = self.config.train.pretrain_epochs if ddp and self.config.train.use_memory_efficient_dataset != 'v1': train_sampler = torch.utils.data.distributed.DistributedSampler(self.train_set) dataloader = DataLoader(self.train_set, batch_size=self.config.train.batch_size, drop_last=True, collate_fn=dataset.collate_affinity_pair_wise_multi_task, num_workers=self.config.train.num_workers, sampler=train_sampler) else: dataloader = DataLoader(self.train_set, batch_size=self.config.train.batch_size, drop_last=True, shuffle=self.config.train.shuffle, collate_fn=dataset.collate_affinity_pair_wise_multi_task, num_workers=self.config.train.num_workers) model = self._model if self.logger is not None: self.logger.info(self.config) self.logger.info('trainable params in model: {:.2f}M'.format( sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6)) self.logger.info('start training...') train_losses = [] val_matric = [] best_matric = self.best_matric best_loss = 1000000 start_epoch = self.start_epoch early_stop = 0 for epoch in range(num_epochs): # train model.train() epoch_start = time() batch_losses, batch_regression_losses, batch_ranking_losses = [], [], [] batch_regression_ic50_losses, batch_regression_kd_losses, batch_regression_ki_losses = [], [], [] batch_ranking_ic50_losses, batch_ranking_kd_losses, batch_ranking_ki_losses = [], [], [] batch_cnt = 0 if ddp and self.config.train.use_memory_efficient_dataset != 'v1': dataloader.sampler.set_epoch(epoch) for batch in dataloader: batch_cnt += 1 if self.device.type == "cuda": batch = self.trans_device(batch) (regression_loss_IC50, regression_loss_Kd, regression_loss_Ki), \ (ranking_loss_IC50, ranking_loss_Kd, ranking_loss_Ki), \ (affinity_pred_IC50, affinity_pred_Kd, affinity_pred_Ki), \ (relation_pred_IC50, relation_pred_Kd, relation_pred_Ki), \ (affinity_IC50, affinity_Kd, affinity_Ki), \ (relation_IC50, relation_Kd, relation_Kd) = model(batch) regression_loss = self.config.train.pretrain_mtl_IC50_lambda * regression_loss_IC50 + \ self.config.train.pretrain_mtl_Kd_lambda * regression_loss_Kd + \ self.config.train.pretrain_mtl_Ki_lambda * regression_loss_Ki ranking_loss = self.config.train.pretrain_mtl_IC50_lambda * ranking_loss_IC50 + \ self.config.train.pretrain_mtl_Kd_lambda * ranking_loss_Kd + \ self.config.train.pretrain_mtl_Kd_lambda * ranking_loss_Ki pretrain_loss = self.config.train.pretrain_ranking_loss_lambda * ranking_loss +\ self.config.train.pretrain_regression_loss_lambda * regression_loss if not pretrain_loss.requires_grad: raise RuntimeError("loss doesn't require grad") self._optimizer.zero_grad() pretrain_loss.backward() self._optimizer.step() batch_ranking_losses.append(self.config.train.pretrain_ranking_loss_lambda * ranking_loss.item()) batch_regression_losses.append(self.config.train.pretrain_regression_loss_lambda * regression_loss.item()) batch_losses.append(pretrain_loss.item()) batch_regression_ic50_losses.append(self.config.train.pretrain_mtl_IC50_lambda * regression_loss_IC50.item()) batch_regression_kd_losses.append(self.config.train.pretrain_mtl_Kd_lambda * regression_loss_Kd.item()) batch_regression_ki_losses.append(self.config.train.pretrain_mtl_Ki_lambda * regression_loss_Ki.item()) batch_ranking_ic50_losses.append(self.config.train.pretrain_mtl_IC50_lambda * ranking_loss_IC50.item()) batch_ranking_kd_losses.append(self.config.train.pretrain_mtl_Kd_lambda * ranking_loss_Kd.item()) batch_ranking_ki_losses.append(self.config.train.pretrain_mtl_Ki_lambda * ranking_loss_Ki.item()) train_losses.append(sum(batch_losses)) if self.logger is not None: self.logger.info('Epoch: %d | Pretrain Loss: %.4f | Regression Loss: %.4f | Ranking Loss: %.4f | ' 'Regression IC50 Loss: %.4f | Regression Kd Loss: %.4f | Regression Ki Loss: %.4f | ' 'Ranking IC50 Loss: %.4f | Ranking Kd Loss: %.4f | Ranking Ki Loss: %.4f | Lr: %.4f | Time: %.4f' % ( epoch + start_epoch, sum(batch_losses), sum(batch_regression_losses), sum(batch_ranking_losses), sum(batch_regression_ic50_losses), sum(batch_regression_kd_losses), sum(batch_regression_ki_losses), sum(batch_ranking_ic50_losses), sum(batch_ranking_kd_losses), sum(batch_ranking_ki_losses), self._optimizer.param_groups[0]['lr'], time() - epoch_start)) if (not ddp) or (ddp and dist.get_rank() == 0): if self.config.train.scheduler.type == "plateau": self._scheduler.step(train_losses[-1]) else: self._scheduler.step() val_list = { 'cur_epoch': epoch + start_epoch, 'best_matric': best_matric, } self.save(self.config.train.save_path, 'latest', ddp, val_list) torch.cuda.empty_cache() if epoch % self.config.train.pretrain_regression_loss_lambda_degrade_epoch == 0: self.config.train.pretrain_regression_loss_lambda *= self.config.train.pretrain_regression_loss_lambda_degrade_ratio self.best_matric = best_matric self.start_epoch = start_epoch + num_epochs print('optimization finished.') print('Total time elapsed: %.5fs' % (time() - train_start)) def train_pairwise_v1_multi_task_v2(self, verbose=1, ddp=False): self.logger = self.config.logger train_start = time() num_epochs = self.config.train.pretrain_epochs if ddp and self.config.train.use_memory_efficient_dataset != 'v1': train_sampler = torch.utils.data.distributed.DistributedSampler(self.train_set) dataloader = DataLoader(self.train_set, batch_size=self.config.train.batch_size, drop_last=True, collate_fn=dataset.collate_affinity_pair_wise_multi_task_v2, num_workers=self.config.train.num_workers, sampler=train_sampler) else: dataloader = DataLoader(self.train_set, batch_size=self.config.train.batch_size, drop_last=True, shuffle=self.config.train.shuffle, collate_fn=dataset.collate_affinity_pair_wise_multi_task_v2, num_workers=self.config.train.num_workers) model = self._model if self.logger is not None: self.logger.info(self.config) self.logger.info('trainable params in model: {:.2f}M'.format( sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6)) self.logger.info('start training...') train_losses = [] val_matric = [] best_matric = self.best_matric best_loss = 1000000 start_epoch = self.start_epoch early_stop = 0 for epoch in range(num_epochs): # train model.train() epoch_start = time() batch_losses, batch_regression_losses, batch_ranking_losses = [], [], [] batch_regression_ic50_losses, batch_regression_k_losses = [], [] batch_ranking_ic50_losses, batch_ranking_k_losses = [], [] batch_cnt = 0 if ddp and self.config.train.use_memory_efficient_dataset != 'v1': dataloader.sampler.set_epoch(epoch) for batch in dataloader: batch_cnt += 1 if self.device.type == "cuda": batch = self.trans_device(batch) (regression_loss_IC50, regression_loss_K), \ (ranking_loss_IC50, ranking_loss_K), \ (affinity_pred_IC50, affinity_pred_K), \ (relation_pred_IC50, relation_pred_K), \ (affinity_IC50, affinity_K), \ (relation_IC50, relation_K) = model(batch) regression_loss = self.config.train.pretrain_mtl_IC50_lambda * regression_loss_IC50 + \ self.config.train.pretrain_mtl_K_lambda * regression_loss_K ranking_loss = self.config.train.pretrain_mtl_IC50_lambda * ranking_loss_IC50 + \ self.config.train.pretrain_mtl_Kd_lambda * ranking_loss_K pretrain_loss = self.config.train.pretrain_ranking_loss_lambda * ranking_loss +\ self.config.train.pretrain_regression_loss_lambda * regression_loss if not pretrain_loss.requires_grad: raise RuntimeError("loss doesn't require grad") self._optimizer.zero_grad() pretrain_loss.backward() self._optimizer.step() batch_ranking_losses.append(self.config.train.pretrain_ranking_loss_lambda * ranking_loss.item()) batch_regression_losses.append(self.config.train.pretrain_regression_loss_lambda * regression_loss.item()) batch_losses.append(pretrain_loss.item()) batch_regression_ic50_losses.append(self.config.train.pretrain_mtl_IC50_lambda * regression_loss_IC50.item()) batch_regression_k_losses.append(self.config.train.pretrain_mtl_Kd_lambda * regression_loss_K.item()) batch_ranking_ic50_losses.append(self.config.train.pretrain_mtl_IC50_lambda * ranking_loss_IC50.item()) batch_ranking_k_losses.append(self.config.train.pretrain_mtl_Kd_lambda * ranking_loss_K.item()) train_losses.append(sum(batch_losses)) if self.logger is not None: self.logger.info('Epoch: %d | Pretrain Loss: %.4f | Regression Loss: %.4f | Ranking Loss: %.4f | ' 'Regression IC50 Loss: %.4f | Regression K Loss: %.4f | ' 'Ranking IC50 Loss: %.4f | Ranking K Loss: %.4f | Lr: %.4f | Time: %.4f' % ( epoch + start_epoch, sum(batch_losses), sum(batch_regression_losses), sum(batch_ranking_losses), sum(batch_regression_ic50_losses), sum(batch_regression_k_losses), sum(batch_ranking_ic50_losses), sum(batch_ranking_k_losses), self._optimizer.param_groups[0]['lr'], time() - epoch_start)) if (not ddp) or (ddp and dist.get_rank() == 0): if self.config.train.scheduler.type == "plateau": self._scheduler.step(train_losses[-1]) else: self._scheduler.step() val_list = { 'cur_epoch': epoch + start_epoch, 'best_matric': best_matric, } self.save(self.config.train.save_path, 'latest', ddp, val_list) # evaluate if self.config.train.eval: eval_rmse = self.evaluate_pairwsie_pdbbind('finetune_val', verbose=1, logger=self.logger) val_matric.append(eval_rmse[-1]) if val_matric[-1] > best_matric: best_matric = val_matric[-1] if self.config.train.save: print('saving checkpoint') val_list = { 'cur_epoch': epoch + start_epoch, 'best_matric': best_matric, } self.save(self.config.train.save_path, f'best_finetune_valid', ddp, val_list) torch.cuda.empty_cache() if epoch % self.config.train.pretrain_regression_loss_lambda_degrade_epoch == 0: self.config.train.pretrain_regression_loss_lambda *= self.config.train.pretrain_regression_loss_lambda_degrade_ratio self.best_matric = best_matric self.start_epoch = start_epoch + num_epochs print('optimization finished.') print('Total time elapsed: %.5fs' % (time() - train_start)) def train_pointwise(self, verbose=1, ddp=False): self.logger = self.config.logger train_start = time() num_epochs = self.config.train.pretrain_epochs if ddp and self.config.train.use_memory_efficient_dataset != 'v1': train_sampler = torch.utils.data.distributed.DistributedSampler(self.train_set) dataloader = DataLoader(self.train_set, batch_size=self.config.train.batch_size, collate_fn=dataset.collate_pdbbind_affinity, num_workers=self.config.train.num_workers, drop_last=True, sampler=train_sampler) else: dataloader = DataLoader(self.train_set, batch_size=self.config.train.batch_size, shuffle=self.config.train.shuffle, collate_fn=dataset.collate_pdbbind_affinity, num_workers=self.config.train.num_workers, drop_last=True) model = self._model if self.logger is not None: self.logger.info(self.config) self.logger.info('trainable params in model: {:.2f}M'.format( sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6)) self.logger.info('start training...') train_losses = [] val_matric = [] best_matric = self.best_matric best_loss = 1000000 start_epoch = self.start_epoch early_stop = 0 for epoch in range(num_epochs): # train model.train() epoch_start = time() batch_losses, batch_regression_losses = [], [] if ddp and self.config.train.use_memory_efficient_dataset != 'v1': dataloader.sampler.set_epoch(epoch) for batch in dataloader: if self.device.type == "cuda": batch = self.trans_device(batch) y_pred, x_output, _ = model(batch) regression_loss = self.loss_fn(y_pred, batch[-3]) pretrain_loss = regression_loss if not pretrain_loss.requires_grad: raise RuntimeError("loss doesn't require grad") self._optimizer.zero_grad() pretrain_loss.backward() self._optimizer.step() batch_losses.append(pretrain_loss.item()) batch_regression_losses.append(regression_loss.item()) train_losses.append(sum(batch_losses)) if self.logger is not None: self.logger.info('Epoch: %d | Pretrain Loss: %.4f | Regression Loss: %.4f | Ranking Loss: %.4f | Lr: %.4f | Time: %.4f' % ( epoch + start_epoch, sum(batch_losses), sum(batch_regression_losses), 0.0, self._optimizer.param_groups[0]['lr'], time() - epoch_start)) if (not ddp) or (ddp and dist.get_rank() == 0): # evaluate if self.config.train.eval: eval_acc, eval_rmse = self.evaluate_pointwise('val', verbose=1, logger=self.logger) val_matric.append(eval_acc) if self.config.train.scheduler.type == "plateau": self._scheduler.step(train_losses[-1]) else: self._scheduler.step() val_list = { 'cur_epoch': epoch + start_epoch, 'best_matric': best_matric, } self.save(self.config.train.save_path, 'latest', ddp, val_list) if sum(batch_losses) < best_loss: best_loss = sum(batch_losses) self.save(self.config.train.save_path, 'best_loss', ddp, val_list) if val_matric[-1] > best_matric: best_matric = val_matric[-1] if self.config.train.save: print('saving checkpoint') val_list = { 'cur_epoch': epoch + start_epoch, 'best_matric': best_matric, } self.save(self.config.train.save_path, epoch + start_epoch, ddp, val_list) torch.cuda.empty_cache() if epoch % self.config.train.pretrain_regression_loss_lambda_degrade_epoch == 0: self.config.train.pretrain_regression_loss_lambda *= self.config.train.pretrain_regression_loss_lambda_degrade_ratio self.best_matric = best_matric self.start_epoch = start_epoch + num_epochs print('optimization finished.') print('Total time elapsed: %.5fs' % (time() - train_start)) def train_pointwise_multi_task(self, verbose=1, ddp=False): self.logger = self.config.logger train_start = time() num_epochs = self.config.train.pretrain_epochs if ddp and self.config.train.use_memory_efficient_dataset != 'v1': train_sampler = torch.utils.data.distributed.DistributedSampler(self.train_set) dataloader = DataLoader(self.train_set, batch_size=self.config.train.batch_size, collate_fn=dataset.collate_pdbbind_affinity_multi_task, num_workers=self.config.train.num_workers, drop_last=True, sampler=train_sampler) else: dataloader = DataLoader(self.train_set, batch_size=self.config.train.batch_size, shuffle=self.config.train.shuffle, collate_fn=dataset.collate_pdbbind_affinity_multi_task, num_workers=self.config.train.num_workers, drop_last=True) model = self._model if self.logger is not None: self.logger.info(self.config) self.logger.info('trainable params in model: {:.2f}M'.format( sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6)) self.logger.info('start training...') train_losses = [] val_matric = [] best_matric = self.best_matric best_loss = 1000000 start_epoch = self.start_epoch early_stop = 0 for epoch in range(num_epochs): # train model.train() epoch_start = time() batch_losses, batch_regression_losses = [], [] batch_regression_ic50_losses, batch_regression_kd_losses, batch_regression_ki_losses = [], [], [] if ddp and self.config.train.use_memory_efficient_dataset != 'v1': dataloader.sampler.set_epoch(epoch) for batch in dataloader: if self.device.type == "cuda": batch = self.trans_device(batch) (regression_loss_IC50, regression_loss_Kd, regression_loss_Ki), \ (affinity_pred_IC50, affinity_pred_Kd, affinity_pred_Ki), \ (affinity_IC50, affinity_Kd, affinity_Ki) = model(batch, ASRP=False) pretrain_loss = self.config.train.pretrain_mtl_IC50_lambda * regression_loss_IC50 + \ self.config.train.pretrain_mtl_Kd_lambda * regression_loss_Kd + \ self.config.train.pretrain_mtl_Ki_lambda * regression_loss_Ki if not pretrain_loss.requires_grad: raise RuntimeError("loss doesn't require grad") self._optimizer.zero_grad() pretrain_loss.backward() self._optimizer.step() batch_losses.append(pretrain_loss.item()) batch_regression_ic50_losses.append(self.config.train.pretrain_mtl_IC50_lambda * regression_loss_IC50.item()) batch_regression_kd_losses.append(self.config.train.pretrain_mtl_Kd_lambda * regression_loss_Kd.item()) batch_regression_ki_losses.append(self.config.train.pretrain_mtl_Ki_lambda * regression_loss_Ki.item()) train_losses.append(sum(batch_losses)) if self.logger is not None: self.logger.info('Epoch: %d | Pretrain Loss: %.4f | ' 'Regression IC50 Loss: %.4f | Regression Kd Loss: %.4f | Regression Ki Loss: %.4f | ' 'Lr: %.4f | Time: %.4f' % ( epoch + start_epoch, sum(batch_losses), sum(batch_regression_ic50_losses), sum(batch_regression_kd_losses), sum(batch_regression_ki_losses), self._optimizer.param_groups[0]['lr'], time() - epoch_start)) if (not ddp) or (ddp and dist.get_rank() == 0): if self.config.train.scheduler.type == "plateau": self._scheduler.step(train_losses[-1]) else: self._scheduler.step() val_list = { 'cur_epoch': epoch + start_epoch, 'best_matric': best_matric, } self.save(self.config.train.save_path, 'latest', ddp, val_list) torch.cuda.empty_cache() if epoch % self.config.train.pretrain_regression_loss_lambda_degrade_epoch == 0: self.config.train.pretrain_regression_loss_lambda *= self.config.train.pretrain_regression_loss_lambda_degrade_ratio self.best_matric = best_matric self.start_epoch = start_epoch + num_epochs print('optimization finished.') print('Total time elapsed: %.5fs' % (time() - train_start)) def train_pointwise_multi_task_v2(self, verbose=1, ddp=False): self.logger = self.config.logger train_start = time() num_epochs = self.config.train.pretrain_epochs if ddp and self.config.train.use_memory_efficient_dataset != 'v1': train_sampler = torch.utils.data.distributed.DistributedSampler(self.train_set) dataloader = DataLoader(self.train_set, batch_size=self.config.train.batch_size, collate_fn=dataset.collate_pdbbind_affinity_multi_task_v2, num_workers=self.config.train.num_workers, drop_last=True, sampler=train_sampler) else: dataloader = DataLoader(self.train_set, batch_size=self.config.train.batch_size, shuffle=self.config.train.shuffle, collate_fn=dataset.collate_pdbbind_affinity_multi_task_v2, num_workers=self.config.train.num_workers, drop_last=True) model = self._model if self.logger is not None: self.logger.info(self.config) self.logger.info('trainable params in model: {:.2f}M'.format( sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6)) self.logger.info('start training...') train_losses = [] val_matric = [] best_matric = self.best_matric best_loss = 1000000 start_epoch = self.start_epoch early_stop = 0 for epoch in range(num_epochs): # train model.train() epoch_start = time() batch_losses, batch_regression_losses = [], [] batch_regression_ic50_losses, batch_regression_k_losses = [], [] if ddp and self.config.train.use_memory_efficient_dataset != 'v1': dataloader.sampler.set_epoch(epoch) for batch in dataloader: if self.device.type == "cuda": batch = self.trans_device(batch) (regression_loss_IC50, regression_loss_K), \ (affinity_pred_IC50, affinity_pred_K), \ (affinity_IC50, affinity_K) = model(batch, ASRP=False) pretrain_loss = self.config.train.pretrain_mtl_IC50_lambda * regression_loss_IC50 + \ self.config.train.pretrain_mtl_K_lambda * regression_loss_K if not pretrain_loss.requires_grad: raise RuntimeError("loss doesn't require grad") self._optimizer.zero_grad() pretrain_loss.backward() self._optimizer.step() batch_losses.append(pretrain_loss.item()) batch_regression_ic50_losses.append(self.config.train.pretrain_mtl_IC50_lambda * regression_loss_IC50.item()) batch_regression_k_losses.append(self.config.train.pretrain_mtl_K_lambda * regression_loss_K.item()) train_losses.append(sum(batch_losses)) if self.logger is not None: self.logger.info('Epoch: %d | Pretrain Loss: %.4f | ' 'Regression IC50 Loss: %.4f | Regression K Loss: %.4f | ' 'Lr: %.4f | Time: %.4f' % ( epoch + start_epoch, sum(batch_losses), sum(batch_regression_ic50_losses), sum(batch_regression_k_losses), self._optimizer.param_groups[0]['lr'], time() - epoch_start)) if (not ddp) or (ddp and dist.get_rank() == 0): if self.config.train.scheduler.type == "plateau": self._scheduler.step(train_losses[-1]) else: self._scheduler.step() val_list = { 'cur_epoch': epoch + start_epoch, 'best_matric': best_matric, } self.save(self.config.train.save_path, 'latest', ddp, val_list) torch.cuda.empty_cache() if epoch % self.config.train.pretrain_regression_loss_lambda_degrade_epoch == 0: self.config.train.pretrain_regression_loss_lambda *= self.config.train.pretrain_regression_loss_lambda_degrade_ratio self.best_matric = best_matric self.start_epoch = start_epoch + num_epochs print('optimization finished.') print('Total time elapsed: %.5fs' % (time() - train_start))