import matplotlib matplotlib.use('Agg') import os, sys import yaml from argparse import ArgumentParser from time import gmtime, strftime from shutil import copy from frames_dataset import FramesDataset from modules.inpainting_network import InpaintingNetwork from modules.keypoint_detector import KPDetector from modules.bg_motion_predictor import BGMotionPredictor from modules.dense_motion import DenseMotionNetwork from modules.avd_network import AVDNetwork import torch from train import train from train_avd import train_avd from reconstruction import reconstruction import os if __name__ == "__main__": if sys.version_info[0] < 3: raise Exception("You must use Python 3 or higher. Recommended version is Python 3.9") parser = ArgumentParser() parser.add_argument("--config", default="config/vox-256.yaml", help="path to config") parser.add_argument("--mode", default="train", choices=["train", "reconstruction", "train_avd"]) parser.add_argument("--log_dir", default='log', help="path to log into") parser.add_argument("--checkpoint", default=None, help="path to checkpoint to restore") parser.add_argument("--device_ids", default="0,1", type=lambda x: list(map(int, x.split(','))), help="Names of the devices comma separated.") opt = parser.parse_args() with open(opt.config) as f: config = yaml.load(f) if opt.checkpoint is not None: log_dir = os.path.join(*os.path.split(opt.checkpoint)[:-1]) else: log_dir = os.path.join(opt.log_dir, os.path.basename(opt.config).split('.')[0]) log_dir += ' ' + strftime("%d_%m_%y_%H.%M.%S", gmtime()) inpainting = InpaintingNetwork(**config['model_params']['generator_params'], **config['model_params']['common_params']) if torch.cuda.is_available(): cuda_device = torch.device('cuda:'+str(opt.device_ids[0])) inpainting.to(cuda_device) kp_detector = KPDetector(**config['model_params']['common_params']) dense_motion_network = DenseMotionNetwork(**config['model_params']['common_params'], **config['model_params']['dense_motion_params']) if torch.cuda.is_available(): kp_detector.to(opt.device_ids[0]) dense_motion_network.to(opt.device_ids[0]) bg_predictor = None if (config['model_params']['common_params']['bg']): bg_predictor = BGMotionPredictor() if torch.cuda.is_available(): bg_predictor.to(opt.device_ids[0]) avd_network = None if opt.mode == "train_avd": avd_network = AVDNetwork(num_tps=config['model_params']['common_params']['num_tps'], **config['model_params']['avd_network_params']) if torch.cuda.is_available(): avd_network.to(opt.device_ids[0]) dataset = FramesDataset(is_train=(opt.mode.startswith('train')), **config['dataset_params']) if not os.path.exists(log_dir): os.makedirs(log_dir) if not os.path.exists(os.path.join(log_dir, os.path.basename(opt.config))): copy(opt.config, log_dir) if opt.mode == 'train': print("Training...") train(config, inpainting, kp_detector, bg_predictor, dense_motion_network, opt.checkpoint, log_dir, dataset) elif opt.mode == 'train_avd': print("Training Animation via Disentaglement...") train_avd(config, inpainting, kp_detector, bg_predictor, dense_motion_network, avd_network, opt.checkpoint, log_dir, dataset) elif opt.mode == 'reconstruction': print("Reconstruction...") reconstruction(config, inpainting, kp_detector, bg_predictor, dense_motion_network, opt.checkpoint, log_dir, dataset)