import gradio as gr from typing import Dict, Tuple from tqdm import tqdm import torch import torch.nn as nn import torch.nn.functional as F from torch.utils.data import DataLoader from torchvision import models, transforms from torchvision.utils import save_image, make_grid import matplotlib.pyplot as plt from matplotlib.animation import FuncAnimation, PillowWriter import numpy as np from IPython.display import HTML from diffusion_utilities import * from PIL import Image as im #openai.api_key = os.getenv('OPENAI_API_KEY') class ContextUnet(nn.Module): def __init__(self, in_channels, n_feat=256, n_cfeat=10, height=28): # cfeat - context features super(ContextUnet, self).__init__() # number of input channels, number of intermediate feature maps and number of classes self.in_channels = in_channels self.n_feat = n_feat self.n_cfeat = n_cfeat self.h = height #assume h == w. must be divisible by 4, so 28,24,20,16... # Initialize the initial convolutional layer self.init_conv = ResidualConvBlock(in_channels, n_feat, is_res=True) # Initialize the down-sampling path of the U-Net with two levels self.down1 = UnetDown(n_feat, n_feat) # down1 #[10, 256, 8, 8] self.down2 = UnetDown(n_feat, 2 * n_feat) # down2 #[10, 256, 4, 4] # original: self.to_vec = nn.Sequential(nn.AvgPool2d(7), nn.GELU()) self.to_vec = nn.Sequential(nn.AvgPool2d((4)), nn.GELU()) # Embed the timestep and context labels with a one-layer fully connected neural network self.timeembed1 = EmbedFC(1, 2*n_feat) self.timeembed2 = EmbedFC(1, 1*n_feat) self.contextembed1 = EmbedFC(n_cfeat, 2*n_feat) self.contextembed2 = EmbedFC(n_cfeat, 1*n_feat) # Initialize the up-sampling path of the U-Net with three levels self.up0 = nn.Sequential( nn.ConvTranspose2d(2 * n_feat, 2 * n_feat, self.h//4, self.h//4), nn.GroupNorm(8, 2 * n_feat), # normalize nn.ReLU(), ) self.up1 = UnetUp(4 * n_feat, n_feat) self.up2 = UnetUp(2 * n_feat, n_feat) # Initialize the final convolutional layers to map to the same number of channels as the input image self.out = nn.Sequential( nn.Conv2d(2 * n_feat, n_feat, 3, 1, 1), # reduce number of feature maps #in_channels, out_channels, kernel_size, stride=1, padding=0 nn.GroupNorm(8, n_feat), # normalize nn.ReLU(), nn.Conv2d(n_feat, self.in_channels, 3, 1, 1), # map to same number of channels as input ) def forward(self, x, t, c=None): """ x : (batch, n_feat, h, w) : input image t : (batch, n_cfeat) : time step c : (batch, n_classes) : context label """ # x is the input image, c is the context label, t is the timestep, context_mask says which samples to block the context on # pass the input image through the initial convolutional layer x = self.init_conv(x) # pass the result through the down-sampling path down1 = self.down1(x) #[10, 256, 8, 8] down2 = self.down2(down1) #[10, 256, 4, 4] # convert the feature maps to a vector and apply an activation hiddenvec = self.to_vec(down2) # mask out context if context_mask == 1 if c is None: c = torch.zeros(x.shape[0], self.n_cfeat).to(x) # embed context and timestep cemb1 = self.contextembed1(c).view(-1, self.n_feat * 2, 1, 1) # (batch, 2*n_feat, 1,1) temb1 = self.timeembed1(t).view(-1, self.n_feat * 2, 1, 1) cemb2 = self.contextembed2(c).view(-1, self.n_feat, 1, 1) temb2 = self.timeembed2(t).view(-1, self.n_feat, 1, 1) #print(f"uunet forward: cemb1 {cemb1.shape}. temb1 {temb1.shape}, cemb2 {cemb2.shape}. temb2 {temb2.shape}") up1 = self.up0(hiddenvec) up2 = self.up1(cemb1*up1 + temb1, down2) # add and multiply embeddings up3 = self.up2(cemb2*up2 + temb2, down1) out = self.out(torch.cat((up3, x), 1)) return out # hyperparameters # diffusion hyperparameters timesteps = 1000 beta1 = 1e-4 beta2 = 0.02 # network hyperparameters device = torch.device("cuda:0" if torch.cuda.is_available() else torch.device('cpu')) n_feat = 64 # 64 hidden dimension feature n_cfeat = 5 # context vector is of size 5 height = 16 # 16x16 image save_dir = './weights/' # training hyperparameters batch_size = 1000 n_epoch = 512 lrate=1e-3 # construct DDPM noise schedule b_t = (beta2 - beta1) * torch.linspace(0, 1, timesteps + 1, device=device) + beta1 a_t = 1 - b_t ab_t = torch.cumsum(a_t.log(), dim=0).exp() ab_t[0] = 1 # construct model nn_model = ContextUnet(in_channels=3, n_feat=n_feat, n_cfeat=n_cfeat, height=height).to(device) # define sampling function for DDIM # removes the noise using ddim def denoise_ddim(x, t, t_prev, pred_noise): ab = ab_t[t] ab_prev = ab_t[t_prev] x0_pred = ab_prev.sqrt() / ab.sqrt() * (x - (1 - ab).sqrt() * pred_noise) dir_xt = (1 - ab_prev).sqrt() * pred_noise return x0_pred + dir_xt # load in model weights and set to eval mode nn_model.load_state_dict(torch.load(f"{save_dir}/model_31.pth", map_location=device)) nn_model.eval() print("Loaded in Model without context") # sample quickly using DDIM @torch.no_grad() def sample_ddim(n_sample, n=20): # x_T ~ N(0, 1), sample initial noise samples = torch.randn(n_sample, 3, height, height).to(device) # array to keep track of generated steps for plotting intermediate = [] step_size = timesteps // n for i in range(timesteps, 0, -step_size): print(f'sampling timestep {i:3d}', end='\r') # reshape time tensor t = torch.tensor([i / timesteps])[:, None, None, None].to(device) eps = nn_model(samples, t) # predict noise e_(x_t,t) samples = denoise_ddim(samples, i, i - step_size, eps) intermediate.append(samples.detach().cpu().numpy()) intermediate = np.stack(intermediate) return samples, intermediate # load in model weights and set to eval mode nn_model.load_state_dict(torch.load(f"{save_dir}/ft_context_model_31.pth", map_location=device)) nn_model.eval() print("Loaded in Context Model") # fast sampling algorithm with context @torch.no_grad() def sample_ddim_context(n_sample, context, n=20): # x_T ~ N(0, 1), sample initial noise samples = torch.randn(n_sample, 3, height, height).to(device) # array to keep track of generated steps for plotting intermediate = [] step_size = timesteps // n for i in range(timesteps, 0, -step_size): print(f'sampling timestep {i:3d}', end='\r') # reshape time tensor t = torch.tensor([i / timesteps])[:, None, None, None].to(device) eps = nn_model(samples, t, c=context) # predict noise e_(x_t,t) samples = denoise_ddim(samples, i, i - step_size, eps) intermediate.append(samples.detach().cpu().numpy()) intermediate = np.stack(intermediate) return samples, intermediate # helper function; removes the predicted noise (but adds some noise back in to avoid collapse) def denoise_add_noise(x, t, pred_noise, z=None): if z is None: z = torch.randn_like(x) noise = b_t.sqrt()[t] * z mean = (x - pred_noise * ((1 - a_t[t]) / (1 - ab_t[t]).sqrt())) / a_t[t].sqrt() return mean + noise # sample using standard algorithm @torch.no_grad() def sample_ddpm(n_sample, context, save_rate=20): # x_T ~ N(0, 1), sample initial noise samples = torch.randn(n_sample, 3, height, height).to(device) # array to keep track of generated steps for plotting intermediate = [] for i in range(timesteps, 0, -1): print(f'sampling timestep {i:3d}', end='\r') # reshape time tensor t = torch.tensor([i / timesteps])[:, None, None, None].to(device) # sample some random noise to inject back in. For i = 1, don't add back in noise z = torch.randn_like(samples) if i > 1 else 0 eps = nn_model(samples, t) # predict noise e_(x_t,t) samples = denoise_add_noise(samples, i, eps, z) if i % save_rate ==0 or i==timesteps or i<8: intermediate.append(samples.detach().cpu().numpy()) intermediate = np.stack(intermediate) return samples, intermediate @torch.no_grad() def sample_ddpm_context(n_sample, timesteps, context, save_rate=20): # x_T ~ N(0, 1), sample initial noise samples = torch.randn(n_sample, 3, height, height).to(device) # array to keep track of generated steps for plotting intermediate = [] for i in range(timesteps, 0, -1): print(f'sampling timestep {i:3d}', end='\r') # reshape time tensor t = torch.tensor([i / timesteps])[:, None, None, None].to(device) # sample some random noise to inject back in. For i = 1, don't add back in noise z = torch.randn_like(samples) if i > 1 else 0 eps = nn_model(samples, t, c=context) # predict noise e_(x_t,t) samples = denoise_add_noise(samples, i, eps, z) if i % save_rate ==0 or i==timesteps or i<8: intermediate.append(samples.detach().cpu().numpy()) intermediate = np.stack(intermediate) return samples, intermediate def greet(input): steps = int(input) image_count = 1; #ctx = F.one_hot(torch.randint(0, 5, (32,)), 5).to(device=device).float() # hero, non-hero, food, spell, side-facing one_hot_enc = np.array([1, 0, 0, 0, 0]) shape = (image_count, 5) mtx_2d = np.ones(shape) * one_hot_enc ctx = torch.from_numpy(mtx_2d).to(device=device).float() #samples, intermediate = sample_ddim_context(32, ctx, n=steps) samples, intermediate = sample_ddpm_context(image_count, steps, ctx) # #samples, intermediate = sample_ddim_context(image_count, ctx, steps) #samples, intermediate = sample_ddim(32, n=steps) #ctx = F.one_hot(torch.randint(0, 5, (32,)), 5).to(device=device).float() #samples, intermediate = sample_ddim_context(32, ctx, steps) #samples, intermediate = sample_ddpm(steps) #response = transform2(transform(np.moveaxis(samples.detach().cpu().numpy(),1,3)[-1])) #response2 = transform2(transform(np.moveaxis(samples.detach().cpu().numpy(),1,3)[1])) #response = im.fromarray(intermediate[24][0][1]).convert("RGB") # # sx_gen_store = np.moveaxis(intermediate,2,4) # nsx_gen_store = norm_all(sx_gen_store, sx_gen_store.shape[0], image_count) # # response = intermediate.shape; # response2 = transform2(transform(nsx_gen_store[-1][0])) # sx_gen_store = np.moveaxis([samples],2,4) nsx_gen_store = norm_all(sx_gen_store, sx_gen_store.shape[0], image_count) response3 = transform2(transform(nsx_gen_store[-1][0])) # response3 = transform2(transform(nsx_gen_store[-1][1])) # response4 = transform2(transform(nsx_gen_store[-1][2])) # response5 = transform2(transform(nsx_gen_store[-1][3])) # response6 = transform2(transform(nsx_gen_store[-1][4])) # response7 = transform2(transform(nsx_gen_store[-1][5])) # response8 = transform2(transform(nsx_gen_store[-1][6])) # response9 = transform2(transform(nsx_gen_store[-1][7])) # response10 = transform2(transform(nsx_gen_store[-1][8])) # response11 = transform2(transform(nsx_gen_store[-1][9])) # response12 = transform2(transform(nsx_gen_store[-1][10])) # response13 = transform2(transform(nsx_gen_store[-1][11])) # response14 = transform2(transform(nsx_gen_store[-1][12])) # response15 = transform2(transform(nsx_gen_store[-1][13])) # response16 = transform2(transform(nsx_gen_store[-1][14])) # response17 = transform2(transform(nsx_gen_store[-1][15])) # response18 = transform2(transform(nsx_gen_store[-1][16])) # response19 = transform2(transform(nsx_gen_store[-1][17])) # response20 = transform2(transform(nsx_gen_store[-1][18])) # response21 = transform2(transform(nsx_gen_store[-1][19])) # response22 = transform2(transform(nsx_gen_store[-1][20])) # response23 = transform2(transform(nsx_gen_store[-1][21])) # response24 = transform2(transform(nsx_gen_store[-1][22])) # response25 = transform2(transform(nsx_gen_store[-1][23])) # response26 = transform2(transform(nsx_gen_store[-1][24])) # response27 = transform2(transform(nsx_gen_store[-1][25])) # response28 = transform2(transform(nsx_gen_store[-1][26])) # response29 = transform2(transform(nsx_gen_store[-1][27])) # response30= transform2(transform(nsx_gen_store[-1][28])) # response31 = transform2(transform(nsx_gen_store[-1][29])) # response32 = transform2(transform(nsx_gen_store[-1][30])) # response33 = transform2(transform(nsx_gen_store[-1][31])) #response = intermediate.shape; #response2 = transform2(transform(np.moveaxis(intermediate,2,4)[0][0])) #response3 = transform2(transform(np.moveaxis(intermediate,2,4)[int(steps/2)][0])) #response4 = transform2(transform(np.moveaxis(intermediate,2,4)[int(steps/4)][0])) #response5 = transform2(transform(np.moveaxis(intermediate,2,4)[-1][0])) ###return response, response2, response3, response4, response5, response6, response7, response8, response9, response10, response11, response12, response13, response14, response15, response16, response17, response18, response19, response20, response21, response22, response23, response24, response25, response26, response27, response28, response29, response30, response31, response32, response33 return response3 transform2 = transforms.ToPILImage() #iface = gr.Interface(fn=greet, inputs="text", outputs="text") #iface.launch() #iface = gr.Interface(fn=greet, inputs=[gr.Textbox(label="Text to find entities", lines=2)], outputs=[gr.HighlightedText(label="Text with entities")], title="NER with dslim/bert-base-NER", description="Find entities using the `dslim/bert-base-NER` model under the hood!", allow_flagging="never", examples=["My name is Andrew and I live in California", "My name is Poli and work at HuggingFace"]) #iface = gr.Interface(fn=greet, inputs=[gr.Textbox(label="Co-Retailing Business")], outputs=[gr.outputs.Image(type="pil", width=64, label="Output Image"), gr.outputs.Image(type="pil", width=64, label="Output Image2"), gr.outputs.Image(type="pil", width=64, label="Output Image3"), gr.outputs.Image(type="pil", width=64, label="Output Image4")]) #iface = gr.Interface(fn=greet, inputs=[gr.Textbox(label="steps", value=20)], outputs=[gr.Textbox(label="Info"), gr.Image(type="pil", width=64, label="Output Image"), gr.Image(type="pil", width=64, label="Output Image2"), gr.Image(type="pil", width=64, label="Output Image3"), gr.Image(type="pil", width=64, label="Output Image4")]) ###iface = gr.Interface(fn=greet, inputs=[gr.Textbox(label="steps", value=20)], outputs=[gr.Textbox(label="Info"), gr.Image(type="pil", width=64, label="Output Image"), gr.Image(type="pil", width=64, label="Output Image2"), gr.Image(type="pil", width=64, label="Output Image3"), gr.Image(type="pil", width=64, label="Output Image4"), gr.Image(type="pil", width=64, label="Output Image5"), gr.Image(type="pil", width=64, label="Output Image6"), gr.Image(type="pil", width=64, label="Output Image7"), gr.Image(type="pil", width=64, label="Output Image8"), gr.Image(type="pil", width=64, label="Output Image9"), gr.Image(type="pil", width=64, label="Output Image10"), gr.Image(type="pil", width=64, label="Output Image11"), gr.Image(type="pil", width=64, label="Output Image12"), gr.Image(type="pil", width=64, label="Output Image13"), gr.Image(type="pil", width=64, label="Output Image14"), gr.Image(type="pil", width=64, label="Output Image15"), gr.Image(type="pil", width=64, label="Output Image16"), gr.Image(type="pil", width=64, label="Output Image17"), gr.Image(type="pil", width=64, label="Output Image18"), gr.Image(type="pil", width=64, label="Output Image19"), gr.Image(type="pil", width=64, label="Output Image20"), gr.Image(type="pil", width=64, label="Output Image21"), gr.Image(type="pil", width=64, label="Output Image22"), gr.Image(type="pil", width=64, label="Output Image23"), gr.Image(type="pil", width=64, label="Output Image24"), gr.Image(type="pil", width=64, label="Output Image25"), gr.Image(type="pil", width=64, label="Output Image26"), gr.Image(type="pil", width=64, label="Output Image27"), gr.Image(type="pil", width=64, label="Output Image28"), gr.Image(type="pil", width=64, label="Output Image29"), gr.Image(type="pil", width=64, label="Output Image30"), gr.Image(type="pil", width=64, label="Output Image31"), gr.Image(type="pil", width=64, label="Output Image32")]) iface = gr.Interface(fn=greet, inputs=[gr.Textbox(label="steps", value=20)], outputs=[gr.Image(type="pil", width=64, label="Output Image")]) #iface = gr.Interface(fn=greet, inputs=[gr.Textbox(label="Co-Retailing Business")], outputs=[gr.Textbox()]) iface.launch()