Spaces:
Running
on
Zero
Running
on
Zero
import math | |
import torch | |
from einops import rearrange | |
from tqdm import tqdm | |
from .utils import get_tensor_items, exist | |
import numpy as np | |
def get_named_beta_schedule(schedule_name, timesteps): | |
if schedule_name == "linear": | |
scale = 1000 / timesteps | |
beta_start = scale * 0.0001 | |
beta_end = scale * 0.02 | |
return torch.linspace( | |
beta_start, beta_end, timesteps, dtype=torch.float32 | |
) | |
elif schedule_name == "cosine": | |
alpha_bar = lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2 | |
betas = [] | |
for i in range(timesteps): | |
t1 = i / timesteps | |
t2 = (i + 1) / timesteps | |
betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), 0.999)) | |
return torch.tensor(betas, dtype=torch.float32) | |
class BaseDiffusion: | |
def __init__(self, betas, percentile=None, gen_noise=torch.randn_like): | |
self.betas = betas | |
self.num_timesteps = betas.shape[0] | |
alphas = 1. - betas | |
self.alphas_cumprod = torch.cumprod(alphas, dim=0) | |
self.alphas_cumprod_prev = torch.cat([torch.ones(1, dtype=betas.dtype), self.alphas_cumprod[:-1]]) | |
# calculate q(x_t | x_{t-1}) | |
self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod) | |
self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - self.alphas_cumprod) | |
# calculate q(x_{t-1} | x_t, x_0) | |
self.posterior_mean_coef_1 = (torch.sqrt(self.alphas_cumprod_prev) * betas / (1. - self.alphas_cumprod)) | |
self.posterior_mean_coef_2 = (torch.sqrt(alphas) * (1. - self.alphas_cumprod_prev) / (1. - self.alphas_cumprod)) | |
self.posterior_variance = betas * (1. - self.alphas_cumprod_prev) / (1. - self.alphas_cumprod) | |
self.posterior_log_variance = (torch.log( | |
torch.cat([self.posterior_variance[1].unsqueeze(0), self.posterior_variance[1:]]) | |
)) | |
self.percentile = percentile | |
self.time_scale = 1000 // self.num_timesteps | |
self.gen_noise = gen_noise | |
def q_sample(self, x_start, t, noise=None): | |
if noise is None: | |
noise = self.gen_noise(x_start) | |
sqrt_alphas_cumprod = get_tensor_items(self.sqrt_alphas_cumprod, t, x_start.shape) | |
sqrt_one_minus_alphas_cumprod = get_tensor_items(self.sqrt_one_minus_alphas_cumprod, t, noise.shape) | |
x_t = sqrt_alphas_cumprod * x_start + sqrt_one_minus_alphas_cumprod * noise | |
return x_t | |
def p_sample_loop( | |
self, model, shape, device, dtype, lowres_img, times=[979, 729, 479, 229] | |
): | |
img = torch.randn(*shape, device=device).to(dtype=dtype) | |
times = times + [0,] | |
times = list(zip(times[:-1], times[1:])) | |
for time, prev_time in tqdm(times): | |
time = torch.tensor([time] * shape[0], device=device) | |
x_t = self.q_sample(img, time) | |
img = model(x_t.to(dtype), time.to(dtype), lowres_img=lowres_img.to(dtype)) | |
return img | |
def refine(self, model, img, **large_model_kwargs): | |
for time in tqdm([729, 479, 229]): | |
time = torch.tensor([time,] * img.shape[0], device=img.device) | |
x_t = self.q_sample(img, time) | |
img = model(x_t, time.type(x_t.dtype), **large_model_kwargs) | |
return img | |
def get_diffusion(conf): | |
betas = get_named_beta_schedule(**conf.schedule_params) | |
base_diffusion = BaseDiffusion(betas, **conf.diffusion_params) | |
return base_diffusion |