spritediffuser / app.py
debisoft's picture
ddpm
ddd5ebc
import gradio as gr
from typing import Dict, Tuple
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import models, transforms
from torchvision.utils import save_image, make_grid
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation, PillowWriter
import numpy as np
from IPython.display import HTML
from diffusion_utilities import *
from PIL import Image as im
#openai.api_key = os.getenv('OPENAI_API_KEY')
class ContextUnet(nn.Module):
def __init__(self, in_channels, n_feat=256, n_cfeat=10, height=28): # cfeat - context features
super(ContextUnet, self).__init__()
# number of input channels, number of intermediate feature maps and number of classes
self.in_channels = in_channels
self.n_feat = n_feat
self.n_cfeat = n_cfeat
self.h = height #assume h == w. must be divisible by 4, so 28,24,20,16...
# Initialize the initial convolutional layer
self.init_conv = ResidualConvBlock(in_channels, n_feat, is_res=True)
# Initialize the down-sampling path of the U-Net with two levels
self.down1 = UnetDown(n_feat, n_feat) # down1 #[10, 256, 8, 8]
self.down2 = UnetDown(n_feat, 2 * n_feat) # down2 #[10, 256, 4, 4]
# original: self.to_vec = nn.Sequential(nn.AvgPool2d(7), nn.GELU())
self.to_vec = nn.Sequential(nn.AvgPool2d((4)), nn.GELU())
# Embed the timestep and context labels with a one-layer fully connected neural network
self.timeembed1 = EmbedFC(1, 2*n_feat)
self.timeembed2 = EmbedFC(1, 1*n_feat)
self.contextembed1 = EmbedFC(n_cfeat, 2*n_feat)
self.contextembed2 = EmbedFC(n_cfeat, 1*n_feat)
# Initialize the up-sampling path of the U-Net with three levels
self.up0 = nn.Sequential(
nn.ConvTranspose2d(2 * n_feat, 2 * n_feat, self.h//4, self.h//4),
nn.GroupNorm(8, 2 * n_feat), # normalize
nn.ReLU(),
)
self.up1 = UnetUp(4 * n_feat, n_feat)
self.up2 = UnetUp(2 * n_feat, n_feat)
# Initialize the final convolutional layers to map to the same number of channels as the input image
self.out = nn.Sequential(
nn.Conv2d(2 * n_feat, n_feat, 3, 1, 1), # reduce number of feature maps #in_channels, out_channels, kernel_size, stride=1, padding=0
nn.GroupNorm(8, n_feat), # normalize
nn.ReLU(),
nn.Conv2d(n_feat, self.in_channels, 3, 1, 1), # map to same number of channels as input
)
def forward(self, x, t, c=None):
"""
x : (batch, n_feat, h, w) : input image
t : (batch, n_cfeat) : time step
c : (batch, n_classes) : context label
"""
# x is the input image, c is the context label, t is the timestep, context_mask says which samples to block the context on
# pass the input image through the initial convolutional layer
x = self.init_conv(x)
# pass the result through the down-sampling path
down1 = self.down1(x) #[10, 256, 8, 8]
down2 = self.down2(down1) #[10, 256, 4, 4]
# convert the feature maps to a vector and apply an activation
hiddenvec = self.to_vec(down2)
# mask out context if context_mask == 1
if c is None:
c = torch.zeros(x.shape[0], self.n_cfeat).to(x)
# embed context and timestep
cemb1 = self.contextembed1(c).view(-1, self.n_feat * 2, 1, 1) # (batch, 2*n_feat, 1,1)
temb1 = self.timeembed1(t).view(-1, self.n_feat * 2, 1, 1)
cemb2 = self.contextembed2(c).view(-1, self.n_feat, 1, 1)
temb2 = self.timeembed2(t).view(-1, self.n_feat, 1, 1)
#print(f"uunet forward: cemb1 {cemb1.shape}. temb1 {temb1.shape}, cemb2 {cemb2.shape}. temb2 {temb2.shape}")
up1 = self.up0(hiddenvec)
up2 = self.up1(cemb1*up1 + temb1, down2) # add and multiply embeddings
up3 = self.up2(cemb2*up2 + temb2, down1)
out = self.out(torch.cat((up3, x), 1))
return out
# hyperparameters
# diffusion hyperparameters
timesteps = 1000
beta1 = 1e-4
beta2 = 0.02
# network hyperparameters
device = torch.device("cuda:0" if torch.cuda.is_available() else torch.device('cpu'))
n_feat = 64 # 64 hidden dimension feature
n_cfeat = 5 # context vector is of size 5
height = 16 # 16x16 image
save_dir = './weights/'
# training hyperparameters
batch_size = 1000
n_epoch = 512
lrate=1e-3
# construct DDPM noise schedule
b_t = (beta2 - beta1) * torch.linspace(0, 1, timesteps + 1, device=device) + beta1
a_t = 1 - b_t
ab_t = torch.cumsum(a_t.log(), dim=0).exp()
ab_t[0] = 1
# construct model
nn_model = ContextUnet(in_channels=3, n_feat=n_feat, n_cfeat=n_cfeat, height=height).to(device)
# define sampling function for DDIM
# removes the noise using ddim
def denoise_ddim(x, t, t_prev, pred_noise):
ab = ab_t[t]
ab_prev = ab_t[t_prev]
x0_pred = ab_prev.sqrt() / ab.sqrt() * (x - (1 - ab).sqrt() * pred_noise)
dir_xt = (1 - ab_prev).sqrt() * pred_noise
return x0_pred + dir_xt
# load in model weights and set to eval mode
nn_model.load_state_dict(torch.load(f"{save_dir}/model_31.pth", map_location=device))
nn_model.eval()
print("Loaded in Model without context")
# sample quickly using DDIM
@torch.no_grad()
def sample_ddim(n_sample, n=20):
# x_T ~ N(0, 1), sample initial noise
samples = torch.randn(n_sample, 3, height, height).to(device)
# array to keep track of generated steps for plotting
intermediate = []
step_size = timesteps // n
for i in range(timesteps, 0, -step_size):
print(f'sampling timestep {i:3d}', end='\r')
# reshape time tensor
t = torch.tensor([i / timesteps])[:, None, None, None].to(device)
eps = nn_model(samples, t) # predict noise e_(x_t,t)
samples = denoise_ddim(samples, i, i - step_size, eps)
intermediate.append(samples.detach().cpu().numpy())
intermediate = np.stack(intermediate)
return samples, intermediate
# load in model weights and set to eval mode
nn_model.load_state_dict(torch.load(f"{save_dir}/ft_context_model_31.pth", map_location=device))
nn_model.eval()
print("Loaded in Context Model")
# fast sampling algorithm with context
@torch.no_grad()
def sample_ddim_context(n_sample, context, n=20):
# x_T ~ N(0, 1), sample initial noise
samples = torch.randn(n_sample, 3, height, height).to(device)
# array to keep track of generated steps for plotting
intermediate = []
step_size = timesteps // n
for i in range(timesteps, 0, -step_size):
print(f'sampling timestep {i:3d}', end='\r')
# reshape time tensor
t = torch.tensor([i / timesteps])[:, None, None, None].to(device)
eps = nn_model(samples, t, c=context) # predict noise e_(x_t,t)
samples = denoise_ddim(samples, i, i - step_size, eps)
intermediate.append(samples.detach().cpu().numpy())
intermediate = np.stack(intermediate)
return samples, intermediate
# helper function; removes the predicted noise (but adds some noise back in to avoid collapse)
def denoise_add_noise(x, t, pred_noise, z=None):
if z is None:
z = torch.randn_like(x)
noise = b_t.sqrt()[t] * z
mean = (x - pred_noise * ((1 - a_t[t]) / (1 - ab_t[t]).sqrt())) / a_t[t].sqrt()
return mean + noise
# sample using standard algorithm
@torch.no_grad()
def sample_ddpm(n_sample, context, save_rate=20):
# x_T ~ N(0, 1), sample initial noise
samples = torch.randn(n_sample, 3, height, height).to(device)
# array to keep track of generated steps for plotting
intermediate = []
for i in range(timesteps, 0, -1):
print(f'sampling timestep {i:3d}', end='\r')
# reshape time tensor
t = torch.tensor([i / timesteps])[:, None, None, None].to(device)
# sample some random noise to inject back in. For i = 1, don't add back in noise
z = torch.randn_like(samples) if i > 1 else 0
eps = nn_model(samples, t) # predict noise e_(x_t,t)
samples = denoise_add_noise(samples, i, eps, z)
if i % save_rate ==0 or i==timesteps or i<8:
intermediate.append(samples.detach().cpu().numpy())
intermediate = np.stack(intermediate)
return samples, intermediate
@torch.no_grad()
def sample_ddpm_context(n_sample, timesteps, context, save_rate=20):
# x_T ~ N(0, 1), sample initial noise
samples = torch.randn(n_sample, 3, height, height).to(device)
# array to keep track of generated steps for plotting
intermediate = []
for i in range(timesteps, 0, -1):
print(f'sampling timestep {i:3d}', end='\r')
# reshape time tensor
t = torch.tensor([i / timesteps])[:, None, None, None].to(device)
# sample some random noise to inject back in. For i = 1, don't add back in noise
z = torch.randn_like(samples) if i > 1 else 0
eps = nn_model(samples, t, c=context) # predict noise e_(x_t,t)
samples = denoise_add_noise(samples, i, eps, z)
if i % save_rate ==0 or i==timesteps or i<8:
intermediate.append(samples.detach().cpu().numpy())
intermediate = np.stack(intermediate)
return samples, intermediate
def greet(input):
steps = int(input)
image_count = 1;
#ctx = F.one_hot(torch.randint(0, 5, (32,)), 5).to(device=device).float()
# hero, non-hero, food, spell, side-facing
one_hot_enc = np.array([1, 0, 0, 0, 0])
shape = (image_count, 5)
mtx_2d = np.ones(shape) * one_hot_enc
ctx = torch.from_numpy(mtx_2d).to(device=device).float()
#samples, intermediate = sample_ddim_context(32, ctx, n=steps)
samples, intermediate = sample_ddpm_context(image_count, steps, ctx)
# #samples, intermediate = sample_ddim_context(image_count, ctx, steps)
#samples, intermediate = sample_ddim(32, n=steps)
#ctx = F.one_hot(torch.randint(0, 5, (32,)), 5).to(device=device).float()
#samples, intermediate = sample_ddim_context(32, ctx, steps)
#samples, intermediate = sample_ddpm(steps)
#response = transform2(transform(np.moveaxis(samples.detach().cpu().numpy(),1,3)[-1]))
#response2 = transform2(transform(np.moveaxis(samples.detach().cpu().numpy(),1,3)[1]))
#response = im.fromarray(intermediate[24][0][1]).convert("RGB")
# # sx_gen_store = np.moveaxis(intermediate,2,4)
# nsx_gen_store = norm_all(sx_gen_store, sx_gen_store.shape[0], image_count)
#
# response = intermediate.shape;
# response2 = transform2(transform(nsx_gen_store[-1][0]))
#
sx_gen_store = np.moveaxis([samples],2,4)
nsx_gen_store = norm_all(sx_gen_store, sx_gen_store.shape[0], image_count)
response3 = transform2(transform(nsx_gen_store[-1][0]))
# response3 = transform2(transform(nsx_gen_store[-1][1]))
# response4 = transform2(transform(nsx_gen_store[-1][2]))
# response5 = transform2(transform(nsx_gen_store[-1][3]))
# response6 = transform2(transform(nsx_gen_store[-1][4]))
# response7 = transform2(transform(nsx_gen_store[-1][5]))
# response8 = transform2(transform(nsx_gen_store[-1][6]))
# response9 = transform2(transform(nsx_gen_store[-1][7]))
# response10 = transform2(transform(nsx_gen_store[-1][8]))
# response11 = transform2(transform(nsx_gen_store[-1][9]))
# response12 = transform2(transform(nsx_gen_store[-1][10]))
# response13 = transform2(transform(nsx_gen_store[-1][11]))
# response14 = transform2(transform(nsx_gen_store[-1][12]))
# response15 = transform2(transform(nsx_gen_store[-1][13]))
# response16 = transform2(transform(nsx_gen_store[-1][14]))
# response17 = transform2(transform(nsx_gen_store[-1][15]))
# response18 = transform2(transform(nsx_gen_store[-1][16]))
# response19 = transform2(transform(nsx_gen_store[-1][17]))
# response20 = transform2(transform(nsx_gen_store[-1][18]))
# response21 = transform2(transform(nsx_gen_store[-1][19]))
# response22 = transform2(transform(nsx_gen_store[-1][20]))
# response23 = transform2(transform(nsx_gen_store[-1][21]))
# response24 = transform2(transform(nsx_gen_store[-1][22]))
# response25 = transform2(transform(nsx_gen_store[-1][23]))
# response26 = transform2(transform(nsx_gen_store[-1][24]))
# response27 = transform2(transform(nsx_gen_store[-1][25]))
# response28 = transform2(transform(nsx_gen_store[-1][26]))
# response29 = transform2(transform(nsx_gen_store[-1][27]))
# response30= transform2(transform(nsx_gen_store[-1][28]))
# response31 = transform2(transform(nsx_gen_store[-1][29]))
# response32 = transform2(transform(nsx_gen_store[-1][30]))
# response33 = transform2(transform(nsx_gen_store[-1][31]))
#response = intermediate.shape;
#response2 = transform2(transform(np.moveaxis(intermediate,2,4)[0][0]))
#response3 = transform2(transform(np.moveaxis(intermediate,2,4)[int(steps/2)][0]))
#response4 = transform2(transform(np.moveaxis(intermediate,2,4)[int(steps/4)][0]))
#response5 = transform2(transform(np.moveaxis(intermediate,2,4)[-1][0]))
###return response, response2, response3, response4, response5, response6, response7, response8, response9, response10, response11, response12, response13, response14, response15, response16, response17, response18, response19, response20, response21, response22, response23, response24, response25, response26, response27, response28, response29, response30, response31, response32, response33
return response3
transform2 = transforms.ToPILImage()
#iface = gr.Interface(fn=greet, inputs="text", outputs="text")
#iface.launch()
#iface = gr.Interface(fn=greet, inputs=[gr.Textbox(label="Text to find entities", lines=2)], outputs=[gr.HighlightedText(label="Text with entities")], title="NER with dslim/bert-base-NER", description="Find entities using the `dslim/bert-base-NER` model under the hood!", allow_flagging="never", examples=["My name is Andrew and I live in California", "My name is Poli and work at HuggingFace"])
#iface = gr.Interface(fn=greet, inputs=[gr.Textbox(label="Co-Retailing Business")], outputs=[gr.outputs.Image(type="pil", width=64, label="Output Image"), gr.outputs.Image(type="pil", width=64, label="Output Image2"), gr.outputs.Image(type="pil", width=64, label="Output Image3"), gr.outputs.Image(type="pil", width=64, label="Output Image4")])
#iface = gr.Interface(fn=greet, inputs=[gr.Textbox(label="steps", value=20)], outputs=[gr.Textbox(label="Info"), gr.Image(type="pil", width=64, label="Output Image"), gr.Image(type="pil", width=64, label="Output Image2"), gr.Image(type="pil", width=64, label="Output Image3"), gr.Image(type="pil", width=64, label="Output Image4")])
###iface = gr.Interface(fn=greet, inputs=[gr.Textbox(label="steps", value=20)], outputs=[gr.Textbox(label="Info"), gr.Image(type="pil", width=64, label="Output Image"), gr.Image(type="pil", width=64, label="Output Image2"), gr.Image(type="pil", width=64, label="Output Image3"), gr.Image(type="pil", width=64, label="Output Image4"), gr.Image(type="pil", width=64, label="Output Image5"), gr.Image(type="pil", width=64, label="Output Image6"), gr.Image(type="pil", width=64, label="Output Image7"), gr.Image(type="pil", width=64, label="Output Image8"), gr.Image(type="pil", width=64, label="Output Image9"), gr.Image(type="pil", width=64, label="Output Image10"), gr.Image(type="pil", width=64, label="Output Image11"), gr.Image(type="pil", width=64, label="Output Image12"), gr.Image(type="pil", width=64, label="Output Image13"), gr.Image(type="pil", width=64, label="Output Image14"), gr.Image(type="pil", width=64, label="Output Image15"), gr.Image(type="pil", width=64, label="Output Image16"), gr.Image(type="pil", width=64, label="Output Image17"), gr.Image(type="pil", width=64, label="Output Image18"), gr.Image(type="pil", width=64, label="Output Image19"), gr.Image(type="pil", width=64, label="Output Image20"), gr.Image(type="pil", width=64, label="Output Image21"), gr.Image(type="pil", width=64, label="Output Image22"), gr.Image(type="pil", width=64, label="Output Image23"), gr.Image(type="pil", width=64, label="Output Image24"), gr.Image(type="pil", width=64, label="Output Image25"), gr.Image(type="pil", width=64, label="Output Image26"), gr.Image(type="pil", width=64, label="Output Image27"), gr.Image(type="pil", width=64, label="Output Image28"), gr.Image(type="pil", width=64, label="Output Image29"), gr.Image(type="pil", width=64, label="Output Image30"), gr.Image(type="pil", width=64, label="Output Image31"), gr.Image(type="pil", width=64, label="Output Image32")])
iface = gr.Interface(fn=greet, inputs=[gr.Textbox(label="steps", value=20)], outputs=[gr.Image(type="pil", width=64, label="Output Image")])
#iface = gr.Interface(fn=greet, inputs=[gr.Textbox(label="Co-Retailing Business")], outputs=[gr.Textbox()])
iface.launch()