Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -19,7 +19,7 @@ from editing import get_direction, debias
|
|
19 |
from sampling import sample_weights
|
20 |
from lora_w2w import LoRAw2w
|
21 |
from huggingface_hub import snapshot_download
|
22 |
-
|
23 |
|
24 |
|
25 |
|
@@ -32,9 +32,7 @@ global tokenizer
|
|
32 |
global noise_scheduler
|
33 |
global network
|
34 |
device = "cuda:0"
|
35 |
-
generator = torch.Generator(device=device)
|
36 |
-
|
37 |
-
|
38 |
|
39 |
models_path = snapshot_download(repo_id="Snapchat/w2w")
|
40 |
|
@@ -49,7 +47,6 @@ pinverse = torch.load(f"{models_path}/files/pinverse_1000pc.pt").bfloat16().to(d
|
|
49 |
|
50 |
unet, vae, text_encoder, tokenizer, noise_scheduler = load_models(device)
|
51 |
|
52 |
-
|
53 |
def sample_model():
|
54 |
global unet
|
55 |
del unet
|
@@ -59,6 +56,7 @@ def sample_model():
|
|
59 |
network = sample_weights(unet, proj, mean, std, v[:, :1000], device, factor = 1.00)
|
60 |
|
61 |
@torch.no_grad()
|
|
|
62 |
def inference( prompt, negative_prompt, guidance_scale, ddim_steps, seed):
|
63 |
global device
|
64 |
global generator
|
@@ -109,6 +107,7 @@ def inference( prompt, negative_prompt, guidance_scale, ddim_steps, seed):
|
|
109 |
|
110 |
|
111 |
@torch.no_grad()
|
|
|
112 |
def edit_inference(prompt, negative_prompt, guidance_scale, ddim_steps, seed, start_noise, a1, a2, a3, a4):
|
113 |
|
114 |
global device
|
@@ -193,7 +192,8 @@ def edit_inference(prompt, negative_prompt, guidance_scale, ddim_steps, seed, st
|
|
193 |
network.reset()
|
194 |
|
195 |
return image
|
196 |
-
|
|
|
197 |
def sample_then_run():
|
198 |
sample_model()
|
199 |
prompt = "sks person"
|
@@ -267,7 +267,8 @@ class CustomImageDataset(Dataset):
|
|
267 |
if self.transform:
|
268 |
image = self.transform(image)
|
269 |
return image
|
270 |
-
|
|
|
271 |
def invert(image, mask, pcs=10000, epochs=400, weight_decay = 1e-10, lr=1e-1):
|
272 |
global unet
|
273 |
del unet
|
@@ -332,7 +333,7 @@ def invert(image, mask, pcs=10000, epochs=400, weight_decay = 1e-10, lr=1e-1):
|
|
332 |
return network
|
333 |
|
334 |
|
335 |
-
|
336 |
def run_inversion(dict, pcs, epochs, weight_decay,lr):
|
337 |
global network
|
338 |
init_image = dict["image"].convert("RGB").resize((512, 512))
|
@@ -351,7 +352,7 @@ def run_inversion(dict, pcs, epochs, weight_decay,lr):
|
|
351 |
return image, "model.pt"
|
352 |
|
353 |
|
354 |
-
|
355 |
def file_upload(file):
|
356 |
global unet
|
357 |
del unet
|
|
|
19 |
from sampling import sample_weights
|
20 |
from lora_w2w import LoRAw2w
|
21 |
from huggingface_hub import snapshot_download
|
22 |
+
import spaces
|
23 |
|
24 |
|
25 |
|
|
|
32 |
global noise_scheduler
|
33 |
global network
|
34 |
device = "cuda:0"
|
35 |
+
#generator = torch.Generator(device=device)
|
|
|
|
|
36 |
|
37 |
models_path = snapshot_download(repo_id="Snapchat/w2w")
|
38 |
|
|
|
47 |
|
48 |
unet, vae, text_encoder, tokenizer, noise_scheduler = load_models(device)
|
49 |
|
|
|
50 |
def sample_model():
|
51 |
global unet
|
52 |
del unet
|
|
|
56 |
network = sample_weights(unet, proj, mean, std, v[:, :1000], device, factor = 1.00)
|
57 |
|
58 |
@torch.no_grad()
|
59 |
+
@spaces.GPU
|
60 |
def inference( prompt, negative_prompt, guidance_scale, ddim_steps, seed):
|
61 |
global device
|
62 |
global generator
|
|
|
107 |
|
108 |
|
109 |
@torch.no_grad()
|
110 |
+
@spaces.GPU
|
111 |
def edit_inference(prompt, negative_prompt, guidance_scale, ddim_steps, seed, start_noise, a1, a2, a3, a4):
|
112 |
|
113 |
global device
|
|
|
192 |
network.reset()
|
193 |
|
194 |
return image
|
195 |
+
|
196 |
+
@spaces.GPU
|
197 |
def sample_then_run():
|
198 |
sample_model()
|
199 |
prompt = "sks person"
|
|
|
267 |
if self.transform:
|
268 |
image = self.transform(image)
|
269 |
return image
|
270 |
+
|
271 |
+
@spaces.GPU
|
272 |
def invert(image, mask, pcs=10000, epochs=400, weight_decay = 1e-10, lr=1e-1):
|
273 |
global unet
|
274 |
del unet
|
|
|
333 |
return network
|
334 |
|
335 |
|
336 |
+
@spaces.GPU
|
337 |
def run_inversion(dict, pcs, epochs, weight_decay,lr):
|
338 |
global network
|
339 |
init_image = dict["image"].convert("RGB").resize((512, 512))
|
|
|
352 |
return image, "model.pt"
|
353 |
|
354 |
|
355 |
+
@spaces.GPU
|
356 |
def file_upload(file):
|
357 |
global unet
|
358 |
del unet
|