amildravid4292 commited on
Commit
791ac08
·
verified ·
1 Parent(s): 7152f14

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +76 -68
app.py CHANGED
@@ -22,6 +22,8 @@ from sampling import sample_weights
22
  from lora_w2w import LoRAw2w
23
  from huggingface_hub import snapshot_download
24
  import numpy as np
 
 
25
  global device
26
  global generator
27
  global unet
@@ -59,9 +61,8 @@ def sample_model():
59
  unet, _, _, _, _ = load_models(device)
60
  network = sample_weights(unet, proj, mean, std, v[:, :1000], device, factor = 1.00)
61
 
62
-
63
  @torch.no_grad()
64
- def inference(prompt, negative_prompt, guidance_scale, ddim_steps, seed):
65
  global device
66
  global generator
67
  global unet
@@ -110,7 +111,6 @@ def inference(prompt, negative_prompt, guidance_scale, ddim_steps, seed):
110
  return image
111
 
112
 
113
-
114
  @torch.no_grad()
115
  def edit_inference(prompt, negative_prompt, guidance_scale, ddim_steps, seed, start_noise, a1, a2, a3, a4):
116
 
@@ -199,7 +199,7 @@ def edit_inference(prompt, negative_prompt, guidance_scale, ddim_steps, seed, st
199
  network.reset()
200
 
201
  return (original_image, image)
202
-
203
  def sample_then_run():
204
  global original_image
205
  sample_model()
@@ -210,6 +210,7 @@ def sample_then_run():
210
  steps = 50
211
  original_image = inference( prompt, negative_prompt, cfg, steps, seed)
212
  torch.save(network.proj, "model.pt" )
 
213
 
214
  return (original_image, original_image), "model.pt"
215
 
@@ -273,11 +274,15 @@ class CustomImageDataset(Dataset):
273
  if self.transform:
274
  image = self.transform(image)
275
  return image
276
-
277
- def invert(image, mask, pcs=10000, epochs=400, weight_decay = 1e-10, lr=1e-1):
278
  global unet
279
  del unet
280
  global network
 
 
 
 
281
  unet, _, _, _, _ = load_models(device)
282
 
283
  proj = torch.zeros(1,pcs).bfloat16().to(device)
@@ -308,7 +313,7 @@ def invert(image, mask, pcs=10000, epochs=400, weight_decay = 1e-10, lr=1e-1):
308
  transforms.Normalize([0.5], [0.5])])
309
 
310
 
311
- train_dataset = CustomImageDataset(image, transform=image_transforms)
312
  train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=1, shuffle=True)
313
 
314
 
@@ -346,12 +351,14 @@ def invert(image, mask, pcs=10000, epochs=400, weight_decay = 1e-10, lr=1e-1):
346
  return network
347
 
348
 
349
- # @spaces.GPU(duration=200)
350
- def run_inversion(input_image, pcs, epochs, weight_decay,lr):
 
351
  global network
352
- init_image = input_image["background"].convert("RGB").resize((512, 512))
353
- mask = input_image["layers"][0].convert("RGB").resize((512, 512))
354
- network = invert([init_image], mask, pcs, epochs, weight_decay,lr)
 
355
 
356
 
357
  #sample an image
@@ -360,20 +367,19 @@ def run_inversion(input_image, pcs, epochs, weight_decay,lr):
360
  seed = 5
361
  cfg = 3.0
362
  steps = 50
363
- image = inference( prompt, negative_prompt, cfg, steps, seed)
364
  torch.save(network.proj, "model.pt" )
365
- return (image,init_image), "model.pt"
366
 
367
 
368
 
369
 
370
-
371
- # @spaces.GPU()
372
  def file_upload(file):
373
  global unet
374
  del unet
375
  global network
376
  global device
 
377
 
378
 
379
 
@@ -402,18 +408,20 @@ def file_upload(file):
402
  seed = 5
403
  cfg = 3.0
404
  steps = 50
405
- image = inference( prompt, negative_prompt, cfg, steps, seed)
406
- return image
407
 
408
 
409
 
410
 
411
 
412
 
413
-
414
-
415
 
416
 
 
 
 
 
