File size: 10,695 Bytes
9965bf6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 |
"""
Generate a large batch of image samples from a model and save them as a large
numpy array. This can be used to produce samples for FID evaluation.
"""
import argparse
import os
import numpy as np
import torch as th
import torch.distributed as dist
import torch.nn.functional as F
from guided_diffusion import dist_util, midi_util, logger
from guided_diffusion.dit import DiT_models
from guided_diffusion.script_util import (
NUM_CLASSES,
model_and_diffusion_defaults,
create_diffusion,
create_model_and_diffusion,
add_dict_to_argparser,
args_to_dict,
)
from load_utils import load_model
from guided_diffusion.pr_datasets import FUNC_DICT
def main():
args = create_argparser().parse_args()
comm = dist_util.setup_dist(port=args.port)
logger.configure(args=args, comm=comm)
logger.log("creating model and diffusion...")
model = DiT_models[args.model](
input_size=args.image_size,
in_channels=args.in_channels,
num_classes=args.num_classes,
learn_sigma=args.learn_sigma,
patchify=args.patchify,
)
diffusion = create_diffusion(
learn_sigma=args.learn_sigma,
diffusion_steps=args.diffusion_steps,
noise_schedule=args.noise_schedule,
timestep_respacing=args.timestep_respacing,
use_kl=args.use_kl,
predict_xstart=args.predict_xstart,
rescale_timesteps=args.rescale_timesteps,
rescale_learned_sigmas=args.rescale_learned_sigmas,
)
model.load_state_dict(
dist_util.load_state_dict(args.model_path, map_location="cpu"), strict=False
)
model.to(dist_util.dev())
if args.use_fp16:
model.convert_to_fp16()
model.eval()
logger.log("loading classifier...")
classifier = DiT_models[args.classifier_model](
input_size=args.image_size, # classifier trained on latents for now, so has the same img size as diffusion
in_channels=args.in_channels,
num_classes=args.classifier_num_classes,
patchify=args.classifier_patchify,
)
classifier.load_state_dict(
dist_util.load_state_dict(args.classifier_path, map_location="cpu")
)
classifier.to(dist_util.dev())
if args.classifier_use_fp16:
classifier.convert_to_fp16()
classifier.eval()
# y is a dummy input for cond_fn, rule is the real input
def cond_fn_xentropy(x, t, y=None, rule=None):
# Xentropy cond_fn
assert rule is not None
with th.enable_grad():
x_in = x.detach().requires_grad_(True)
logits = classifier(x_in, t)
log_probs = F.log_softmax(logits, dim=-1)
selected = log_probs[range(len(logits)), rule.view(-1)]
return th.autograd.grad(selected.sum(), x_in)[0] * args.classifier_scale
def cond_fn_mse(x, t, y=None, rule=None, var=0.1):
# mse cond_fn, preset var
assert rule is not None
with th.enable_grad():
x_in = x.detach().requires_grad_(True)
logits = classifier(x_in, t)
log_probs = -1/(2*var) * F.mse_loss(logits, rule, reduction="none").sum(dim=-1)
return th.autograd.grad(log_probs.sum(), x_in)[0] * args.classifier_scale
def model_fn(x, t, y=None, rule=None):
# y has to be composer, rule is a dummy input
y_null = th.tensor([args.num_classes] * args.batch_size, device=dist_util.dev())
if args.class_cond:
if args.cfg:
return (1 + args.w) * model(x, t, y) - args.w * model(x, t, y_null)
else:
return model(x, t, y)
else:
return model(x, t, y_null)
# create embed model
embed_model = load_model(args.embed_model_name, args.embed_model_ckpt)
embed_model.to(dist_util.dev())
embed_model.eval()
logger.log("sampling...")
all_images = []
all_labels = []
while len(all_images) * args.batch_size < args.num_samples:
model_kwargs = {}
if args.class_cond:
# only generate one class
classes = th.ones(size=(args.batch_size,), device=dist_util.dev(), dtype=th.int) * 1
# randomly select classes
# classes = th.randint(
# low=0, high=NUM_CLASSES, size=(args.batch_size,), device=dist_util.dev()
# )
model_kwargs["y"] = classes
if args.rule is not None:
if args.rule == 'pitch_hist':
# rule_label = th.zeros(size=(args.batch_size, 12), device=dist_util.dev())
major_profile = th.tensor(
# [6.35, 2.23, 3.48, 2.33, 4.38, 4.09, 2.52, 5.19, 2.39, 3.66, 2.29, 2.88],
[6.35, 0., 0., 0., 4.38, 0., 0., 5.19, 0., 0., 0., 0.],
device=dist_util.dev())
minor_profile = th.tensor(
# [6.33, 2.68, 3.52, 5.38, 2.60, 3.53, 2.54, 4.75, 3.98, 2.69, 3.34, 3.17], # C minor
# [3.17, 6.33, 2.68, 3.52, 5.38, 2.60, 3.53, 2.54, 4.75, 3.98, 2.69, 3.34], # C# minor
[6.33, 0., 0., 5.38, 0., 0., 0., 4.75, 0., 0., 0., 0.],
device=dist_util.dev())
# rule_label_c_maj = major_profile.repeat(args.batch_size//2, 1)
# rule_label_c_min = minor_profile.repeat(args.batch_size//2, 1)
# rule_label = th.concat((rule_label_c_maj, rule_label_c_min), dim=0)
if args.major:
rule_label = major_profile.repeat(args.batch_size, 1)
else:
rule_label = minor_profile.repeat(args.batch_size, 1)
rule_label = rule_label / (th.sum(rule_label, dim=-1, keepdim=True) + 1e-12)
else:
rule_label = None
model_kwargs["rule"] = rule_label
sample_fn = (
diffusion.p_sample_loop if not args.use_ddim else diffusion.ddim_sample_loop
)
if args.rule in ['']: # leave for other category based rules
cond_fn_used = cond_fn_xentropy
elif args.rule in ['pitch_hist', 'note_density']:
cond_fn_used = cond_fn_mse
else:
cond_fn_used = None
sample = sample_fn(
model_fn,
(args.batch_size, args.in_channels, args.image_size[0], args.image_size[1]),
clip_denoised=args.clip_denoised,
model_kwargs=model_kwargs,
device=dist_util.dev(),
cond_fn=cond_fn_used,
progress=True
)
sample = sample / args.scale_factor
image_size_h = args.image_size[0]
image_size_w = args.image_size[1]
if image_size_h != image_size_w: # transposed for raster col
sample = sample.permute(0, 1, 3, 2) # vertical axis means pitch after transpose
save_dir = os.path.join(logger.get_dir(), "generated_samples")
os.makedirs(os.path.expanduser(save_dir), exist_ok=True)
with th.no_grad():
if embed_model is not None:
if image_size_h != image_size_w:
sample = th.chunk(sample, image_size_h // image_size_w, dim=-1) # B x C x H x W
sample = th.concat(sample, dim=0) # 1st second for all batch, 2nd second for all batch, ...
# th.save(sample, save_dir + '/latents.pt') # todo: uncomment if wants to save sampled latents
if args.get_KL:
sample = embed_model.decode(sample)
else:
sample = embed_model.decode_diff(sample)
if image_size_h != image_size_w:
sample = th.concat(th.chunk(sample, image_size_h // image_size_w, dim=0), dim=-1)
sample[sample <= -0.95] = -1. # heuristic thresholding the background
sample = ((sample + 1) * 63.5).clamp(0, 127).to(th.uint8)
# # todo: comment above two lines to test no threshold
# sample[sample <= -1.] = -1.
# sample = (sample + 1) * 63.5
sample = sample.permute(0, 2, 3, 1)
sample = sample.contiguous()
gathered_samples = [th.zeros_like(sample) for _ in range(dist.get_world_size())]
dist.all_gather(gathered_samples, sample) # gather not supported with NCCL
all_images.extend([sample.cpu().numpy() for sample in gathered_samples])
if args.class_cond:
gathered_labels = [
th.zeros_like(classes) for _ in range(dist.get_world_size())
]
dist.all_gather(gathered_labels, classes)
all_labels.extend([labels.cpu().numpy() for labels in gathered_labels])
logger.log(f"created {len(all_images) * args.batch_size} samples")
arr = np.concatenate(all_images, axis=0).squeeze()
arr = arr[: args.num_samples]
# test distance between generated samples and target
generated_samples = th.from_numpy(arr) / 63.5 - 1
loss = []
rule_label = rule_label.cpu()
for i in range(generated_samples.shape[0]):
gen_rule = FUNC_DICT[args.rule](generated_samples[i][None][None])
loss.append(F.mse_loss(rule_label, gen_rule).item())
mean_loss = th.tensor(loss).mean()
print(f"loss: {mean_loss}") # distance between generated rule and target rule
if args.class_cond:
label_arr = np.concatenate(all_labels, axis=0)
label_arr = label_arr[: args.num_samples]
if dist.get_rank() == 0:
if args.class_cond:
midi_util.save_piano_roll_midi(arr, save_dir, args.fs, y=label_arr)
else:
midi_util.save_piano_roll_midi(arr, save_dir, args.fs)
dist.barrier()
logger.log("sampling complete")
def create_argparser():
defaults = dict(
project="music-sampling",
dir="",
model="DiT-B/4", # DiT model names
embed_model_name='kl/f8',
embed_model_ckpt='kl-f8-4-6/steps_42k.ckpt',
classifier_use_fp16=False,
classifier_model='DiTRel-XS/2-cls',
classifier_num_classes=12,
classifier_patchify=True,
classifier_path='loggings/classifier-pitch-hist/ditrel-noise/model004999.pt',
clip_denoised=False,
num_samples=128,
batch_size=16,
use_ddim=False,
model_path="",
get_KL=True,
scale_factor=1.,
patchify=True,
fs=100,
num_classes=0,
cfg=False,
w=4., # for cfg
classifier_scale=1.0,
rule=None,
major=True,
port=None,
)
defaults.update(model_and_diffusion_defaults())
parser = argparse.ArgumentParser()
add_dict_to_argparser(parser, defaults)
return parser
if __name__ == "__main__":
main()
|