""" Train a diffusion model on images. """ import argparse import os import os.path as osp from guided_diffusion import dist_util, logger from guided_diffusion.dit import DiT_models from guided_diffusion.resample import create_named_schedule_sampler from guided_diffusion.script_util import ( model_and_diffusion_defaults, create_diffusion, args_to_dict, add_dict_to_argparser, ) from guided_diffusion.train_util import TrainLoop from guided_diffusion.pr_datasets_all import load_data from load_utils import load_model from mpi4py import MPI from absl import app from absl import flags from absl.flags import argparse_flags def main(args): comm = dist_util.setup_dist(port=args.port) logger.configure(args=args, comm=comm) logger.log("creating model and diffusion...") model = DiT_models[args.model]( input_size=args.image_size, in_channels=args.in_channels, num_classes=args.num_classes, learn_sigma=args.learn_sigma, ) diffusion = create_diffusion( learn_sigma=args.learn_sigma, diffusion_steps=args.diffusion_steps, noise_schedule=args.noise_schedule, timestep_respacing=args.timestep_respacing, use_kl=args.use_kl, predict_xstart=args.predict_xstart, rescale_timesteps=args.rescale_timesteps, rescale_learned_sigmas=args.rescale_learned_sigmas, ) model.to(dist_util.dev()) # create model architecture for eval loss, need to use ema params eval_model = DiT_models[args.model]( input_size=args.image_size, in_channels=args.in_channels, num_classes=args.num_classes, learn_sigma=args.learn_sigma, ) eval_model.to(dist_util.dev()) # create embed model if args.embed_model_name is not None: embed_model = load_model(args.embed_model_name, args.embed_model_ckpt) del embed_model.loss embed_model.to(dist_util.dev()) embed_model.eval() schedule_sampler = create_named_schedule_sampler(args.schedule_sampler, diffusion) logger.log("creating data loader...") data = load_data( data_dir=args.data_dir + "_train.csv", batch_size=args.batch_size // args.encode_rep, class_cond=args.class_cond, image_size=args.pr_image_size, ) eval_data = load_data( data_dir=args.data_dir + "_test.csv", batch_size=args.batch_size // args.encode_rep, class_cond=args.class_cond, image_size=args.pr_image_size, ) logger.log("training...") TrainLoop( model=model, eval_model=eval_model, diffusion=diffusion, data=data, batch_size=args.batch_size, microbatch=args.microbatch, lr=args.lr, ema_rate=args.ema_rate, log_interval=args.log_interval, save_interval=args.save_interval, resume_checkpoint=args.resume_checkpoint, embed_model=embed_model if args.embed_model_name is not None else None, use_fp16=args.use_fp16, fp16_scale_growth=args.fp16_scale_growth, schedule_sampler=schedule_sampler, weight_decay=args.weight_decay, lr_anneal_steps=args.lr_anneal_steps, eval_data=eval_data, eval_interval=args.save_interval, eval_sample_batch_size=16, total_num_gpus = MPI.COMM_WORLD.Get_size(), eval_sample_use_ddim=False, eval_sample_clip_denoised=args.eval_sample_clip_denoised, # do not clip when training on latent space in_channels=args.in_channels, fs=args.fs, # hard code 100 for embed pedal=args.pedal, scale_factor=args.scale_factor, # need to manually set scale_factor when resume num_classes=args.num_classes, # whether to use class_cond in sampling microbatch_encode=args.microbatch_encode, encode_rep=args.encode_rep, shift_size=args.shift_size, ).run_loop() def parse_flags(argv): parser = argparse_flags.ArgumentParser(description='An argparse + app.run example') defaults = dict( project="music-guided", dir="", data_dir="", model="DiTRotary_XL_8", # DiT model names schedule_sampler="uniform", lr=1e-4, weight_decay=0.0, lr_anneal_steps=0, # total steps, if set to be a positive number, lr will linearly decay batch_size=1, encode_rep=4, # whether to use recombination of encoded excerpts shift_size=2, # need to be compatible with encode_rep to get effective batch size microbatch=-1, # -1 disables microbatches ema_rate="0.9999", # comma-separated list of EMA values log_interval=10, save_interval=10000, resume_checkpoint="", use_fp16=False, fp16_scale_growth=1e-3, embed_model_name="kl/f8-all-onset", embed_model_ckpt="taming-transformers/checkpoints/all_onset/epoch_14.ckpt", eval_sample_clip_denoised=False, scale_factor=1., fs=100, # saving piano roll with this fs pedal=False, num_classes=0, # 0 is unconditional microbatch_encode=-1, pr_image_size=1024, training=True, ngc=False, # whether to use dist setup on ngc port=None, ) defaults.update(model_and_diffusion_defaults()) add_dict_to_argparser(parser, defaults) return parser.parse_args(argv[1:]) if __name__ == "__main__": app.run(main, flags_parser=parse_flags)