rule-guided-music / scripts /cfg_sample.py
yjhuangcd
First commit
9965bf6
"""
Generate a large batch of image samples from a model and save them as a large
numpy array. This can be used to produce samples for FID evaluation.
"""
import argparse
import os
import numpy as np
import torch as th
import torch.distributed as dist
from guided_diffusion import dist_util, midi_util, logger
from guided_diffusion.dit import DiT_models
from guided_diffusion.script_util import (
NUM_CLASSES,
model_and_diffusion_defaults,
create_diffusion,
create_model_and_diffusion,
add_dict_to_argparser,
args_to_dict,
)
from load_utils import load_model
def main():
args = create_argparser().parse_args()
comm = dist_util.setup_dist(port=args.port)
logger.configure(args=args, comm=comm)
logger.log("creating model and diffusion...")
model = DiT_models[args.model](
input_size=args.image_size,
in_channels=args.in_channels,
num_classes=args.num_classes,
learn_sigma=args.learn_sigma,
)
diffusion = create_diffusion(
learn_sigma=args.learn_sigma,
diffusion_steps=args.diffusion_steps,
noise_schedule=args.noise_schedule,
timestep_respacing=args.timestep_respacing,
use_kl=args.use_kl,
predict_xstart=args.predict_xstart,
rescale_timesteps=args.rescale_timesteps,
rescale_learned_sigmas=args.rescale_learned_sigmas,
)
model.load_state_dict(
dist_util.load_state_dict(args.model_path, map_location="cpu"), strict=False
)
model.to(dist_util.dev())
if args.use_fp16:
model.convert_to_fp16()
model.eval()
def model_fn(x, t, y=None):
y_null = th.tensor([args.num_classes] * args.batch_size, device=dist_util.dev())
if args.class_cond:
if args.cfg:
return (1 + args.w) * model(x, t, y) - args.w * model(x, t, y_null)
else:
return model(x, t, y)
else:
return model(x, t, y_null)
# create embed model
embed_model = load_model(args.embed_model_name, args.embed_model_ckpt)
embed_model.to(dist_util.dev())
embed_model.eval()
logger.log("sampling...")
all_images = []
all_labels = []
save_dir = os.path.join(logger.get_dir(), f"gen_cls_{args.class_label}{args.save_name}")
os.makedirs(os.path.expanduser(save_dir), exist_ok=True)
print(save_dir)
while len(all_images) * args.batch_size < args.num_samples:
model_kwargs = {}
if args.class_cond:
# only generate one class
classes = th.ones(size=(args.batch_size,), device=dist_util.dev(), dtype=th.int) * args.class_label
# randomly select classes
# classes = th.randint(
# low=0, high=NUM_CLASSES, size=(args.batch_size,), device=dist_util.dev()
# )
model_kwargs["y"] = classes
sample_fn = (
diffusion.p_sample_loop if not args.use_ddim else diffusion.ddim_sample_loop
)
sample = sample_fn(
model_fn,
(args.batch_size, args.in_channels, args.image_size[0], args.image_size[1]),
clip_denoised=args.clip_denoised,
model_kwargs=model_kwargs,
device=dist_util.dev(),
progress=True
)
sample = midi_util.decode_sample_for_midi(sample, embed_model=embed_model,
scale_factor=args.scale_factor, threshold=-0.95)
gathered_samples = [th.zeros_like(sample) for _ in range(dist.get_world_size())]
dist.all_gather(gathered_samples, sample) # gather not supported with NCCL
all_images.extend([sample.cpu().numpy() for sample in gathered_samples])
if args.class_cond:
gathered_labels = [
th.zeros_like(classes) for _ in range(dist.get_world_size())
]
dist.all_gather(gathered_labels, classes)
all_labels.extend([labels.cpu().numpy() for labels in gathered_labels])
logger.log(f"created {len(all_images) * args.batch_size} samples")
arr = np.concatenate(all_images, axis=0)
if arr.shape[-1] == 1: # no pedal, need shape B x 128 x 1024
arr = arr.squeeze(axis=-1)
else: # with pedal, need shape: B x 2 x 128 x 1024
arr = arr.transpose(0, 3, 1, 2)
arr = arr[: args.num_samples]
if args.class_cond:
label_arr = np.concatenate(all_labels, axis=0)
label_arr = label_arr[: args.num_samples]
if dist.get_rank() == 0:
if args.class_cond:
midi_util.save_piano_roll_midi(arr, save_dir, args.fs, y=label_arr)
else:
midi_util.save_piano_roll_midi(arr, save_dir, args.fs)
dist.barrier()
logger.log("sampling complete")
def create_argparser():
defaults = dict(
project="music-sampling",
save_name="", # optional save name to indicate iter
dir="",
model="DiTRotary_XL_8", # DiT model names
embed_model_name='kl/f8-all-onset',
embed_model_ckpt='taming-transformers/checkpoints/all_onset/epoch_14.ckpt',
clip_denoised=False,
num_samples=128,
batch_size=16,
use_ddim=False,
model_path="",
scale_factor=1.,
fs=100,
num_classes=3,
class_label=1, # class to generate
cfg=False,
w=4., # for cfg
training=False,
port=None,
)
defaults.update(model_and_diffusion_defaults())
parser = argparse.ArgumentParser()
add_dict_to_argparser(parser, defaults)
return parser
if __name__ == "__main__":
main()