resshift / models /script_util.py
yuhj95's picture
Upload folder using huggingface_hub
4730cdc verified
raw
history blame
2.74 kB
import argparse
import inspect
from . import gaussian_diffusion as gd
from .respace import SpacedDiffusion, space_timesteps, SpacedDiffusionDDPM
def create_gaussian_diffusion(
*,
normalize_input,
schedule_name,
sf=4,
min_noise_level=0.01,
steps=1000,
kappa=1,
etas_end=0.99,
schedule_kwargs=None,
weighted_mse=False,
predict_type='xstart',
timestep_respacing=None,
scale_factor=None,
latent_flag=True,
):
sqrt_etas = gd.get_named_eta_schedule(
schedule_name,
num_diffusion_timesteps=steps,
min_noise_level=min_noise_level,
etas_end=etas_end,
kappa=kappa,
kwargs=schedule_kwargs,
)
if timestep_respacing is None:
timestep_respacing = steps
else:
assert isinstance(timestep_respacing, int)
if predict_type == 'xstart':
model_mean_type = gd.ModelMeanType.START_X
elif predict_type == 'epsilon':
model_mean_type = gd.ModelMeanType.EPSILON
elif predict_type == 'epsilon_scale':
model_mean_type = gd.ModelMeanType.EPSILON_SCALE
elif predict_type == 'residual':
model_mean_type = gd.ModelMeanType.RESIDUAL
else:
raise ValueError(f'Unknown Predicted type: {predict_type}')
return SpacedDiffusion(
use_timesteps=space_timesteps(steps, timestep_respacing),
sqrt_etas=sqrt_etas,
kappa=kappa,
model_mean_type=model_mean_type,
loss_type=gd.LossType.WEIGHTED_MSE if weighted_mse else gd.LossType.MSE,
scale_factor=scale_factor,
normalize_input=normalize_input,
sf=sf,
latent_flag=latent_flag,
)
def create_gaussian_diffusion_ddpm(
*,
beta_start,
beta_end,
sf=4,
steps=1000,
learn_sigma=False,
sigma_small=False,
noise_schedule="linear",
predict_xstart=False,
timestep_respacing=None,
scale_factor=1.0,
):
betas = gd.get_named_beta_schedule(noise_schedule, steps, beta_start, beta_end)
if timestep_respacing is None:
timestep_respacing = steps
else:
assert isinstance(timestep_respacing, int)
return SpacedDiffusionDDPM(
use_timesteps=space_timesteps(steps, timestep_respacing),
betas=betas,
model_mean_type=(
gd.ModelMeanType.EPSILON if not predict_xstart else gd.ModelMeanType.START_X
),
model_var_type=(
(
gd.ModelVarTypeDDPM.FIXED_LARGE
if not sigma_small
else gd.ModelVarTypeDDPM.FIXED_SMALL
)
if not learn_sigma
else gd.ModelVarTypeDDPM.LEARNED_RANGE
),
scale_factor=scale_factor,
sf=sf,
)