|
import argparse |
|
import inspect |
|
|
|
from . import gaussian_diffusion as gd |
|
from .respace import SpacedDiffusion, space_timesteps, SpacedDiffusionDDPM |
|
|
|
def create_gaussian_diffusion( |
|
*, |
|
normalize_input, |
|
schedule_name, |
|
sf=4, |
|
min_noise_level=0.01, |
|
steps=1000, |
|
kappa=1, |
|
etas_end=0.99, |
|
schedule_kwargs=None, |
|
weighted_mse=False, |
|
predict_type='xstart', |
|
timestep_respacing=None, |
|
scale_factor=None, |
|
latent_flag=True, |
|
): |
|
sqrt_etas = gd.get_named_eta_schedule( |
|
schedule_name, |
|
num_diffusion_timesteps=steps, |
|
min_noise_level=min_noise_level, |
|
etas_end=etas_end, |
|
kappa=kappa, |
|
kwargs=schedule_kwargs, |
|
) |
|
if timestep_respacing is None: |
|
timestep_respacing = steps |
|
else: |
|
assert isinstance(timestep_respacing, int) |
|
if predict_type == 'xstart': |
|
model_mean_type = gd.ModelMeanType.START_X |
|
elif predict_type == 'epsilon': |
|
model_mean_type = gd.ModelMeanType.EPSILON |
|
elif predict_type == 'epsilon_scale': |
|
model_mean_type = gd.ModelMeanType.EPSILON_SCALE |
|
elif predict_type == 'residual': |
|
model_mean_type = gd.ModelMeanType.RESIDUAL |
|
else: |
|
raise ValueError(f'Unknown Predicted type: {predict_type}') |
|
return SpacedDiffusion( |
|
use_timesteps=space_timesteps(steps, timestep_respacing), |
|
sqrt_etas=sqrt_etas, |
|
kappa=kappa, |
|
model_mean_type=model_mean_type, |
|
loss_type=gd.LossType.WEIGHTED_MSE if weighted_mse else gd.LossType.MSE, |
|
scale_factor=scale_factor, |
|
normalize_input=normalize_input, |
|
sf=sf, |
|
latent_flag=latent_flag, |
|
) |
|
|
|
def create_gaussian_diffusion_ddpm( |
|
*, |
|
beta_start, |
|
beta_end, |
|
sf=4, |
|
steps=1000, |
|
learn_sigma=False, |
|
sigma_small=False, |
|
noise_schedule="linear", |
|
predict_xstart=False, |
|
timestep_respacing=None, |
|
scale_factor=1.0, |
|
): |
|
betas = gd.get_named_beta_schedule(noise_schedule, steps, beta_start, beta_end) |
|
if timestep_respacing is None: |
|
timestep_respacing = steps |
|
else: |
|
assert isinstance(timestep_respacing, int) |
|
return SpacedDiffusionDDPM( |
|
use_timesteps=space_timesteps(steps, timestep_respacing), |
|
betas=betas, |
|
model_mean_type=( |
|
gd.ModelMeanType.EPSILON if not predict_xstart else gd.ModelMeanType.START_X |
|
), |
|
model_var_type=( |
|
( |
|
gd.ModelVarTypeDDPM.FIXED_LARGE |
|
if not sigma_small |
|
else gd.ModelVarTypeDDPM.FIXED_SMALL |
|
) |
|
if not learn_sigma |
|
else gd.ModelVarTypeDDPM.LEARNED_RANGE |
|
), |
|
scale_factor=scale_factor, |
|
sf=sf, |
|
) |
|
|