File size: 5,060 Bytes
9123ba9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 |
"Adapted from https://github.com/SongweiGe/TATS"
from __future__ import absolute_import, division, print_function, annotations
import os
import socket
from typing import Optional, Union, Callable
import sys
sys.path.insert(1, "../diffusion-chem-fm")
import hydra
from omegaconf import DictConfig, open_dict
from train.get_dataset import get_dataset
from vq_gan_3d.model import VQGAN
from trainer import Trainer
# Torch
import torch
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from torch.distributed import init_process_group, destroy_process_group
# Get MPI:
try:
from mpi4py import MPI
WITH_DDP = True
LOCAL_RANK = os.environ.get('OMPI_COMM_WORLD_LOCAL_RANK', '0')
# LOCAL_RANK = os.environ['OMPI_COMM_WORLD_LOCAL_RANK']
SIZE = MPI.COMM_WORLD.Get_size()
RANK = MPI.COMM_WORLD.Get_rank()
WITH_CUDA = torch.cuda.is_available()
DEVICE = 'gpu' if WITH_CUDA else 'CPU'
# pytorch will look for these
os.environ['RANK'] = str(RANK)
os.environ['WORLD_SIZE'] = str(SIZE)
# -----------------------------------------------------------
# NOTE: Get the hostname of the master node, and broadcast
# it to all other nodes It will want the master address too,
# which we'll broadcast:
# -----------------------------------------------------------
MASTER_ADDR = socket.gethostname() if RANK == 0 else None
MASTER_ADDR = MPI.COMM_WORLD.bcast(MASTER_ADDR, root=0)
os.environ['MASTER_ADDR'] = MASTER_ADDR
os.environ['MASTER_PORT'] = str(2345)
except (ImportError, ModuleNotFoundError) as e:
WITH_DDP = False
SIZE = 1
RANK = 0
LOCAL_RANK = 0
MASTER_ADDR = 'localhost'
def ddp_setup(
rank: Union[int, str],
world_size: Union[int, str],
backend: Optional[str] = None,
) -> None:
if WITH_CUDA:
backend = 'nccl' if backend is None else str(backend)
else:
backend = 'gloo' if backend is None else str(backend)
init_process_group(
backend,
rank=int(rank),
world_size=int(world_size),
init_method='env://',
)
torch.cuda.set_device(int(LOCAL_RANK))
def load_train_objs(cfg):
# datasets
train_dataset, val_dataset, _ = get_dataset(cfg)
train_dataloader = DataLoader(
train_dataset,
batch_size=cfg.model.batch_size,
num_workers=cfg.model.num_workers,
sampler=DistributedSampler(
train_dataset,
num_replicas=SIZE,
rank=RANK,
),
shuffle=False,
pin_memory=True
)
val_dataloader = None
# model
model = VQGAN(cfg)
# optimizers
opt_ae = torch.optim.Adam(list(model.encoder.parameters()) +
list(model.decoder.parameters()) +
list(model.pre_vq_conv.parameters()) +
list(model.post_vq_conv.parameters()) +
list(model.codebook.parameters()) +
list(model.perceptual_model.parameters()),
lr=cfg.model.lr, betas=(0.5, 0.9))
opt_disc = torch.optim.Adam(list(model.image_discriminator.parameters()),
lr=cfg.model.lr, betas=(0.5, 0.9))
return [train_dataloader, val_dataloader], model, [opt_ae, opt_disc]
def run(cfg, save_every: int, total_epochs: int, save_checkpoint_path: str, load_checkpoint_path: str):
ddp_setup(RANK, SIZE, backend='nccl')
print('### DDP Info ###')
print('WORLD SIZE:', SIZE)
print('LOCAL_RANK:', LOCAL_RANK)
print('RANK:', RANK)
print('')
datasets, model, optimizers = load_train_objs(cfg)
trainer = Trainer(
model,
datasets,
optimizers,
save_every,
save_checkpoint_path,
load_checkpoint_path,
cfg
)
trainer.train(total_epochs)
destroy_process_group()
@hydra.main(config_path='../config', config_name='base_cfg', version_base=None)
def main(cfg: DictConfig):
# automatically adjust learning rate
bs, base_lr, ngpu, accumulate = cfg.model.batch_size, cfg.model.lr, cfg.model.gpus, cfg.model.accumulate_grad_batches
with open_dict(cfg):
cfg.model.lr = accumulate * (ngpu/8.) * (bs/4.) * base_lr
cfg.model.default_root_dir = os.path.join(cfg.model.default_root_dir, cfg.dataset.name, cfg.model.default_root_dir_postfix)
print("Setting learning rate to {:.2e} = {} (accumulate_grad_batches) * {} (num_gpus/8) * {} (batchsize/4) * {:.2e} (base_lr)".format(cfg.model.lr, accumulate, ngpu/8, bs/4, base_lr))
if not os.path.isdir(cfg.model.save_checkpoint_path):
os.mkdir(cfg.model.save_checkpoint_path)
run(
cfg,
save_every=cfg.model.checkpoint_every,
total_epochs=cfg.model.max_epochs,
save_checkpoint_path=cfg.model.save_checkpoint_path,
load_checkpoint_path=cfg.model.resume_from_checkpoint
)
if __name__ == '__main__':
main()
|