File size: 5,494 Bytes
9965bf6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
"""
Train a diffusion model on images.
"""

import argparse
import os
import os.path as osp

from guided_diffusion import dist_util, logger
from guided_diffusion.dit import DiT_models
from guided_diffusion.resample import create_named_schedule_sampler
from guided_diffusion.script_util import (
    model_and_diffusion_defaults,
    create_diffusion,
    args_to_dict,
    add_dict_to_argparser,
)
from guided_diffusion.train_util import TrainLoop
from guided_diffusion.pr_datasets_all import load_data
from load_utils import load_model
from mpi4py import MPI
from absl import app
from absl import flags
from absl.flags import argparse_flags


def main(args):
    comm = dist_util.setup_dist(port=args.port)
    logger.configure(args=args, comm=comm)

    logger.log("creating model and diffusion...")
    model = DiT_models[args.model](
        input_size=args.image_size,
        in_channels=args.in_channels,
        num_classes=args.num_classes,
        learn_sigma=args.learn_sigma,
    )
    diffusion = create_diffusion(
        learn_sigma=args.learn_sigma,
        diffusion_steps=args.diffusion_steps,
        noise_schedule=args.noise_schedule,
        timestep_respacing=args.timestep_respacing,
        use_kl=args.use_kl,
        predict_xstart=args.predict_xstart,
        rescale_timesteps=args.rescale_timesteps,
        rescale_learned_sigmas=args.rescale_learned_sigmas,
    )
    model.to(dist_util.dev())
    # create model architecture for eval loss, need to use ema params
    eval_model = DiT_models[args.model](
        input_size=args.image_size,
        in_channels=args.in_channels,
        num_classes=args.num_classes,
        learn_sigma=args.learn_sigma,
    )
    eval_model.to(dist_util.dev())
    # create embed model
    if args.embed_model_name is not None:
        embed_model = load_model(args.embed_model_name, args.embed_model_ckpt)
        del embed_model.loss
        embed_model.to(dist_util.dev())
        embed_model.eval()

    schedule_sampler = create_named_schedule_sampler(args.schedule_sampler, diffusion)

    logger.log("creating data loader...")
    data = load_data(
        data_dir=args.data_dir + "_train.csv",
        batch_size=args.batch_size // args.encode_rep,
        class_cond=args.class_cond,
        image_size=args.pr_image_size,
    )

    eval_data = load_data(
        data_dir=args.data_dir + "_test.csv",
        batch_size=args.batch_size // args.encode_rep,
        class_cond=args.class_cond,
        image_size=args.pr_image_size,
    )

    logger.log("training...")
    TrainLoop(
        model=model,
        eval_model=eval_model,
        diffusion=diffusion,
        data=data,
        batch_size=args.batch_size,
        microbatch=args.microbatch,
        lr=args.lr,
        ema_rate=args.ema_rate,
        log_interval=args.log_interval,
        save_interval=args.save_interval,
        resume_checkpoint=args.resume_checkpoint,
        embed_model=embed_model if args.embed_model_name is not None else None,
        use_fp16=args.use_fp16,
        fp16_scale_growth=args.fp16_scale_growth,
        schedule_sampler=schedule_sampler,
        weight_decay=args.weight_decay,
        lr_anneal_steps=args.lr_anneal_steps,
        eval_data=eval_data,
        eval_interval=args.save_interval,
        eval_sample_batch_size=16,
        total_num_gpus = MPI.COMM_WORLD.Get_size(),
        eval_sample_use_ddim=False,
        eval_sample_clip_denoised=args.eval_sample_clip_denoised,    # do not clip when training on latent space
        in_channels=args.in_channels,
        fs=args.fs,  # hard code 100 for embed
        pedal=args.pedal,
        scale_factor=args.scale_factor,   # need to manually set scale_factor when resume
        num_classes=args.num_classes,   # whether to use class_cond in sampling
        microbatch_encode=args.microbatch_encode,
        encode_rep=args.encode_rep,
        shift_size=args.shift_size,
    ).run_loop()


def parse_flags(argv):
    parser = argparse_flags.ArgumentParser(description='An argparse + app.run example')
    defaults = dict(
        project="music-guided",
        dir="",
        data_dir="",
        model="DiTRotary_XL_8",  # DiT model names
        schedule_sampler="uniform",
        lr=1e-4,
        weight_decay=0.0,
        lr_anneal_steps=0,   # total steps, if set to be a positive number, lr will linearly decay
        batch_size=1,
        encode_rep=4,  # whether to use recombination of encoded excerpts
        shift_size=2,  # need to be compatible with encode_rep to get effective batch size
        microbatch=-1,  # -1 disables microbatches
        ema_rate="0.9999",  # comma-separated list of EMA values
        log_interval=10,
        save_interval=10000,
        resume_checkpoint="",
        use_fp16=False,
        fp16_scale_growth=1e-3,
        embed_model_name="kl/f8-all-onset",
        embed_model_ckpt="taming-transformers/checkpoints/all_onset/epoch_14.ckpt",
        eval_sample_clip_denoised=False,
        scale_factor=1.,
        fs=100,   # saving piano roll with this fs
        pedal=False,
        num_classes=0,    # 0 is unconditional
        microbatch_encode=-1,
        pr_image_size=1024,
        training=True,
        ngc=False,   # whether to use dist setup on ngc
        port=None,
    )
    defaults.update(model_and_diffusion_defaults())
    add_dict_to_argparser(parser, defaults)
    return parser.parse_args(argv[1:])


if __name__ == "__main__":
    app.run(main, flags_parser=parse_flags)