Spaces:
Running
on
L4
Running
on
L4
File size: 1,910 Bytes
81ecb2b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 |
# Modified from OpenAI's diffusion repos
# GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
# ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
# IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
from . import gaussian_diffusion as gd
from .respace import SpacedDiffusion, space_timesteps
def create_diffusion(
timestep_respacing,
noise_schedule="linear",
use_kl=False,
sigma_small=False,
parameterization="eps",
learn_sigma=True,
rescale_learned_sigmas=False,
diffusion_steps=1000
):
betas = gd.get_named_beta_schedule(noise_schedule, diffusion_steps)
if use_kl:
loss_type = gd.LossType.RESCALED_KL
elif rescale_learned_sigmas:
loss_type = gd.LossType.RESCALED_MSE
else:
loss_type = gd.LossType.MSE
if timestep_respacing is None or timestep_respacing == "":
timestep_respacing = [diffusion_steps]
if parameterization == "eps":
model_mean_type = gd.ModelMeanType.EPSILON
elif parameterization == "xstart":
model_mean_type = gd.ModelMeanType.START_X
elif parameterization == "v":
model_mean_type = gd.ModelMeanType.VELOCITY
else:
raise NotImplementedError("Model Mean Type {} is not supported!".format(parameterization))
return SpacedDiffusion(
use_timesteps=space_timesteps(diffusion_steps, timestep_respacing),
betas=betas,
model_mean_type=model_mean_type,
model_var_type=(
(
gd.ModelVarType.FIXED_LARGE
if not sigma_small
else gd.ModelVarType.FIXED_SMALL
)
if not learn_sigma
else gd.ModelVarType.LEARNED_RANGE
),
loss_type=loss_type
# rescale_timesteps=rescale_timesteps,
)
|