|
"Adapted from https://github.com/SongweiGe/TATS" |
|
|
|
from __future__ import absolute_import, division, print_function, annotations |
|
import os |
|
import socket |
|
|
|
from typing import Optional, Union, Callable |
|
|
|
import sys |
|
sys.path.insert(1, "../diffusion-chem-fm") |
|
|
|
import hydra |
|
from omegaconf import DictConfig, open_dict |
|
from train.get_dataset import get_dataset |
|
from vq_gan_3d.model import VQGAN |
|
from trainer import Trainer |
|
|
|
|
|
import torch |
|
from torch.utils.data import DataLoader |
|
from torch.utils.data.distributed import DistributedSampler |
|
from torch.distributed import init_process_group, destroy_process_group |
|
|
|
|
|
|
|
try: |
|
from mpi4py import MPI |
|
WITH_DDP = True |
|
LOCAL_RANK = os.environ.get('OMPI_COMM_WORLD_LOCAL_RANK', '0') |
|
|
|
SIZE = MPI.COMM_WORLD.Get_size() |
|
RANK = MPI.COMM_WORLD.Get_rank() |
|
|
|
WITH_CUDA = torch.cuda.is_available() |
|
DEVICE = 'gpu' if WITH_CUDA else 'CPU' |
|
|
|
|
|
os.environ['RANK'] = str(RANK) |
|
os.environ['WORLD_SIZE'] = str(SIZE) |
|
|
|
|
|
|
|
|
|
|
|
MASTER_ADDR = socket.gethostname() if RANK == 0 else None |
|
MASTER_ADDR = MPI.COMM_WORLD.bcast(MASTER_ADDR, root=0) |
|
os.environ['MASTER_ADDR'] = MASTER_ADDR |
|
os.environ['MASTER_PORT'] = str(2345) |
|
|
|
except (ImportError, ModuleNotFoundError) as e: |
|
WITH_DDP = False |
|
SIZE = 1 |
|
RANK = 0 |
|
LOCAL_RANK = 0 |
|
MASTER_ADDR = 'localhost' |
|
|
|
|
|
def ddp_setup( |
|
rank: Union[int, str], |
|
world_size: Union[int, str], |
|
backend: Optional[str] = None, |
|
) -> None: |
|
if WITH_CUDA: |
|
backend = 'nccl' if backend is None else str(backend) |
|
else: |
|
backend = 'gloo' if backend is None else str(backend) |
|
|
|
init_process_group( |
|
backend, |
|
rank=int(rank), |
|
world_size=int(world_size), |
|
init_method='env://', |
|
) |
|
torch.cuda.set_device(int(LOCAL_RANK)) |
|
|
|
|
|
def load_train_objs(cfg): |
|
|
|
train_dataset, val_dataset, _ = get_dataset(cfg) |
|
train_dataloader = DataLoader( |
|
train_dataset, |
|
batch_size=cfg.model.batch_size, |
|
num_workers=cfg.model.num_workers, |
|
sampler=DistributedSampler( |
|
train_dataset, |
|
num_replicas=SIZE, |
|
rank=RANK, |
|
), |
|
shuffle=False, |
|
pin_memory=True |
|
) |
|
val_dataloader = None |
|
|
|
|
|
model = VQGAN(cfg) |
|
|
|
|
|
opt_ae = torch.optim.Adam(list(model.encoder.parameters()) + |
|
list(model.decoder.parameters()) + |
|
list(model.pre_vq_conv.parameters()) + |
|
list(model.post_vq_conv.parameters()) + |
|
list(model.codebook.parameters()) + |
|
list(model.perceptual_model.parameters()), |
|
lr=cfg.model.lr, betas=(0.5, 0.9)) |
|
opt_disc = torch.optim.Adam(list(model.image_discriminator.parameters()), |
|
lr=cfg.model.lr, betas=(0.5, 0.9)) |
|
|
|
return [train_dataloader, val_dataloader], model, [opt_ae, opt_disc] |
|
|
|
|
|
def run(cfg, save_every: int, total_epochs: int, save_checkpoint_path: str, load_checkpoint_path: str): |
|
ddp_setup(RANK, SIZE, backend='nccl') |
|
|
|
print('### DDP Info ###') |
|
print('WORLD SIZE:', SIZE) |
|
print('LOCAL_RANK:', LOCAL_RANK) |
|
print('RANK:', RANK) |
|
print('') |
|
|
|
datasets, model, optimizers = load_train_objs(cfg) |
|
|
|
trainer = Trainer( |
|
model, |
|
datasets, |
|
optimizers, |
|
save_every, |
|
save_checkpoint_path, |
|
load_checkpoint_path, |
|
cfg |
|
) |
|
trainer.train(total_epochs) |
|
|
|
destroy_process_group() |
|
|
|
|
|
@hydra.main(config_path='../config', config_name='base_cfg', version_base=None) |
|
def main(cfg: DictConfig): |
|
|
|
bs, base_lr, ngpu, accumulate = cfg.model.batch_size, cfg.model.lr, cfg.model.gpus, cfg.model.accumulate_grad_batches |
|
|
|
with open_dict(cfg): |
|
cfg.model.lr = accumulate * (ngpu/8.) * (bs/4.) * base_lr |
|
cfg.model.default_root_dir = os.path.join(cfg.model.default_root_dir, cfg.dataset.name, cfg.model.default_root_dir_postfix) |
|
print("Setting learning rate to {:.2e} = {} (accumulate_grad_batches) * {} (num_gpus/8) * {} (batchsize/4) * {:.2e} (base_lr)".format(cfg.model.lr, accumulate, ngpu/8, bs/4, base_lr)) |
|
|
|
if not os.path.isdir(cfg.model.save_checkpoint_path): |
|
os.mkdir(cfg.model.save_checkpoint_path) |
|
|
|
run( |
|
cfg, |
|
save_every=cfg.model.checkpoint_every, |
|
total_epochs=cfg.model.max_epochs, |
|
save_checkpoint_path=cfg.model.save_checkpoint_path, |
|
load_checkpoint_path=cfg.model.resume_from_checkpoint |
|
) |
|
|
|
|
|
if __name__ == '__main__': |
|
main() |
|
|