Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -308,14 +308,15 @@ class CustomImageDataset(Dataset):
|
|
308 |
return image
|
309 |
|
310 |
@spaces.GPU
|
311 |
-
def invert(
|
312 |
-
|
313 |
-
|
314 |
-
|
315 |
-
|
|
|
316 |
|
317 |
-
|
318 |
-
network = LoRAw2w(
|
319 |
unet,
|
320 |
rank=1,
|
321 |
multiplier=1.0,
|
@@ -367,18 +368,27 @@ def invert(self, image, mask, pcs=10000, epochs=400, weight_decay = 1e-10, lr=1e
|
|
367 |
optim.zero_grad()
|
368 |
loss.backward()
|
369 |
optim.step()
|
370 |
-
|
371 |
-
|
372 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
373 |
|
374 |
|
375 |
@spaces.GPU
|
376 |
-
def run_inversion(dict, pcs, epochs, weight_decay,lr):
|
377 |
-
|
378 |
-
|
379 |
-
|
380 |
-
|
381 |
-
|
382 |
#sample an image
|
383 |
prompt = "sks person"
|
384 |
negative_prompt = "low quality, blurry, unfinished, nudity"
|
@@ -387,7 +397,7 @@ def run_inversion(dict, pcs, epochs, weight_decay,lr):
|
|
387 |
steps = 25
|
388 |
image = inference( prompt, negative_prompt, cfg, steps, seed)
|
389 |
torch.save(network.proj, "model.pt" )
|
390 |
-
return
|
391 |
|
392 |
|
393 |
@spaces.GPU
|
@@ -408,7 +418,7 @@ def file_upload(file, net):
|
|
408 |
cfg = 3.0
|
409 |
steps = 25
|
410 |
image = inference(net, prompt, negative_prompt, cfg, steps, seed)
|
411 |
-
return net,image
|
412 |
|
413 |
|
414 |
|
@@ -504,8 +514,8 @@ with gr.Blocks(css="style.css") as demo:
|
|
504 |
|
505 |
|
506 |
invert_button.click(fn=run_inversion,
|
507 |
-
inputs=[input_image, pcs, epochs, weight_decay,lr],
|
508 |
-
outputs = [
|
509 |
|
510 |
|
511 |
sample.click(fn=sample_then_run,inputs = [net], outputs=[net, file_output, input_image])
|
|
|
308 |
return image
|
309 |
|
310 |
@spaces.GPU
|
311 |
+
def invert(image, mask, pcs=10000, epochs=400, weight_decay = 1e-10, lr=1e-1):
|
312 |
+
device = "cuda"
|
313 |
+
mean.to(device)
|
314 |
+
std.to(device)
|
315 |
+
v.to(device)
|
316 |
+
|
317 |
|
318 |
+
weights = torch.zeros(1,pcs).bfloat16().to(device)
|
319 |
+
network = LoRAw2w( weights, mean, std, v[:, :pcs],
|
320 |
unet,
|
321 |
rank=1,
|
322 |
multiplier=1.0,
|
|
|
368 |
optim.zero_grad()
|
369 |
loss.backward()
|
370 |
optim.step()
|
371 |
+
|
372 |
+
|
373 |
+
#pad to 10000 PCs
|
374 |
+
pcs_original = weights.shape[1]
|
375 |
+
padding = torch.zeros((1,10000-pcs_original)).to(device)
|
376 |
+
weights = network.proj.detach()
|
377 |
+
weights = torch.cat((weights, padding), 1)
|
378 |
+
|
379 |
+
net = "model_"+str(uuid.uuid4())[:4]+".pt"
|
380 |
+
torch.save(weights, net)
|
381 |
+
|
382 |
+
return net
|
383 |
|
384 |
|
385 |
@spaces.GPU
|
386 |
+
def run_inversion(net, dict, pcs, epochs, weight_decay,lr):
|
387 |
+
init_image = dict["background"].convert("RGB").resize((512, 512))
|
388 |
+
mask = dict["layers"][0].convert("RGB").resize((512, 512))
|
389 |
+
|
390 |
+
net = invert(init_image, mask, pcs, epochs, weight_decay,lr)
|
391 |
+
|
392 |
#sample an image
|
393 |
prompt = "sks person"
|
394 |
negative_prompt = "low quality, blurry, unfinished, nudity"
|
|
|
397 |
steps = 25
|
398 |
image = inference( prompt, negative_prompt, cfg, steps, seed)
|
399 |
torch.save(network.proj, "model.pt" )
|
400 |
+
return net, net, image
|
401 |
|
402 |
|
403 |
@spaces.GPU
|
|
|
418 |
cfg = 3.0
|
419 |
steps = 25
|
420 |
image = inference(net, prompt, negative_prompt, cfg, steps, seed)
|
421 |
+
return net, image
|
422 |
|
423 |
|
424 |
|
|
|
514 |
|
515 |
|
516 |
invert_button.click(fn=run_inversion,
|
517 |
+
inputs=[net, input_image, pcs, epochs, weight_decay,lr],
|
518 |
+
outputs = [net, file_output, input_image])
|
519 |
|
520 |
|
521 |
sample.click(fn=sample_then_run,inputs = [net], outputs=[net, file_output, input_image])
|