Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -22,6 +22,8 @@ from sampling import sample_weights
|
|
22 |
from lora_w2w import LoRAw2w
|
23 |
from huggingface_hub import snapshot_download
|
24 |
import numpy as np
|
|
|
|
|
25 |
global device
|
26 |
global generator
|
27 |
global unet
|
@@ -59,9 +61,8 @@ def sample_model():
|
|
59 |
unet, _, _, _, _ = load_models(device)
|
60 |
network = sample_weights(unet, proj, mean, std, v[:, :1000], device, factor = 1.00)
|
61 |
|
62 |
-
|
63 |
@torch.no_grad()
|
64 |
-
def inference(prompt, negative_prompt, guidance_scale, ddim_steps, seed):
|
65 |
global device
|
66 |
global generator
|
67 |
global unet
|
@@ -110,7 +111,6 @@ def inference(prompt, negative_prompt, guidance_scale, ddim_steps, seed):
|
|
110 |
return image
|
111 |
|
112 |
|
113 |
-
|
114 |
@torch.no_grad()
|
115 |
def edit_inference(prompt, negative_prompt, guidance_scale, ddim_steps, seed, start_noise, a1, a2, a3, a4):
|
116 |
|
@@ -199,7 +199,7 @@ def edit_inference(prompt, negative_prompt, guidance_scale, ddim_steps, seed, st
|
|
199 |
network.reset()
|
200 |
|
201 |
return (original_image, image)
|
202 |
-
|
203 |
def sample_then_run():
|
204 |
global original_image
|
205 |
sample_model()
|
@@ -210,6 +210,7 @@ def sample_then_run():
|
|
210 |
steps = 50
|
211 |
original_image = inference( prompt, negative_prompt, cfg, steps, seed)
|
212 |
torch.save(network.proj, "model.pt" )
|
|
|
213 |
|
214 |
return (original_image, original_image), "model.pt"
|
215 |
|
@@ -273,11 +274,15 @@ class CustomImageDataset(Dataset):
|
|
273 |
if self.transform:
|
274 |
image = self.transform(image)
|
275 |
return image
|
276 |
-
|
277 |
-
def invert(
|
278 |
global unet
|
279 |
del unet
|
280 |
global network
|
|
|
|
|
|
|
|
|
281 |
unet, _, _, _, _ = load_models(device)
|
282 |
|
283 |
proj = torch.zeros(1,pcs).bfloat16().to(device)
|
@@ -308,7 +313,7 @@ def invert(image, mask, pcs=10000, epochs=400, weight_decay = 1e-10, lr=1e-1):
|
|
308 |
transforms.Normalize([0.5], [0.5])])
|
309 |
|
310 |
|
311 |
-
train_dataset = CustomImageDataset(image, transform=image_transforms)
|
312 |
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=1, shuffle=True)
|
313 |
|
314 |
|
@@ -346,12 +351,14 @@ def invert(image, mask, pcs=10000, epochs=400, weight_decay = 1e-10, lr=1e-1):
|
|
346 |
return network
|
347 |
|
348 |
|
349 |
-
|
350 |
-
|
|
|
351 |
global network
|
352 |
-
|
353 |
-
|
354 |
-
|
|
|
355 |
|
356 |
|
357 |
#sample an image
|
@@ -360,20 +367,19 @@ def run_inversion(input_image, pcs, epochs, weight_decay,lr):
|
|
360 |
seed = 5
|
361 |
cfg = 3.0
|
362 |
steps = 50
|
363 |
-
|
364 |
torch.save(network.proj, "model.pt" )
|
365 |
-
return (
|
366 |
|
367 |
|
368 |
|
369 |
|
370 |
-
|
371 |
-
# @spaces.GPU()
|
372 |
def file_upload(file):
|
373 |
global unet
|
374 |
del unet
|
375 |
global network
|
376 |
global device
|
|
|
377 |
|
378 |
|
379 |
|
@@ -402,18 +408,20 @@ def file_upload(file):
|
|
402 |
seed = 5
|
403 |
cfg = 3.0
|
404 |
steps = 50
|
405 |
-
|
406 |
-
return
|
407 |
|
408 |
|
409 |
|
410 |
|
411 |
|
412 |
|
413 |
-
|
414 |
-
|
415 |
|
416 |
|
|
|
|
|
|
|
|
|
417 |
intro = """
|
418 |
<div style="display: flex;align-items: center;justify-content: center">
|
419 |
<h1 style="margin-left: 12px;text-align: center;margin-bottom: 7px;display: inline-block">weights2weights</h1>
|
@@ -443,13 +451,12 @@ with gr.Blocks(css="style.css") as demo:
|
|
443 |
with gr.Row():
|
444 |
with gr.Column():
|
445 |
sample = gr.Button("🎲 Sample New Model")
|
446 |
-
|
447 |
file_input = gr.File(label="Upload Model", container=True)
|
448 |
|
449 |
|
450 |
-
# invert_button = gr.Button("⏪ Invert")
|
451 |
with gr.Column():
|
452 |
-
|
453 |
|
454 |
prompt1 = gr.Textbox(label="Prompt",
|
455 |
info="Make sure to include 'sks person'" ,
|
@@ -461,11 +468,11 @@ with gr.Blocks(css="style.css") as demo:
|
|
461 |
|
462 |
|
463 |
with gr.Row():
|
464 |
-
|
465 |
-
|
466 |
with gr.Row():
|
467 |
-
|
468 |
-
|
469 |
|
470 |
|
471 |
|
@@ -473,7 +480,7 @@ with gr.Blocks(css="style.css") as demo:
|
|
473 |
cfg1= gr.Slider(label="CFG", value=3.0, step=0.1, minimum=0, maximum=10, interactive=True)
|
474 |
steps1 = gr.Slider(label="Inference Steps", value=50, step=1, minimum=0, maximum=100, interactive=True)
|
475 |
negative_prompt1 = gr.Textbox(label="Negative Prompt", placeholder="low quality, blurry, unfinished, nudity, weapon", value="low quality, blurry, unfinished, nudity, weapon")
|
476 |
-
|
477 |
|
478 |
|
479 |
|
@@ -481,44 +488,48 @@ with gr.Blocks(css="style.css") as demo:
|
|
481 |
submit1 = gr.Button("Generate")
|
482 |
|
483 |
with gr.Tab("Inversion"):
|
484 |
-
with gr.
|
485 |
-
|
486 |
-
|
487 |
-
|
488 |
-
|
489 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
490 |
|
491 |
-
with gr.Column():
|
492 |
-
image_slider = ImageSlider(position=0.5, type="pil", height=512, width=512)
|
493 |
|
494 |
-
prompt1 = gr.Textbox(label="Prompt",
|
495 |
-
info="Make sure to include 'sks person'" ,
|
496 |
-
placeholder="sks person",
|
497 |
-
value="sks person")
|
498 |
-
seed1 = gr.Number(value=5, label="Seed", precision=0, interactive=True)
|
499 |
-
|
500 |
|
501 |
-
|
502 |
|
503 |
-
|
504 |
-
|
505 |
-
|
506 |
-
|
507 |
-
|
508 |
-
|
509 |
|
510 |
-
|
511 |
-
|
512 |
-
with gr.Accordion("Advanced Options", open=False):
|
513 |
-
cfg1= gr.Slider(label="CFG", value=3.0, step=0.1, minimum=0, maximum=10, interactive=True)
|
514 |
-
steps1 = gr.Slider(label="Inference Steps", value=50, step=1, minimum=0, maximum=100, interactive=True)
|
515 |
-
negative_prompt1 = gr.Textbox(label="Negative Prompt", placeholder="low quality, blurry, unfinished, nudity, weapon", value="low quality, blurry, unfinished, nudity, weapon")
|
516 |
-
injection_step = gr.Slider(label="Injection Step", value=800, step=1, minimum=0, maximum=1000, interactive=True)
|
517 |
-
|
518 |
-
|
519 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
520 |
|
521 |
-
|
|
|
|
|
522 |
|
523 |
|
524 |
|
@@ -527,17 +538,14 @@ with gr.Blocks(css="style.css") as demo:
|
|
527 |
|
528 |
|
529 |
|
530 |
-
sample.click(fn=sample_then_run, outputs=[
|
531 |
-
|
532 |
-
|
533 |
-
submit1.click(fn=edit_inference, inputs=[ prompt1, negative_prompt1, cfg1, steps1, seed1, injection_step, a1, a2, a3, a4], outputs=image_slider)
|
534 |
-
file_input.change(fn=file_upload, inputs=file_input, outputs = image_slider)
|
535 |
|
536 |
|
|
|
|
|
537 |
|
538 |
|
539 |
-
|
540 |
-
|
541 |
-
|
542 |
-
|
543 |
demo.queue().launch()
|
|
|
22 |
from lora_w2w import LoRAw2w
|
23 |
from huggingface_hub import snapshot_download
|
24 |
import numpy as np
|
25 |
+
|
26 |
+
|
27 |
global device
|
28 |
global generator
|
29 |
global unet
|
|
|
61 |
unet, _, _, _, _ = load_models(device)
|
62 |
network = sample_weights(unet, proj, mean, std, v[:, :1000], device, factor = 1.00)
|
63 |
|
|
|
64 |
@torch.no_grad()
|
65 |
+
def inference( prompt, negative_prompt, guidance_scale, ddim_steps, seed):
|
66 |
global device
|
67 |
global generator
|
68 |
global unet
|
|
|
111 |
return image
|
112 |
|
113 |
|
|
|
114 |
@torch.no_grad()
|
115 |
def edit_inference(prompt, negative_prompt, guidance_scale, ddim_steps, seed, start_noise, a1, a2, a3, a4):
|
116 |
|
|
|
199 |
network.reset()
|
200 |
|
201 |
return (original_image, image)
|
202 |
+
|
203 |
def sample_then_run():
|
204 |
global original_image
|
205 |
sample_model()
|
|
|
210 |
steps = 50
|
211 |
original_image = inference( prompt, negative_prompt, cfg, steps, seed)
|
212 |
torch.save(network.proj, "model.pt" )
|
213 |
+
|
214 |
|
215 |
return (original_image, original_image), "model.pt"
|
216 |
|
|
|
274 |
if self.transform:
|
275 |
image = self.transform(image)
|
276 |
return image
|
277 |
+
|
278 |
+
def invert(dict, pcs=10000, epochs=400, weight_decay = 1e-10, lr=1e-1):
|
279 |
global unet
|
280 |
del unet
|
281 |
global network
|
282 |
+
|
283 |
+
image = dict["background"].convert("RGB").resize((512, 512))
|
284 |
+
mask = dict["layers"][0].convert("RGB").resize((512, 512))
|
285 |
+
|
286 |
unet, _, _, _, _ = load_models(device)
|
287 |
|
288 |
proj = torch.zeros(1,pcs).bfloat16().to(device)
|
|
|
313 |
transforms.Normalize([0.5], [0.5])])
|
314 |
|
315 |
|
316 |
+
train_dataset = CustomImageDataset([image], transform=image_transforms)
|
317 |
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=1, shuffle=True)
|
318 |
|
319 |
|
|
|
351 |
return network
|
352 |
|
353 |
|
354 |
+
|
355 |
+
|
356 |
+
def run_inversion(dict, pcs, epochs, weight_decay,lr):
|
357 |
global network
|
358 |
+
global original_image
|
359 |
+
# init_image = dict["image"].convert("RGB").resize((512, 512))
|
360 |
+
# mask = dict["ma print(dict)
|
361 |
+
network = invert( dict, pcs, epochs, weight_decay,lr)
|
362 |
|
363 |
|
364 |
#sample an image
|
|
|
367 |
seed = 5
|
368 |
cfg = 3.0
|
369 |
steps = 50
|
370 |
+
original_image = inference( prompt, negative_prompt, cfg, steps, seed)
|
371 |
torch.save(network.proj, "model.pt" )
|
372 |
+
return (original_image, original_image), "model.pt"
|
373 |
|
374 |
|
375 |
|
376 |
|
|
|
|
|
377 |
def file_upload(file):
|
378 |
global unet
|
379 |
del unet
|
380 |
global network
|
381 |
global device
|
382 |
+
global original_image
|
383 |
|
384 |
|
385 |
|
|
|
408 |
seed = 5
|
409 |
cfg = 3.0
|
410 |
steps = 50
|
411 |
+
original_image = inference( prompt, negative_prompt, cfg, steps, seed)
|
412 |
+
return (original_image, original_image)
|
413 |
|
414 |
|
415 |
|
416 |
|
417 |
|
418 |
|
|
|
|
|
419 |
|
420 |
|
421 |
+
|
422 |
+
|
423 |
+
|
424 |
+
|
425 |
intro = """
|
426 |
<div style="display: flex;align-items: center;justify-content: center">
|
427 |
<h1 style="margin-left: 12px;text-align: center;margin-bottom: 7px;display: inline-block">weights2weights</h1>
|
|
|
451 |
with gr.Row():
|
452 |
with gr.Column():
|
453 |
sample = gr.Button("🎲 Sample New Model")
|
454 |
+
file_output1 = gr.File(label="Download Sampled Model", container=True, interactive=False)
|
455 |
file_input = gr.File(label="Upload Model", container=True)
|
456 |
|
457 |
|
|
|
458 |
with gr.Column():
|
459 |
+
image_slider1 = ImageSlider(position=0.5, type="pil", height=512, width=512)
|
460 |
|
461 |
prompt1 = gr.Textbox(label="Prompt",
|
462 |
info="Make sure to include 'sks person'" ,
|
|
|
468 |
|
469 |
|
470 |
with gr.Row():
|
471 |
+
a1_1 = gr.Slider(label="- Young +", value=0, step=0.001, minimum=-1, maximum=1, interactive=True)
|
472 |
+
a2_1 = gr.Slider(label="- Pointy Nose +", value=0, step=0.001, minimum=-1, maximum=1, interactive=True)
|
473 |
with gr.Row():
|
474 |
+
a3_1 = gr.Slider(label="- Curly Hair +", value=0, step=0.001, minimum=-1, maximum=1, interactive=True)
|
475 |
+
a4_1 = gr.Slider(label="- Thick Eyebrows +", value=0, step=0.001, minimum=-1, maximum=1, interactive=True)
|
476 |
|
477 |
|
478 |
|
|
|
480 |
cfg1= gr.Slider(label="CFG", value=3.0, step=0.1, minimum=0, maximum=10, interactive=True)
|
481 |
steps1 = gr.Slider(label="Inference Steps", value=50, step=1, minimum=0, maximum=100, interactive=True)
|
482 |
negative_prompt1 = gr.Textbox(label="Negative Prompt", placeholder="low quality, blurry, unfinished, nudity, weapon", value="low quality, blurry, unfinished, nudity, weapon")
|
483 |
+
injection_step1 = gr.Slider(label="Injection Step", value=800, step=1, minimum=0, maximum=1000, interactive=True)
|
484 |
|
485 |
|
486 |
|
|
|
488 |
submit1 = gr.Button("Generate")
|
489 |
|
490 |
with gr.Tab("Inversion"):
|
491 |
+
with gr.Row():
|
492 |
+
with gr.Column():
|
493 |
+
input_image = gr.ImageEditor(elem_id="image_upload", type='pil', label="Upload image and draw to define mask", height=512, width=512, brush=gr.Brush(), layers=False)
|
494 |
+
lr = gr.Number(value=1e-1, label="Learning Rate", interactive=True)
|
495 |
+
pcs = gr.Slider(label="# Principal Components", value=10000, step=1, minimum=1, maximum=10000, interactive=True)
|
496 |
+
epochs = gr.Slider(label="Epochs", value=400, step=1, minimum=1, maximum=2000, interactive=True)
|
497 |
+
weight_decay = gr.Number(value=1e-10, label="Weight Decay", interactive=True)
|
498 |
+
invert_button = gr.Button("🎲 Invert")
|
499 |
+
|
500 |
+
with gr.Column():
|
501 |
+
image_slider2 = ImageSlider(position=0.5, type="pil", height=512, width=512)
|
502 |
+
|
503 |
+
prompt2 = gr.Textbox(label="Prompt",
|
504 |
+
info="Make sure to include 'sks person'" ,
|
505 |
+
placeholder="sks person",
|
506 |
+
value="sks person")
|
507 |
+
seed2 = gr.Number(value=5, label="Seed", precision=0, interactive=True)
|
508 |
|
|
|
|
|
509 |
|
|
|
|
|
|
|
|
|
|
|
|
|
510 |
|
|
|
511 |
|
512 |
+
with gr.Row():
|
513 |
+
a1_2 = gr.Slider(label="- Young +", value=0, step=0.001, minimum=-1, maximum=1, interactive=True)
|
514 |
+
a2_2 = gr.Slider(label="- Pointy Nose +", value=0, step=0.001, minimum=-1, maximum=1, interactive=True)
|
515 |
+
with gr.Row():
|
516 |
+
a3_2 = gr.Slider(label="- Curly Hair +", value=0, step=0.001, minimum=-1, maximum=1, interactive=True)
|
517 |
+
a4_2 = gr.Slider(label="- Thick Eyebrows +", value=0, step=0.001, minimum=-1, maximum=1, interactive=True)
|
518 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
519 |
|
520 |
+
|
521 |
+
with gr.Accordion("Advanced Options", open=False):
|
522 |
+
cfg2= gr.Slider(label="CFG", value=3.0, step=0.1, minimum=0, maximum=10, interactive=True)
|
523 |
+
steps2 = gr.Slider(label="Inference Steps", value=50, step=1, minimum=0, maximum=100, interactive=True)
|
524 |
+
negative_prompt2 = gr.Textbox(label="Negative Prompt", placeholder="low quality, blurry, unfinished, nudity, weapon", value="low quality, blurry, unfinished, nudity, weapon")
|
525 |
+
injection_step2 = gr.Slider(label="Injection Step", value=800, step=1, minimum=0, maximum=1000, interactive=True)
|
526 |
+
|
527 |
+
|
528 |
+
|
529 |
|
530 |
+
submit2 = gr.Button("Generate")
|
531 |
+
|
532 |
+
file_output2 = gr.File(label="Download Inverted Model", container=True, interactive=False)
|
533 |
|
534 |
|
535 |
|
|
|
538 |
|
539 |
|
540 |
|
541 |
+
sample.click(fn=sample_then_run, outputs=[image_slider1, file_output1])
|
542 |
+
submit1.click(fn=edit_inference, inputs=[ prompt1, negative_prompt1, cfg1, steps1, seed1, injection_step1, a1_1, a2_1, a3_1, a4_1], outputs=image_slider1)
|
543 |
+
file_input.change(fn=file_upload, inputs=file_input, outputs = image_slider1)
|
|
|
|
|
544 |
|
545 |
|
546 |
+
invert_button.click(fn=run_inversion, inputs=[input_image, pcs, epochs, weight_decay,lr], outputs = [image_slider2, file_output2])
|
547 |
+
submit2.click(fn=edit_inference, inputs=[ prompt2, negative_prompt2, cfg2, steps2, seed2, injection_step2, a1_2, a2_2, a3_2, a4_2], outputs=image_slider2)
|
548 |
|
549 |
|
550 |
+
|
|
|
|
|
|
|
551 |
demo.queue().launch()
|