amildravid4292 commited on
Commit
99172cd
·
verified ·
1 Parent(s): e789a6b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +34 -20
app.py CHANGED
@@ -243,7 +243,25 @@ class main():
243
  @spaces.GPU
244
  def edit_inference(self, prompt, negative_prompt, guidance_scale, ddim_steps, seed, start_noise, a1, a2, a3, a4):
245
  device = self.device
246
- original_weights = self,network.proj.clone()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
247
 
248
  #pad to same number of PCs
249
  pcs_original = original_weights.shape[1]
@@ -256,7 +274,7 @@ class main():
256
 
257
 
258
  edited_weights = original_weights+a1*1e6*young_pad+a2*1e6*pointy_pad+a3*1e6*wavy_pad+a4*2e6*thick_pad
259
-
260
  generator = torch.Generator(device=device).manual_seed(seed)
261
  latents = torch.randn(
262
  (1, self.unet.in_channels, 512 // 8, 512 // 8),
@@ -267,19 +285,19 @@ class main():
267
 
268
  text_input = self.tokenizer(prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt")
269
 
270
- text_embeddings = text_encoder(text_input.input_ids.to(device))[0]
271
 
272
  max_length = text_input.input_ids.shape[-1]
273
- uncond_input = tokenizer(
274
  [negative_prompt], padding="max_length", max_length=max_length, return_tensors="pt"
275
  )
276
- uncond_embeddings = text_encoder(uncond_input.input_ids.to(device))[0]
277
- text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
278
- noise_scheduler.set_timesteps(ddim_steps)
279
- latents = latents * noise_scheduler.init_noise_sigma
280
-
281
 
282
-
283
  for i,t in enumerate(tqdm.tqdm(self.noise_scheduler.timesteps)):
284
  latent_model_input = torch.cat([latents] * 2)
285
  latent_model_input = self.noise_scheduler.scale_model_input(latent_model_input, timestep=t)
@@ -287,11 +305,10 @@ class main():
287
  if t>start_noise:
288
  pass
289
  elif t<=start_noise:
290
- self.network.proj = torch.nn.Parameter(edited_weights)
291
- self.network.reset()
292
-
293
-
294
- with self.network:
295
  noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings, timestep_cond= None).sample
296
 
297
 
@@ -301,16 +318,13 @@ class main():
301
  latents = noise_scheduler.step(noise_pred, t, latents).prev_sample
302
 
303
  latents = 1 / 0.18215 * latents
304
- image = self.vae.decode(latents).sample
305
  image = (image / 2 + 0.5).clamp(0, 1)
306
 
307
  image = image.detach().cpu().float().permute(0, 2, 3, 1).numpy()[0]
308
 
309
  image = Image.fromarray((image * 255).round().astype("uint8"))
310
-
311
- #reset weights back to original
312
- self.network.proj = torch.nn.Parameter(original_weights)
313
- self.network.reset()
314
 
315
  return image
316
 
 
243
  @spaces.GPU
244
  def edit_inference(self, prompt, negative_prompt, guidance_scale, ddim_steps, seed, start_noise, a1, a2, a3, a4):
245
  device = self.device
246
+ self.unet.to(device)
247
+ self.text_encoder.to(device)
248
+ self.vae.to(device)
249
+ self.mean.to(device)
250
+ self.std.to(device)
251
+ self.v.to(device)
252
+ self.proj.to(device)
253
+ self.weights.to(device)
254
+
255
+ network = LoRAw2w( self.weights.bfloat16(), self.mean.bfloat16(), self.std.bfloat16(), self.v[:, :1000].bfloat16(),
256
+ self.unet,
257
+ rank=1,
258
+ multiplier=1.0,
259
+ alpha=27.0,
260
+ train_method="xattn-strict"
261
+ ).to(device, torch.bfloat16)
262
+
263
+
264
+ original_weights = self.weights.clone()
265
 
266
  #pad to same number of PCs
267
  pcs_original = original_weights.shape[1]
 
274
 
275
 
276
  edited_weights = original_weights+a1*1e6*young_pad+a2*1e6*pointy_pad+a3*1e6*wavy_pad+a4*2e6*thick_pad
277
+
278
  generator = torch.Generator(device=device).manual_seed(seed)
279
  latents = torch.randn(
280
  (1, self.unet.in_channels, 512 // 8, 512 // 8),
 
285
 
286
  text_input = self.tokenizer(prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt")
287
 
288
+ text_embeddings = self.text_encoder(text_input.input_ids.to(device))[0]
289
 
290
  max_length = text_input.input_ids.shape[-1]
291
+ uncond_input = self.tokenizer(
292
  [negative_prompt], padding="max_length", max_length=max_length, return_tensors="pt"
293
  )
294
+ uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(device))[0]
295
+ text_embeddings = torch.cat([uncond_embeddings, text_embeddings]).bfloat16()
296
+ self.noise_scheduler.set_timesteps(ddim_steps)
297
+ latents = latents * self.noise_scheduler.init_noise_sigma
298
+
299
 
300
+
301
  for i,t in enumerate(tqdm.tqdm(self.noise_scheduler.timesteps)):
302
  latent_model_input = torch.cat([latents] * 2)
303
  latent_model_input = self.noise_scheduler.scale_model_input(latent_model_input, timestep=t)
 
305
  if t>start_noise:
306
  pass
307
  elif t<=start_noise:
308
+ network.proj = torch.nn.Parameter(edited_weights)
309
+ network.reset()
310
+
311
+ with network:
 
312
  noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings, timestep_cond= None).sample
313
 
314
 
 
318
  latents = noise_scheduler.step(noise_pred, t, latents).prev_sample
319
 
320
  latents = 1 / 0.18215 * latents
321
+ image = self.vae.decode(latents.float()).sample
322
  image = (image / 2 + 0.5).clamp(0, 1)
323
 
324
  image = image.detach().cpu().float().permute(0, 2, 3, 1).numpy()[0]
325
 
326
  image = Image.fromarray((image * 255).round().astype("uint8"))
327
+
 
 
 
328
 
329
  return image
330