vshirasuna's picture
Move code to 3dgrid_vqgan folder
a4c759f
"Adapted from https://github.com/SongweiGe/TATS"
from __future__ import absolute_import, division, print_function, annotations
import os
import socket
from typing import Optional, Union, Callable
import sys
sys.path.insert(1, "../diffusion-chem-fm")
import hydra
from omegaconf import DictConfig, open_dict
from train.get_dataset import get_dataset
from vq_gan_3d.model import VQGAN
from trainer import Trainer
# Torch
import torch
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from torch.distributed import init_process_group, destroy_process_group
# Get MPI:
try:
from mpi4py import MPI
WITH_DDP = True
LOCAL_RANK = os.environ.get('OMPI_COMM_WORLD_LOCAL_RANK', '0')
# LOCAL_RANK = os.environ['OMPI_COMM_WORLD_LOCAL_RANK']
SIZE = MPI.COMM_WORLD.Get_size()
RANK = MPI.COMM_WORLD.Get_rank()
WITH_CUDA = torch.cuda.is_available()
DEVICE = 'gpu' if WITH_CUDA else 'CPU'
# pytorch will look for these
os.environ['RANK'] = str(RANK)
os.environ['WORLD_SIZE'] = str(SIZE)
# -----------------------------------------------------------
# NOTE: Get the hostname of the master node, and broadcast
# it to all other nodes It will want the master address too,
# which we'll broadcast:
# -----------------------------------------------------------
MASTER_ADDR = socket.gethostname() if RANK == 0 else None
MASTER_ADDR = MPI.COMM_WORLD.bcast(MASTER_ADDR, root=0)
os.environ['MASTER_ADDR'] = MASTER_ADDR
os.environ['MASTER_PORT'] = str(2345)
except (ImportError, ModuleNotFoundError) as e:
WITH_DDP = False
SIZE = 1
RANK = 0
LOCAL_RANK = 0
MASTER_ADDR = 'localhost'
def ddp_setup(
rank: Union[int, str],
world_size: Union[int, str],
backend: Optional[str] = None,
) -> None:
if WITH_CUDA:
backend = 'nccl' if backend is None else str(backend)
else:
backend = 'gloo' if backend is None else str(backend)
init_process_group(
backend,
rank=int(rank),
world_size=int(world_size),
init_method='env://',
)
torch.cuda.set_device(int(LOCAL_RANK))
def load_train_objs(cfg):
# datasets
train_dataset, val_dataset, _ = get_dataset(cfg)
train_dataloader = DataLoader(
train_dataset,
batch_size=cfg.model.batch_size,
num_workers=cfg.model.num_workers,
sampler=DistributedSampler(
train_dataset,
num_replicas=SIZE,
rank=RANK,
),
shuffle=False,
pin_memory=True
)
val_dataloader = None
# model
model = VQGAN(cfg)
# optimizers
opt_ae = torch.optim.Adam(list(model.encoder.parameters()) +
list(model.decoder.parameters()) +
list(model.pre_vq_conv.parameters()) +
list(model.post_vq_conv.parameters()) +
list(model.codebook.parameters()) +
list(model.perceptual_model.parameters()),
lr=cfg.model.lr, betas=(0.5, 0.9))
opt_disc = torch.optim.Adam(list(model.image_discriminator.parameters()),
lr=cfg.model.lr, betas=(0.5, 0.9))
return [train_dataloader, val_dataloader], model, [opt_ae, opt_disc]
def run(cfg, save_every: int, total_epochs: int, save_checkpoint_path: str, load_checkpoint_path: str):
ddp_setup(RANK, SIZE, backend='nccl')
print('### DDP Info ###')
print('WORLD SIZE:', SIZE)
print('LOCAL_RANK:', LOCAL_RANK)
print('RANK:', RANK)
print('')
datasets, model, optimizers = load_train_objs(cfg)
trainer = Trainer(
model,
datasets,
optimizers,
save_every,
save_checkpoint_path,
load_checkpoint_path,
cfg
)
trainer.train(total_epochs)
destroy_process_group()
@hydra.main(config_path='../config', config_name='base_cfg', version_base=None)
def main(cfg: DictConfig):
# automatically adjust learning rate
bs, base_lr, ngpu, accumulate = cfg.model.batch_size, cfg.model.lr, cfg.model.gpus, cfg.model.accumulate_grad_batches
with open_dict(cfg):
cfg.model.lr = accumulate * (ngpu/8.) * (bs/4.) * base_lr
cfg.model.default_root_dir = os.path.join(cfg.model.default_root_dir, cfg.dataset.name, cfg.model.default_root_dir_postfix)
print("Setting learning rate to {:.2e} = {} (accumulate_grad_batches) * {} (num_gpus/8) * {} (batchsize/4) * {:.2e} (base_lr)".format(cfg.model.lr, accumulate, ngpu/8, bs/4, base_lr))
if not os.path.isdir(cfg.model.save_checkpoint_path):
os.mkdir(cfg.model.save_checkpoint_path)
run(
cfg,
save_every=cfg.model.checkpoint_every,
total_epochs=cfg.model.max_epochs,
save_checkpoint_path=cfg.model.save_checkpoint_path,
load_checkpoint_path=cfg.model.resume_from_checkpoint
)
if __name__ == '__main__':
main()