|
|
|
''' |
|
sequentially add noise to images |
|
precomputed values used |
|
''' |
|
|
|
import torch.nn.functional as F |
|
import torch |
|
from precomputes import betas, sqrt_recip_alphas, sqrt_alphas_cumulative_products, sqrt_one_minus_alphas_cumulative_products, posterior_variance |
|
|
|
from torchvision import transforms |
|
from matplotlib import pyplot as plt |
|
import numpy as np |
|
|
|
def get_index_from_list(vals, t, x_shape): |
|
batch_size = t.shape[0] |
|
out = vals.gather(-1, t.cpu()) |
|
return out.reshape(batch_size, *((1,)* (len(x_shape) - 1))).to(t.device) |
|
|
|
def forward_diffusion_sample(x_0, t, device="cpu"): |
|
noise = torch.randn_like(x_0) |
|
sqrt_alphas_cumulative_products_t = get_index_from_list(sqrt_alphas_cumulative_products, t, x_0.shape) |
|
sqrt_one_minus_alphas_cumulative_products_t = get_index_from_list( |
|
sqrt_one_minus_alphas_cumulative_products, t, x_0.shape |
|
) |
|
|
|
return sqrt_alphas_cumulative_products_t.to(device) * x_0.to(device) \ |
|
+ sqrt_one_minus_alphas_cumulative_products_t.to(device) * noise.to(device), noise.to(device) |
|
|
|
|
|
@torch.no_grad() |
|
def sample_timestep(model, x, t): |
|
""" |
|
Calls the model to predict the noise in the image and returns |
|
the denoised image. |
|
Applies noise to this image, if we are not in the last step yet. |
|
""" |
|
betas_t = get_index_from_list(betas, t, x.shape) |
|
sqrt_one_minus_alphas_cumulative_products_t = get_index_from_list( |
|
sqrt_one_minus_alphas_cumulative_products, t, x.shape |
|
) |
|
sqrt_recip_alphas_t = get_index_from_list(sqrt_recip_alphas, t, x.shape) |
|
|
|
|
|
model_mean = sqrt_recip_alphas_t * ( |
|
x - betas_t * model(x, t) / sqrt_one_minus_alphas_cumulative_products_t |
|
) |
|
posterior_variance_t = get_index_from_list(posterior_variance, t, x.shape) |
|
|
|
if t == 0: |
|
return model_mean |
|
else: |
|
noise = torch.randn_like(x) |
|
return model_mean + torch.sqrt(posterior_variance_t) * noise |
|
|
|
@torch.no_grad() |
|
def sample_plot_image(model, IMG_SIZE=64, device="cpu", T=300): |
|
|
|
|
|
img_size = IMG_SIZE |
|
img = torch.randn((1, 3, img_size, img_size), device=device) |
|
plt.figure(figsize=(15,15)) |
|
plt.axis('off') |
|
num_images = 10 |
|
stepsize = int(T/num_images) |
|
|
|
for i in range(0,T)[::-1]: |
|
t = torch.full((1,), i, device=device, dtype=torch.long) |
|
img = sample_timestep(img, t) |
|
|
|
img = torch.clamp(img, -1.0, 1.0) |
|
if i % stepsize == 0: |
|
plt.subplot(1, num_images, int(i/stepsize)+1) |
|
show_tensor_image(model, img.detach().cpu()) |
|
|
|
return img |
|
|
|
def show_tensor_image(image): |
|
reverse_transforms = transforms.Compose([ |
|
transforms.Lambda(lambda t: (t + 1) / 2), |
|
transforms.Lambda(lambda t: t.permute(1, 2, 0)), |
|
transforms.Lambda(lambda t: t * 255.), |
|
transforms.Lambda(lambda t: t.numpy().astype(np.uint8)), |
|
transforms.ToPILImage(), |
|
]) |
|
|
|
|
|
if len(image.shape) == 4: |
|
image = image[0, :, :, :] |
|
plt.imshow(reverse_transforms(image)) |