|
""" |
|
Based on https://github.com/openai/guided-diffusion/blob/22e0df8183507e13a7813f8d38d51b072ca1e67c/guided_diffusion/gaussian_diffusion.py |
|
""" |
|
|
|
from typing import Any, Dict |
|
|
|
import numpy as np |
|
|
|
from .gaussian_diffusion import ( |
|
GaussianDiffusion, |
|
SpacedDiffusion, |
|
get_named_beta_schedule, |
|
space_timesteps, |
|
) |
|
|
|
BASE_DIFFUSION_CONFIG = { |
|
"channel_biases": [0.0, 0.0, 0.0, -1.0, -1.0, -1.0], |
|
"channel_scales": [2.0, 2.0, 2.0, 0.007843137255, 0.007843137255, 0.007843137255], |
|
"mean_type": "epsilon", |
|
"schedule": "cosine", |
|
"timesteps": 1024, |
|
} |
|
|
|
DIFFUSION_CONFIGS = { |
|
"base40M-imagevec": BASE_DIFFUSION_CONFIG, |
|
"base40M-textvec": BASE_DIFFUSION_CONFIG, |
|
"base40M-uncond": BASE_DIFFUSION_CONFIG, |
|
"base40M": BASE_DIFFUSION_CONFIG, |
|
"base300M": BASE_DIFFUSION_CONFIG, |
|
"base1B": BASE_DIFFUSION_CONFIG, |
|
"upsample": { |
|
"channel_biases": [0.0, 0.0, 0.0, -1.0, -1.0, -1.0], |
|
"channel_scales": [2.0, 2.0, 2.0, 0.007843137255, 0.007843137255, 0.007843137255], |
|
"mean_type": "epsilon", |
|
"schedule": "linear", |
|
"timesteps": 1024, |
|
}, |
|
} |
|
|
|
|
|
def diffusion_from_config(config: Dict[str, Any]) -> GaussianDiffusion: |
|
schedule = config["schedule"] |
|
steps = config["timesteps"] |
|
respace = config.get("respacing", None) |
|
mean_type = config.get("mean_type", "epsilon") |
|
betas = get_named_beta_schedule(schedule, steps) |
|
channel_scales = config.get("channel_scales", None) |
|
channel_biases = config.get("channel_biases", None) |
|
if channel_scales is not None: |
|
channel_scales = np.array(channel_scales) |
|
if channel_biases is not None: |
|
channel_biases = np.array(channel_biases) |
|
kwargs = dict( |
|
betas=betas, |
|
model_mean_type=mean_type, |
|
model_var_type="learned_range", |
|
loss_type="mse", |
|
channel_scales=channel_scales, |
|
channel_biases=channel_biases, |
|
) |
|
if respace is None: |
|
return GaussianDiffusion(**kwargs) |
|
else: |
|
return SpacedDiffusion(use_timesteps=space_timesteps(steps, respace), **kwargs) |
|
|