amildravid4292 commited on
Commit
32ac86d
·
verified ·
1 Parent(s): 5253e77

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +30 -20
app.py CHANGED
@@ -308,14 +308,15 @@ class CustomImageDataset(Dataset):
308
  return image
309
 
310
  @spaces.GPU
311
- def invert(self, image, mask, pcs=10000, epochs=400, weight_decay = 1e-10, lr=1e-1):
312
-
313
- del unet
314
- del network
315
- unet, _, _, _, _ = load_models(device)
 
316
 
317
- proj = torch.zeros(1,pcs).bfloat16().to(device)
318
- network = LoRAw2w( proj, mean, std, v[:, :pcs],
319
  unet,
320
  rank=1,
321
  multiplier=1.0,
@@ -367,18 +368,27 @@ def invert(self, image, mask, pcs=10000, epochs=400, weight_decay = 1e-10, lr=1e
367
  optim.zero_grad()
368
  loss.backward()
369
  optim.step()
370
-
371
- ### return optimized network
372
- return network
 
 
 
 
 
 
 
 
 
373
 
374
 
375
  @spaces.GPU
376
- def run_inversion(dict, pcs, epochs, weight_decay,lr):
377
- print(dict)
378
- print(dict.keys())
379
- init_image = dict["image"].convert("RGB").resize((512, 512))
380
- mask = dict["mask"].convert("RGB").resize((512, 512))
381
- network = invert([init_image], mask, pcs, epochs, weight_decay,lr)
382
  #sample an image
383
  prompt = "sks person"
384
  negative_prompt = "low quality, blurry, unfinished, nudity"
@@ -387,7 +397,7 @@ def run_inversion(dict, pcs, epochs, weight_decay,lr):
387
  steps = 25
388
  image = inference( prompt, negative_prompt, cfg, steps, seed)
389
  torch.save(network.proj, "model.pt" )
390
- return image, "model.pt"
391
 
392
 
393
  @spaces.GPU
@@ -408,7 +418,7 @@ def file_upload(file, net):
408
  cfg = 3.0
409
  steps = 25
410
  image = inference(net, prompt, negative_prompt, cfg, steps, seed)
411
- return net,image
412
 
413
 
414
 
@@ -504,8 +514,8 @@ with gr.Blocks(css="style.css") as demo:
504
 
505
 
506
  invert_button.click(fn=run_inversion,
507
- inputs=[input_image, pcs, epochs, weight_decay,lr],
508
- outputs = [input_image, file_output])
509
 
510
 
511
  sample.click(fn=sample_then_run,inputs = [net], outputs=[net, file_output, input_image])
 
308
  return image
309
 
310
  @spaces.GPU
311
+ def invert(image, mask, pcs=10000, epochs=400, weight_decay = 1e-10, lr=1e-1):
312
+ device = "cuda"
313
+ mean.to(device)
314
+ std.to(device)
315
+ v.to(device)
316
+
317
 
318
+ weights = torch.zeros(1,pcs).bfloat16().to(device)
319
+ network = LoRAw2w( weights, mean, std, v[:, :pcs],
320
  unet,
321
  rank=1,
322
  multiplier=1.0,
 
368
  optim.zero_grad()
369
  loss.backward()
370
  optim.step()
371
+
372
+
373
+ #pad to 10000 PCs
374
+ pcs_original = weights.shape[1]
375
+ padding = torch.zeros((1,10000-pcs_original)).to(device)
376
+ weights = network.proj.detach()
377
+ weights = torch.cat((weights, padding), 1)
378
+
379
+ net = "model_"+str(uuid.uuid4())[:4]+".pt"
380
+ torch.save(weights, net)
381
+
382
+ return net
383
 
384
 
385
  @spaces.GPU
386
+ def run_inversion(net, dict, pcs, epochs, weight_decay,lr):
387
+ init_image = dict["background"].convert("RGB").resize((512, 512))
388
+ mask = dict["layers"][0].convert("RGB").resize((512, 512))
389
+
390
+ net = invert(init_image, mask, pcs, epochs, weight_decay,lr)
391
+
392
  #sample an image
393
  prompt = "sks person"
394
  negative_prompt = "low quality, blurry, unfinished, nudity"
 
397
  steps = 25
398
  image = inference( prompt, negative_prompt, cfg, steps, seed)
399
  torch.save(network.proj, "model.pt" )
400
+ return net, net, image
401
 
402
 
403
  @spaces.GPU
 
418
  cfg = 3.0
419
  steps = 25
420
  image = inference(net, prompt, negative_prompt, cfg, steps, seed)
421
+ return net, image
422
 
423
 
424
 
 
514
 
515
 
516
  invert_button.click(fn=run_inversion,
517
+ inputs=[net, input_image, pcs, epochs, weight_decay,lr],
518
+ outputs = [net, file_output, input_image])
519
 
520
 
521
  sample.click(fn=sample_then_run,inputs = [net], outputs=[net, file_output, input_image])