DrugGEN / train.py
gyigit's picture
refactor
4c9e6d9
import os
import time
import random
import pickle
import argparse
import os.path as osp
import torch
import torch.utils.data
from torch import nn
from torch_geometric.loader import DataLoader
import wandb
from rdkit import RDLogger
torch.set_num_threads(5)
RDLogger.DisableLog('rdApp.*')
from src.util.utils import *
from src.model.models import Generator, Discriminator, simple_disc
from src.data.dataset import DruggenDataset
from src.data.utils import get_encoders_decoders, load_molecules
from src.model.loss import discriminator_loss, generator_loss
class Train(object):
"""Trainer for DrugGEN."""
def __init__(self, config):
if config.set_seed:
np.random.seed(config.seed)
random.seed(config.seed)
torch.manual_seed(config.seed)
torch.cuda.manual_seed_all(config.seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
os.environ["PYTHONHASHSEED"] = str(config.seed)
print(f'Using seed {config.seed}')
self.device = torch.device("cuda" if torch.cuda.is_available() else 'cpu')
# Initialize configurations
self.submodel = config.submodel
# Data loader.
self.raw_file = config.raw_file # SMILES containing text file for dataset.
# Write the full path to file.
self.drug_raw_file = config.drug_raw_file # SMILES containing text file for second dataset.
# Write the full path to file.
# Automatically infer dataset file names from raw file names
raw_file_basename = osp.basename(self.raw_file)
drug_raw_file_basename = osp.basename(self.drug_raw_file)
# Get the base name without extension and add max_atom to it
self.max_atom = config.max_atom # Model is based on one-shot generation.
raw_file_base = os.path.splitext(raw_file_basename)[0]
drug_raw_file_base = os.path.splitext(drug_raw_file_basename)[0]
# Change extension from .smi to .pt and add max_atom to the filename
self.dataset_file = f"{raw_file_base}{self.max_atom}.pt"
self.drugs_dataset_file = f"{drug_raw_file_base}{self.max_atom}.pt"
self.mol_data_dir = config.mol_data_dir # Directory where the dataset files are stored.
self.drug_data_dir = config.drug_data_dir # Directory where the drug dataset files are stored.
self.dataset_name = self.dataset_file.split(".")[0]
self.drugs_dataset_name = self.drugs_dataset_file.split(".")[0]
self.features = config.features # Small model uses atom types as node features. (Boolean, False uses atom types only.)
# Additional node features can be added. Please check new_dataloarder.py Line 102.
self.batch_size = config.batch_size # Batch size for training.
self.parallel = config.parallel
# Get atom and bond encoders/decoders
atom_encoder, atom_decoder, bond_encoder, bond_decoder = get_encoders_decoders(
self.raw_file,
self.drug_raw_file,
self.max_atom
)
self.atom_encoder = atom_encoder
self.atom_decoder = atom_decoder
self.bond_encoder = bond_encoder
self.bond_decoder = bond_decoder
self.dataset = DruggenDataset(self.mol_data_dir,
self.dataset_file,
self.raw_file,
self.max_atom,
self.features,
atom_encoder=atom_encoder,
atom_decoder=atom_decoder,
bond_encoder=bond_encoder,
bond_decoder=bond_decoder)
self.loader = DataLoader(self.dataset,
shuffle=True,
batch_size=self.batch_size,
drop_last=True) # PyG dataloader for the GAN.
self.drugs = DruggenDataset(self.drug_data_dir,
self.drugs_dataset_file,
self.drug_raw_file,
self.max_atom,
self.features,
atom_encoder=atom_encoder,
atom_decoder=atom_decoder,
bond_encoder=bond_encoder,
bond_decoder=bond_decoder)
self.drugs_loader = DataLoader(self.drugs,
shuffle=True,
batch_size=self.batch_size,
drop_last=True) # PyG dataloader for the second GAN.
self.m_dim = len(self.atom_decoder) if not self.features else int(self.loader.dataset[0].x.shape[1]) # Atom type dimension.
self.b_dim = len(self.bond_decoder) # Bond type dimension.
self.vertexes = int(self.loader.dataset[0].x.shape[0]) # Number of nodes in the graph.
# Model configurations.
self.act = config.act
self.lambda_gp = config.lambda_gp
self.dim = config.dim
self.depth = config.depth
self.heads = config.heads
self.mlp_ratio = config.mlp_ratio
self.ddepth = config.ddepth
self.ddropout = config.ddropout
# Training configurations.
self.epoch = config.epoch
self.g_lr = config.g_lr
self.d_lr = config.d_lr
self.dropout = config.dropout
self.beta1 = config.beta1
self.beta2 = config.beta2
# Directories.
self.log_dir = config.log_dir
self.sample_dir = config.sample_dir
self.model_save_dir = config.model_save_dir
# Step size.
self.log_step = config.log_sample_step
# resume training
self.resume = config.resume
self.resume_epoch = config.resume_epoch
self.resume_iter = config.resume_iter
self.resume_directory = config.resume_directory
# wandb configuration
self.use_wandb = config.use_wandb
self.online = config.online
self.exp_name = config.exp_name
# Arguments for the model.
self.arguments = "{}_{}_glr{}_dlr{}_dim{}_depth{}_heads{}_batch{}_epoch{}_dataset{}_dropout{}".format(self.exp_name, self.submodel, self.g_lr, self.d_lr, self.dim, self.depth, self.heads, self.batch_size, self.epoch, self.dataset_name, self.dropout)
self.build_model(self.model_save_dir, self.arguments)
def build_model(self, model_save_dir, arguments):
"""Create generators and discriminators."""
''' Generator is based on Transformer Encoder:
@ g_conv_dim: Dimensions for MLP layers before Transformer Encoder
@ vertexes: maximum length of generated molecules (atom length)
@ b_dim: number of bond types
@ m_dim: number of atom types (or number of features used)
@ dropout: dropout possibility
@ dim: Hidden dimension of Transformer Encoder
@ depth: Transformer layer number
@ heads: Number of multihead-attention heads
@ mlp_ratio: Read-out layer dimension of Transformer
@ drop_rate: depricated
@ tra_conv: Whether module creates output for TransformerConv discriminator
'''
self.G = Generator(self.act,
self.vertexes,
self.b_dim,
self.m_dim,
self.dropout,
dim=self.dim,
depth=self.depth,
heads=self.heads,
mlp_ratio=self.mlp_ratio)
''' Discriminator implementation with Transformer Encoder:
@ act: Activation function for MLP
@ vertexes: maximum length of generated molecules (molecule length)
@ b_dim: number of bond types
@ m_dim: number of atom types (or number of features used)
@ dropout: dropout possibility
@ dim: Hidden dimension of Transformer Encoder
@ depth: Transformer layer number
@ heads: Number of multihead-attention heads
@ mlp_ratio: Read-out layer dimension of Transformer'''
self.D = Discriminator(self.act,
self.vertexes,
self.b_dim,
self.m_dim,
self.ddropout,
dim=self.dim,
depth=self.ddepth,
heads=self.heads,
mlp_ratio=self.mlp_ratio)
self.g_optimizer = torch.optim.AdamW(self.G.parameters(), self.g_lr, [self.beta1, self.beta2])
self.d_optimizer = torch.optim.AdamW(self.D.parameters(), self.d_lr, [self.beta1, self.beta2])
network_path = os.path.join(model_save_dir, arguments)
self.print_network(self.G, 'G', network_path)
self.print_network(self.D, 'D', network_path)
if self.parallel and torch.cuda.device_count() > 1:
print(f"Using {torch.cuda.device_count()} GPUs!")
self.G = nn.DataParallel(self.G)
self.D = nn.DataParallel(self.D)
self.G.to(self.device)
self.D.to(self.device)
def print_network(self, model, name, save_dir):
"""Print out the network information."""
num_params = 0
for p in model.parameters():
num_params += p.numel()
if not os.path.exists(save_dir):
os.makedirs(save_dir)
network_path = os.path.join(save_dir, "{}_modules.txt".format(name))
with open(network_path, "w+") as file:
for module in model.modules():
file.write(f"{module.__class__.__name__}:\n")
print(module.__class__.__name__)
for n, param in module.named_parameters():
if param is not None:
file.write(f" - {n}: {param.size()}\n")
print(f" - {n}: {param.size()}")
break
file.write(f"Total number of parameters: {num_params}\n")
print(f"Total number of parameters: {num_params}\n\n")
def restore_model(self, epoch, iteration, model_directory):
"""Restore the trained generator and discriminator."""
print('Loading the trained models from epoch / iteration {}-{}...'.format(epoch, iteration))
G_path = os.path.join(model_directory, '{}-{}-G.ckpt'.format(epoch, iteration))
D_path = os.path.join(model_directory, '{}-{}-D.ckpt'.format(epoch, iteration))
self.G.load_state_dict(torch.load(G_path, map_location=lambda storage, loc: storage))
self.D.load_state_dict(torch.load(D_path, map_location=lambda storage, loc: storage))
def save_model(self, model_directory, idx,i):
G_path = os.path.join(model_directory, '{}-{}-G.ckpt'.format(idx+1,i+1))
D_path = os.path.join(model_directory, '{}-{}-D.ckpt'.format(idx+1,i+1))
torch.save(self.G.state_dict(), G_path)
torch.save(self.D.state_dict(), D_path)
def reset_grad(self):
"""Reset the gradient buffers."""
self.g_optimizer.zero_grad()
self.d_optimizer.zero_grad()
def train(self, config):
''' Training Script starts from here'''
if self.use_wandb:
mode = 'online' if self.online else 'offline'
else:
mode = 'disabled'
kwargs = {'name': self.exp_name, 'project': 'druggen', 'config': config,
'settings': wandb.Settings(_disable_stats=True), 'reinit': True, 'mode': mode, 'save_code': True}
wandb.init(**kwargs)
wandb.save(os.path.join(self.model_save_dir, self.arguments, "G_modules.txt"))
wandb.save(os.path.join(self.model_save_dir, self.arguments, "D_modules.txt"))
self.model_directory = os.path.join(self.model_save_dir, self.arguments)
self.sample_directory = os.path.join(self.sample_dir, self.arguments)
self.log_path = os.path.join(self.log_dir, "{}.txt".format(self.arguments))
if not os.path.exists(self.model_directory):
os.makedirs(self.model_directory)
if not os.path.exists(self.sample_directory):
os.makedirs(self.sample_directory)
# smiles data for metrics calculation.
drug_smiles = [line for line in open(self.drug_raw_file, 'r').read().splitlines()]
drug_mols = [Chem.MolFromSmiles(smi) for smi in drug_smiles]
drug_vecs = [AllChem.GetMorganFingerprintAsBitVect(x, 2, nBits=1024) for x in drug_mols if x is not None]
if self.resume:
self.restore_model(self.resume_epoch, self.resume_iter, self.resume_directory)
# Start training.
print('Start training...')
self.start_time = time.time()
for idx in range(self.epoch):
# =================================================================================== #
# 1. Preprocess input data #
# =================================================================================== #
# Load the data
dataloader_iterator = iter(self.drugs_loader)
wandb.log({"epoch": idx})
for i, data in enumerate(self.loader):
try:
drugs = next(dataloader_iterator)
except StopIteration:
dataloader_iterator = iter(self.drugs_loader)
drugs = next(dataloader_iterator)
wandb.log({"iter": i})
# Preprocess both dataset
real_graphs, a_tensor, x_tensor = load_molecules(
data=data,
batch_size=self.batch_size,
device=self.device,
b_dim=self.b_dim,
m_dim=self.m_dim,
)
drug_graphs, drugs_a_tensor, drugs_x_tensor = load_molecules(
data=drugs,
batch_size=self.batch_size,
device=self.device,
b_dim=self.b_dim,
m_dim=self.m_dim,
)
# Training configuration.
GEN_node = x_tensor # Generator input node features (annotation matrix of real molecules)
GEN_edge = a_tensor # Generator input edge features (adjacency matrix of real molecules)
if self.submodel == "DrugGEN":
DISC_node = drugs_x_tensor # Discriminator input node features (annotation matrix of drug molecules)
DISC_edge = drugs_a_tensor # Discriminator input edge features (adjacency matrix of drug molecules)
elif self.submodel == "NoTarget":
DISC_node = x_tensor # Discriminator input node features (annotation matrix of real molecules)
DISC_edge = a_tensor # Discriminator input edge features (adjacency matrix of real molecules)
# =================================================================================== #
# 2. Train the GAN #
# =================================================================================== #
loss = {}
self.reset_grad()
# Compute discriminator loss.
node, edge, d_loss = discriminator_loss(self.G,
self.D,
DISC_edge,
DISC_node,
GEN_edge,
GEN_node,
self.batch_size,
self.device,
self.lambda_gp)
d_total = d_loss
wandb.log({"d_loss": d_total.item()})
loss["d_total"] = d_total.item()
d_total.backward()
self.d_optimizer.step()
self.reset_grad()
# Compute generator loss.
generator_output = generator_loss(self.G,
self.D,
GEN_edge,
GEN_node,
self.batch_size)
g_loss, node, edge, node_sample, edge_sample = generator_output
g_total = g_loss
wandb.log({"g_loss": g_total.item()})
loss["g_total"] = g_total.item()
g_total.backward()
self.g_optimizer.step()
# Logging.
if (i+1) % self.log_step == 0:
logging(self.log_path, self.start_time, i, idx, loss, self.sample_directory,
drug_smiles,edge_sample, node_sample, self.dataset.matrices2mol,
self.dataset_name, a_tensor, x_tensor, drug_vecs)
mol_sample(self.sample_directory, edge_sample.detach(), node_sample.detach(),
idx, i, self.dataset.matrices2mol, self.dataset_name)
print("samples saved at epoch {} and iteration {}".format(idx,i))
self.save_model(self.model_directory, idx, i)
print("model saved at epoch {} and iteration {}".format(idx,i))
if __name__ == '__main__':
parser = argparse.ArgumentParser()
# Data configuration.
parser.add_argument('--raw_file', type=str, required=True)
parser.add_argument('--drug_raw_file', type=str, required=False, help='Required for DrugGEN model, optional for NoTarget')
parser.add_argument('--drug_data_dir', type=str, default='data')
parser.add_argument('--mol_data_dir', type=str, default='data')
parser.add_argument('--features', action='store_true', help='features dimension for nodes')
# Model configuration.
parser.add_argument('--submodel', type=str, default="DrugGEN", help="Chose model subtype: DrugGEN, NoTarget", choices=['DrugGEN', 'NoTarget'])
parser.add_argument('--act', type=str, default="relu", help="Activation function for the model.", choices=['relu', 'tanh', 'leaky', 'sigmoid'])
parser.add_argument('--max_atom', type=int, default=45, help='Max atom number for molecules must be specified.')
parser.add_argument('--dim', type=int, default=128, help='Dimension of the Transformer Encoder model for the GAN.')
parser.add_argument('--depth', type=int, default=1, help='Depth of the Transformer model from the GAN.')
parser.add_argument('--ddepth', type=int, default=1, help='Depth of the Transformer model from the discriminator.')
parser.add_argument('--heads', type=int, default=8, help='Number of heads for the MultiHeadAttention module from the GAN.')
parser.add_argument('--mlp_ratio', type=int, default=3, help='MLP ratio for the Transformer.')
parser.add_argument('--dropout', type=float, default=0., help='dropout rate')
parser.add_argument('--ddropout', type=float, default=0., help='dropout rate for the discriminator')
parser.add_argument('--lambda_gp', type=float, default=10, help='Gradient penalty lambda multiplier for the GAN.')
# Training configuration.
parser.add_argument('--batch_size', type=int, default=128, help='Batch size for the training.')
parser.add_argument('--epoch', type=int, default=10, help='Epoch number for Training.')
parser.add_argument('--g_lr', type=float, default=0.00001, help='learning rate for G')
parser.add_argument('--d_lr', type=float, default=0.00001, help='learning rate for D')
parser.add_argument('--beta1', type=float, default=0.9, help='beta1 for Adam optimizer')
parser.add_argument('--beta2', type=float, default=0.999, help='beta2 for Adam optimizer')
parser.add_argument('--log_dir', type=str, default='experiments/logs')
parser.add_argument('--sample_dir', type=str, default='experiments/samples')
parser.add_argument('--model_save_dir', type=str, default='experiments/models')
parser.add_argument('--log_sample_step', type=int, default=1000, help='step size for sampling during training')
# Resume training.
parser.add_argument('--resume', type=bool, default=False, help='resume training')
parser.add_argument('--resume_epoch', type=int, default=None, help='resume training from this epoch')
parser.add_argument('--resume_iter', type=int, default=None, help='resume training from this step')
parser.add_argument('--resume_directory', type=str, default=None, help='load pretrained weights from this directory')
# Seed configuration.
parser.add_argument('--set_seed', action='store_true', help='set seed for reproducibility')
parser.add_argument('--seed', type=int, default=1, help='seed for reproducibility')
# wandb configuration.
parser.add_argument('--use_wandb', action='store_true', help='use wandb for logging')
parser.add_argument('--online', action='store_true', help='use wandb online')
parser.add_argument('--exp_name', type=str, default='druggen', help='experiment name')
parser.add_argument('--parallel', action='store_true', help='Parallelize training')
config = parser.parse_args()
# Check if drug_raw_file is provided when using DrugGEN model
if config.submodel == "DrugGEN" and not config.drug_raw_file:
parser.error("--drug_raw_file is required when using DrugGEN model")
# If using NoTarget model and drug_raw_file is not provided, use a dummy file
if config.submodel == "NoTarget" and not config.drug_raw_file:
config.drug_raw_file = "data/akt_train.smi" # Use a reference file for NoTarget model (AKT) (not used for training for ease of use and encoder/decoder's)
trainer = Train(config)
trainer.train(config)