DiffBIR / model /gaussian_diffusion.py
MultiMatrix's picture
Upload 19 files
5d60839 verified
raw
history blame
4.23 kB
from functools import partial
from typing import Tuple
import torch
from torch import nn
import numpy as np
def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
if schedule == "linear":
betas = (
np.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=np.float64) ** 2
)
elif schedule == "cosine":
timesteps = (
np.arange(n_timestep + 1, dtype=np.float64) / n_timestep + cosine_s
)
alphas = timesteps / (1 + cosine_s) * np.pi / 2
alphas = np.cos(alphas).pow(2)
alphas = alphas / alphas[0]
betas = 1 - alphas[1:] / alphas[:-1]
betas = np.clip(betas, a_min=0, a_max=0.999)
elif schedule == "sqrt_linear":
betas = np.linspace(linear_start, linear_end, n_timestep, dtype=np.float64)
elif schedule == "sqrt":
betas = np.linspace(linear_start, linear_end, n_timestep, dtype=np.float64) ** 0.5
else:
raise ValueError(f"schedule '{schedule}' unknown.")
return betas
def extract_into_tensor(a: torch.Tensor, t: torch.Tensor, x_shape: Tuple[int]) -> torch.Tensor:
b, *_ = t.shape
out = a.gather(-1, t)
return out.reshape(b, *((1,) * (len(x_shape) - 1)))
class Diffusion(nn.Module):
def __init__(
self,
timesteps=1000,
beta_schedule="linear",
loss_type="l2",
linear_start=1e-4,
linear_end=2e-2,
cosine_s=8e-3,
parameterization="eps"
):
super().__init__()
self.num_timesteps = timesteps
self.beta_schedule = beta_schedule
self.linear_start = linear_start
self.linear_end = linear_end
self.cosine_s = cosine_s
assert parameterization in ["eps", "x0", "v"], "currently only supporting 'eps' and 'x0' and 'v'"
self.parameterization = parameterization
self.loss_type = loss_type
betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end,
cosine_s=cosine_s)
alphas = 1. - betas
alphas_cumprod = np.cumprod(alphas, axis=0)
sqrt_alphas_cumprod = np.sqrt(alphas_cumprod)
sqrt_one_minus_alphas_cumprod = np.sqrt(1. - alphas_cumprod)
self.betas = betas
self.register("sqrt_alphas_cumprod", sqrt_alphas_cumprod)
self.register("sqrt_one_minus_alphas_cumprod", sqrt_one_minus_alphas_cumprod)
def register(self, name: str, value: np.ndarray) -> None:
self.register_buffer(name, torch.tensor(value, dtype=torch.float32))
def q_sample(self, x_start, t, noise):
return (
extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
)
def get_v(self, x, noise, t):
return (
extract_into_tensor(self.sqrt_alphas_cumprod, t, x.shape) * noise -
extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x.shape) * x
)
def get_loss(self, pred, target, mean=True):
if self.loss_type == 'l1':
loss = (target - pred).abs()
if mean:
loss = loss.mean()
elif self.loss_type == 'l2':
if mean:
loss = torch.nn.functional.mse_loss(target, pred)
else:
loss = torch.nn.functional.mse_loss(target, pred, reduction='none')
else:
raise NotImplementedError("unknown loss type '{loss_type}'")
return loss
def p_losses(self, model, x_start, t, cond):
noise = torch.randn_like(x_start)
x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
model_output = model(x_noisy, t, cond)
if self.parameterization == "x0":
target = x_start
elif self.parameterization == "eps":
target = noise
elif self.parameterization == "v":
target = self.get_v(x_start, noise, t)
else:
raise NotImplementedError()
loss_simple = self.get_loss(model_output, target, mean=False).mean()
return loss_simple