Spaces:
Sleeping
Sleeping
File size: 16,776 Bytes
275a9c1 814c118 129ae32 814c118 ae5e2a7 814c118 297b0e7 814c118 bffedab 814c118 5385ce0 814c118 0e9915e 9eca6d8 814c118 84e1ed4 3f89984 e18412d 81078ec e18412d 3e7e692 feade85 3e7e692 feade85 14d3ba8 feade85 fe83398 10395e0 e3b0acb f5c772b 0857505 e3b0acb f5c772b 2daadda f5c772b 42ec853 e3b0acb ddd5ebc f5c772b 3e7e692 25714ed b55697c 25714ed ddd5ebc d636df2 2c2c652 ddd5ebc afc4094 25714ed 8779074 afc4094 2c2c652 814c118 47027d5 814c118 903ef62 297b0e7 afc4094 48be36f 297b0e7 6bbf1fe 814c118 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 |
import gradio as gr
from typing import Dict, Tuple
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import models, transforms
from torchvision.utils import save_image, make_grid
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation, PillowWriter
import numpy as np
from IPython.display import HTML
from diffusion_utilities import *
from PIL import Image as im
#openai.api_key = os.getenv('OPENAI_API_KEY')
class ContextUnet(nn.Module):
def __init__(self, in_channels, n_feat=256, n_cfeat=10, height=28): # cfeat - context features
super(ContextUnet, self).__init__()
# number of input channels, number of intermediate feature maps and number of classes
self.in_channels = in_channels
self.n_feat = n_feat
self.n_cfeat = n_cfeat
self.h = height #assume h == w. must be divisible by 4, so 28,24,20,16...
# Initialize the initial convolutional layer
self.init_conv = ResidualConvBlock(in_channels, n_feat, is_res=True)
# Initialize the down-sampling path of the U-Net with two levels
self.down1 = UnetDown(n_feat, n_feat) # down1 #[10, 256, 8, 8]
self.down2 = UnetDown(n_feat, 2 * n_feat) # down2 #[10, 256, 4, 4]
# original: self.to_vec = nn.Sequential(nn.AvgPool2d(7), nn.GELU())
self.to_vec = nn.Sequential(nn.AvgPool2d((4)), nn.GELU())
# Embed the timestep and context labels with a one-layer fully connected neural network
self.timeembed1 = EmbedFC(1, 2*n_feat)
self.timeembed2 = EmbedFC(1, 1*n_feat)
self.contextembed1 = EmbedFC(n_cfeat, 2*n_feat)
self.contextembed2 = EmbedFC(n_cfeat, 1*n_feat)
# Initialize the up-sampling path of the U-Net with three levels
self.up0 = nn.Sequential(
nn.ConvTranspose2d(2 * n_feat, 2 * n_feat, self.h//4, self.h//4),
nn.GroupNorm(8, 2 * n_feat), # normalize
nn.ReLU(),
)
self.up1 = UnetUp(4 * n_feat, n_feat)
self.up2 = UnetUp(2 * n_feat, n_feat)
# Initialize the final convolutional layers to map to the same number of channels as the input image
self.out = nn.Sequential(
nn.Conv2d(2 * n_feat, n_feat, 3, 1, 1), # reduce number of feature maps #in_channels, out_channels, kernel_size, stride=1, padding=0
nn.GroupNorm(8, n_feat), # normalize
nn.ReLU(),
nn.Conv2d(n_feat, self.in_channels, 3, 1, 1), # map to same number of channels as input
)
def forward(self, x, t, c=None):
"""
x : (batch, n_feat, h, w) : input image
t : (batch, n_cfeat) : time step
c : (batch, n_classes) : context label
"""
# x is the input image, c is the context label, t is the timestep, context_mask says which samples to block the context on
# pass the input image through the initial convolutional layer
x = self.init_conv(x)
# pass the result through the down-sampling path
down1 = self.down1(x) #[10, 256, 8, 8]
down2 = self.down2(down1) #[10, 256, 4, 4]
# convert the feature maps to a vector and apply an activation
hiddenvec = self.to_vec(down2)
# mask out context if context_mask == 1
if c is None:
c = torch.zeros(x.shape[0], self.n_cfeat).to(x)
# embed context and timestep
cemb1 = self.contextembed1(c).view(-1, self.n_feat * 2, 1, 1) # (batch, 2*n_feat, 1,1)
temb1 = self.timeembed1(t).view(-1, self.n_feat * 2, 1, 1)
cemb2 = self.contextembed2(c).view(-1, self.n_feat, 1, 1)
temb2 = self.timeembed2(t).view(-1, self.n_feat, 1, 1)
#print(f"uunet forward: cemb1 {cemb1.shape}. temb1 {temb1.shape}, cemb2 {cemb2.shape}. temb2 {temb2.shape}")
up1 = self.up0(hiddenvec)
up2 = self.up1(cemb1*up1 + temb1, down2) # add and multiply embeddings
up3 = self.up2(cemb2*up2 + temb2, down1)
out = self.out(torch.cat((up3, x), 1))
return out
# hyperparameters
# diffusion hyperparameters
timesteps = 1000
beta1 = 1e-4
beta2 = 0.02
# network hyperparameters
device = torch.device("cuda:0" if torch.cuda.is_available() else torch.device('cpu'))
n_feat = 64 # 64 hidden dimension feature
n_cfeat = 5 # context vector is of size 5
height = 16 # 16x16 image
save_dir = './weights/'
# training hyperparameters
batch_size = 1000
n_epoch = 512
lrate=1e-3
# construct DDPM noise schedule
b_t = (beta2 - beta1) * torch.linspace(0, 1, timesteps + 1, device=device) + beta1
a_t = 1 - b_t
ab_t = torch.cumsum(a_t.log(), dim=0).exp()
ab_t[0] = 1
# construct model
nn_model = ContextUnet(in_channels=3, n_feat=n_feat, n_cfeat=n_cfeat, height=height).to(device)
# define sampling function for DDIM
# removes the noise using ddim
def denoise_ddim(x, t, t_prev, pred_noise):
ab = ab_t[t]
ab_prev = ab_t[t_prev]
x0_pred = ab_prev.sqrt() / ab.sqrt() * (x - (1 - ab).sqrt() * pred_noise)
dir_xt = (1 - ab_prev).sqrt() * pred_noise
return x0_pred + dir_xt
# load in model weights and set to eval mode
nn_model.load_state_dict(torch.load(f"{save_dir}/model_31.pth", map_location=device))
nn_model.eval()
print("Loaded in Model without context")
# sample quickly using DDIM
@torch.no_grad()
def sample_ddim(n_sample, n=20):
# x_T ~ N(0, 1), sample initial noise
samples = torch.randn(n_sample, 3, height, height).to(device)
# array to keep track of generated steps for plotting
intermediate = []
step_size = timesteps // n
for i in range(timesteps, 0, -step_size):
print(f'sampling timestep {i:3d}', end='\r')
# reshape time tensor
t = torch.tensor([i / timesteps])[:, None, None, None].to(device)
eps = nn_model(samples, t) # predict noise e_(x_t,t)
samples = denoise_ddim(samples, i, i - step_size, eps)
intermediate.append(samples.detach().cpu().numpy())
intermediate = np.stack(intermediate)
return samples, intermediate
# load in model weights and set to eval mode
nn_model.load_state_dict(torch.load(f"{save_dir}/ft_context_model_31.pth", map_location=device))
nn_model.eval()
print("Loaded in Context Model")
# fast sampling algorithm with context
@torch.no_grad()
def sample_ddim_context(n_sample, context, n=20):
# x_T ~ N(0, 1), sample initial noise
samples = torch.randn(n_sample, 3, height, height).to(device)
# array to keep track of generated steps for plotting
intermediate = []
step_size = timesteps // n
for i in range(timesteps, 0, -step_size):
print(f'sampling timestep {i:3d}', end='\r')
# reshape time tensor
t = torch.tensor([i / timesteps])[:, None, None, None].to(device)
eps = nn_model(samples, t, c=context) # predict noise e_(x_t,t)
samples = denoise_ddim(samples, i, i - step_size, eps)
intermediate.append(samples.detach().cpu().numpy())
intermediate = np.stack(intermediate)
return samples, intermediate
# helper function; removes the predicted noise (but adds some noise back in to avoid collapse)
def denoise_add_noise(x, t, pred_noise, z=None):
if z is None:
z = torch.randn_like(x)
noise = b_t.sqrt()[t] * z
mean = (x - pred_noise * ((1 - a_t[t]) / (1 - ab_t[t]).sqrt())) / a_t[t].sqrt()
return mean + noise
# sample using standard algorithm
@torch.no_grad()
def sample_ddpm(n_sample, context, save_rate=20):
# x_T ~ N(0, 1), sample initial noise
samples = torch.randn(n_sample, 3, height, height).to(device)
# array to keep track of generated steps for plotting
intermediate = []
for i in range(timesteps, 0, -1):
print(f'sampling timestep {i:3d}', end='\r')
# reshape time tensor
t = torch.tensor([i / timesteps])[:, None, None, None].to(device)
# sample some random noise to inject back in. For i = 1, don't add back in noise
z = torch.randn_like(samples) if i > 1 else 0
eps = nn_model(samples, t) # predict noise e_(x_t,t)
samples = denoise_add_noise(samples, i, eps, z)
if i % save_rate ==0 or i==timesteps or i<8:
intermediate.append(samples.detach().cpu().numpy())
intermediate = np.stack(intermediate)
return samples, intermediate
@torch.no_grad()
def sample_ddpm_context(n_sample, timesteps, context, save_rate=20):
# x_T ~ N(0, 1), sample initial noise
samples = torch.randn(n_sample, 3, height, height).to(device)
# array to keep track of generated steps for plotting
intermediate = []
for i in range(timesteps, 0, -1):
print(f'sampling timestep {i:3d}', end='\r')
# reshape time tensor
t = torch.tensor([i / timesteps])[:, None, None, None].to(device)
# sample some random noise to inject back in. For i = 1, don't add back in noise
z = torch.randn_like(samples) if i > 1 else 0
eps = nn_model(samples, t, c=context) # predict noise e_(x_t,t)
samples = denoise_add_noise(samples, i, eps, z)
if i % save_rate ==0 or i==timesteps or i<8:
intermediate.append(samples.detach().cpu().numpy())
intermediate = np.stack(intermediate)
return samples, intermediate
def greet(input):
steps = int(input)
image_count = 1;
#ctx = F.one_hot(torch.randint(0, 5, (32,)), 5).to(device=device).float()
# hero, non-hero, food, spell, side-facing
one_hot_enc = np.array([1, 0, 0, 0, 0])
shape = (image_count, 5)
mtx_2d = np.ones(shape) * one_hot_enc
ctx = torch.from_numpy(mtx_2d).to(device=device).float()
#samples, intermediate = sample_ddim_context(32, ctx, n=steps)
samples, intermediate = sample_ddpm_context(image_count, steps, ctx)
# #samples, intermediate = sample_ddim_context(image_count, ctx, steps)
#samples, intermediate = sample_ddim(32, n=steps)
#ctx = F.one_hot(torch.randint(0, 5, (32,)), 5).to(device=device).float()
#samples, intermediate = sample_ddim_context(32, ctx, steps)
#samples, intermediate = sample_ddpm(steps)
#response = transform2(transform(np.moveaxis(samples.detach().cpu().numpy(),1,3)[-1]))
#response2 = transform2(transform(np.moveaxis(samples.detach().cpu().numpy(),1,3)[1]))
#response = im.fromarray(intermediate[24][0][1]).convert("RGB")
# # sx_gen_store = np.moveaxis(intermediate,2,4)
# nsx_gen_store = norm_all(sx_gen_store, sx_gen_store.shape[0], image_count)
#
# response = intermediate.shape;
# response2 = transform2(transform(nsx_gen_store[-1][0]))
#
sx_gen_store = np.moveaxis([samples],2,4)
nsx_gen_store = norm_all(sx_gen_store, sx_gen_store.shape[0], image_count)
response3 = transform2(transform(nsx_gen_store[-1][0]))
# response3 = transform2(transform(nsx_gen_store[-1][1]))
# response4 = transform2(transform(nsx_gen_store[-1][2]))
# response5 = transform2(transform(nsx_gen_store[-1][3]))
# response6 = transform2(transform(nsx_gen_store[-1][4]))
# response7 = transform2(transform(nsx_gen_store[-1][5]))
# response8 = transform2(transform(nsx_gen_store[-1][6]))
# response9 = transform2(transform(nsx_gen_store[-1][7]))
# response10 = transform2(transform(nsx_gen_store[-1][8]))
# response11 = transform2(transform(nsx_gen_store[-1][9]))
# response12 = transform2(transform(nsx_gen_store[-1][10]))
# response13 = transform2(transform(nsx_gen_store[-1][11]))
# response14 = transform2(transform(nsx_gen_store[-1][12]))
# response15 = transform2(transform(nsx_gen_store[-1][13]))
# response16 = transform2(transform(nsx_gen_store[-1][14]))
# response17 = transform2(transform(nsx_gen_store[-1][15]))
# response18 = transform2(transform(nsx_gen_store[-1][16]))
# response19 = transform2(transform(nsx_gen_store[-1][17]))
# response20 = transform2(transform(nsx_gen_store[-1][18]))
# response21 = transform2(transform(nsx_gen_store[-1][19]))
# response22 = transform2(transform(nsx_gen_store[-1][20]))
# response23 = transform2(transform(nsx_gen_store[-1][21]))
# response24 = transform2(transform(nsx_gen_store[-1][22]))
# response25 = transform2(transform(nsx_gen_store[-1][23]))
# response26 = transform2(transform(nsx_gen_store[-1][24]))
# response27 = transform2(transform(nsx_gen_store[-1][25]))
# response28 = transform2(transform(nsx_gen_store[-1][26]))
# response29 = transform2(transform(nsx_gen_store[-1][27]))
# response30= transform2(transform(nsx_gen_store[-1][28]))
# response31 = transform2(transform(nsx_gen_store[-1][29]))
# response32 = transform2(transform(nsx_gen_store[-1][30]))
# response33 = transform2(transform(nsx_gen_store[-1][31]))
#response = intermediate.shape;
#response2 = transform2(transform(np.moveaxis(intermediate,2,4)[0][0]))
#response3 = transform2(transform(np.moveaxis(intermediate,2,4)[int(steps/2)][0]))
#response4 = transform2(transform(np.moveaxis(intermediate,2,4)[int(steps/4)][0]))
#response5 = transform2(transform(np.moveaxis(intermediate,2,4)[-1][0]))
###return response, response2, response3, response4, response5, response6, response7, response8, response9, response10, response11, response12, response13, response14, response15, response16, response17, response18, response19, response20, response21, response22, response23, response24, response25, response26, response27, response28, response29, response30, response31, response32, response33
return response3
transform2 = transforms.ToPILImage()
#iface = gr.Interface(fn=greet, inputs="text", outputs="text")
#iface.launch()
#iface = gr.Interface(fn=greet, inputs=[gr.Textbox(label="Text to find entities", lines=2)], outputs=[gr.HighlightedText(label="Text with entities")], title="NER with dslim/bert-base-NER", description="Find entities using the `dslim/bert-base-NER` model under the hood!", allow_flagging="never", examples=["My name is Andrew and I live in California", "My name is Poli and work at HuggingFace"])
#iface = gr.Interface(fn=greet, inputs=[gr.Textbox(label="Co-Retailing Business")], outputs=[gr.outputs.Image(type="pil", width=64, label="Output Image"), gr.outputs.Image(type="pil", width=64, label="Output Image2"), gr.outputs.Image(type="pil", width=64, label="Output Image3"), gr.outputs.Image(type="pil", width=64, label="Output Image4")])
#iface = gr.Interface(fn=greet, inputs=[gr.Textbox(label="steps", value=20)], outputs=[gr.Textbox(label="Info"), gr.Image(type="pil", width=64, label="Output Image"), gr.Image(type="pil", width=64, label="Output Image2"), gr.Image(type="pil", width=64, label="Output Image3"), gr.Image(type="pil", width=64, label="Output Image4")])
###iface = gr.Interface(fn=greet, inputs=[gr.Textbox(label="steps", value=20)], outputs=[gr.Textbox(label="Info"), gr.Image(type="pil", width=64, label="Output Image"), gr.Image(type="pil", width=64, label="Output Image2"), gr.Image(type="pil", width=64, label="Output Image3"), gr.Image(type="pil", width=64, label="Output Image4"), gr.Image(type="pil", width=64, label="Output Image5"), gr.Image(type="pil", width=64, label="Output Image6"), gr.Image(type="pil", width=64, label="Output Image7"), gr.Image(type="pil", width=64, label="Output Image8"), gr.Image(type="pil", width=64, label="Output Image9"), gr.Image(type="pil", width=64, label="Output Image10"), gr.Image(type="pil", width=64, label="Output Image11"), gr.Image(type="pil", width=64, label="Output Image12"), gr.Image(type="pil", width=64, label="Output Image13"), gr.Image(type="pil", width=64, label="Output Image14"), gr.Image(type="pil", width=64, label="Output Image15"), gr.Image(type="pil", width=64, label="Output Image16"), gr.Image(type="pil", width=64, label="Output Image17"), gr.Image(type="pil", width=64, label="Output Image18"), gr.Image(type="pil", width=64, label="Output Image19"), gr.Image(type="pil", width=64, label="Output Image20"), gr.Image(type="pil", width=64, label="Output Image21"), gr.Image(type="pil", width=64, label="Output Image22"), gr.Image(type="pil", width=64, label="Output Image23"), gr.Image(type="pil", width=64, label="Output Image24"), gr.Image(type="pil", width=64, label="Output Image25"), gr.Image(type="pil", width=64, label="Output Image26"), gr.Image(type="pil", width=64, label="Output Image27"), gr.Image(type="pil", width=64, label="Output Image28"), gr.Image(type="pil", width=64, label="Output Image29"), gr.Image(type="pil", width=64, label="Output Image30"), gr.Image(type="pil", width=64, label="Output Image31"), gr.Image(type="pil", width=64, label="Output Image32")])
iface = gr.Interface(fn=greet, inputs=[gr.Textbox(label="steps", value=20)], outputs=[gr.Image(type="pil", width=64, label="Output Image")])
#iface = gr.Interface(fn=greet, inputs=[gr.Textbox(label="Co-Retailing Business")], outputs=[gr.Textbox()])
iface.launch()
|