File size: 2,541 Bytes
02e480f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
# This code uses both encoder and decoder losses.
#
#

# Deep learning
import torch
from torch_optimizer.lamb import Lamb
from trainer import TrainerEncoderDecoder

# Parallel
from torch.utils.data.distributed import DistributedSampler
from torch.distributed import init_process_group, destroy_process_group

# Data
from utils import MoleculeModule, get_optim_groups
from torch.utils.data import DataLoader

# Standard library
import os
import args


def ddp_setup():
    init_process_group(backend="nccl")
    torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))


def load_train_objs(config):
    # load data
    train_loader = MoleculeModule(
        config.max_len, 
        config.train_load, 
        config.data_root
    )
    train_loader.setup()
    
    loader = DataLoader(
        train_loader.pubchem,
        batch_size=config.n_batch,
        pin_memory=True,
        shuffle=False,
        collate_fn=train_loader.text_encoder.process, 
        sampler=DistributedSampler(train_loader.pubchem),
        num_workers=config.n_workers
    )

    # load model
    if config.smi_ted_version == 'v1':
        from smi_ted_light.load import Smi_ted
    elif config.smi_ted_version == 'v2':
        from smi_ted_large.load import Smi_ted
    
    model = Smi_ted(config, train_loader.get_vocab())
    model.apply(model._init_weights)

    # load optimizer
    optim_groupsE = get_optim_groups(model.encoder)
    optim_groupsD = get_optim_groups(model.decoder)
    optimizerE = Lamb(optim_groupsE, lr=config.lr_start*config.lr_multiplier, betas=(0.9, 0.99))
    optimizerD = torch.optim.Adam(optim_groupsD, lr=config.lr_decoder, betas=(0.9, 0.99))
    
    return loader, model, (optimizerE, optimizerD)


def main(
        config, 
        save_every: int, 
        total_epochs: int, 
        save_checkpoint_path: str,
        load_checkpoint_path: str
    ):
    ddp_setup()

    # training objects
    train_data, model, optimizers = load_train_objs(config)

    # init trainer
    trainer = TrainerEncoderDecoder(
        model, 
        train_data, 
        optimizers, 
        save_every, 
        save_checkpoint_path,
        load_checkpoint_path, 
        config
    )
    trainer.train(total_epochs)
    destroy_process_group()


if __name__ == '__main__':
    parser = args.get_parser()
    args = parser.parse_args()
    main(
        args, 
        args.checkpoint_every, 
        args.max_epochs, 
        save_checkpoint_path=args.save_checkpoint_path,
        load_checkpoint_path=args.load_checkpoint_path, 
    )