rule-guided-music / scripts /diffcollage_sample.py
yjhuangcd
First commit
9965bf6
"""
Generate a large batch of image samples from a model and save them as a large
numpy array. This can be used to produce samples for FID evaluation.
"""
import argparse
import os
import numpy as np
import torch as th
import torch.distributed as dist
from guided_diffusion import dist_util, midi_util, logger
from guided_diffusion.dit import DiT_models
from guided_diffusion.script_util import (
NUM_CLASSES,
model_and_diffusion_defaults,
create_diffusion,
create_model_and_diffusion,
add_dict_to_argparser,
args_to_dict,
)
from load_utils import load_model
import diff_collage as dc
def main():
args = create_argparser().parse_args()
comm = dist_util.setup_dist(port=args.port)
logger.configure(args=args, comm=comm)
logger.log("creating model and diffusion...")
model = DiT_models[args.model](
input_size=args.image_size,
in_channels=args.in_channels,
num_classes=args.num_classes,
learn_sigma=args.learn_sigma,
# patchify=args.patchify,
# half=True,
)
diffusion = create_diffusion(
learn_sigma=args.learn_sigma,
diffusion_steps=args.diffusion_steps,
noise_schedule=args.noise_schedule,
timestep_respacing=args.timestep_respacing,
use_kl=args.use_kl,
predict_xstart=args.predict_xstart,
rescale_timesteps=args.rescale_timesteps,
rescale_learned_sigmas=args.rescale_learned_sigmas,
)
model.load_state_dict(
dist_util.load_state_dict(args.model_path, map_location="cpu"), strict=False
)
model.to(dist_util.dev())
if args.use_fp16:
model.convert_to_fp16()
model.eval()
overlap_size = 64 # how much overlap
num_img = 3 # how many square images
img_shape = (4, 16, 128) # image shape # todo: add in args
def eps_fn(x, t, y=None, half=False):
return model(x.permute(0, 1, 3, 2), t, y=y, half=half).permute(0, 1, 3, 2) # since our backbone takes 128x16 as input
worker = dc.CondIndCircle(img_shape, eps_fn, num_img+1, overlap_size=overlap_size) # circle need one more num_img than linear
# worker = dc.CondIndSimple(img_shape, eps_fn, num_img, overlap_size=overlap_size)
model_long_fn = worker.eps_scalar_t_fn
def model_fn(x, t, y=None):
# takes in 4 x pitch x time, return 4 x pitch x time
y_null = th.tensor([args.num_classes] * args.batch_size, device=dist_util.dev())
if args.class_cond:
return (1 + args.w) * model_long_fn(x, t, y) - args.w * model_long_fn(x, t, y_null)
else:
return model_long_fn(x, t, y_null)
# create embed model
embed_model = load_model(args.embed_model_name, args.embed_model_ckpt)
embed_model.to(dist_util.dev())
embed_model.eval()
logger.log("sampling...")
all_images = []
all_labels = []
save_dir = os.path.join(logger.get_dir(), "generated_samples")
os.makedirs(os.path.expanduser(save_dir), exist_ok=True)
while len(all_images) * args.batch_size < args.num_samples:
model_kwargs = {}
if args.class_cond:
# only generate one class
classes = th.ones(size=(args.batch_size,), device=dist_util.dev(), dtype=th.int) * 0
# randomly select classes
# classes = th.randint(
# low=0, high=NUM_CLASSES, size=(args.batch_size,), device=dist_util.dev()
# )
model_kwargs["y"] = classes
sample_fn = (
diffusion.p_sample_loop if not args.use_ddim else diffusion.ddim_sample_loop
)
sample = sample_fn(
model_fn,
(args.batch_size, *worker.shape),
clip_denoised=args.clip_denoised,
model_kwargs=model_kwargs,
device=dist_util.dev(),
progress=True
)
sample = midi_util.decode_sample_for_midi(sample, embed_model=embed_model,
scale_factor=args.scale_factor, threshold=-0.95)
gathered_samples = [th.zeros_like(sample) for _ in range(dist.get_world_size())]
dist.all_gather(gathered_samples, sample) # gather not supported with NCCL
all_images.extend([sample.cpu().numpy() for sample in gathered_samples])
if args.class_cond:
gathered_labels = [
th.zeros_like(classes) for _ in range(dist.get_world_size())
]
dist.all_gather(gathered_labels, classes)
all_labels.extend([labels.cpu().numpy() for labels in gathered_labels])
logger.log(f"created {len(all_images) * args.batch_size} samples")
arr = np.concatenate(all_images, axis=0)
if arr.shape[-1] == 1: # no pedal, need shape B x 128 x 1024
arr = arr.squeeze(axis=-1)
else: # with pedal, need shape: B x 2 x 128 x 1024
arr = arr.transpose(0, 3, 1, 2)
arr = arr[: args.num_samples]
if args.class_cond:
label_arr = np.concatenate(all_labels, axis=0)
label_arr = label_arr[: args.num_samples]
if dist.get_rank() == 0:
if args.class_cond:
midi_util.save_piano_roll_midi(arr, save_dir, args.fs, y=label_arr)
else:
midi_util.save_piano_roll_midi(arr, save_dir, args.fs)
dist.barrier()
logger.log("sampling complete")
def create_argparser():
defaults = dict(
project="music-sampling",
dir="",
model="DiT-B/4", # DiT model names
embed_model_name='kl/f8',
embed_model_ckpt='kl-f8-4-6/steps_42k.ckpt',
clip_denoised=False,
num_samples=128,
batch_size=16,
use_ddim=False,
model_path="",
get_KL=True,
scale_factor=1.,
patchify=True,
fs=100,
num_classes=0,
w=4., # for cfg
training=False,
port=None,
)
defaults.update(model_and_diffusion_defaults())
parser = argparse.ArgumentParser()
add_dict_to_argparser(parser, defaults)
return parser
if __name__ == "__main__":
main()