amildravid4292 commited on
Commit
9935195
·
verified ·
1 Parent(s): d169e6d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +178 -178
app.py CHANGED
@@ -88,166 +88,166 @@ class main():
88
  thick = debias(thick, "Heavy_Makeup", df, pinverse, device)
89
  self.thick = thick
90
 
91
- def sample_model(self):
92
- self.unet, _, _, _, _ = load_models(self.device)
93
- self.network = sample_weights(self.unet, self.proj, self.mean, self.std, self.v[:, :1000], self.device, factor = 1.00)
94
 
95
 
96
- @torch.no_grad()
97
- @spaces.GPU
98
- def inference(self, prompt, negative_prompt, guidance_scale, ddim_steps, seed):
99
- generator = torch.Generator(device=device).manual_seed(seed)
100
- latents = torch.randn(
101
  (1, self.unet.in_channels, 512 // 8, 512 // 8),
102
  generator = generator,
103
  device = self.device
104
  ).bfloat16()
105
 
106
 
107
- text_input = self.tokenizer(prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt")
108
 
109
- text_embeddings = self.text_encoder(text_input.input_ids.to(device))[0]
110
 
111
- max_length = text_input.input_ids.shape[-1]
112
- uncond_input = self.tokenizer(
113
  [negative_prompt], padding="max_length", max_length=max_length, return_tensors="pt"
114
  )
115
- uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(device))[0]
116
- text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
117
- self.noise_scheduler.set_timesteps(ddim_steps)
118
- latents = latents * self.noise_scheduler.init_noise_sigma
119
 
120
- for i,t in enumerate(tqdm.tqdm(self.noise_scheduler.timesteps)):
121
- latent_model_input = torch.cat([latents] * 2)
122
- latent_model_input = self.noise_scheduler.scale_model_input(latent_model_input, timestep=t)
123
- with self.network:
124
- noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings, timestep_cond= None).sample
125
  #guidance
126
- noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
127
- noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
128
- latents = noise_scheduler.step(noise_pred, t, latents).prev_sample
129
 
130
- latents = 1 / 0.18215 * latents
131
- image = self.vae.decode(latents).sample
132
- image = (image / 2 + 0.5).clamp(0, 1)
133
- image = image.detach().cpu().float().permute(0, 2, 3, 1).numpy()[0]
134
 
135
- image = Image.fromarray((image * 255).round().astype("uint8"))
136
 
137
- return image
138
 
139
 
140
- @torch.no_grad()
141
- @spaces.GPU
142
- def edit_inference(self, prompt, negative_prompt, guidance_scale, ddim_steps, seed, start_noise, a1, a2, a3, a4):
143
- original_weights = self,network.proj.clone()
144
 
145
- #pad to same number of PCs
146
- pcs_original = original_weights.shape[1]
147
- pcs_edits = self.young.shape[1]
148
- padding = torch.zeros((1,pcs_original-pcs_edits)).to(device)
149
- young_pad = torch.cat((self.young, padding), 1)
150
- pointy_pad = torch.cat((self.pointy, padding), 1)
151
- wavy_pad = torch.cat((self.wavy, padding), 1)
152
- thick_pad = torch.cat((self.thick, padding), 1)
153
 
154
 
155
- edited_weights = original_weights+a1*1e6*young_pad+a2*1e6*pointy_pad+a3*1e6*wavy_pad+a4*2e6*thick_pad
156
 
