File size: 2,541 Bytes
02e480f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 |
# This code uses both encoder and decoder losses.
#
#
# Deep learning
import torch
from torch_optimizer.lamb import Lamb
from trainer import TrainerEncoderDecoder
# Parallel
from torch.utils.data.distributed import DistributedSampler
from torch.distributed import init_process_group, destroy_process_group
# Data
from utils import MoleculeModule, get_optim_groups
from torch.utils.data import DataLoader
# Standard library
import os
import args
def ddp_setup():
init_process_group(backend="nccl")
torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))
def load_train_objs(config):
# load data
train_loader = MoleculeModule(
config.max_len,
config.train_load,
config.data_root
)
train_loader.setup()
loader = DataLoader(
train_loader.pubchem,
batch_size=config.n_batch,
pin_memory=True,
shuffle=False,
collate_fn=train_loader.text_encoder.process,
sampler=DistributedSampler(train_loader.pubchem),
num_workers=config.n_workers
)
# load model
if config.smi_ted_version == 'v1':
from smi_ted_light.load import Smi_ted
elif config.smi_ted_version == 'v2':
from smi_ted_large.load import Smi_ted
model = Smi_ted(config, train_loader.get_vocab())
model.apply(model._init_weights)
# load optimizer
optim_groupsE = get_optim_groups(model.encoder)
optim_groupsD = get_optim_groups(model.decoder)
optimizerE = Lamb(optim_groupsE, lr=config.lr_start*config.lr_multiplier, betas=(0.9, 0.99))
optimizerD = torch.optim.Adam(optim_groupsD, lr=config.lr_decoder, betas=(0.9, 0.99))
return loader, model, (optimizerE, optimizerD)
def main(
config,
save_every: int,
total_epochs: int,
save_checkpoint_path: str,
load_checkpoint_path: str
):
ddp_setup()
# training objects
train_data, model, optimizers = load_train_objs(config)
# init trainer
trainer = TrainerEncoderDecoder(
model,
train_data,
optimizers,
save_every,
save_checkpoint_path,
load_checkpoint_path,
config
)
trainer.train(total_epochs)
destroy_process_group()
if __name__ == '__main__':
parser = args.get_parser()
args = parser.parse_args()
main(
args,
args.checkpoint_every,
args.max_epochs,
save_checkpoint_path=args.save_checkpoint_path,
load_checkpoint_path=args.load_checkpoint_path,
)
|