import os import torch def main(run, cfg): from torch.utils.data.distributed import DistributedSampler from utils.trainer import Trainer if cfg.reproduce_dire: dataset = TMDireDataset(cfg.dataset_root) val_dataset = TMDireDataset(cfg.dataset_test_root) elif cfg.only_eps: dataset = TMEPSOnlyDataset(cfg.dataset_root) val_dataset = TMEPSOnlyDataset(cfg.dataset_root) elif cfg.only_img: dataset = TMIMGOnlyDataset(cfg.dataset_root, istrain=True) val_dataset = TMIMGOnlyDataset(cfg.dataset_test_root, istrain=False) else: dataset= TMDistilDireDataset(cfg.dataset_root) val_dataset = TMDistilDireDataset(cfg.dataset_test_root) sampler = DistributedSampler(dataset) val_samlper = DistributedSampler(val_dataset) dataloader = DataLoader(dataset, batch_size=cfg.batch_size, sampler=sampler, num_workers=2) val_loader = DataLoader(val_dataset, batch_size=cfg.batch_size, sampler=val_samlper, num_workers=2) trainer = Trainer(cfg, dataloader, val_loader, run, local_rank, True, world_size, cfg.kd) if cfg.pretrained_weights: trainer.load_networks(cfg.pretrained_weights) trainer.train() if __name__ == "__main__": import torch.distributed as dist import os import wandb from torch.utils.data import DataLoader from dataset import TMDistilDireDataset, TMDireDataset, TMEPSOnlyDataset, TMIMGOnlyDataset dist.init_process_group(backend='nccl', init_method='env://') local_rank = int(os.environ['LOCAL_RANK']) world_size = int(os.environ['WORLD_SIZE']) torch.cuda.set_device(local_rank) dist.barrier() from utils.config import cfg run = None if local_rank == 0: run = wandb.init(project=f'dire-distill-truemedia', config=cfg, dir=cfg.exp_dir) main(run, cfg)