417
  intro = """
418
  <div style="display: flex;align-items: center;justify-content: center">
419
  <h1 style="margin-left: 12px;text-align: center;margin-bottom: 7px;display: inline-block">weights2weights</h1>
@@ -443,13 +451,12 @@ with gr.Blocks(css="style.css") as demo:
443
  with gr.Row():
444
  with gr.Column():
445
  sample = gr.Button("🎲 Sample New Model")
446
- file_output = gr.File(label="Download Sampled Model", container=True, interactive=False)
447
  file_input = gr.File(label="Upload Model", container=True)
448
 
449
 
450
- # invert_button = gr.Button("⏪ Invert")
451
  with gr.Column():
452
- image_slider = ImageSlider(position=0.5, type="pil", height=512, width=512)
453
 
454
  prompt1 = gr.Textbox(label="Prompt",
455
  info="Make sure to include 'sks person'" ,
@@ -461,11 +468,11 @@ with gr.Blocks(css="style.css") as demo:
461
 
462
 
463
  with gr.Row():
464
- a1 = gr.Slider(label="- Young +", value=0, step=0.001, minimum=-1, maximum=1, interactive=True)
465
- a2 = gr.Slider(label="- Pointy Nose +", value=0, step=0.001, minimum=-1, maximum=1, interactive=True)
466
  with gr.Row():
467
- a3 = gr.Slider(label="- Curly Hair +", value=0, step=0.001, minimum=-1, maximum=1, interactive=True)
468
- a4 = gr.Slider(label="- Thick Eyebrows +", value=0, step=0.001, minimum=-1, maximum=1, interactive=True)
469
 
470
 
471
 
@@ -473,7 +480,7 @@ with gr.Blocks(css="style.css") as demo:
473
  cfg1= gr.Slider(label="CFG", value=3.0, step=0.1, minimum=0, maximum=10, interactive=True)
474
  steps1 = gr.Slider(label="Inference Steps", value=50, step=1, minimum=0, maximum=100, interactive=True)
475
  negative_prompt1 = gr.Textbox(label="Negative Prompt", placeholder="low quality, blurry, unfinished, nudity, weapon", value="low quality, blurry, unfinished, nudity, weapon")
476
- injection_step = gr.Slider(label="Injection Step", value=800, step=1, minimum=0, maximum=1000, interactive=True)
477
 
478
 
479
 
@@ -481,44 +488,48 @@ with gr.Blocks(css="style.css") as demo:
481
  submit1 = gr.Button("Generate")
482
 
483
  with gr.Tab("Inversion"):
484
- with gr.Column():
485
- input_image = gr.ImageEditor(elem_id="image_upload", type='pil', label="Upload image and draw to define mask", height=512, width=512, brush=gr.Brush(), layers=False)
486
- lr = gr.Number(value=1e-1, label="Learning Rate", interactive=True)
487
- pcs = gr.Slider(label="# Principal Components", value=10000, step=1, minimum=1, maximum=10000, interactive=True)
488
- epochs = gr.Slider(label="Epochs", value=400, step=1, minimum=1, maximum=2000, interactive=True)
489
- weight_decay = gr.Number(value=1e-10, label="Weight Decay", interactive=True)
 
 
 
 
 
 
 
 
 
 
 
490
 
491
- with gr.Column():
492
- image_slider = ImageSlider(position=0.5, type="pil", height=512, width=512)
493
 
494
- prompt1 = gr.Textbox(label="Prompt",
495
- info="Make sure to include 'sks person'" ,
496
- placeholder="sks person",
497
- value="sks person")
498
- seed1 = gr.Number(value=5, label="Seed", precision=0, interactive=True)
499
-
500
 
501
-
502
 
503
- with gr.Row():
504
- a1 = gr.Slider(label="- Young +", value=0, step=0.001, minimum=-1, maximum=1, interactive=True)
505
- a2 = gr.Slider(label="- Pointy Nose +", value=0, step=0.001, minimum=-1, maximum=1, interactive=True)
506
- with gr.Row():
507
- a3 = gr.Slider(label="- Curly Hair +", value=0, step=0.001, minimum=-1, maximum=1, interactive=True)
508
- a4 = gr.Slider(label="- Thick Eyebrows +", value=0, step=0.001, minimum=-1, maximum=1, interactive=True)
509
 
510
-
511
-
512
- with gr.Accordion("Advanced Options", open=False):
513
- cfg1= gr.Slider(label="CFG", value=3.0, step=0.1, minimum=0, maximum=10, interactive=True)
514
- steps1 = gr.Slider(label="Inference Steps", value=50, step=1, minimum=0, maximum=100, interactive=True)
515
- negative_prompt1 = gr.Textbox(label="Negative Prompt", placeholder="low quality, blurry, unfinished, nudity, weapon", value="low quality, blurry, unfinished, nudity, weapon")
516
- injection_step = gr.Slider(label="Injection Step", value=800, step=1, minimum=0, maximum=1000, interactive=True)
517
-
518
-
519
 
 
 
 
 
 
 
 
 
 
520
 
521
- submit2 = gr.Button("Generate")
 
 
522
 
523
 
524
 
@@ -527,17 +538,14 @@ with gr.Blocks(css="style.css") as demo:
527
 
528
 
529
 
530
- sample.click(fn=sample_then_run, outputs=[image_slider, file_output])
531
-
532
-
533
- submit1.click(fn=edit_inference, inputs=[ prompt1, negative_prompt1, cfg1, steps1, seed1, injection_step, a1, a2, a3, a4], outputs=image_slider)
534
- file_input.change(fn=file_upload, inputs=file_input, outputs = image_slider)
535
 
536
 
 
 
537
 
538
 
539
-
540
-
541
-
542
-
543
  demo.queue().launch()
 
22
  from lora_w2w import LoRAw2w
23
  from huggingface_hub import snapshot_download
24
  import numpy as np
25
+
26
+
27
  global device
28
  global generator
29
  global unet
 
61
  unet, _, _, _, _ = load_models(device)
62
  network = sample_weights(unet, proj, mean, std, v[:, :1000], device, factor = 1.00)
63
 
 
64
  @torch.no_grad()
65
+ def inference( prompt, negative_prompt, guidance_scale, ddim_steps, seed):
66
  global device
67
  global generator
68
  global unet
 
111
  return image
112
 
113
 
 
114
  @torch.no_grad()
115
  def edit_inference(prompt, negative_prompt, guidance_scale, ddim_steps, seed, start_noise, a1, a2, a3, a4):
116
 
 
199
  network.reset()
200
 
201
  return (original_image, image)
202
+
203
  def sample_then_run():
204
  global original_image
205
  sample_model()
 
210
  steps = 50
211
  original_image = inference( prompt, negative_prompt, cfg, steps, seed)
212
  torch.save(network.proj, "model.pt" )
213
+
214
 
215
  return (original_image, original_image), "model.pt"
216
 
 
274
  if self.transform:
275
  image = self.transform(image)
276
  return image
277
+
278
+ def invert(dict, pcs=10000, epochs=400, weight_decay = 1e-10, lr=1e-1):
279
  global unet
280
  del unet
281
  global network
282
+
283
+ image = dict["background"].convert("RGB").resize((512, 512))
284
+ mask = dict["layers"][0].convert("RGB").resize((512, 512))
285
+
286
  unet, _, _, _, _ = load_models(device)
287
 
288
  proj = torch.zeros(1,pcs).bfloat16().to(device)
 
313
  transforms.Normalize([0.5], [0.5])])
314
 
315
 
316
+ train_dataset = CustomImageDataset([image], transform=image_transforms)
317
  train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=1, shuffle=True)
318
 
319
 
 
351
  return network
352
 
353
 
354
+
355
+
356
+ def run_inversion(dict, pcs, epochs, weight_decay,lr):
357
  global network
358
+ global original_image
359
+ # init_image = dict["image"].convert("RGB").resize((512, 512))
360
+ # mask = dict["ma print(dict)
361
+ network = invert( dict, pcs, epochs, weight_decay,lr)
362
 
363
 
364
  #sample an image
 
367
  seed = 5
368
  cfg = 3.0
369
  steps = 50
370
+ original_image = inference( prompt, negative_prompt, cfg, steps, seed)
371
  torch.save(network.proj, "model.pt" )
372
+ return (original_image, original_image), "model.pt"
373
 
374
 
375
 
376
 
 
 
377
  def file_upload(file):
378
  global unet
379
  del unet
380
  global network
381
  global device
382
+ global original_image
383
 
384
 
385
 
 
408
  seed = 5
409
  cfg = 3.0
410
  steps = 50
411
+ original_image = inference( prompt, negative_prompt, cfg, steps, seed)
412
+ return (original_image, original_image)
413
 
414
 
415
 
416
 
417
 
418
 
 
 
419
 
420
 
421
+
422
+
423
+
424
+
425
  intro = """