157
- generator = torch.Generator(device=device).manual_seed(seed)
158
- latents = torch.randn(
159
  (1, self.unet.in_channels, 512 // 8, 512 // 8),
160
  generator = generator,
161
  device = self.device
162
  ).bfloat16()
163
 
164
 
165
- text_input = self.tokenizer(prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt")
166
 
167
- text_embeddings = text_encoder(text_input.input_ids.to(device))[0]
168
 
169
- max_length = text_input.input_ids.shape[-1]
170
- uncond_input = tokenizer(
171
  [negative_prompt], padding="max_length", max_length=max_length, return_tensors="pt"
172
  )
173
- uncond_embeddings = text_encoder(uncond_input.input_ids.to(device))[0]
174
- text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
175
- noise_scheduler.set_timesteps(ddim_steps)
176
- latents = latents * noise_scheduler.init_noise_sigma
177
 
178
 
179
 
180
- for i,t in enumerate(tqdm.tqdm(self.noise_scheduler.timesteps)):
181
- latent_model_input = torch.cat([latents] * 2)
182
- latent_model_input = self.noise_scheduler.scale_model_input(latent_model_input, timestep=t)
183
 
184
- if t>start_noise:
185
- pass
186
- elif t<=start_noise:
187
- self.network.proj = torch.nn.Parameter(edited_weights)
188
- self.network.reset()
189
 
190
 
191
- with self.network:
192
- noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings, timestep_cond= None).sample
193
 
194
 
195
- #guidance
196
- noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
197
- noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
198
- latents = noise_scheduler.step(noise_pred, t, latents).prev_sample
199
 
200
- latents = 1 / 0.18215 * latents
201
- image = self.vae.decode(latents).sample
202
- image = (image / 2 + 0.5).clamp(0, 1)
203
 
204
- image = image.detach().cpu().float().permute(0, 2, 3, 1).numpy()[0]
205
 
206
- image = Image.fromarray((image * 255).round().astype("uint8"))
207
 
208
- #reset weights back to original
209
- self.network.proj = torch.nn.Parameter(original_weights)
210
- self.network.reset()
211
 
212
- return image
213
 
214
- @spaces.GPU
215
- def sample_then_run(self):
216
- sample_model()
217
- prompt = "sks person"
218
- negative_prompt = "low quality, blurry, unfinished, nudity, weapon"
219
- seed = 5
220
- cfg = 3.0
221
- steps = 25
222
- image = inference( prompt, negative_prompt, cfg, steps, seed)
223
- torch.save(self.network.proj, "model.pt" )
224
- return image, "model.pt"
225
 
226
 
227
 
228
- class CustomImageDataset(Dataset):
229
- def __init__(self, images, transform=None):
230
- self.images = images
231
- self.transform = transform
232
-
233
- def __len__(self):
234
- return len(self.images)
235
-
236
- def __getitem__(self, idx):
237
- image = self.images[idx]
238
- if self.transform:
239
- image = self.transform(image)
240
- return image
241
-
242
- @spaces.GPU
243
- def invert(self, image, mask, pcs=10000, epochs=400, weight_decay = 1e-10, lr=1e-1):
244
 
245
- del unet
246
- del network
247
- unet, _, _, _, _ = load_models(device)
248
 
249
- proj = torch.zeros(1,pcs).bfloat16().to(device)
250
- network = LoRAw2w( proj, mean, std, v[:, :pcs],
251
  unet,
252
  rank=1,
253
  multiplier=1.0,
@@ -255,87 +255,87 @@ class main():
255
  train_method="xattn-strict"
256
  ).to(device, torch.bfloat16)
257
 
258
- ### load mask
259
- mask = transforms.Resize((64,64), interpolation=transforms.InterpolationMode.BILINEAR)(mask)
260
- mask = torchvision.transforms.functional.pil_to_tensor(mask).unsqueeze(0).to(device).bfloat16()[:,0,:,:].unsqueeze(1)
261
- ### check if an actual mask was draw, otherwise mask is just all ones
262
- if torch.sum(mask) == 0:
263
- mask = torch.ones((1,1,64,64)).to(device).bfloat16()
264
 
265
- ### single image dataset
266
- image_transforms = transforms.Compose([transforms.Resize(512, interpolation=transforms.InterpolationMode.BILINEAR),
267
  transforms.RandomCrop(512),
268
  transforms.ToTensor(),
269
  transforms.Normalize([0.5], [0.5])])
270
 
271
 
272
- train_dataset = CustomImageDataset(image, transform=image_transforms)
273
- train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=1, shuffle=True)
274
 
275
- ### optimizer
276
- optim = torch.optim.Adam(network.parameters(), lr=lr, weight_decay=weight_decay)
277
 
278
- ### training loop
279
- unet.train()
280
- for epoch in tqdm.tqdm(range(epochs)):
281
- for batch in train_dataloader:
282
- ### prepare inputs
283
- batch = batch.to(device).bfloat16()
284
- latents = vae.encode(batch).latent_dist.sample()
285
- latents = latents*0.18215
286
- noise = torch.randn_like(latents)
287
- bsz = latents.shape[0]
288
 
