import argparse import os from omegaconf import OmegaConf import wandb from trainer import DiffusionTrainer, GANTrainer, ODETrainer, ScoreDistillationTrainer def main(): parser = argparse.ArgumentParser() parser.add_argument("--config_path", type=str, required=True) parser.add_argument("--no_save", action="store_true") parser.add_argument("--no_visualize", action="store_true") parser.add_argument("--logdir", type=str, default="", help="Path to the directory to save logs") parser.add_argument("--wandb-save-dir", type=str, default="", help="Path to the directory to save wandb logs") parser.add_argument("--disable-wandb", action="store_true") args = parser.parse_args() config = OmegaConf.load(args.config_path) default_config = OmegaConf.load("configs/default_config.yaml") config = OmegaConf.merge(default_config, config) config.no_save = args.no_save config.no_visualize = args.no_visualize # get the filename of config_path config_name = os.path.basename(args.config_path).split(".")[0] config.config_name = config_name config.logdir = args.logdir config.wandb_save_dir = args.wandb_save_dir config.disable_wandb = args.disable_wandb if config.trainer == "diffusion": trainer = DiffusionTrainer(config) elif config.trainer == "gan": trainer = GANTrainer(config) elif config.trainer == "ode": trainer = ODETrainer(config) elif config.trainer == "score_distillation": trainer = ScoreDistillationTrainer(config) trainer.train() wandb.finish() if __name__ == "__main__": main()