Spaces:
Running
on
Zero
Running
on
Zero
change to zerogpu
Browse files
app.py
CHANGED
@@ -33,7 +33,7 @@ global network
|
|
33 |
device = "cuda:0"
|
34 |
generator = torch.Generator(device=device)
|
35 |
from gradio_imageslider import ImageSlider
|
36 |
-
|
37 |
|
38 |
|
39 |
models_path = snapshot_download(repo_id="Snapchat/w2w")
|
@@ -46,9 +46,10 @@ df = torch.load(f"{models_path}/files/identity_df.pt")
|
|
46 |
weight_dimensions = torch.load(f"{models_path}/files/weight_dimensions.pt")
|
47 |
pinverse = torch.load(f"{models_path}/files/pinverse_1000pc.pt").bfloat16().to(device)
|
48 |
|
|
|
49 |
unet, vae, text_encoder, tokenizer, noise_scheduler = load_models(device)
|
50 |
|
51 |
-
|
52 |
def sample_model():
|
53 |
global unet
|
54 |
del unet
|
@@ -57,9 +58,7 @@ def sample_model():
|
|
57 |
unet, _, _, _, _ = load_models(device)
|
58 |
network = sample_weights(unet, proj, mean, std, v[:, :1000], device, factor = 1.00)
|
59 |
|
60 |
-
|
61 |
|
62 |
-
|
63 |
@torch.no_grad()
|
64 |
def inference(prompt, negative_prompt, guidance_scale, ddim_steps, seed):
|
65 |
global device
|
@@ -111,7 +110,7 @@ def inference(prompt, negative_prompt, guidance_scale, ddim_steps, seed):
|
|
111 |
|
112 |
|
113 |
|
114 |
-
|
115 |
@torch.no_grad()
|
116 |
def edit_inference(input_image, prompt, negative_prompt, guidance_scale, ddim_steps, seed, start_noise, a1, a2, a3, a4):
|
117 |
|
@@ -198,7 +197,7 @@ def edit_inference(input_image, prompt, negative_prompt, guidance_scale, ddim_st
|
|
198 |
|
199 |
return (image, input_image["background"])
|
200 |
|
201 |
-
|
202 |
def sample_then_run():
|
203 |
sample_model()
|
204 |
prompt = "sks person"
|
@@ -270,7 +269,7 @@ class CustomImageDataset(Dataset):
|
|
270 |
if self.transform:
|
271 |
image = self.transform(image)
|
272 |
return image
|
273 |
-
|
274 |
def invert(image, mask, pcs=10000, epochs=400, weight_decay = 1e-10, lr=1e-1):
|
275 |
global unet
|
276 |
del unet
|
@@ -343,10 +342,9 @@ def invert(image, mask, pcs=10000, epochs=400, weight_decay = 1e-10, lr=1e-1):
|
|
343 |
return network
|
344 |
|
345 |
|
346 |
-
|
347 |
def run_inversion(input_image, pcs, epochs, weight_decay,lr):
|
348 |
global network
|
349 |
-
print(len(input_image["layers"]))
|
350 |
init_image = input_image["background"].convert("RGB").resize((512, 512))
|
351 |
mask = input_image["layers"][0].convert("RGB").resize((512, 512))
|
352 |
network = invert([init_image], mask, pcs, epochs, weight_decay,lr)
|
@@ -366,7 +364,7 @@ def run_inversion(input_image, pcs, epochs, weight_decay,lr):
|
|
366 |
|
367 |
|
368 |
|
369 |
-
|
370 |
def file_upload(file):
|
371 |
global unet
|
372 |
del unet
|
|
|
33 |
device = "cuda:0"
|
34 |
generator = torch.Generator(device=device)
|
35 |
from gradio_imageslider import ImageSlider
|
36 |
+
import spaces
|
37 |
|
38 |
|
39 |
models_path = snapshot_download(repo_id="Snapchat/w2w")
|
|
|
46 |
weight_dimensions = torch.load(f"{models_path}/files/weight_dimensions.pt")
|
47 |
pinverse = torch.load(f"{models_path}/files/pinverse_1000pc.pt").bfloat16().to(device)
|
48 |
|
49 |
+
@spaces.GPU()
|
50 |
unet, vae, text_encoder, tokenizer, noise_scheduler = load_models(device)
|
51 |
|
52 |
+
@spaces.GPU()
|
53 |
def sample_model():
|
54 |
global unet
|
55 |
del unet
|
|
|
58 |
unet, _, _, _, _ = load_models(device)
|
59 |
network = sample_weights(unet, proj, mean, std, v[:, :1000], device, factor = 1.00)
|
60 |
|
|
|
61 |
|
|
|
62 |
@torch.no_grad()
|
63 |
def inference(prompt, negative_prompt, guidance_scale, ddim_steps, seed):
|
64 |
global device
|
|
|
110 |
|
111 |
|
112 |
|
113 |
+
@spaces.GPU()
|
114 |
@torch.no_grad()
|
115 |
def edit_inference(input_image, prompt, negative_prompt, guidance_scale, ddim_steps, seed, start_noise, a1, a2, a3, a4):
|
116 |
|
|
|
197 |
|
198 |
return (image, input_image["background"])
|
199 |
|
200 |
+
@spaces.GPU()
|
201 |
def sample_then_run():
|
202 |
sample_model()
|
203 |
prompt = "sks person"
|
|
|
269 |
if self.transform:
|
270 |
image = self.transform(image)
|
271 |
return image
|
272 |
+
|
273 |
def invert(image, mask, pcs=10000, epochs=400, weight_decay = 1e-10, lr=1e-1):
|
274 |
global unet
|
275 |
del unet
|
|
|
342 |
return network
|
343 |
|
344 |
|
345 |
+
@spaces.GPU(duration=200)
|
346 |
def run_inversion(input_image, pcs, epochs, weight_decay,lr):
|
347 |
global network
|
|
|
348 |
init_image = input_image["background"].convert("RGB").resize((512, 512))
|
349 |
mask = input_image["layers"][0].convert("RGB").resize((512, 512))
|
350 |
network = invert([init_image], mask, pcs, epochs, weight_decay,lr)
|
|
|
364 |
|
365 |
|
366 |
|
367 |
+
@spaces.GPU()
|
368 |
def file_upload(file):
|
369 |
global unet
|
370 |
del unet
|