Spaces:
Sleeping
Sleeping
"""DDPM scheduler for the simplex diffusion model.""" | |
import math | |
from dataclasses import dataclass | |
from typing import Optional, Tuple, Union | |
import numpy as np | |
import torch | |
from diffusers import DDPMScheduler | |
from diffusers.configuration_utils import register_to_config | |
from diffusers.schedulers.scheduling_ddpm import DDPMSchedulerOutput | |
from diffusers.utils import BaseOutput | |
class SimplexDDPMSchedulerOutput(BaseOutput): | |
""" | |
Output class for the scheduler's step function output. | |
Args: | |
prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): | |
Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the | |
denoising loop. | |
projected_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, vocab_size)`): | |
The projected logits sample (x_{0}) based on the model output from the current timestep. | |
""" | |
prev_sample: torch.FloatTensor | |
projected_logits: Optional[torch.FloatTensor] = None | |
def betas_for_alpha_bar( | |
num_diffusion_timesteps, device, max_beta=0.999, improved_ddpm=False | |
): | |
""" | |
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of | |
(1-beta) over time from t = [0,1]. | |
Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up | |
to that part of the diffusion process. | |
Args: | |
num_diffusion_timesteps (`int`): the number of betas to produce. | |
max_beta (`float`): the maximum beta to use; use values lower than 1 to | |
prevent singularities. | |
Returns: | |
betas (`np.ndarray`): the betas used by the scheduler to step the model outputs | |
""" | |
def default_alpha_bar(time_step): | |
return math.cos((time_step + 1e-4) / (1 + 1e-4) * math.pi / 2) ** 2 | |
if improved_ddpm: | |
# Implements eqn. 17 in https://arxiv.org/pdf/2102.09672.pdf. | |
alpha_bar = lambda x: ( # noqa: E731 | |
default_alpha_bar(x) / default_alpha_bar(0.0) | |
) | |
alphas_cumprod = [] | |
else: | |
alpha_bar = default_alpha_bar | |
betas = [] | |
for i in range(num_diffusion_timesteps): | |
t1 = i / num_diffusion_timesteps | |
t2 = (i + 1) / num_diffusion_timesteps | |
alpha_bar_t1 = alpha_bar(t1) | |
betas.append(min(1 - alpha_bar(t2) / alpha_bar_t1, max_beta)) | |
if improved_ddpm: | |
alphas_cumprod.append(alpha_bar_t1) | |
# TODO(rabeeh): maybe this cause memory issue. | |
betas = torch.tensor(betas, dtype=torch.float32, device=device) | |
if improved_ddpm: | |
return betas, torch.tensor( | |
alphas_cumprod, dtype=torch.torch.float32, device=device | |
) | |
return betas | |
class SimplexDDPMScheduler(DDPMScheduler): | |
def __init__( | |
self, | |
device, | |
simplex_value: float, | |
num_train_timesteps: int = 1000, | |
num_inference_timesteps: int = 1000, | |
beta_start: float = 0.0001, | |
beta_end: float = 0.02, | |
beta_schedule: str = "linear", | |
trained_betas: Optional[np.ndarray] = None, | |
variance_type: str = "fixed_small", | |
clip_sample: bool = False, | |
): | |
if trained_betas is not None: | |
self.betas = torch.from_numpy(trained_betas) | |
elif beta_schedule == "linear": | |
self.betas = torch.linspace( | |
beta_start, | |
beta_end, | |
num_train_timesteps, | |
dtype=torch.float32, | |
device=device, | |
) | |
elif beta_schedule == "scaled_linear": | |
# this schedule is very specific to the latent diffusion model. | |
self.betas = ( | |
torch.linspace( | |
beta_start**0.5, | |
beta_end**0.5, | |
num_train_timesteps, | |
dtype=torch.float32, | |
device=device, | |
) | |
** 2 | |
) | |
elif beta_schedule == "squaredcos_cap_v2": | |
# Glide cosine schedule | |
self.betas = betas_for_alpha_bar(num_train_timesteps, device=device) | |
elif beta_schedule == "squaredcos_improved_ddpm": | |
self.betas, self.alphas_cumprod = betas_for_alpha_bar( | |
num_train_timesteps, device=device, improved_ddpm=True | |
) | |
elif beta_schedule == "sigmoid": | |
# GeoDiff sigmoid schedule | |
betas = torch.linspace(-6, 6, num_train_timesteps, device=device) | |
self.betas = torch.sigmoid(betas) * (beta_end - beta_start) + beta_start | |
else: | |
raise NotImplementedError( | |
f"{beta_schedule} does is not implemented for {self.__class__}" | |
) | |
if beta_schedule == "squaredcos_improved_ddpm": | |
self.alphas = None | |
else: | |
self.alphas = 1.0 - self.betas | |
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) | |
self.one = torch.tensor(1.0, device=device) | |
# standard deviation of the initial noise distribution | |
self.init_noise_sigma = 1.0 | |
# setable values | |
self.num_inference_steps = None | |
# TODO(rabeeh): if memory issue, we can not add this to GPU and convert them iteratively. | |
self.timesteps = torch.from_numpy( | |
np.arange(0, num_train_timesteps)[::-1].copy() | |
).to(device=device) | |
self.variance_type = variance_type | |
def step( | |
self, | |
projected_logits: torch.FloatTensor, | |
timestep: int, | |
t_prev: int, # previous timestep. recall we are in backward process, so this is the next timestep. | |
noise: torch.FloatTensor, | |
generator=None, | |
) -> Union[DDPMSchedulerOutput, Tuple]: | |
""" | |
Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion | |
process from the learned model outputs (most often the predicted noise). | |
Args: | |
projected_logits (`torch.FloatTensor`): projected logits from the diffusion model. | |
timestep (`int`): current discrete timestep in the diffusion chain. | |
noise (`torch.FloatTensor`): a random noise with simplex_value standard deviation. | |
generator: random number generator. | |
Returns: | |
[`~schedulers.scheduling_utils.DDPMSchedulerOutput`] resulted values. | |
""" | |
t = timestep[0, 0].item() | |
# 1. compute alphas, betas | |
alpha_prod_t_prev = self.alphas_cumprod[t - 1] if t > 0 else self.one | |
# 3. Clip "predicted x_0" | |
if self.config.clip_sample: | |
projected_logits = torch.clamp(projected_logits, -1, 1) | |
# See algorithm 2 in Figure 3 in https://arxiv.org/pdf/2210.17432.pdf. | |
predicted_logits_coeff = alpha_prod_t_prev ** (0.5) | |
noise_coeff = (1 - alpha_prod_t_prev) ** (0.5) | |
pred_prev_sample = ( | |
predicted_logits_coeff * projected_logits + noise_coeff * noise | |
) | |
return SimplexDDPMSchedulerOutput( | |
prev_sample=pred_prev_sample, projected_logits=projected_logits | |
) | |
def add_noise( | |
self, | |
original_samples: torch.FloatTensor, | |
noise: torch.FloatTensor, | |
timesteps: torch.IntTensor, | |
) -> torch.FloatTensor: | |
timesteps = timesteps.long() | |
# if same shape, we have per-token timesteps | |
if timesteps.shape == noise.shape[:2]: | |
alphas_cumprod_timesteps = self.alphas_cumprod[timesteps][:, :, None] | |
else: | |
alphas_cumprod_timesteps = self.alphas_cumprod[timesteps].view(-1, 1, 1) | |
sqrt_alpha_prod = alphas_cumprod_timesteps**0.5 | |
sqrt_one_minus_alpha_prod = (1 - alphas_cumprod_timesteps) ** 0.5 | |
noisy_samples = ( | |
sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise | |
) | |
return noisy_samples | |
class TokenWiseSimplexDDPMScheduler(DDPMScheduler): | |
def __init__( | |
self, | |
device, | |
simplex_value: float, | |
num_train_timesteps: int = 1000, | |
num_inference_timesteps: int = 1000, | |
beta_start: float = 0.0001, | |
beta_end: float = 0.02, | |
beta_schedule: str = "linear", | |
trained_betas: Optional[np.ndarray] = None, | |
variance_type: str = "fixed_small", | |
clip_sample: bool = False, | |
multiply_factor: float = 1.0, | |
): | |
if trained_betas is not None: | |
self.betas = torch.from_numpy(trained_betas) | |
elif beta_schedule == "linear": | |
self.betas = torch.linspace( | |
beta_start, | |
beta_end, | |
num_train_timesteps, | |
dtype=torch.float32, | |
device=device, | |
) | |
elif beta_schedule == "scaled_linear": | |
# this schedule is very specific to the latent diffusion model. | |
self.betas = ( | |
torch.linspace( | |
beta_start**0.5, | |
beta_end**0.5, | |
num_train_timesteps, | |
dtype=torch.float32, | |
device=device, | |
) | |
** 2 | |
) | |
elif beta_schedule == "squaredcos_cap_v2": | |
# Glide cosine schedule | |
self.betas = betas_for_alpha_bar(num_train_timesteps, device=device) | |
elif beta_schedule == "squaredcos_improved_ddpm": | |
self.betas, self.alphas_cumprod = betas_for_alpha_bar( | |
num_train_timesteps, device=device, improved_ddpm=True | |
) | |
elif beta_schedule == "sigmoid": | |
# GeoDiff sigmoid schedule | |
betas = torch.linspace(-6, 6, num_train_timesteps, device=device) | |
self.betas = torch.sigmoid(betas) * (beta_end - beta_start) + beta_start | |
else: | |
raise NotImplementedError( | |
f"{beta_schedule} does is not implemented for {self.__class__}" | |
) | |
if beta_schedule == "squaredcos_improved_ddpm": | |
self.alphas = None | |
else: | |
self.alphas = 1.0 - self.betas | |
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) | |
self.one = torch.tensor(1.0, device=device) | |
self.alphas_cumprod = multiply_factor * self.alphas_cumprod | |
# standard deviation of the initial noise distribution | |
self.init_noise_sigma = 1.0 | |
# setable values | |
self.num_inference_steps = None | |
# TODO(rabeeh): if memory issue, we can not add this to GPU and convert them iteratively. | |
self.timesteps = torch.from_numpy( | |
np.arange(0, num_train_timesteps)[::-1].copy() | |
).to(device=device) | |
self.variance_type = variance_type | |
def step( | |
self, | |
projected_logits: torch.FloatTensor, | |
timestep: int, | |
t_prev: int, # previous timestep. recall we are in backward process, so this is the next timestep. | |
noise: torch.FloatTensor, | |
generator=None, | |
) -> Union[DDPMSchedulerOutput, Tuple]: | |
""" | |
Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion | |
process from the learned model outputs (most often the predicted noise). | |
Args: | |
projected_logits (`torch.FloatTensor`): projected logits from the diffusion model. | |
timestep (`int`): current discrete timestep in the diffusion chain. | |
noise (`torch.FloatTensor`): a random noise with simplex_value standard deviation. | |
generator: random number generator. | |
Returns: | |
[`~schedulers.scheduling_utils.DDPMSchedulerOutput`] resulted values. | |
""" | |
position_timestep = timestep.long() | |
# 1. compute alphas, betas | |
# index into alphas cumprod | |
alphas_cumprods = [] | |
for i, pos_timestep in enumerate(position_timestep): | |
alphas_cumprods.append( | |
torch.where(pos_timestep > 0, self.alphas_cumprod[t_prev[i]], self.one) | |
) | |
# alphas_cumprods has dim: [batch, positions, timesteps] | |
alpha_prod_t_prev = torch.stack(alphas_cumprods, dim=0)[:, :, None] | |
# current_timesteps = current_timesteps.unsqueeze(-1) | |
# now, we can use gather! | |
# alpha_prod_t_prev = torch.where(current_timesteps > 0, alphas_cumprods.gather(dim=-1, index=current_timesteps.long()), self.one) | |
# alpha_prod_t_prev = self.alphas_cumprod[position_timestep][t - 1] if t > 0 else self.one | |
# alpha_prod_t_prev = alpha_prod_t_prev | |
# 3. Clip "predicted x_0" | |
if self.config.clip_sample: | |
projected_logits = torch.clamp(projected_logits, -1, 1) | |
# See algorithm 2 in Figure 3 in https://arxiv.org/pdf/2210.17432.pdf. | |
predicted_logits_coeff = alpha_prod_t_prev ** (0.5) | |
noise_coeff = (1 - alpha_prod_t_prev) ** (0.5) | |
pred_prev_sample = ( | |
predicted_logits_coeff * projected_logits + noise_coeff * noise | |
) | |
return SimplexDDPMSchedulerOutput( | |
prev_sample=pred_prev_sample, projected_logits=projected_logits | |
) | |
def add_noise( | |
self, | |
original_samples: torch.FloatTensor, | |
noise: torch.FloatTensor, | |
timesteps: torch.IntTensor, | |
) -> torch.FloatTensor: | |
# if same shape, we have per-token timesteps | |
if timesteps.shape == noise.shape[:2]: | |
alphas_cumprod_timesteps = self.alphas_cumprod[timesteps.long()][:, :, None] | |
else: | |
alphas_cumprod_timesteps = self.alphas_cumprod[timesteps.long()].view( | |
-1, 1, 1 | |
) | |
sqrt_alpha_prod = alphas_cumprod_timesteps**0.5 | |
sqrt_one_minus_alpha_prod = (1 - alphas_cumprod_timesteps) ** 0.5 | |
noisy_samples = ( | |
sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise | |
) | |
return noisy_samples | |