DrugGEN / trainer.py
osbm's picture
add codes
1a3cfaf
raw
history blame
42.5 kB
import os
import time
import torch.nn
import torch
from utils import *
from models import Generator, Generator2, simple_disc
import torch_geometric.utils as geoutils
#import #wandb
import re
from torch_geometric.loader import DataLoader
from new_dataloader import DruggenDataset
import torch.utils.data
from rdkit import RDLogger
import pickle
from rdkit.Chem.Scaffolds import MurckoScaffold
torch.set_num_threads(5)
RDLogger.DisableLog('rdApp.*')
from loss import discriminator_loss, generator_loss, discriminator2_loss, generator2_loss
from training_data import load_data
import random
class Trainer(object):
"""Trainer for training and testing DrugGEN."""
def __init__(self, config):
self.device = torch.device("cuda" if torch.cuda.is_available() else 'cpu')
"""Initialize configurations."""
self.submodel = config.submodel
self.inference_model = config.inference_model
# Data loader.
self.raw_file = config.raw_file # SMILES containing text file for first dataset.
# Write the full path to file.
self.drug_raw_file = config.drug_raw_file # SMILES containing text file for second dataset.
# Write the full path to file.
self.dataset_file = config.dataset_file # Dataset file name for the first GAN.
# Contains large number of molecules.
self.drugs_dataset_file = config.drug_dataset_file # Drug dataset file name for the second GAN.
# Contains drug molecules only. (In this case AKT1 inhibitors.)
self.inf_raw_file = config.inf_raw_file # SMILES containing text file for first dataset.
# Write the full path to file.
self.inf_drug_raw_file = config.inf_drug_raw_file # SMILES containing text file for second dataset.
# Write the full path to file.
self.inf_dataset_file = config.inf_dataset_file # Dataset file name for the first GAN.
# Contains large number of molecules.
self.inf_drugs_dataset_file = config.inf_drug_dataset_file # Drug dataset file name for the second GAN.
# Contains drug molecules only. (In this case AKT1 inhibitors.)
self.mol_data_dir = config.mol_data_dir # Directory where the dataset files are stored.
self.drug_data_dir = config.drug_data_dir # Directory where the drug dataset files are stored.
self.dataset_name = self.dataset_file.split(".")[0]
self.drugs_name = self.drugs_dataset_file.split(".")[0]
self.max_atom = config.max_atom # Model is based on one-shot generation.
# Max atom number for molecules must be specified.
self.features = config.features # Small model uses atom types as node features. (Boolean, False uses atom types only.)
# Additional node features can be added. Please check new_dataloarder.py Line 102.
self.batch_size = config.batch_size # Batch size for training.
self.dataset = DruggenDataset(self.mol_data_dir,
self.dataset_file,
self.raw_file,
self.max_atom,
self.features) # Dataset for the first GAN. Custom dataset class from PyG parent class.
# Can create any molecular graph dataset given smiles string.
# Nonisomeric SMILES are suggested but not necessary.
# Uses sparse matrix representation for graphs,
# For computational and speed efficiency.
self.loader = DataLoader(self.dataset,
shuffle=True,
batch_size=self.batch_size,
drop_last=True) # PyG dataloader for the first GAN.
self.drugs = DruggenDataset(self.drug_data_dir,
self.drugs_dataset_file,
self.drug_raw_file,
self.max_atom,
self.features) # Dataset for the second GAN. Custom dataset class from PyG parent class.
# Can create any molecular graph dataset given smiles string.
# Nonisomeric SMILES are suggested but not necessary.
# Uses sparse matrix representation for graphs,
# For computational and speed efficiency.
self.drugs_loader = DataLoader(self.drugs,
shuffle=True,
batch_size=self.batch_size,
drop_last=True) # PyG dataloader for the second GAN.
# Atom and bond type dimensions for the construction of the model.
self.atom_decoders = self.decoder_load("atom") # Atom type decoders for first GAN.
# eg. 0:0, 1:6 (C), 2:7 (N), 3:8 (O), 4:9 (F)
self.bond_decoders = self.decoder_load("bond") # Bond type decoders for first GAN.
# eg. 0: (no-bond), 1: (single), 2: (double), 3: (triple), 4: (aromatic)
self.m_dim = len(self.atom_decoders) if not self.features else int(self.loader.dataset[0].x.shape[1]) # Atom type dimension.
self.b_dim = len(self.bond_decoders) # Bond type dimension.
self.vertexes = int(self.loader.dataset[0].x.shape[0]) # Number of nodes in the graph.
self.drugs_atom_decoders = self.drug_decoder_load("atom") # Atom type decoders for second GAN.
# eg. 0:0, 1:6 (C), 2:7 (N), 3:8 (O), 4:9 (F)
self.drugs_bond_decoders = self.drug_decoder_load("bond") # Bond type decoders for second GAN.
# eg. 0: (no-bond), 1: (single), 2: (double), 3: (triple), 4: (aromatic)
self.drugs_m_dim = len(self.drugs_atom_decoders) if not self.features else int(self.drugs_loader.dataset[0].x.shape[1]) # Atom type dimension.
self.drugs_b_dim = len(self.drugs_bond_decoders) # Bond type dimension.
self.drug_vertexes = int(self.drugs_loader.dataset[0].x.shape[0]) # Number of nodes in the graph.
# Transformer and Convolution configurations.
self.act = config.act
self.z_dim = config.z_dim
self.lambda_gp = config.lambda_gp
self.dim = config.dim
self.depth = config.depth
self.heads = config.heads
self.mlp_ratio = config.mlp_ratio
self.dec_depth = config.dec_depth
self.dec_heads = config.dec_heads
self.dec_dim = config.dec_dim
self.dis_select = config.dis_select
"""self.la = config.la
self.la2 = config.la2
self.gcn_depth = config.gcn_depth
self.g_conv_dim = config.g_conv_dim
self.d_conv_dim = config.d_conv_dim"""
"""# PNA config
self.agg = config.aggregators
self.sca = config.scalers
self.pna_in_ch = config.pna_in_ch
self.pna_out_ch = config.pna_out_ch
self.edge_dim = config.edge_dim
self.towers = config.towers
self.pre_lay = config.pre_lay
self.post_lay = config.post_lay
self.pna_layer_num = config.pna_layer_num
self.graph_add = config.graph_add"""
# Training configurations.
self.epoch = config.epoch
self.g_lr = config.g_lr
self.d_lr = config.d_lr
self.g2_lr = config.g2_lr
self.d2_lr = config.d2_lr
self.dropout = config.dropout
self.dec_dropout = config.dec_dropout
self.n_critic = config.n_critic
self.beta1 = config.beta1
self.beta2 = config.beta2
self.resume_iters = config.resume_iters
self.warm_up_steps = config.warm_up_steps
# Test configurations.
self.num_test_epoch = config.num_test_epoch
self.test_iters = config.test_iters
self.inference_sample_num = config.inference_sample_num
# Directories.
self.log_dir = config.log_dir
self.sample_dir = config.sample_dir
self.model_save_dir = config.model_save_dir
self.result_dir = config.result_dir
# Step size.
self.log_step = config.log_sample_step
self.clipping_value = config.clipping_value
# Miscellaneous.
self.mode = config.mode
self.noise_strength_0 = torch.nn.Parameter(torch.zeros([]))
self.noise_strength_1 = torch.nn.Parameter(torch.zeros([]))
self.noise_strength_2 = torch.nn.Parameter(torch.zeros([]))
self.noise_strength_3 = torch.nn.Parameter(torch.zeros([]))
self.init_type = config.init_type
self.build_model()
def build_model(self):
"""Create generators and discriminators."""
''' Generator is based on Transformer Encoder:
@ g_conv_dim: Dimensions for first MLP layers before Transformer Encoder
@ vertexes: maximum length of generated molecules (atom length)
@ b_dim: number of bond types
@ m_dim: number of atom types (or number of features used)
@ dropout: dropout possibility
@ dim: Hidden dimension of Transformer Encoder
@ depth: Transformer layer number
@ heads: Number of multihead-attention heads
@ mlp_ratio: Read-out layer dimension of Transformer
@ drop_rate: depricated
@ tra_conv: Whether module creates output for TransformerConv discriminator
'''
self.G = Generator(self.z_dim,
self.act,
self.vertexes,
self.b_dim,
self.m_dim,
self.dropout,
dim=self.dim,
depth=self.depth,
heads=self.heads,
mlp_ratio=self.mlp_ratio,
submodel = self.submodel)
self.G2 = Generator2(self.dim,
self.dec_dim,
self.dec_depth,
self.dec_heads,
self.mlp_ratio,
self.dec_dropout,
self.drugs_m_dim,
self.drugs_b_dim,
self.submodel)
''' Discriminator implementation with PNA:
@ deg: Degree distribution based on used data. (Created with _genDegree() function)
@ agg: aggregators used in PNA
@ sca: scalers used in PNA
@ pna_in_ch: First PNA hidden dimension
@ pna_out_ch: Last PNA hidden dimension
@ edge_dim: Edge hidden dimension
@ towers: Number of towers (Splitting the hidden dimension to multiple parallel processes)
@ pre_lay: Pre-transformation layer
@ post_lay: Post-transformation layer
@ pna_layer_num: number of PNA layers
@ graph_add: global pooling layer selection
'''
''' Discriminator implementation with Graph Convolution:
@ d_conv_dim: convolution dimensions for GCN
@ m_dim: number of atom types (or number of features used)
@ b_dim: number of bond types
@ dropout: dropout possibility
'''
''' Discriminator implementation with MLP:
@ act: Activation function for MLP
@ m_dim: number of atom types (or number of features used)
@ b_dim: number of bond types
@ dropout: dropout possibility
@ vertexes: maximum length of generated molecules (molecule length)
'''
#self.D = Discriminator_old(self.d_conv_dim, self.m_dim , self.b_dim, self.dropout, self.gcn_depth)
self.D2 = simple_disc("tanh", self.drugs_m_dim, self.drug_vertexes, self.drugs_b_dim)
self.D = simple_disc("tanh", self.m_dim, self.vertexes, self.b_dim)
self.V = simple_disc("tanh", self.m_dim, self.vertexes, self.b_dim)
self.V2 = simple_disc("tanh", self.drugs_m_dim, self.drug_vertexes, self.drugs_b_dim)
''' Optimizers for G1, G2, D1, and D2:
Adam Optimizer is used and different beta1 and beta2s are used for GAN1 and GAN2
'''
self.g_optimizer = torch.optim.AdamW(self.G.parameters(), self.g_lr, [self.beta1, self.beta2])
self.g2_optimizer = torch.optim.AdamW(self.G2.parameters(), self.g2_lr, [self.beta1, self.beta2])
self.d_optimizer = torch.optim.AdamW(self.D.parameters(), self.d_lr, [self.beta1, self.beta2])
self.d2_optimizer = torch.optim.AdamW(self.D2.parameters(), self.d2_lr, [self.beta1, self.beta2])
self.v_optimizer = torch.optim.AdamW(self.V.parameters(), self.d_lr, [self.beta1, self.beta2])
self.v2_optimizer = torch.optim.AdamW(self.V2.parameters(), self.d2_lr, [self.beta1, self.beta2])
''' Learning rate scheduler:
Changes learning rate based on loss.
'''
#self.scheduler_g = ReduceLROnPlateau(self.g_optimizer, mode='min', factor=0.5, patience=10, min_lr=0.00001)
#self.scheduler_d = ReduceLROnPlateau(self.d_optimizer, mode='min', factor=0.5, patience=10, min_lr=0.00001)
#self.scheduler_v = ReduceLROnPlateau(self.v_optimizer, mode='min', factor=0.5, patience=10, min_lr=0.00001)
#self.scheduler_g2 = ReduceLROnPlateau(self.g2_optimizer, mode='min', factor=0.5, patience=10, min_lr=0.00001)
#self.scheduler_d2 = ReduceLROnPlateau(self.d2_optimizer, mode='min', factor=0.5, patience=10, min_lr=0.00001)
#self.scheduler_v2 = ReduceLROnPlateau(self.v2_optimizer, mode='min', factor=0.5, patience=10, min_lr=0.00001)
self.print_network(self.G, 'G')
self.print_network(self.D, 'D')
self.print_network(self.G2, 'G2')
self.print_network(self.D2, 'D2')
self.G.to(self.device)
self.D.to(self.device)
self.V.to(self.device)
self.V2.to(self.device)
self.G2.to(self.device)
self.D2.to(self.device)
#self.V2.to(self.device)
#self.modules_of_the_model = (self.G, self.D, self.G2, self.D2)
"""for p in self.G.parameters():
if p.dim() > 1:
if self.init_type == 'uniform':
torch.nn.init.xavier_uniform_(p)
elif self.init_type == 'normal':
torch.nn.init.xavier_normal_(p)
elif self.init_type == 'random_normal':
torch.nn.init.normal_(p, 0.0, 0.02)
for p in self.G2.parameters():
if p.dim() > 1:
if self.init_type == 'uniform':
torch.nn.init.xavier_uniform_(p)
elif self.init_type == 'normal':
torch.nn.init.xavier_normal_(p)
elif self.init_type == 'random_normal':
torch.nn.init.normal_(p, 0.0, 0.02)
if self.dis_select == "conv":
for p in self.D.parameters():
if p.dim() > 1:
if self.init_type == 'uniform':
torch.nn.init.xavier_uniform_(p)
elif self.init_type == 'normal':
torch.nn.init.xavier_normal_(p)
elif self.init_type == 'random_normal':
torch.nn.init.normal_(p, 0.0, 0.02)
if self.dis_select == "conv":
for p in self.D2.parameters():
if p.dim() > 1:
if self.init_type == 'uniform':
torch.nn.init.xavier_uniform_(p)
elif self.init_type == 'normal':
torch.nn.init.xavier_normal_(p)
elif self.init_type == 'random_normal':
torch.nn.init.normal_(p, 0.0, 0.02)"""
def decoder_load(self, dictionary_name):
''' Loading the atom and bond decoders'''
with open("DrugGEN/data/decoders/" + dictionary_name + "_" + self.dataset_name + '.pkl', 'rb') as f:
return pickle.load(f)
def drug_decoder_load(self, dictionary_name):
''' Loading the atom and bond decoders'''
with open("DrugGEN/data/decoders/" + dictionary_name +"_" + self.drugs_name +'.pkl', 'rb') as f:
return pickle.load(f)
def print_network(self, model, name):
"""Print out the network information."""
num_params = 0
for p in model.parameters():
num_params += p.numel()
print(model)
print(name)
print("The number of parameters: {}".format(num_params))
def restore_model(self, epoch, iteration, model_directory):
"""Restore the trained generator and discriminator."""
print('Loading the trained models from epoch / iteration {}-{}...'.format(epoch, iteration))
G_path = os.path.join(model_directory, '{}-{}-G.ckpt'.format(epoch, iteration))
#D_path = os.path.join(model_directory, '{}-{}-D.ckpt'.format(epoch, iteration))
self.G.load_state_dict(torch.load(G_path, map_location=lambda storage, loc: storage))
#self.D.load_state_dict(torch.load(D_path, map_location=lambda storage, loc: storage))
G2_path = os.path.join(model_directory, '{}-{}-G2.ckpt'.format(epoch, iteration))
#D2_path = os.path.join(model_directory, '{}-{}-D2.ckpt'.format(epoch, iteration))
self.G2.load_state_dict(torch.load(G2_path, map_location=lambda storage, loc: storage))
#self.D2.load_state_dict(torch.load(D2_path, map_location=lambda storage, loc: storage))
def save_model(self, model_directory, idx,i):
G_path = os.path.join(model_directory, '{}-{}-G.ckpt'.format(idx+1,i+1))
D_path = os.path.join(model_directory, '{}-{}-D.ckpt'.format(idx+1,i+1))
torch.save(self.G.state_dict(), G_path)
torch.save(self.D.state_dict(), D_path)
if self.submodel != "NoTarget" and self.submodel != "CrossLoss":
G2_path = os.path.join(model_directory, '{}-{}-G2.ckpt'.format(idx+1,i+1))
D2_path = os.path.join(model_directory, '{}-{}-D2.ckpt'.format(idx+1,i+1))
torch.save(self.G2.state_dict(), G2_path)
torch.save(self.D2.state_dict(), D2_path)
def reset_grad(self):
"""Reset the gradient buffers."""
self.g_optimizer.zero_grad()
self.v_optimizer.zero_grad()
self.g2_optimizer.zero_grad()
self.v2_optimizer.zero_grad()
self.d_optimizer.zero_grad()
self.d2_optimizer.zero_grad()
def gradient_penalty(self, y, x):
"""Compute gradient penalty: (L2_norm(dy/dx) - 1)**2."""
weight = torch.ones(y.size(),requires_grad=False).to(self.device)
dydx = torch.autograd.grad(outputs=y,
inputs=x,
grad_outputs=weight,
retain_graph=True,
create_graph=True,
only_inputs=True)[0]
dydx = dydx.view(dydx.size(0), -1)
gradient_penalty = ((dydx.norm(2, dim=1) - 1) ** 2).mean()
return gradient_penalty
def train(self):
''' Training Script starts from here'''
#wandb.config = {'beta2': 0.999}
#wandb.init(project="DrugGEN2", entity="atabeyunlu")
# Defining sampling paths and creating logger
self.arguments = "{}_glr{}_dlr{}_g2lr{}_d2lr{}_dim{}_depth{}_heads{}_decdepth{}_decheads{}_ncritic{}_batch{}_epoch{}_warmup{}_dataset{}_dropout{}".format(self.submodel,self.g_lr,self.d_lr,self.g2_lr,self.d2_lr,self.dim,self.depth,self.heads,self.dec_depth,self.dec_heads,self.n_critic,self.batch_size,self.epoch,self.warm_up_steps,self.dataset_name,self.dropout)
self.model_directory= os.path.join(self.model_save_dir,self.arguments)
self.sample_directory=os.path.join(self.sample_dir,self.arguments)
self.log_path = os.path.join(self.log_dir, "{}.txt".format(self.arguments))
if not os.path.exists(self.model_directory):
os.makedirs(self.model_directory)
if not os.path.exists(self.sample_directory):
os.makedirs(self.sample_directory)
# Learning rate cache for decaying.
# protein data
full_smiles = [line for line in open("DrugGEN/data/chembl_train.smi", 'r').read().splitlines()]
drug_smiles = [line for line in open("DrugGEN/data/akt_train.smi", 'r').read().splitlines()]
drug_mols = [Chem.MolFromSmiles(smi) for smi in drug_smiles]
drug_scaf = [MurckoScaffold.GetScaffoldForMol(x) for x in drug_mols]
fps_r = [Chem.RDKFingerprint(x) for x in drug_scaf]
akt1_human_adj = torch.load("DrugGEN/data/akt/AKT1_human_adj.pt").reshape(1,-1).to(self.device).float()
akt1_human_annot = torch.load("DrugGEN/data/akt/AKT1_human_annot.pt").reshape(1,-1).to(self.device).float()
# Start training.
print('Start training...')
self.start_time = time.time()
for idx in range(self.epoch):
# =================================================================================== #
# 1. Preprocess input data #
# =================================================================================== #
# Load the data
dataloader_iterator = iter(self.drugs_loader)
for i, data in enumerate(self.loader):
try:
drugs = next(dataloader_iterator)
except StopIteration:
dataloader_iterator = iter(self.drugs_loader)
drugs = next(dataloader_iterator)
# Preprocess both dataset
bulk_data = load_data(data,
drugs,
self.batch_size,
self.device,
self.b_dim,
self.m_dim,
self.drugs_b_dim,
self.drugs_m_dim,
self.z_dim,
self.vertexes)
drug_graphs, real_graphs, a_tensor, x_tensor, drugs_a_tensor, drugs_x_tensor, z, z_edge, z_node = bulk_data
if self.submodel == "CrossLoss":
GAN1_input_e = drugs_a_tensor
GAN1_input_x = drugs_x_tensor
GAN1_disc_e = a_tensor
GAN1_disc_x = x_tensor
elif self.submodel == "Ligand":
GAN1_input_e = a_tensor
GAN1_input_x = x_tensor
GAN1_disc_e = a_tensor
GAN1_disc_x = x_tensor
GAN2_input_e = drugs_a_tensor
GAN2_input_x = drugs_x_tensor
GAN2_disc_e = drugs_a_tensor
GAN2_disc_x = drugs_x_tensor
elif self.submodel == "Prot":
GAN1_input_e = a_tensor
GAN1_input_x = x_tensor
GAN1_disc_e = a_tensor
GAN1_disc_x = x_tensor
GAN2_input_e = akt1_human_adj
GAN2_input_x = akt1_human_annot
GAN2_disc_e = drugs_a_tensor
GAN2_disc_x = drugs_x_tensor
elif self.submodel == "RL":
GAN1_input_e = z_edge
GAN1_input_x = z_node
GAN1_disc_e = a_tensor
GAN1_disc_x = x_tensor
GAN2_input_e = drugs_a_tensor
GAN2_input_x = drugs_x_tensor
GAN2_disc_e = drugs_a_tensor
GAN2_disc_x = drugs_x_tensor
elif self.submodel == "NoTarget":
GAN1_input_e = z_edge
GAN1_input_x = z_node
GAN1_disc_e = a_tensor
GAN1_disc_x = x_tensor
# =================================================================================== #
# 2. Train the discriminator #
# =================================================================================== #
loss = {}
self.reset_grad()
# Compute discriminator loss.
node, edge, d_loss = discriminator_loss(self.G,
self.D,
real_graphs,
GAN1_disc_e,
GAN1_disc_x,
self.batch_size,
self.device,
self.gradient_penalty,
self.lambda_gp,
GAN1_input_e,
GAN1_input_x)
d_total = d_loss
if self.submodel != "NoTarget" and self.submodel != "CrossLoss":
d2_loss = discriminator2_loss(self.G2,
self.D2,
drug_graphs,
edge,
node,
self.batch_size,
self.device,
self.gradient_penalty,
self.lambda_gp,
GAN2_input_e,
GAN2_input_x)
d_total = d_loss + d2_loss
loss["d_total"] = d_total.item()
d_total.backward()
self.d_optimizer.step()
if self.submodel != "NoTarget" and self.submodel != "CrossLoss":
self.d2_optimizer.step()
self.reset_grad()
generator_output = generator_loss(self.G,
self.D,
self.V,
GAN1_input_e,
GAN1_input_x,
self.batch_size,
sim_reward,
self.dataset.matrices2mol_drugs,
fps_r,
self.submodel)
g_loss, fake_mol, g_edges_hat_sample, g_nodes_hat_sample, node, edge = generator_output
self.reset_grad()
g_total = g_loss
if self.submodel != "NoTarget" and self.submodel != "CrossLoss":
output = generator2_loss(self.G2,
self.D2,
self.V2,
edge,
node,
self.batch_size,
sim_reward,
self.dataset.matrices2mol_drugs,
fps_r,
GAN2_input_e,
GAN2_input_x,
self.submodel)
g2_loss, fake_mol_g, dr_g_edges_hat_sample, dr_g_nodes_hat_sample = output
g_total = g_loss + g2_loss
loss["g_total"] = g_total.item()
g_total.backward()
self.g_optimizer.step()
if self.submodel != "NoTarget" and self.submodel != "CrossLoss":
self.g2_optimizer.step()
if self.submodel == "RL":
self.v_optimizer.step()
self.v2_optimizer.step()
if (i+1) % self.log_step == 0:
logging(self.log_path, self.start_time, fake_mol, full_smiles, i, idx, loss, 1,self.sample_directory)
mol_sample(self.sample_directory,"GAN1",fake_mol, g_edges_hat_sample.detach(), g_nodes_hat_sample.detach(), idx, i)
if self.submodel != "NoTarget" and self.submodel != "CrossLoss":
logging(self.log_path, self.start_time, fake_mol_g, drug_smiles, i, idx, loss, 2,self.sample_directory)
mol_sample(self.sample_directory,"GAN2",fake_mol_g, dr_g_edges_hat_sample.detach(), dr_g_nodes_hat_sample.detach(), idx, i)
if (idx+1) % 10 == 0:
self.save_model(self.model_directory,idx,i)
print("model saved at epoch {} and iteration {}".format(idx,i))
def inference(self):
# Load the trained generator.
self.G.to(self.device)
#self.D.to(self.device)
self.G2.to(self.device)
#self.D2.to(self.device)
G_path = os.path.join(self.inference_model, '{}-G.ckpt'.format(self.submodel))
self.G.load_state_dict(torch.load(G_path, map_location=lambda storage, loc: storage))
G2_path = os.path.join(self.inference_model, '{}-G2.ckpt'.format(self.submodel))
self.G2.load_state_dict(torch.load(G2_path, map_location=lambda storage, loc: storage))
drug_smiles = [line for line in open("DrugGEN/data/akt_test.smi", 'r').read().splitlines()]
drug_mols = [Chem.MolFromSmiles(smi) for smi in drug_smiles]
drug_scaf = [MurckoScaffold.GetScaffoldForMol(x) for x in drug_mols]
fps_r = [Chem.RDKFingerprint(x) for x in drug_scaf]
akt1_human_adj = torch.load("DrugGEN/data/akt/AKT1_human_adj.pt").reshape(1,-1).to(self.device).float()
akt1_human_annot = torch.load("DrugGEN/data/akt/AKT1_human_annot.pt").reshape(1,-1).to(self.device).float()
self.G.eval()
#self.D.eval()
self.G2.eval()
#self.D2.eval()
self.inf_batch_size =256
self.inf_dataset = DruggenDataset(self.mol_data_dir,
self.inf_dataset_file,
self.inf_raw_file,
self.max_atom,
self.features) # Dataset for the first GAN. Custom dataset class from PyG parent class.
# Can create any molecular graph dataset given smiles string.
# Nonisomeric SMILES are suggested but not necessary.
# Uses sparse matrix representation for graphs,
# For computational and speed efficiency.
self.inf_loader = DataLoader(self.inf_dataset,
shuffle=True,
batch_size=self.inf_batch_size,
drop_last=True) # PyG dataloader for the first GAN.
self.inf_drugs = DruggenDataset(self.drug_data_dir,
self.inf_drugs_dataset_file,
self.inf_drug_raw_file,
self.max_atom,
self.features) # Dataset for the second GAN. Custom dataset class from PyG parent class.
# Can create any molecular graph dataset given smiles string.
# Nonisomeric SMILES are suggested but not necessary.
# Uses sparse matrix representation for graphs,
# For computational and speed efficiency.
self.inf_drugs_loader = DataLoader(self.inf_drugs,
shuffle=True,
batch_size=self.inf_batch_size,
drop_last=True) # PyG dataloader for the second GAN.
start_time = time.time()
#metric_calc_mol = []
metric_calc_dr = []
date = time.time()
if not os.path.exists("DrugGEN/experiments/inference/{}".format(self.submodel)):
os.makedirs("DrugGEN/experiments/inference/{}".format(self.submodel))
with torch.inference_mode():
dataloader_iterator = iter(self.drugs_loader)
for i, data in enumerate(self.loader):
try:
drugs = next(dataloader_iterator)
except StopIteration:
dataloader_iterator = iter(self.drugs_loader)
drugs = next(dataloader_iterator)
# Preprocess both dataset
bulk_data = load_data(data,
drugs,
self.batch_size,
self.device,
self.b_dim,
self.m_dim,
self.drugs_b_dim,
self.drugs_m_dim,
self.z_dim,
self.vertexes)
drug_graphs, real_graphs, a_tensor, x_tensor, drugs_a_tensor, drugs_x_tensor, z, z_edge, z_node = bulk_data
if self.submodel == "CrossLoss":
GAN1_input_e = a_tensor
GAN1_input_x = x_tensor
GAN1_disc_e = drugs_a_tensor
GAN1_disc_x = drugs_x_tensor
GAN2_input_e = drugs_a_tensor
GAN2_input_x = drugs_x_tensor
GAN2_disc_e = a_tensor
GAN2_disc_x = x_tensor
elif self.submodel == "Ligand":
GAN1_input_e = a_tensor
GAN1_input_x = x_tensor
GAN1_disc_e = a_tensor
GAN1_disc_x = x_tensor
GAN2_input_e = drugs_a_tensor
GAN2_input_x = drugs_x_tensor
GAN2_disc_e = drugs_a_tensor
GAN2_disc_x = drugs_x_tensor
elif self.submodel == "Prot":
GAN1_input_e = a_tensor
GAN1_input_x = x_tensor
GAN1_disc_e = a_tensor
GAN1_disc_x = x_tensor
GAN2_input_e = akt1_human_adj
GAN2_input_x = akt1_human_annot
GAN2_disc_e = drugs_a_tensor
GAN2_disc_x = drugs_x_tensor
elif self.submodel == "RL":
GAN1_input_e = z_edge
GAN1_input_x = z_node
GAN1_disc_e = a_tensor
GAN1_disc_x = x_tensor
GAN2_input_e = drugs_a_tensor
GAN2_input_x = drugs_x_tensor
GAN2_disc_e = drugs_a_tensor
GAN2_disc_x = drugs_x_tensor
elif self.submodel == "NoTarget":
GAN1_input_e = z_edge
GAN1_input_x = z_node
GAN1_disc_e = a_tensor
GAN1_disc_x = x_tensor
# =================================================================================== #
# 2. GAN1 Inference #
# =================================================================================== #
generator_output = generator_loss(self.G,
self.D,
self.V,
GAN1_input_e,
GAN1_input_x,
self.batch_size,
sim_reward,
self.dataset.matrices2mol_drugs,
fps_r,
self.submodel)
_, fake_mol, _, _, node, edge = generator_output
# =================================================================================== #
# 3. GAN2 Inference #
# =================================================================================== #
output = generator2_loss(self.G2,
self.D2,
self.V2,
edge,
node,
self.batch_size,
sim_reward,
self.dataset.matrices2mol_drugs,
fps_r,
GAN2_input_e,
GAN2_input_x,
self.submodel)
_, fake_mol_g, _, _ = output
inference_drugs = [Chem.MolToSmiles(line) for line in fake_mol_g if line is not None]
#inference_smiles = [Chem.MolToSmiles(line) for line in fake_mol]
print("molecule batch {} inferred".format(i))
with open("DrugGEN/experiments/inference/{}/inference_drugs.txt".format(self.submodel), "a") as f:
for molecules in inference_drugs:
f.write(molecules)
f.write("\n")
metric_calc_dr.append(molecules)
if i == 120:
break
et = time.time() - start_time
print("Inference mode is lasted for {:.2f} seconds".format(et))
print("Metrics calculation started using MOSES.")
print("Validity: ", fraction_valid(inference_drugs), "\n")
print("Uniqueness: ", fraction_unique(inference_drugs), "\n")
print("Validity: ", novelty(inference_drugs, drug_smiles), "\n")
print("Metrics are calculated.")