Spaces:
Running
on
Zero
Running
on
Zero
# -------------------------------------------------------- | |
# What Matters When Repurposing Diffusion Models for General Dense Perception Tasks? (https://arxiv.org/abs/2403.06090) | |
# Github source: https://github.com/aim-uofa/GenPercept | |
# Copyright (c) 2024, Advanced Intelligent Machines (AIM) | |
# Licensed under The BSD 2-Clause License [see LICENSE for details] | |
# By Guangkai Xu | |
# Based on Marigold, diffusers codebases | |
# https://github.com/prs-eth/marigold | |
# https://github.com/huggingface/diffusers | |
# -------------------------------------------------------- | |
import torch | |
from typing import List, Optional, Tuple, Union | |
import numpy as np | |
from diffusers import DDIMScheduler, DDPMScheduler | |
from diffusers.configuration_utils import ConfigMixin, register_to_config | |
def rescale_zero_terminal_snr(betas): | |
""" | |
Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1) | |
Args: | |
betas (`torch.FloatTensor`): | |
the betas that the scheduler is being initialized with. | |
Returns: | |
`torch.FloatTensor`: rescaled betas with zero terminal SNR | |
""" | |
# Convert betas to alphas_bar_sqrt | |
alphas = 1.0 - betas | |
alphas_cumprod = torch.cumprod(alphas, dim=0) | |
alphas_bar_sqrt = alphas_cumprod.sqrt() | |
# Store old values. | |
alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone() | |
alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone() | |
# Shift so the last timestep is zero. | |
alphas_bar_sqrt -= alphas_bar_sqrt_T | |
# Scale so the first timestep is back to the old value. | |
alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T) | |
# Convert alphas_bar_sqrt to betas | |
alphas_bar = alphas_bar_sqrt**2 # Revert sqrt | |
alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod | |
alphas = torch.cat([alphas_bar[0:1], alphas]) | |
betas = 1 - alphas | |
return betas | |
class DDPMSchedulerCustomized(DDPMScheduler): | |
def __init__( | |
self, | |
num_train_timesteps: int = 1000, | |
beta_start: float = 0.0001, | |
beta_end: float = 0.02, | |
beta_schedule: str = "linear", | |
trained_betas: Optional[Union[np.ndarray, List[float]]] = None, | |
variance_type: str = "fixed_small", | |
clip_sample: bool = True, | |
prediction_type: str = "epsilon", | |
thresholding: bool = False, | |
dynamic_thresholding_ratio: float = 0.995, | |
clip_sample_range: float = 1.0, | |
sample_max_value: float = 1.0, | |
timestep_spacing: str = "leading", | |
steps_offset: int = 0, | |
rescale_betas_zero_snr: int = False, | |
power_beta_curve = 1.0, | |
): | |
if trained_betas is not None: | |
self.betas = torch.tensor(trained_betas, dtype=torch.float32) | |
elif beta_schedule == "linear": | |
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) | |
elif beta_schedule == "scaled_linear": | |
# this schedule is very specific to the latent diffusion model. | |
self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 | |
elif beta_schedule == "scaled_linear_power": | |
self.betas = torch.linspace(beta_start**(1/power_beta_curve), beta_end**(1/power_beta_curve), num_train_timesteps, dtype=torch.float32) ** power_beta_curve | |
elif beta_schedule == "squaredcos_cap_v2": | |
# Glide cosine schedule | |
self.betas = betas_for_alpha_bar(num_train_timesteps) | |
elif beta_schedule == "sigmoid": | |
# GeoDiff sigmoid schedule | |
betas = torch.linspace(-6, 6, num_train_timesteps) | |
self.betas = torch.sigmoid(betas) * (beta_end - beta_start) + beta_start | |
else: | |
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") | |
# Rescale for zero SNR | |
if rescale_betas_zero_snr: | |
self.betas = rescale_zero_terminal_snr(self.betas) | |
self.alphas = 1.0 - self.betas | |
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) | |
self.one = torch.tensor(1.0) | |
# standard deviation of the initial noise distribution | |
self.init_noise_sigma = 1.0 | |
# setable values | |
self.custom_timesteps = False | |
self.num_inference_steps = None | |
self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy()) | |
self.variance_type = variance_type | |
def get_velocity( | |
self, sample: torch.FloatTensor, noise: torch.FloatTensor, timesteps: torch.IntTensor | |
) -> torch.FloatTensor: | |
# Make sure alphas_cumprod and timestep have same device and dtype as sample | |
self.alphas_cumprod = self.alphas_cumprod.to(device=sample.device) | |
alphas_cumprod = self.alphas_cumprod.to(dtype=sample.dtype) | |
timesteps = timesteps.to(sample.device) | |
sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 | |
sqrt_alpha_prod = sqrt_alpha_prod.flatten() | |
while len(sqrt_alpha_prod.shape) < len(sample.shape): | |
sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) | |
sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5 | |
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() | |
while len(sqrt_one_minus_alpha_prod.shape) < len(sample.shape): | |
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) | |
# import pdb | |
# pdb.set_trace() | |
velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample | |
return velocity | |
class DDIMSchedulerCustomized(DDIMScheduler): | |
def __init__( | |
self, | |
num_train_timesteps: int = 1000, | |
beta_start: float = 0.0001, | |
beta_end: float = 0.02, | |
beta_schedule: str = "linear", | |
trained_betas: Optional[Union[np.ndarray, List[float]]] = None, | |
clip_sample: bool = True, | |
set_alpha_to_one: bool = True, | |
steps_offset: int = 0, | |
prediction_type: str = "epsilon", | |
thresholding: bool = False, | |
dynamic_thresholding_ratio: float = 0.995, | |
clip_sample_range: float = 1.0, | |
sample_max_value: float = 1.0, | |
timestep_spacing: str = "leading", | |
rescale_betas_zero_snr: bool = False, | |
power_beta_curve = 1.0, | |
): | |
if trained_betas is not None: | |
self.betas = torch.tensor(trained_betas, dtype=torch.float32) | |
elif beta_schedule == "linear": | |
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) | |
elif beta_schedule == "scaled_linear": | |
# this schedule is very specific to the latent diffusion model. | |
self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 | |
elif beta_schedule == "scaled_linear_power": | |
self.betas = torch.linspace(beta_start**(1/power_beta_curve), beta_end**(1/power_beta_curve), num_train_timesteps, dtype=torch.float32) ** power_beta_curve | |
self.power_beta_curve = power_beta_curve | |
elif beta_schedule == "squaredcos_cap_v2": | |
# Glide cosine schedule | |
self.betas = betas_for_alpha_bar(num_train_timesteps) | |
else: | |
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") | |
# Rescale for zero SNR | |
if rescale_betas_zero_snr: | |
self.betas = rescale_zero_terminal_snr(self.betas) | |
# self.betas = self.betas.double() | |
self.alphas = 1.0 - self.betas | |
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) | |
# At every step in ddim, we are looking into the previous alphas_cumprod | |
# For the final step, there is no previous alphas_cumprod because we are already at 0 | |
# `set_alpha_to_one` decides whether we set this parameter simply to one or | |
# whether we use the final alpha of the "non-previous" one. | |
self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0] | |
# standard deviation of the initial noise distribution | |
self.init_noise_sigma = 1.0 | |
# setable values | |
self.num_inference_steps = None | |
self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy().astype(np.int64)) | |
self.beta_schedule = beta_schedule | |
def _get_variance(self, timestep, prev_timestep): | |
alpha_prod_t = self.alphas_cumprod[timestep] | |
alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod | |
beta_prod_t = 1 - alpha_prod_t | |
beta_prod_t_prev = 1 - alpha_prod_t_prev | |
alpha_t_prev_to_t = self.alphas[(prev_timestep+1):(timestep+1)] | |
alpha_t_prev_to_t = torch.prod(alpha_t_prev_to_t) | |
variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_t_prev_to_t) | |
return variance | |