Spaces:
Running
on
Zero
Running
on
Zero
import argparse | |
import os | |
from omegaconf import OmegaConf | |
import wandb | |
from trainer import DiffusionTrainer, GANTrainer, ODETrainer, ScoreDistillationTrainer | |
def main(): | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--config_path", type=str, required=True) | |
parser.add_argument("--no_save", action="store_true") | |
parser.add_argument("--no_visualize", action="store_true") | |
parser.add_argument("--logdir", type=str, default="", help="Path to the directory to save logs") | |
parser.add_argument("--wandb-save-dir", type=str, default="", help="Path to the directory to save wandb logs") | |
parser.add_argument("--disable-wandb", action="store_true") | |
args = parser.parse_args() | |
config = OmegaConf.load(args.config_path) | |
default_config = OmegaConf.load("configs/default_config.yaml") | |
config = OmegaConf.merge(default_config, config) | |
config.no_save = args.no_save | |
config.no_visualize = args.no_visualize | |
# get the filename of config_path | |
config_name = os.path.basename(args.config_path).split(".")[0] | |
config.config_name = config_name | |
config.logdir = args.logdir | |
config.wandb_save_dir = args.wandb_save_dir | |
config.disable_wandb = args.disable_wandb | |
if config.trainer == "diffusion": | |
trainer = DiffusionTrainer(config) | |
elif config.trainer == "gan": | |
trainer = GANTrainer(config) | |
elif config.trainer == "ode": | |
trainer = ODETrainer(config) | |
elif config.trainer == "score_distillation": | |
trainer = ScoreDistillationTrainer(config) | |
trainer.train() | |
wandb.finish() | |
if __name__ == "__main__": | |
main() | |