426
  <div style="display: flex;align-items: center;justify-content: center">
427
  <h1 style="margin-left: 12px;text-align: center;margin-bottom: 7px;display: inline-block">weights2weights</h1>
 
451
  with gr.Row():
452
  with gr.Column():
453
  sample = gr.Button("🎲 Sample New Model")
454
+ file_output1 = gr.File(label="Download Sampled Model", container=True, interactive=False)
455
  file_input = gr.File(label="Upload Model", container=True)
456
 
457
 
 
458
  with gr.Column():
459
+ image_slider1 = ImageSlider(position=0.5, type="pil", height=512, width=512)
460
 
461
  prompt1 = gr.Textbox(label="Prompt",
462
  info="Make sure to include 'sks person'" ,
 
468
 
469
 
470
  with gr.Row():
471
+ a1_1 = gr.Slider(label="- Young +", value=0, step=0.001, minimum=-1, maximum=1, interactive=True)
472
+ a2_1 = gr.Slider(label="- Pointy Nose +", value=0, step=0.001, minimum=-1, maximum=1, interactive=True)
473
  with gr.Row():
474
+ a3_1 = gr.Slider(label="- Curly Hair +", value=0, step=0.001, minimum=-1, maximum=1, interactive=True)
475
+ a4_1 = gr.Slider(label="- Thick Eyebrows +", value=0, step=0.001, minimum=-1, maximum=1, interactive=True)
476
 
477
 
478
 
 
480
  cfg1= gr.Slider(label="CFG", value=3.0, step=0.1, minimum=0, maximum=10, interactive=True)
481
  steps1 = gr.Slider(label="Inference Steps", value=50, step=1, minimum=0, maximum=100, interactive=True)
482
  negative_prompt1 = gr.Textbox(label="Negative Prompt", placeholder="low quality, blurry, unfinished, nudity, weapon", value="low quality, blurry, unfinished, nudity, weapon")
483
+ injection_step1 = gr.Slider(label="Injection Step", value=800, step=1, minimum=0, maximum=1000, interactive=True)
484
 
485
 
486
 
 
488
  submit1 = gr.Button("Generate")
489
 
490
  with gr.Tab("Inversion"):
491
+ with gr.Row():
492
+ with gr.Column():
493
+ input_image = gr.ImageEditor(elem_id="image_upload", type='pil', label="Upload image and draw to define mask", height=512, width=512, brush=gr.Brush(), layers=False)
494
+ lr = gr.Number(value=1e-1, label="Learning Rate", interactive=True)
495
+ pcs = gr.Slider(label="# Principal Components", value=10000, step=1, minimum=1, maximum=10000, interactive=True)
496
+ epochs = gr.Slider(label="Epochs", value=400, step=1, minimum=1, maximum=2000, interactive=True)
497
+ weight_decay = gr.Number(value=1e-10, label="Weight Decay", interactive=True)
498
+ invert_button = gr.Button("🎲 Invert")
499
+
500
+ with gr.Column():
501
+ image_slider2 = ImageSlider(position=0.5, type="pil", height=512, width=512)
502
+
503
+ prompt2 = gr.Textbox(label="Prompt",
504
+ info="Make sure to include 'sks person'" ,
505
+ placeholder="sks person",
506
+ value="sks person")
507
+ seed2 = gr.Number(value=5, label="Seed", precision=0, interactive=True)
508
 
 
 
509
 
 
 
 
 
 
 
510
 
 
511
 
512
+ with gr.Row():
513
+ a1_2 = gr.Slider(label="- Young +", value=0, step=0.001, minimum=-1, maximum=1, interactive=True)
514
+ a2_2 = gr.Slider(label="- Pointy Nose +", value=0, step=0.001, minimum=-1, maximum=1, interactive=True)
515
+ with gr.Row():
516
+ a3_2 = gr.Slider(label="- Curly Hair +", value=0, step=0.001, minimum=-1, maximum=1, interactive=True)
517
+ a4_2 = gr.Slider(label="- Thick Eyebrows +", value=0, step=0.001, minimum=-1, maximum=1, interactive=True)
518
 
 
 
 
 
 
 
 
 
 
519
 
520
+
521
+ with gr.Accordion("Advanced Options", open=False):
522
+ cfg2= gr.Slider(label="CFG", value=3.0, step=0.1, minimum=0, maximum=10, interactive=True)
523
+ steps2 = gr.Slider(label="Inference Steps", value=50, step=1, minimum=0, maximum=100, interactive=True)
524
+ negative_prompt2 = gr.Textbox(label="Negative Prompt", placeholder="low quality, blurry, unfinished, nudity, weapon", value="low quality, blurry, unfinished, nudity, weapon")
525
+ injection_step2 = gr.Slider(label="Injection Step", value=800, step=1, minimum=0, maximum=1000, interactive=True)
526
+
527
+
528
+
529
 
530
+ submit2 = gr.Button("Generate")
531
+
532
+ file_output2 = gr.File(label="Download Inverted Model", container=True, interactive=False)
533
 
534
 
535
 
 
538
 
539
 
540
 
541
+ sample.click(fn=sample_then_run, outputs=[image_slider1, file_output1])
542
+ submit1.click(fn=edit_inference, inputs=[ prompt1, negative_prompt1, cfg1, steps1, seed1, injection_step1, a1_1, a2_1, a3_1, a4_1], outputs=image_slider1)
543
+ file_input.change(fn=file_upload, inputs=file_input, outputs = image_slider1)
 
 
544
 
545
 
546
+ invert_button.click(fn=run_inversion, inputs=[input_image, pcs, epochs, weight_decay,lr], outputs = [image_slider2, file_output2])
547
+ submit2.click(fn=edit_inference, inputs=[ prompt2, negative_prompt2, cfg2, steps2, seed2, injection_step2, a1_2, a2_2, a3_2, a4_2], outputs=image_slider2)
548
 
549
 
550
+
 
 
 
551
  demo.queue().launch()