|
from functools import partial |
|
|
|
import torch |
|
|
|
from opensora.registry import SCHEDULERS |
|
|
|
from . import gaussian_diffusion as gd |
|
from .respace import SpacedDiffusion, space_timesteps |
|
|
|
|
|
@SCHEDULERS.register_module("iddpm") |
|
class IDDPM(SpacedDiffusion): |
|
def __init__( |
|
self, |
|
num_sampling_steps=None, |
|
timestep_respacing=None, |
|
noise_schedule="linear", |
|
use_kl=False, |
|
sigma_small=False, |
|
predict_xstart=False, |
|
learn_sigma=True, |
|
rescale_learned_sigmas=False, |
|
diffusion_steps=1000, |
|
cfg_scale=4.0, |
|
): |
|
betas = gd.get_named_beta_schedule(noise_schedule, diffusion_steps) |
|
if use_kl: |
|
loss_type = gd.LossType.RESCALED_KL |
|
elif rescale_learned_sigmas: |
|
loss_type = gd.LossType.RESCALED_MSE |
|
else: |
|
loss_type = gd.LossType.MSE |
|
if num_sampling_steps is not None: |
|
assert timestep_respacing is None |
|
timestep_respacing = str(num_sampling_steps) |
|
if timestep_respacing is None or timestep_respacing == "": |
|
timestep_respacing = [diffusion_steps] |
|
super().__init__( |
|
use_timesteps=space_timesteps(diffusion_steps, timestep_respacing), |
|
betas=betas, |
|
model_mean_type=(gd.ModelMeanType.EPSILON if not predict_xstart else gd.ModelMeanType.START_X), |
|
model_var_type=( |
|
(gd.ModelVarType.FIXED_LARGE if not sigma_small else gd.ModelVarType.FIXED_SMALL) |
|
if not learn_sigma |
|
else gd.ModelVarType.LEARNED_RANGE |
|
), |
|
loss_type=loss_type, |
|
|
|
) |
|
|
|
self.cfg_scale = cfg_scale |
|
|
|
def sample( |
|
self, |
|
model, |
|
text_encoder, |
|
z_size, |
|
prompts, |
|
device, |
|
additional_args=None, |
|
): |
|
n = len(prompts) |
|
z = torch.randn(n, *z_size, device=device) |
|
z = torch.cat([z, z], 0) |
|
model_args = text_encoder.encode(prompts) |
|
y_null = text_encoder.null(n) |
|
model_args["y"] = torch.cat([model_args["y"], y_null], 0) |
|
if additional_args is not None: |
|
model_args.update(additional_args) |
|
|
|
forward = partial(forward_with_cfg, model, cfg_scale=self.cfg_scale) |
|
samples = self.p_sample_loop( |
|
forward, |
|
z.shape, |
|
z, |
|
clip_denoised=False, |
|
model_kwargs=model_args, |
|
progress=True, |
|
device=device, |
|
) |
|
samples, _ = samples.chunk(2, dim=0) |
|
return samples |
|
|
|
|
|
def forward_with_cfg(model, x, timestep, y, cfg_scale, **kwargs): |
|
|
|
half = x[: len(x) // 2] |
|
combined = torch.cat([half, half], dim=0) |
|
model_out = model.forward(combined, timestep, y, **kwargs) |
|
model_out = model_out["x"] if isinstance(model_out, dict) else model_out |
|
eps, rest = model_out[:, :3], model_out[:, 3:] |
|
cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0) |
|
half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps) |
|
eps = torch.cat([half_eps, half_eps], dim=0) |
|
return torch.cat([eps, rest], dim=1) |
|
|