rule-guided-music / scripts /classifier_sample.py
yjhuangcd
First commit
9965bf6
"""
Generate a large batch of image samples from a model and save them as a large
numpy array. This can be used to produce samples for FID evaluation.
"""
import argparse
import os
import numpy as np
import torch as th
import torch.distributed as dist
import torch.nn.functional as F
from guided_diffusion import dist_util, midi_util, logger
from guided_diffusion.dit import DiT_models
from guided_diffusion.script_util import (
NUM_CLASSES,
model_and_diffusion_defaults,
create_diffusion,
create_model_and_diffusion,
add_dict_to_argparser,
args_to_dict,
)
from load_utils import load_model
from guided_diffusion.pr_datasets import FUNC_DICT
def main():
args = create_argparser().parse_args()
comm = dist_util.setup_dist(port=args.port)
logger.configure(args=args, comm=comm)
logger.log("creating model and diffusion...")
model = DiT_models[args.model](
input_size=args.image_size,
in_channels=args.in_channels,
num_classes=args.num_classes,
learn_sigma=args.learn_sigma,
patchify=args.patchify,
)
diffusion = create_diffusion(
learn_sigma=args.learn_sigma,
diffusion_steps=args.diffusion_steps,
noise_schedule=args.noise_schedule,
timestep_respacing=args.timestep_respacing,
use_kl=args.use_kl,
predict_xstart=args.predict_xstart,
rescale_timesteps=args.rescale_timesteps,
rescale_learned_sigmas=args.rescale_learned_sigmas,
)
model.load_state_dict(
dist_util.load_state_dict(args.model_path, map_location="cpu"), strict=False
)
model.to(dist_util.dev())
if args.use_fp16:
model.convert_to_fp16()
model.eval()
logger.log("loading classifier...")
classifier = DiT_models[args.classifier_model](
input_size=args.image_size, # classifier trained on latents for now, so has the same img size as diffusion
in_channels=args.in_channels,
num_classes=args.classifier_num_classes,
patchify=args.classifier_patchify,
)
classifier.load_state_dict(
dist_util.load_state_dict(args.classifier_path, map_location="cpu")
)
classifier.to(dist_util.dev())
if args.classifier_use_fp16:
classifier.convert_to_fp16()
classifier.eval()
# y is a dummy input for cond_fn, rule is the real input
def cond_fn_xentropy(x, t, y=None, rule=None):
# Xentropy cond_fn
assert rule is not None
with th.enable_grad():
x_in = x.detach().requires_grad_(True)
logits = classifier(x_in, t)
log_probs = F.log_softmax(logits, dim=-1)
selected = log_probs[range(len(logits)), rule.view(-1)]
return th.autograd.grad(selected.sum(), x_in)[0] * args.classifier_scale
def cond_fn_mse(x, t, y=None, rule=None, var=0.1):
# mse cond_fn, preset var
assert rule is not None
with th.enable_grad():
x_in = x.detach().requires_grad_(True)
logits = classifier(x_in, t)
log_probs = -1/(2*var) * F.mse_loss(logits, rule, reduction="none").sum(dim=-1)
return th.autograd.grad(log_probs.sum(), x_in)[0] * args.classifier_scale
def model_fn(x, t, y=None, rule=None):
# y has to be composer, rule is a dummy input
y_null = th.tensor([args.num_classes] * args.batch_size, device=dist_util.dev())
if args.class_cond:
if args.cfg:
return (1 + args.w) * model(x, t, y) - args.w * model(x, t, y_null)
else:
return model(x, t, y)
else:
return model(x, t, y_null)
# create embed model
embed_model = load_model(args.embed_model_name, args.embed_model_ckpt)
embed_model.to(dist_util.dev())
embed_model.eval()
logger.log("sampling...")
all_images = []
all_labels = []
while len(all_images) * args.batch_size < args.num_samples:
model_kwargs = {}
if args.class_cond:
# only generate one class
classes = th.ones(size=(args.batch_size,), device=dist_util.dev(), dtype=th.int) * 1
# randomly select classes
# classes = th.randint(
# low=0, high=NUM_CLASSES, size=(args.batch_size,), device=dist_util.dev()
# )
model_kwargs["y"] = classes
if args.rule is not None:
if args.rule == 'pitch_hist':
# rule_label = th.zeros(size=(args.batch_size, 12), device=dist_util.dev())
major_profile = th.tensor(
# [6.35, 2.23, 3.48, 2.33, 4.38, 4.09, 2.52, 5.19, 2.39, 3.66, 2.29, 2.88],
[6.35, 0., 0., 0., 4.38, 0., 0., 5.19, 0., 0., 0., 0.],
device=dist_util.dev())
minor_profile = th.tensor(
# [6.33, 2.68, 3.52, 5.38, 2.60, 3.53, 2.54, 4.75, 3.98, 2.69, 3.34, 3.17], # C minor
# [3.17, 6.33, 2.68, 3.52, 5.38, 2.60, 3.53, 2.54, 4.75, 3.98, 2.69, 3.34], # C# minor
[6.33, 0., 0., 5.38, 0., 0., 0., 4.75, 0., 0., 0., 0.],
device=dist_util.dev())
# rule_label_c_maj = major_profile.repeat(args.batch_size//2, 1)
# rule_label_c_min = minor_profile.repeat(args.batch_size//2, 1)
# rule_label = th.concat((rule_label_c_maj, rule_label_c_min), dim=0)
if args.major:
rule_label = major_profile.repeat(args.batch_size, 1)
else:
rule_label = minor_profile.repeat(args.batch_size, 1)
rule_label = rule_label / (th.sum(rule_label, dim=-1, keepdim=True) + 1e-12)
else:
rule_label = None
model_kwargs["rule"] = rule_label
sample_fn = (
diffusion.p_sample_loop if not args.use_ddim else diffusion.ddim_sample_loop
)
if args.rule in ['']: # leave for other category based rules
cond_fn_used = cond_fn_xentropy
elif args.rule in ['pitch_hist', 'note_density']:
cond_fn_used = cond_fn_mse
else:
cond_fn_used = None
sample = sample_fn(
model_fn,
(args.batch_size, args.in_channels, args.image_size[0], args.image_size[1]),
clip_denoised=args.clip_denoised,
model_kwargs=model_kwargs,
device=dist_util.dev(),
cond_fn=cond_fn_used,
progress=True
)
sample = sample / args.scale_factor
image_size_h = args.image_size[0]
image_size_w = args.image_size[1]
if image_size_h != image_size_w: # transposed for raster col
sample = sample.permute(0, 1, 3, 2) # vertical axis means pitch after transpose
save_dir = os.path.join(logger.get_dir(), "generated_samples")
os.makedirs(os.path.expanduser(save_dir), exist_ok=True)
with th.no_grad():
if embed_model is not None:
if image_size_h != image_size_w:
sample = th.chunk(sample, image_size_h // image_size_w, dim=-1) # B x C x H x W
sample = th.concat(sample, dim=0) # 1st second for all batch, 2nd second for all batch, ...
# th.save(sample, save_dir + '/latents.pt') # todo: uncomment if wants to save sampled latents
if args.get_KL:
sample = embed_model.decode(sample)
else:
sample = embed_model.decode_diff(sample)
if image_size_h != image_size_w:
sample = th.concat(th.chunk(sample, image_size_h // image_size_w, dim=0), dim=-1)
sample[sample <= -0.95] = -1. # heuristic thresholding the background
sample = ((sample + 1) * 63.5).clamp(0, 127).to(th.uint8)
# # todo: comment above two lines to test no threshold
# sample[sample <= -1.] = -1.
# sample = (sample + 1) * 63.5
sample = sample.permute(0, 2, 3, 1)
sample = sample.contiguous()
gathered_samples = [th.zeros_like(sample) for _ in range(dist.get_world_size())]
dist.all_gather(gathered_samples, sample) # gather not supported with NCCL
all_images.extend([sample.cpu().numpy() for sample in gathered_samples])
if args.class_cond:
gathered_labels = [
th.zeros_like(classes) for _ in range(dist.get_world_size())
]
dist.all_gather(gathered_labels, classes)
all_labels.extend([labels.cpu().numpy() for labels in gathered_labels])
logger.log(f"created {len(all_images) * args.batch_size} samples")
arr = np.concatenate(all_images, axis=0).squeeze()
arr = arr[: args.num_samples]
# test distance between generated samples and target
generated_samples = th.from_numpy(arr) / 63.5 - 1
loss = []
rule_label = rule_label.cpu()
for i in range(generated_samples.shape[0]):
gen_rule = FUNC_DICT[args.rule](generated_samples[i][None][None])
loss.append(F.mse_loss(rule_label, gen_rule).item())
mean_loss = th.tensor(loss).mean()
print(f"loss: {mean_loss}") # distance between generated rule and target rule
if args.class_cond:
label_arr = np.concatenate(all_labels, axis=0)
label_arr = label_arr[: args.num_samples]
if dist.get_rank() == 0:
if args.class_cond:
midi_util.save_piano_roll_midi(arr, save_dir, args.fs, y=label_arr)
else:
midi_util.save_piano_roll_midi(arr, save_dir, args.fs)
dist.barrier()
logger.log("sampling complete")
def create_argparser():
defaults = dict(
project="music-sampling",
dir="",
model="DiT-B/4", # DiT model names
embed_model_name='kl/f8',
embed_model_ckpt='kl-f8-4-6/steps_42k.ckpt',
classifier_use_fp16=False,
classifier_model='DiTRel-XS/2-cls',
classifier_num_classes=12,
classifier_patchify=True,
classifier_path='loggings/classifier-pitch-hist/ditrel-noise/model004999.pt',
clip_denoised=False,
num_samples=128,
batch_size=16,
use_ddim=False,
model_path="",
get_KL=True,
scale_factor=1.,
patchify=True,
fs=100,
num_classes=0,
cfg=False,
w=4., # for cfg
classifier_scale=1.0,
rule=None,
major=True,
port=None,
)
defaults.update(model_and_diffusion_defaults())
parser = argparse.ArgumentParser()
add_dict_to_argparser(parser, defaults)
return parser
if __name__ == "__main__":
main()