Spaces:
Sleeping
Sleeping
import os | |
import time | |
import torch.nn | |
import torch | |
from utils import * | |
from models import Generator, Generator2, simple_disc | |
import torch_geometric.utils as geoutils | |
#import #wandb | |
import re | |
from torch_geometric.loader import DataLoader | |
from new_dataloader import DruggenDataset | |
import torch.utils.data | |
from rdkit import RDLogger | |
import pickle | |
from rdkit.Chem.Scaffolds import MurckoScaffold | |
torch.set_num_threads(5) | |
RDLogger.DisableLog('rdApp.*') | |
from loss import discriminator_loss, generator_loss, discriminator2_loss, generator2_loss | |
from training_data import load_data | |
import random | |
class Trainer(object): | |
"""Trainer for training and testing DrugGEN.""" | |
def __init__(self, config): | |
self.device = torch.device("cuda" if torch.cuda.is_available() else 'cpu') | |
"""Initialize configurations.""" | |
self.submodel = config.submodel | |
self.inference_model = config.inference_model | |
# Data loader. | |
self.raw_file = config.raw_file # SMILES containing text file for first dataset. | |
# Write the full path to file. | |
self.drug_raw_file = config.drug_raw_file # SMILES containing text file for second dataset. | |
# Write the full path to file. | |
self.dataset_file = config.dataset_file # Dataset file name for the first GAN. | |
# Contains large number of molecules. | |
self.drugs_dataset_file = config.drug_dataset_file # Drug dataset file name for the second GAN. | |
# Contains drug molecules only. (In this case AKT1 inhibitors.) | |
self.inf_raw_file = config.inf_raw_file # SMILES containing text file for first dataset. | |
# Write the full path to file. | |
self.inf_drug_raw_file = config.inf_drug_raw_file # SMILES containing text file for second dataset. | |
# Write the full path to file. | |
self.inf_dataset_file = config.inf_dataset_file # Dataset file name for the first GAN. | |
# Contains large number of molecules. | |
self.inf_drugs_dataset_file = config.inf_drug_dataset_file # Drug dataset file name for the second GAN. | |
# Contains drug molecules only. (In this case AKT1 inhibitors.) | |
self.mol_data_dir = config.mol_data_dir # Directory where the dataset files are stored. | |
self.drug_data_dir = config.drug_data_dir # Directory where the drug dataset files are stored. | |
self.dataset_name = self.dataset_file.split(".")[0] | |
self.drugs_name = self.drugs_dataset_file.split(".")[0] | |
self.max_atom = config.max_atom # Model is based on one-shot generation. | |
# Max atom number for molecules must be specified. | |
self.features = config.features # Small model uses atom types as node features. (Boolean, False uses atom types only.) | |
# Additional node features can be added. Please check new_dataloarder.py Line 102. | |
self.batch_size = config.batch_size # Batch size for training. | |
self.dataset = DruggenDataset(self.mol_data_dir, | |
self.dataset_file, | |
self.raw_file, | |
self.max_atom, | |
self.features) # Dataset for the first GAN. Custom dataset class from PyG parent class. | |
# Can create any molecular graph dataset given smiles string. | |
# Nonisomeric SMILES are suggested but not necessary. | |
# Uses sparse matrix representation for graphs, | |
# For computational and speed efficiency. | |
self.loader = DataLoader(self.dataset, | |
shuffle=True, | |
batch_size=self.batch_size, | |
drop_last=True) # PyG dataloader for the first GAN. | |
self.drugs = DruggenDataset(self.drug_data_dir, | |
self.drugs_dataset_file, | |
self.drug_raw_file, | |
self.max_atom, | |
self.features) # Dataset for the second GAN. Custom dataset class from PyG parent class. | |
# Can create any molecular graph dataset given smiles string. | |
# Nonisomeric SMILES are suggested but not necessary. | |
# Uses sparse matrix representation for graphs, | |
# For computational and speed efficiency. | |
self.drugs_loader = DataLoader(self.drugs, | |
shuffle=True, | |
batch_size=self.batch_size, | |
drop_last=True) # PyG dataloader for the second GAN. | |
# Atom and bond type dimensions for the construction of the model. | |
self.atom_decoders = self.decoder_load("atom") # Atom type decoders for first GAN. | |
# eg. 0:0, 1:6 (C), 2:7 (N), 3:8 (O), 4:9 (F) | |
self.bond_decoders = self.decoder_load("bond") # Bond type decoders for first GAN. | |
# eg. 0: (no-bond), 1: (single), 2: (double), 3: (triple), 4: (aromatic) | |
self.m_dim = len(self.atom_decoders) if not self.features else int(self.loader.dataset[0].x.shape[1]) # Atom type dimension. | |
self.b_dim = len(self.bond_decoders) # Bond type dimension. | |
self.vertexes = int(self.loader.dataset[0].x.shape[0]) # Number of nodes in the graph. | |
self.drugs_atom_decoders = self.drug_decoder_load("atom") # Atom type decoders for second GAN. | |
# eg. 0:0, 1:6 (C), 2:7 (N), 3:8 (O), 4:9 (F) | |
self.drugs_bond_decoders = self.drug_decoder_load("bond") # Bond type decoders for second GAN. | |
# eg. 0: (no-bond), 1: (single), 2: (double), 3: (triple), 4: (aromatic) | |
self.drugs_m_dim = len(self.drugs_atom_decoders) if not self.features else int(self.drugs_loader.dataset[0].x.shape[1]) # Atom type dimension. | |
self.drugs_b_dim = len(self.drugs_bond_decoders) # Bond type dimension. | |
self.drug_vertexes = int(self.drugs_loader.dataset[0].x.shape[0]) # Number of nodes in the graph. | |
# Transformer and Convolution configurations. | |
self.act = config.act | |
self.z_dim = config.z_dim | |
self.lambda_gp = config.lambda_gp | |
self.dim = config.dim | |
self.depth = config.depth | |
self.heads = config.heads | |
self.mlp_ratio = config.mlp_ratio | |
self.dec_depth = config.dec_depth | |
self.dec_heads = config.dec_heads | |
self.dec_dim = config.dec_dim | |
self.dis_select = config.dis_select | |
"""self.la = config.la | |
self.la2 = config.la2 | |
self.gcn_depth = config.gcn_depth | |
self.g_conv_dim = config.g_conv_dim | |
self.d_conv_dim = config.d_conv_dim""" | |
"""# PNA config | |
self.agg = config.aggregators | |
self.sca = config.scalers | |
self.pna_in_ch = config.pna_in_ch | |
self.pna_out_ch = config.pna_out_ch | |
self.edge_dim = config.edge_dim | |
self.towers = config.towers | |
self.pre_lay = config.pre_lay | |
self.post_lay = config.post_lay | |
self.pna_layer_num = config.pna_layer_num | |
self.graph_add = config.graph_add""" | |
# Training configurations. | |
self.epoch = config.epoch | |
self.g_lr = config.g_lr | |
self.d_lr = config.d_lr | |
self.g2_lr = config.g2_lr | |
self.d2_lr = config.d2_lr | |
self.dropout = config.dropout | |
self.dec_dropout = config.dec_dropout | |
self.n_critic = config.n_critic | |
self.beta1 = config.beta1 | |
self.beta2 = config.beta2 | |
self.resume_iters = config.resume_iters | |
self.warm_up_steps = config.warm_up_steps | |
# Test configurations. | |
self.num_test_epoch = config.num_test_epoch | |
self.test_iters = config.test_iters | |
self.inference_sample_num = config.inference_sample_num | |
# Directories. | |
self.log_dir = config.log_dir | |
self.sample_dir = config.sample_dir | |
self.model_save_dir = config.model_save_dir | |
self.result_dir = config.result_dir | |
# Step size. | |
self.log_step = config.log_sample_step | |
self.clipping_value = config.clipping_value | |
# Miscellaneous. | |
self.mode = config.mode | |
self.noise_strength_0 = torch.nn.Parameter(torch.zeros([])) | |
self.noise_strength_1 = torch.nn.Parameter(torch.zeros([])) | |
self.noise_strength_2 = torch.nn.Parameter(torch.zeros([])) | |
self.noise_strength_3 = torch.nn.Parameter(torch.zeros([])) | |
self.init_type = config.init_type | |
self.build_model() | |
def build_model(self): | |
"""Create generators and discriminators.""" | |
''' Generator is based on Transformer Encoder: | |
@ g_conv_dim: Dimensions for first MLP layers before Transformer Encoder | |
@ vertexes: maximum length of generated molecules (atom length) | |
@ b_dim: number of bond types | |
@ m_dim: number of atom types (or number of features used) | |
@ dropout: dropout possibility | |
@ dim: Hidden dimension of Transformer Encoder | |
@ depth: Transformer layer number | |
@ heads: Number of multihead-attention heads | |
@ mlp_ratio: Read-out layer dimension of Transformer | |
@ drop_rate: depricated | |
@ tra_conv: Whether module creates output for TransformerConv discriminator | |
''' | |
self.G = Generator(self.z_dim, | |
self.act, | |
self.vertexes, | |
self.b_dim, | |
self.m_dim, | |
self.dropout, | |
dim=self.dim, | |
depth=self.depth, | |
heads=self.heads, | |
mlp_ratio=self.mlp_ratio, | |
submodel = self.submodel) | |
self.G2 = Generator2(self.dim, | |
self.dec_dim, | |
self.dec_depth, | |
self.dec_heads, | |
self.mlp_ratio, | |
self.dec_dropout, | |
self.drugs_m_dim, | |
self.drugs_b_dim, | |
self.submodel) | |
''' Discriminator implementation with PNA: | |
@ deg: Degree distribution based on used data. (Created with _genDegree() function) | |
@ agg: aggregators used in PNA | |
@ sca: scalers used in PNA | |
@ pna_in_ch: First PNA hidden dimension | |
@ pna_out_ch: Last PNA hidden dimension | |
@ edge_dim: Edge hidden dimension | |
@ towers: Number of towers (Splitting the hidden dimension to multiple parallel processes) | |
@ pre_lay: Pre-transformation layer | |
@ post_lay: Post-transformation layer | |
@ pna_layer_num: number of PNA layers | |
@ graph_add: global pooling layer selection | |
''' | |
''' Discriminator implementation with Graph Convolution: | |
@ d_conv_dim: convolution dimensions for GCN | |
@ m_dim: number of atom types (or number of features used) | |
@ b_dim: number of bond types | |
@ dropout: dropout possibility | |
''' | |
''' Discriminator implementation with MLP: | |
@ act: Activation function for MLP | |
@ m_dim: number of atom types (or number of features used) | |
@ b_dim: number of bond types | |
@ dropout: dropout possibility | |
@ vertexes: maximum length of generated molecules (molecule length) | |
''' | |
#self.D = Discriminator_old(self.d_conv_dim, self.m_dim , self.b_dim, self.dropout, self.gcn_depth) | |
self.D2 = simple_disc("tanh", self.drugs_m_dim, self.drug_vertexes, self.drugs_b_dim) | |
self.D = simple_disc("tanh", self.m_dim, self.vertexes, self.b_dim) | |
self.V = simple_disc("tanh", self.m_dim, self.vertexes, self.b_dim) | |
self.V2 = simple_disc("tanh", self.drugs_m_dim, self.drug_vertexes, self.drugs_b_dim) | |
''' Optimizers for G1, G2, D1, and D2: | |
Adam Optimizer is used and different beta1 and beta2s are used for GAN1 and GAN2 | |
''' | |
self.g_optimizer = torch.optim.AdamW(self.G.parameters(), self.g_lr, [self.beta1, self.beta2]) | |
self.g2_optimizer = torch.optim.AdamW(self.G2.parameters(), self.g2_lr, [self.beta1, self.beta2]) | |
self.d_optimizer = torch.optim.AdamW(self.D.parameters(), self.d_lr, [self.beta1, self.beta2]) | |
self.d2_optimizer = torch.optim.AdamW(self.D2.parameters(), self.d2_lr, [self.beta1, self.beta2]) | |
self.v_optimizer = torch.optim.AdamW(self.V.parameters(), self.d_lr, [self.beta1, self.beta2]) | |
self.v2_optimizer = torch.optim.AdamW(self.V2.parameters(), self.d2_lr, [self.beta1, self.beta2]) | |
''' Learning rate scheduler: | |
Changes learning rate based on loss. | |
''' | |
#self.scheduler_g = ReduceLROnPlateau(self.g_optimizer, mode='min', factor=0.5, patience=10, min_lr=0.00001) | |
#self.scheduler_d = ReduceLROnPlateau(self.d_optimizer, mode='min', factor=0.5, patience=10, min_lr=0.00001) | |
#self.scheduler_v = ReduceLROnPlateau(self.v_optimizer, mode='min', factor=0.5, patience=10, min_lr=0.00001) | |
#self.scheduler_g2 = ReduceLROnPlateau(self.g2_optimizer, mode='min', factor=0.5, patience=10, min_lr=0.00001) | |
#self.scheduler_d2 = ReduceLROnPlateau(self.d2_optimizer, mode='min', factor=0.5, patience=10, min_lr=0.00001) | |
#self.scheduler_v2 = ReduceLROnPlateau(self.v2_optimizer, mode='min', factor=0.5, patience=10, min_lr=0.00001) | |
self.print_network(self.G, 'G') | |
self.print_network(self.D, 'D') | |
self.print_network(self.G2, 'G2') | |
self.print_network(self.D2, 'D2') | |
self.G.to(self.device) | |
self.D.to(self.device) | |
self.V.to(self.device) | |
self.V2.to(self.device) | |
self.G2.to(self.device) | |
self.D2.to(self.device) | |
#self.V2.to(self.device) | |
#self.modules_of_the_model = (self.G, self.D, self.G2, self.D2) | |
"""for p in self.G.parameters(): | |
if p.dim() > 1: | |
if self.init_type == 'uniform': | |
torch.nn.init.xavier_uniform_(p) | |
elif self.init_type == 'normal': | |
torch.nn.init.xavier_normal_(p) | |
elif self.init_type == 'random_normal': | |
torch.nn.init.normal_(p, 0.0, 0.02) | |
for p in self.G2.parameters(): | |
if p.dim() > 1: | |
if self.init_type == 'uniform': | |
torch.nn.init.xavier_uniform_(p) | |
elif self.init_type == 'normal': | |
torch.nn.init.xavier_normal_(p) | |
elif self.init_type == 'random_normal': | |
torch.nn.init.normal_(p, 0.0, 0.02) | |
if self.dis_select == "conv": | |
for p in self.D.parameters(): | |
if p.dim() > 1: | |
if self.init_type == 'uniform': | |
torch.nn.init.xavier_uniform_(p) | |
elif self.init_type == 'normal': | |
torch.nn.init.xavier_normal_(p) | |
elif self.init_type == 'random_normal': | |
torch.nn.init.normal_(p, 0.0, 0.02) | |
if self.dis_select == "conv": | |
for p in self.D2.parameters(): | |
if p.dim() > 1: | |
if self.init_type == 'uniform': | |
torch.nn.init.xavier_uniform_(p) | |
elif self.init_type == 'normal': | |
torch.nn.init.xavier_normal_(p) | |
elif self.init_type == 'random_normal': | |
torch.nn.init.normal_(p, 0.0, 0.02)""" | |
def decoder_load(self, dictionary_name): | |
''' Loading the atom and bond decoders''' | |
with open("DrugGEN/data/decoders/" + dictionary_name + "_" + self.dataset_name + '.pkl', 'rb') as f: | |
return pickle.load(f) | |
def drug_decoder_load(self, dictionary_name): | |
''' Loading the atom and bond decoders''' | |
with open("DrugGEN/data/decoders/" + dictionary_name +"_" + self.drugs_name +'.pkl', 'rb') as f: | |
return pickle.load(f) | |
def print_network(self, model, name): | |
"""Print out the network information.""" | |
num_params = 0 | |
for p in model.parameters(): | |
num_params += p.numel() | |
print(model) | |
print(name) | |
print("The number of parameters: {}".format(num_params)) | |
def restore_model(self, epoch, iteration, model_directory): | |
"""Restore the trained generator and discriminator.""" | |
print('Loading the trained models from epoch / iteration {}-{}...'.format(epoch, iteration)) | |
G_path = os.path.join(model_directory, '{}-{}-G.ckpt'.format(epoch, iteration)) | |
#D_path = os.path.join(model_directory, '{}-{}-D.ckpt'.format(epoch, iteration)) | |
self.G.load_state_dict(torch.load(G_path, map_location=lambda storage, loc: storage)) | |
#self.D.load_state_dict(torch.load(D_path, map_location=lambda storage, loc: storage)) | |
G2_path = os.path.join(model_directory, '{}-{}-G2.ckpt'.format(epoch, iteration)) | |
#D2_path = os.path.join(model_directory, '{}-{}-D2.ckpt'.format(epoch, iteration)) | |
self.G2.load_state_dict(torch.load(G2_path, map_location=lambda storage, loc: storage)) | |
#self.D2.load_state_dict(torch.load(D2_path, map_location=lambda storage, loc: storage)) | |
def save_model(self, model_directory, idx,i): | |
G_path = os.path.join(model_directory, '{}-{}-G.ckpt'.format(idx+1,i+1)) | |
D_path = os.path.join(model_directory, '{}-{}-D.ckpt'.format(idx+1,i+1)) | |
torch.save(self.G.state_dict(), G_path) | |
torch.save(self.D.state_dict(), D_path) | |
if self.submodel != "NoTarget" and self.submodel != "CrossLoss": | |
G2_path = os.path.join(model_directory, '{}-{}-G2.ckpt'.format(idx+1,i+1)) | |
D2_path = os.path.join(model_directory, '{}-{}-D2.ckpt'.format(idx+1,i+1)) | |
torch.save(self.G2.state_dict(), G2_path) | |
torch.save(self.D2.state_dict(), D2_path) | |
def reset_grad(self): | |
"""Reset the gradient buffers.""" | |
self.g_optimizer.zero_grad() | |
self.v_optimizer.zero_grad() | |
self.g2_optimizer.zero_grad() | |
self.v2_optimizer.zero_grad() | |
self.d_optimizer.zero_grad() | |
self.d2_optimizer.zero_grad() | |
def gradient_penalty(self, y, x): | |
"""Compute gradient penalty: (L2_norm(dy/dx) - 1)**2.""" | |
weight = torch.ones(y.size(),requires_grad=False).to(self.device) | |
dydx = torch.autograd.grad(outputs=y, | |
inputs=x, | |
grad_outputs=weight, | |
retain_graph=True, | |
create_graph=True, | |
only_inputs=True)[0] | |
dydx = dydx.view(dydx.size(0), -1) | |
gradient_penalty = ((dydx.norm(2, dim=1) - 1) ** 2).mean() | |
return gradient_penalty | |
def train(self): | |
''' Training Script starts from here''' | |
#wandb.config = {'beta2': 0.999} | |
#wandb.init(project="DrugGEN2", entity="atabeyunlu") | |
# Defining sampling paths and creating logger | |
self.arguments = "{}_glr{}_dlr{}_g2lr{}_d2lr{}_dim{}_depth{}_heads{}_decdepth{}_decheads{}_ncritic{}_batch{}_epoch{}_warmup{}_dataset{}_dropout{}".format(self.submodel,self.g_lr,self.d_lr,self.g2_lr,self.d2_lr,self.dim,self.depth,self.heads,self.dec_depth,self.dec_heads,self.n_critic,self.batch_size,self.epoch,self.warm_up_steps,self.dataset_name,self.dropout) | |
self.model_directory= os.path.join(self.model_save_dir,self.arguments) | |
self.sample_directory=os.path.join(self.sample_dir,self.arguments) | |
self.log_path = os.path.join(self.log_dir, "{}.txt".format(self.arguments)) | |
if not os.path.exists(self.model_directory): | |
os.makedirs(self.model_directory) | |
if not os.path.exists(self.sample_directory): | |
os.makedirs(self.sample_directory) | |
# Learning rate cache for decaying. | |
# protein data | |
full_smiles = [line for line in open("DrugGEN/data/chembl_train.smi", 'r').read().splitlines()] | |
drug_smiles = [line for line in open("DrugGEN/data/akt_train.smi", 'r').read().splitlines()] | |
drug_mols = [Chem.MolFromSmiles(smi) for smi in drug_smiles] | |
drug_scaf = [MurckoScaffold.GetScaffoldForMol(x) for x in drug_mols] | |
fps_r = [Chem.RDKFingerprint(x) for x in drug_scaf] | |
akt1_human_adj = torch.load("DrugGEN/data/akt/AKT1_human_adj.pt").reshape(1,-1).to(self.device).float() | |
akt1_human_annot = torch.load("DrugGEN/data/akt/AKT1_human_annot.pt").reshape(1,-1).to(self.device).float() | |
# Start training. | |
print('Start training...') | |
self.start_time = time.time() | |
for idx in range(self.epoch): | |
# =================================================================================== # | |
# 1. Preprocess input data # | |
# =================================================================================== # | |
# Load the data | |
dataloader_iterator = iter(self.drugs_loader) | |
for i, data in enumerate(self.loader): | |
try: | |
drugs = next(dataloader_iterator) | |
except StopIteration: | |
dataloader_iterator = iter(self.drugs_loader) | |
drugs = next(dataloader_iterator) | |
# Preprocess both dataset | |
bulk_data = load_data(data, | |
drugs, | |
self.batch_size, | |
self.device, | |
self.b_dim, | |
self.m_dim, | |
self.drugs_b_dim, | |
self.drugs_m_dim, | |
self.z_dim, | |
self.vertexes) | |
drug_graphs, real_graphs, a_tensor, x_tensor, drugs_a_tensor, drugs_x_tensor, z, z_edge, z_node = bulk_data | |
if self.submodel == "CrossLoss": | |
GAN1_input_e = drugs_a_tensor | |
GAN1_input_x = drugs_x_tensor | |
GAN1_disc_e = a_tensor | |
GAN1_disc_x = x_tensor | |
elif self.submodel == "Ligand": | |
GAN1_input_e = a_tensor | |
GAN1_input_x = x_tensor | |
GAN1_disc_e = a_tensor | |
GAN1_disc_x = x_tensor | |
GAN2_input_e = drugs_a_tensor | |
GAN2_input_x = drugs_x_tensor | |
GAN2_disc_e = drugs_a_tensor | |
GAN2_disc_x = drugs_x_tensor | |
elif self.submodel == "Prot": | |
GAN1_input_e = a_tensor | |
GAN1_input_x = x_tensor | |
GAN1_disc_e = a_tensor | |
GAN1_disc_x = x_tensor | |
GAN2_input_e = akt1_human_adj | |
GAN2_input_x = akt1_human_annot | |
GAN2_disc_e = drugs_a_tensor | |
GAN2_disc_x = drugs_x_tensor | |
elif self.submodel == "RL": | |
GAN1_input_e = z_edge | |
GAN1_input_x = z_node | |
GAN1_disc_e = a_tensor | |
GAN1_disc_x = x_tensor | |
GAN2_input_e = drugs_a_tensor | |
GAN2_input_x = drugs_x_tensor | |
GAN2_disc_e = drugs_a_tensor | |
GAN2_disc_x = drugs_x_tensor | |
elif self.submodel == "NoTarget": | |
GAN1_input_e = z_edge | |
GAN1_input_x = z_node | |
GAN1_disc_e = a_tensor | |
GAN1_disc_x = x_tensor | |
# =================================================================================== # | |
# 2. Train the discriminator # | |
# =================================================================================== # | |
loss = {} | |
self.reset_grad() | |
# Compute discriminator loss. | |
node, edge, d_loss = discriminator_loss(self.G, | |
self.D, | |
real_graphs, | |
GAN1_disc_e, | |
GAN1_disc_x, | |
self.batch_size, | |
self.device, | |
self.gradient_penalty, | |
self.lambda_gp, | |
GAN1_input_e, | |
GAN1_input_x) | |
d_total = d_loss | |
if self.submodel != "NoTarget" and self.submodel != "CrossLoss": | |
d2_loss = discriminator2_loss(self.G2, | |
self.D2, | |
drug_graphs, | |
edge, | |
node, | |
self.batch_size, | |
self.device, | |
self.gradient_penalty, | |
self.lambda_gp, | |
GAN2_input_e, | |
GAN2_input_x) | |
d_total = d_loss + d2_loss | |
loss["d_total"] = d_total.item() | |
d_total.backward() | |
self.d_optimizer.step() | |
if self.submodel != "NoTarget" and self.submodel != "CrossLoss": | |
self.d2_optimizer.step() | |
self.reset_grad() | |
generator_output = generator_loss(self.G, | |
self.D, | |
self.V, | |
GAN1_input_e, | |
GAN1_input_x, | |
self.batch_size, | |
sim_reward, | |
self.dataset.matrices2mol_drugs, | |
fps_r, | |
self.submodel) | |
g_loss, fake_mol, g_edges_hat_sample, g_nodes_hat_sample, node, edge = generator_output | |
self.reset_grad() | |
g_total = g_loss | |
if self.submodel != "NoTarget" and self.submodel != "CrossLoss": | |
output = generator2_loss(self.G2, | |
self.D2, | |
self.V2, | |
edge, | |
node, | |
self.batch_size, | |
sim_reward, | |
self.dataset.matrices2mol_drugs, | |
fps_r, | |
GAN2_input_e, | |
GAN2_input_x, | |
self.submodel) | |
g2_loss, fake_mol_g, dr_g_edges_hat_sample, dr_g_nodes_hat_sample = output | |
g_total = g_loss + g2_loss | |
loss["g_total"] = g_total.item() | |
g_total.backward() | |
self.g_optimizer.step() | |
if self.submodel != "NoTarget" and self.submodel != "CrossLoss": | |
self.g2_optimizer.step() | |
if self.submodel == "RL": | |
self.v_optimizer.step() | |
self.v2_optimizer.step() | |
if (i+1) % self.log_step == 0: | |
logging(self.log_path, self.start_time, fake_mol, full_smiles, i, idx, loss, 1,self.sample_directory) | |
mol_sample(self.sample_directory,"GAN1",fake_mol, g_edges_hat_sample.detach(), g_nodes_hat_sample.detach(), idx, i) | |
if self.submodel != "NoTarget" and self.submodel != "CrossLoss": | |
logging(self.log_path, self.start_time, fake_mol_g, drug_smiles, i, idx, loss, 2,self.sample_directory) | |
mol_sample(self.sample_directory,"GAN2",fake_mol_g, dr_g_edges_hat_sample.detach(), dr_g_nodes_hat_sample.detach(), idx, i) | |
if (idx+1) % 10 == 0: | |
self.save_model(self.model_directory,idx,i) | |
print("model saved at epoch {} and iteration {}".format(idx,i)) | |
def inference(self): | |
# Load the trained generator. | |
self.G.to(self.device) | |
#self.D.to(self.device) | |
self.G2.to(self.device) | |
#self.D2.to(self.device) | |
G_path = os.path.join(self.inference_model, '{}-G.ckpt'.format(self.submodel)) | |
self.G.load_state_dict(torch.load(G_path, map_location=lambda storage, loc: storage)) | |
G2_path = os.path.join(self.inference_model, '{}-G2.ckpt'.format(self.submodel)) | |
self.G2.load_state_dict(torch.load(G2_path, map_location=lambda storage, loc: storage)) | |
drug_smiles = [line for line in open("DrugGEN/data/akt_test.smi", 'r').read().splitlines()] | |
drug_mols = [Chem.MolFromSmiles(smi) for smi in drug_smiles] | |
drug_scaf = [MurckoScaffold.GetScaffoldForMol(x) for x in drug_mols] | |
fps_r = [Chem.RDKFingerprint(x) for x in drug_scaf] | |
akt1_human_adj = torch.load("DrugGEN/data/akt/AKT1_human_adj.pt").reshape(1,-1).to(self.device).float() | |
akt1_human_annot = torch.load("DrugGEN/data/akt/AKT1_human_annot.pt").reshape(1,-1).to(self.device).float() | |
self.G.eval() | |
#self.D.eval() | |
self.G2.eval() | |
#self.D2.eval() | |
self.inf_batch_size =256 | |
self.inf_dataset = DruggenDataset(self.mol_data_dir, | |
self.inf_dataset_file, | |
self.inf_raw_file, | |
self.max_atom, | |
self.features) # Dataset for the first GAN. Custom dataset class from PyG parent class. | |
# Can create any molecular graph dataset given smiles string. | |
# Nonisomeric SMILES are suggested but not necessary. | |
# Uses sparse matrix representation for graphs, | |
# For computational and speed efficiency. | |
self.inf_loader = DataLoader(self.inf_dataset, | |
shuffle=True, | |
batch_size=self.inf_batch_size, | |
drop_last=True) # PyG dataloader for the first GAN. | |
self.inf_drugs = DruggenDataset(self.drug_data_dir, | |
self.inf_drugs_dataset_file, | |
self.inf_drug_raw_file, | |
self.max_atom, | |
self.features) # Dataset for the second GAN. Custom dataset class from PyG parent class. | |
# Can create any molecular graph dataset given smiles string. | |
# Nonisomeric SMILES are suggested but not necessary. | |
# Uses sparse matrix representation for graphs, | |
# For computational and speed efficiency. | |
self.inf_drugs_loader = DataLoader(self.inf_drugs, | |
shuffle=True, | |
batch_size=self.inf_batch_size, | |
drop_last=True) # PyG dataloader for the second GAN. | |
start_time = time.time() | |
#metric_calc_mol = [] | |
metric_calc_dr = [] | |
date = time.time() | |
if not os.path.exists("DrugGEN/experiments/inference/{}".format(self.submodel)): | |
os.makedirs("DrugGEN/experiments/inference/{}".format(self.submodel)) | |
with torch.inference_mode(): | |
dataloader_iterator = iter(self.drugs_loader) | |
for i, data in enumerate(self.loader): | |
try: | |
drugs = next(dataloader_iterator) | |
except StopIteration: | |
dataloader_iterator = iter(self.drugs_loader) | |
drugs = next(dataloader_iterator) | |
# Preprocess both dataset | |
bulk_data = load_data(data, | |
drugs, | |
self.batch_size, | |
self.device, | |
self.b_dim, | |
self.m_dim, | |
self.drugs_b_dim, | |
self.drugs_m_dim, | |
self.z_dim, | |
self.vertexes) | |
drug_graphs, real_graphs, a_tensor, x_tensor, drugs_a_tensor, drugs_x_tensor, z, z_edge, z_node = bulk_data | |
if self.submodel == "CrossLoss": | |
GAN1_input_e = a_tensor | |
GAN1_input_x = x_tensor | |
GAN1_disc_e = drugs_a_tensor | |
GAN1_disc_x = drugs_x_tensor | |
GAN2_input_e = drugs_a_tensor | |
GAN2_input_x = drugs_x_tensor | |
GAN2_disc_e = a_tensor | |
GAN2_disc_x = x_tensor | |
elif self.submodel == "Ligand": | |
GAN1_input_e = a_tensor | |
GAN1_input_x = x_tensor | |
GAN1_disc_e = a_tensor | |
GAN1_disc_x = x_tensor | |
GAN2_input_e = drugs_a_tensor | |
GAN2_input_x = drugs_x_tensor | |
GAN2_disc_e = drugs_a_tensor | |
GAN2_disc_x = drugs_x_tensor | |
elif self.submodel == "Prot": | |
GAN1_input_e = a_tensor | |
GAN1_input_x = x_tensor | |
GAN1_disc_e = a_tensor | |
GAN1_disc_x = x_tensor | |
GAN2_input_e = akt1_human_adj | |
GAN2_input_x = akt1_human_annot | |
GAN2_disc_e = drugs_a_tensor | |
GAN2_disc_x = drugs_x_tensor | |
elif self.submodel == "RL": | |
GAN1_input_e = z_edge | |
GAN1_input_x = z_node | |
GAN1_disc_e = a_tensor | |
GAN1_disc_x = x_tensor | |
GAN2_input_e = drugs_a_tensor | |
GAN2_input_x = drugs_x_tensor | |
GAN2_disc_e = drugs_a_tensor | |
GAN2_disc_x = drugs_x_tensor | |
elif self.submodel == "NoTarget": | |
GAN1_input_e = z_edge | |
GAN1_input_x = z_node | |
GAN1_disc_e = a_tensor | |
GAN1_disc_x = x_tensor | |
# =================================================================================== # | |
# 2. GAN1 Inference # | |
# =================================================================================== # | |
generator_output = generator_loss(self.G, | |
self.D, | |
self.V, | |
GAN1_input_e, | |
GAN1_input_x, | |
self.batch_size, | |
sim_reward, | |
self.dataset.matrices2mol_drugs, | |
fps_r, | |
self.submodel) | |
_, fake_mol, _, _, node, edge = generator_output | |
# =================================================================================== # | |
# 3. GAN2 Inference # | |
# =================================================================================== # | |
output = generator2_loss(self.G2, | |
self.D2, | |
self.V2, | |
edge, | |
node, | |
self.batch_size, | |
sim_reward, | |
self.dataset.matrices2mol_drugs, | |
fps_r, | |
GAN2_input_e, | |
GAN2_input_x, | |
self.submodel) | |
_, fake_mol_g, _, _ = output | |
inference_drugs = [Chem.MolToSmiles(line) for line in fake_mol_g if line is not None] | |
#inference_smiles = [Chem.MolToSmiles(line) for line in fake_mol] | |
print("molecule batch {} inferred".format(i)) | |
with open("DrugGEN/experiments/inference/{}/inference_drugs.txt".format(self.submodel), "a") as f: | |
for molecules in inference_drugs: | |
f.write(molecules) | |
f.write("\n") | |
metric_calc_dr.append(molecules) | |
if i == 120: | |
break | |
et = time.time() - start_time | |
print("Inference mode is lasted for {:.2f} seconds".format(et)) | |
print("Metrics calculation started using MOSES.") | |
print("Validity: ", fraction_valid(inference_drugs), "\n") | |
print("Uniqueness: ", fraction_unique(inference_drugs), "\n") | |
print("Validity: ", novelty(inference_drugs, drug_smiles), "\n") | |
print("Metrics are calculated.") | |