|
from .DiffAE_support_config import * |
|
from .DiffAE_diffusion_diffusion import SpacedDiffusionBeatGans as Sampler |
|
from .DiffAE_model_unet_autoenc import BeatGANsAutoencModel |
|
|
|
from torch.cuda import amp |
|
|
|
|
|
def render_uncondition(conf: TrainConfig, |
|
model: BeatGANsAutoencModel, |
|
x_T, |
|
sampler: Sampler, |
|
latent_sampler: Sampler, |
|
conds_mean=None, |
|
conds_std=None, |
|
clip_latent_noise: bool = False): |
|
device = x_T.device |
|
if conf.train_mode == TrainMode.diffusion: |
|
assert conf.model_type.can_sample() |
|
return sampler.sample(model=model, noise=x_T) |
|
elif conf.train_mode.is_latent_diffusion(): |
|
model: BeatGANsAutoencModel |
|
if conf.train_mode == TrainMode.latent_diffusion: |
|
latent_noise = torch.randn(len(x_T), conf.style_ch, device=device) |
|
else: |
|
raise NotImplementedError() |
|
|
|
if clip_latent_noise: |
|
latent_noise = latent_noise.clip(-1, 1) |
|
|
|
cond = latent_sampler.sample( |
|
model=model.latent_net, |
|
noise=latent_noise, |
|
clip_denoised=conf.latent_clip_sample, |
|
) |
|
|
|
if conf.latent_znormalize: |
|
cond = cond * conds_std.to(device) + conds_mean.to(device) |
|
|
|
|
|
return sampler.sample(model=model, noise=x_T, cond=cond) |
|
else: |
|
raise NotImplementedError() |
|
|
|
|
|
def render_condition( |
|
conf: TrainConfig, |
|
model: BeatGANsAutoencModel, |
|
x_T, |
|
sampler: Sampler, |
|
x_start=None, |
|
cond=None, |
|
): |
|
if conf.train_mode == TrainMode.diffusion: |
|
assert conf.model_type.has_autoenc() |
|
|
|
if cond is None: |
|
cond = model.encode(x_start) |
|
return sampler.sample(model=model, |
|
noise=x_T, |
|
model_kwargs={'cond': cond}) |
|
else: |
|
raise NotImplementedError() |
|
|