|
""" |
|
Train a diffusion model on images. |
|
""" |
|
|
|
import argparse |
|
import os |
|
import os.path as osp |
|
|
|
from guided_diffusion import dist_util, logger |
|
from guided_diffusion.dit import DiT_models |
|
from guided_diffusion.resample import create_named_schedule_sampler |
|
from guided_diffusion.script_util import ( |
|
model_and_diffusion_defaults, |
|
create_diffusion, |
|
args_to_dict, |
|
add_dict_to_argparser, |
|
) |
|
from guided_diffusion.train_util import TrainLoop |
|
from guided_diffusion.pr_datasets_all import load_data |
|
from load_utils import load_model |
|
from mpi4py import MPI |
|
from absl import app |
|
from absl import flags |
|
from absl.flags import argparse_flags |
|
|
|
|
|
def main(args): |
|
comm = dist_util.setup_dist(port=args.port) |
|
logger.configure(args=args, comm=comm) |
|
|
|
logger.log("creating model and diffusion...") |
|
model = DiT_models[args.model]( |
|
input_size=args.image_size, |
|
in_channels=args.in_channels, |
|
num_classes=args.num_classes, |
|
learn_sigma=args.learn_sigma, |
|
) |
|
diffusion = create_diffusion( |
|
learn_sigma=args.learn_sigma, |
|
diffusion_steps=args.diffusion_steps, |
|
noise_schedule=args.noise_schedule, |
|
timestep_respacing=args.timestep_respacing, |
|
use_kl=args.use_kl, |
|
predict_xstart=args.predict_xstart, |
|
rescale_timesteps=args.rescale_timesteps, |
|
rescale_learned_sigmas=args.rescale_learned_sigmas, |
|
) |
|
model.to(dist_util.dev()) |
|
|
|
eval_model = DiT_models[args.model]( |
|
input_size=args.image_size, |
|
in_channels=args.in_channels, |
|
num_classes=args.num_classes, |
|
learn_sigma=args.learn_sigma, |
|
) |
|
eval_model.to(dist_util.dev()) |
|
|
|
if args.embed_model_name is not None: |
|
embed_model = load_model(args.embed_model_name, args.embed_model_ckpt) |
|
del embed_model.loss |
|
embed_model.to(dist_util.dev()) |
|
embed_model.eval() |
|
|
|
schedule_sampler = create_named_schedule_sampler(args.schedule_sampler, diffusion) |
|
|
|
logger.log("creating data loader...") |
|
data = load_data( |
|
data_dir=args.data_dir + "_train.csv", |
|
batch_size=args.batch_size // args.encode_rep, |
|
class_cond=args.class_cond, |
|
image_size=args.pr_image_size, |
|
) |
|
|
|
eval_data = load_data( |
|
data_dir=args.data_dir + "_test.csv", |
|
batch_size=args.batch_size // args.encode_rep, |
|
class_cond=args.class_cond, |
|
image_size=args.pr_image_size, |
|
) |
|
|
|
logger.log("training...") |
|
TrainLoop( |
|
model=model, |
|
eval_model=eval_model, |
|
diffusion=diffusion, |
|
data=data, |
|
batch_size=args.batch_size, |
|
microbatch=args.microbatch, |
|
lr=args.lr, |
|
ema_rate=args.ema_rate, |
|
log_interval=args.log_interval, |
|
save_interval=args.save_interval, |
|
resume_checkpoint=args.resume_checkpoint, |
|
embed_model=embed_model if args.embed_model_name is not None else None, |
|
use_fp16=args.use_fp16, |
|
fp16_scale_growth=args.fp16_scale_growth, |
|
schedule_sampler=schedule_sampler, |
|
weight_decay=args.weight_decay, |
|
lr_anneal_steps=args.lr_anneal_steps, |
|
eval_data=eval_data, |
|
eval_interval=args.save_interval, |
|
eval_sample_batch_size=16, |
|
total_num_gpus = MPI.COMM_WORLD.Get_size(), |
|
eval_sample_use_ddim=False, |
|
eval_sample_clip_denoised=args.eval_sample_clip_denoised, |
|
in_channels=args.in_channels, |
|
fs=args.fs, |
|
pedal=args.pedal, |
|
scale_factor=args.scale_factor, |
|
num_classes=args.num_classes, |
|
microbatch_encode=args.microbatch_encode, |
|
encode_rep=args.encode_rep, |
|
shift_size=args.shift_size, |
|
).run_loop() |
|
|
|
|
|
def parse_flags(argv): |
|
parser = argparse_flags.ArgumentParser(description='An argparse + app.run example') |
|
defaults = dict( |
|
project="music-guided", |
|
dir="", |
|
data_dir="", |
|
model="DiTRotary_XL_8", |
|
schedule_sampler="uniform", |
|
lr=1e-4, |
|
weight_decay=0.0, |
|
lr_anneal_steps=0, |
|
batch_size=1, |
|
encode_rep=4, |
|
shift_size=2, |
|
microbatch=-1, |
|
ema_rate="0.9999", |
|
log_interval=10, |
|
save_interval=10000, |
|
resume_checkpoint="", |
|
use_fp16=False, |
|
fp16_scale_growth=1e-3, |
|
embed_model_name="kl/f8-all-onset", |
|
embed_model_ckpt="taming-transformers/checkpoints/all_onset/epoch_14.ckpt", |
|
eval_sample_clip_denoised=False, |
|
scale_factor=1., |
|
fs=100, |
|
pedal=False, |
|
num_classes=0, |
|
microbatch_encode=-1, |
|
pr_image_size=1024, |
|
training=True, |
|
ngc=False, |
|
port=None, |
|
) |
|
defaults.update(model_and_diffusion_defaults()) |
|
add_dict_to_argparser(parser, defaults) |
|
return parser.parse_args(argv[1:]) |
|
|
|
|
|
if __name__ == "__main__": |
|
app.run(main, flags_parser=parse_flags) |
|
|