289
- timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
290
- timesteps = timesteps.long()
291
- noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
292
- text_input = tokenizer("sks person", padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt")
293
- text_embeddings = text_encoder(text_input.input_ids.to(device))[0]
294
-
295
- ### loss + sgd step
296
- with network:
297
- model_pred = unet(noisy_latents, timesteps, text_embeddings).sample
298
- loss = torch.nn.functional.mse_loss(mask*model_pred.float(), mask*noise.float(), reduction="mean")
299
- optim.zero_grad()
300
- loss.backward()
301
- optim.step()
302
-
303
- ### return optimized network
304
- return network
305
 
306
 
307
- @spaces.GPU
308
- def run_inversion(self, dict, pcs, epochs, weight_decay,lr):
309
- init_image = dict["image"].convert("RGB").resize((512, 512))
310
- mask = dict["mask"].convert("RGB").resize((512, 512))
311
- network = invert([init_image], mask, pcs, epochs, weight_decay,lr)
312
-
313
-
314
- #sample an image
315
- prompt = "sks person"
316
- negative_prompt = "low quality, blurry, unfinished, nudity"
317
- seed = 5
318
- cfg = 3.0
319
- steps = 25
320
- image = inference( prompt, negative_prompt, cfg, steps, seed)
321
- torch.save(network.proj, "model.pt" )
322
- return image, "model.pt"
323
 
324
 
325
- @spaces.GPU
326
- def file_upload(self, file):
327
 
328
- proj = torch.load(file.name).to(device)
329
 
330
- #pad to 10000 Principal components to keep everything consistent
331
- pcs = proj.shape[1]
332
- padding = torch.zeros((1,10000-pcs)).to(device)
333
- proj = torch.cat((proj, padding), 1)
334
 
335
- unet, _, _, _, _ = load_models(device)
336
 
337
 
338
- network = LoRAw2w( proj, mean, std, v[:, :10000],
339
  unet,
340
  rank=1,
341
  multiplier=1.0,
@@ -344,13 +344,13 @@ class main():
344
  ).to(device, torch.bfloat16)
345
 
346
 
347
- prompt = "sks person"
348
- negative_prompt = "low quality, blurry, unfinished, nudity"
349
- seed = 5
350
- cfg = 3.0
351
- steps = 25
352
- image = inference( prompt, negative_prompt, cfg, steps, seed)
353
- return image
354
 
355
 
356
 
 
88
  thick = debias(thick, "Heavy_Makeup", df, pinverse, device)
89
  self.thick = thick
90
 
91
+ def sample_model(self):
92
+ self.unet, _, _, _, _ = load_models(self.device)
93
+ self.network = sample_weights(self.unet, self.proj, self.mean, self.std, self.v[:, :1000], self.device, factor = 1.00)
94
 
95
 
96
+ @torch.no_grad()
97
+ @spaces.GPU
98
+ def inference(self, prompt, negative_prompt, guidance_scale, ddim_steps, seed):
99
+ generator = torch.Generator(device=device).manual_seed(seed)
100
+ latents = torch.randn(
101
  (1, self.unet.in_channels, 512 // 8, 512 // 8),
102
  generator = generator,
103
  device = self.device
104
  ).bfloat16()
105
 
106
 
107
+ text_input = self.tokenizer(prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt")
108
 
109
+ text_embeddings = self.text_encoder(text_input.input_ids.to(device))[0]
110
 
111
+ max_length = text_input.input_ids.shape[-1]
112
+ uncond_input = self.tokenizer(
113
  [negative_prompt], padding="max_length", max_length=max_length, return_tensors="pt"
114
  )
115
+ uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(device))[0]
116
+ text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
117
+ self.noise_scheduler.set_timesteps(ddim_steps)
118
+ latents = latents * self.noise_scheduler.init_noise_sigma
119
 
120
+ for i,t in enumerate(tqdm.tqdm(self.noise_scheduler.timesteps)):
121
+ latent_model_input = torch.cat([latents] * 2)
122
+ latent_model_input = self.noise_scheduler.scale_model_input(latent_model_input, timestep=t)
123
+ with self.network:
124
+ noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings, timestep_cond= None).sample
125
  #guidance
126
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
127
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
128
+ latents = noise_scheduler.step(noise_pred, t, latents).prev_sample
129
 
130
+ latents = 1 / 0.18215 * latents
131
+ image = self.vae.decode(latents).sample
132
+ image = (image / 2 + 0.5).clamp(0, 1)
133
+ image = image.detach().cpu().float().permute(0, 2, 3, 1).numpy()[0]
134
 
135
+ image = Image.fromarray((image * 255).round().astype("uint8"))
136
 
137
+ return image
138
 
139
 
140
+ @torch.no_grad()
141
+ @spaces.GPU
142
+ def edit_inference(self, prompt, negative_prompt, guidance_scale, ddim_steps, seed, start_noise, a1, a2, a3, a4):
143
+ original_weights = self,network.proj.clone()
144
 
145
+ #pad to same number of PCs
146
+ pcs_original = original_weights.shape[1]
147
+ pcs_edits = self.young.shape[1]
148
+ padding = torch.zeros((1,pcs_original-pcs_edits)).to(device)
149
+ young_pad = torch.cat((self.young, padding), 1)
150
+ pointy_pad = torch.cat((self.pointy, padding), 1)
151
+ wavy_pad = torch.cat((self.wavy, padding), 1)
152
+ thick_pad = torch.cat((self.thick, padding), 1)
153
 
154
 
155
+ edited_weights = original_weights+a1*1e6*young_pad+a2*1e6*pointy_pad+a3*1e6*wavy_pad+a4*2e6*thick_pad
156
 
157
+ generator = torch.Generator(device=device).manual_seed(seed)
158
+ latents = torch.randn(
159
  (1, self.unet.in_channels, 512 // 8, 512 // 8),
160
  generator = generator,
161
  device = self.device
162
  ).bfloat16()
163
 
164
 
165
+ text_input = self.tokenizer(prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt")
166
 
167
+ text_embeddings = text_encoder(text_input.input_ids.to(device))[0]
168
 
169
+ max_length = text_input.input_ids.shape[-1]
170
+ uncond_input = tokenizer(
171
  [negative_prompt], padding="max_length", max_length=max_length, return_tensors="pt"
172
  )
173
+ uncond_embeddings = text_encoder(uncond_input.input_ids.to(device))[0]
174
+ text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
175
+ noise_scheduler.set_timesteps(ddim_steps)
176
+ latents = latents * noise_scheduler.init_noise_sigma
177
 
178
 
179
 
180
+ for i,t in enumerate(tqdm.tqdm(self.noise_scheduler.timesteps)):
181
+ latent_model_input = torch.cat([latents] * 2)
182
+ latent_model_input = self.noise_scheduler.scale_model_input(latent_model_input, timestep=t)
183
 
184
+ if t>start_noise:
185
+ pass
186
+ elif t<=start_noise:
187
+ self.network.proj = torch.nn.Parameter(edited_weights)
188
+ self.network.reset()
189
 
190
 
191
+ with self.network:
192
+ noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings, timestep_cond= None).sample
193
 
194
 
195
+ #guidance
196
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
197
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
198
+ latents = noise_scheduler.step(noise_pred, t, latents).prev_sample
199
 
200
+ latents = 1 / 0.18215 * latents
201
+ image = self.vae.decode(latents).sample
202
+ image = (image / 2 + 0.5).clamp(0, 1)
203
 
204
+ image = image.detach().cpu().float().permute(0, 2, 3, 1).numpy()[0]
205
 
206
+ image = Image.fromarray((image * 255).round().astype("uint8"))
207
 
208
+ #reset weights back to original
209
+ self.network.proj = torch.nn.Parameter(original_weights)
210
+ self.network.reset()
211
 
212
+ return image
213
 
214
+ @spaces.GPU
215
+ def sample_then_run(self):
216
+ sample_model()
217
+ prompt = "sks person"
218
+ negative_prompt = "low quality, blurry, unfinished, nudity, weapon"
219
+ seed = 5
220
+ cfg = 3.0
221
+ steps = 25
222
+ image = inference( prompt, negative_prompt, cfg, steps, seed)
223
+ torch.save(self.network.proj, "model.pt" )
224
+ return image, "model.pt"
225
 
226
 
227
 
228
+ class CustomImageDataset(Dataset):
229
+ def __init__(self, images, transform=None):
230
+ self.images = images
231
+ self.transform = transform
232
+
233
+ def __len__(self):
234
+ return len(self.images)
235
+
236
+ def __getitem__(self, idx):
237
+ image = self.images[idx]
238
+ if self.transform:
239
+ image = self.transform(image)
240
+ return image
241
+
242
+ @spaces.GPU
243
+ def invert(self, image, mask, pcs=10000, epochs=400, weight_decay = 1e-10, lr=1e-1):
244
 
245
+ del unet
246
+ del network
247
+ unet, _, _, _, _ = load_models(device)
248
 
249
+ proj = torch.zeros(1,pcs).bfloat16().to(device)
250
+ network = LoRAw2w( proj, mean, std, v[:, :pcs],
251
  unet,
252
  rank=1,
253
  multiplier=1.0,
 
255
  train_method="xattn-strict"
256
  ).to(device, torch.bfloat16)
257
 
258
+ ### load mask
259
+ mask = transforms.Resize((64,64), interpolation=transforms.InterpolationMode.BILINEAR)(mask)
260
+ mask = torchvision.transforms.functional.pil_to_tensor(mask).unsqueeze(0).to(device).bfloat16()[:,0,:,:].unsqueeze(1)
261
+ ### check if an actual mask was draw, otherwise mask is just all ones
262
+ if torch.sum(mask) == 0:
263
+ mask = torch.ones((1,1,64,64)).to(device).bfloat16()
264
 
265
+ ### single image dataset
266
+ image_transforms = transforms.Compose([transforms.Resize(512, interpolation=transforms.InterpolationMode.BILINEAR),
267
  transforms.RandomCrop(512),
268
  transforms.ToTensor(),
269
  transforms.Normalize([0.5], [0.5])])
270
 
271
 
272
+ train_dataset = CustomImageDataset(image, transform=image_transforms)
273
+ train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=1, shuffle=True)
274
 
275
+ ### optimizer
276
+ optim = torch.optim.Adam(network.parameters(), lr=lr, weight_decay=weight_decay)
277
 
278
+ ### training loop
279
+ unet.train()
280
+ for epoch in tqdm.tqdm(range(epochs)):
281
+ for batch in train_dataloader:
282
+ ### prepare inputs
283
+ batch = batch.to(device).bfloat16()
284
+ latents = vae.encode(batch).latent_dist.sample()
285
+ latents = latents*0.18215
286
+ noise = torch.randn_like(latents)
287
+ bsz = latents.shape[0]
288
 
289
+ timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
290
+ timesteps = timesteps.long()
291
+ noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
292
+ text_input = tokenizer("sks person", padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt")
293
+ text_embeddings = text_encoder(text_input.input_ids.to(device))[0]
294
+
295
+ ### loss + sgd step
296
+ with network:
297
+ model_pred = unet(noisy_latents, timesteps, text_embeddings).sample
298
+ loss = torch.nn.functional.mse_loss(mask*model_pred.float(), mask*noise.float(), reduction="mean")
299
+ optim.zero_grad()
300
+ loss.backward()
301
+ optim.step()
302
+
303
+ ### return optimized network
304
+ return network
305
 
306
 
307
+ @spaces.GPU
308
+ def run_inversion(self, dict, pcs, epochs, weight_decay,lr):
309
+ init_image = dict["image"].convert("RGB").resize((512, 512))
310
+ mask = dict["mask"].convert("RGB").resize((512, 512))
311
+ network = invert([init_image], mask, pcs, epochs, weight_decay,lr)
312
+
313
+
314
+ #sample an image
315
+ prompt = "sks person"
316
+ negative_prompt = "low quality, blurry, unfinished, nudity"
317
+ seed = 5
318
+ cfg = 3.0
319
+ steps = 25
320
+ image = inference( prompt, negative_prompt, cfg, steps, seed)
321
+ torch.save(network.proj, "model.pt" )
322
+ return image, "model.pt"
323
 
324
 
325
+ @spaces.GPU
326
+ def file_upload(self, file):
327
 
328
+ proj = torch.load(file.name).to(device)
329
 
330
+ #pad to 10000 Principal components to keep everything consistent
331
+ pcs = proj.shape[1]
332
+ padding = torch.zeros((1,10000-pcs)).to(device)
333
+ proj = torch.cat((proj, padding), 1)
334
 
335
+ unet, _, _, _, _ = load_models(device)
336
 
337
 
338
+ network = LoRAw2w( proj, mean, std, v[:, :10000],
339
  unet,
340
  rank=1,
341
  multiplier=1.0,
 
344
  ).to(device, torch.bfloat16)
345
 
346
 
347
+ prompt = "sks person"
348
+ negative_prompt = "low quality, blurry, unfinished, nudity"
349
+ seed = 5
350
+ cfg = 3.0
351
+ steps = 25
352
+ image = inference( prompt, negative_prompt, cfg, steps, seed)
353
+ return image
354
 
355
 
356