linoyts HF staff commited on
Commit
40c14f1
·
verified ·
1 Parent(s): b69c69e

change to zerogpu

Browse files
Files changed (1) hide show
  1. app.py +8 -10
app.py CHANGED
@@ -33,7 +33,7 @@ global network
33
  device = "cuda:0"
34
  generator = torch.Generator(device=device)
35
  from gradio_imageslider import ImageSlider
36
-
37
 
38
 
39
  models_path = snapshot_download(repo_id="Snapchat/w2w")
@@ -46,9 +46,10 @@ df = torch.load(f"{models_path}/files/identity_df.pt")
46
  weight_dimensions = torch.load(f"{models_path}/files/weight_dimensions.pt")
47
  pinverse = torch.load(f"{models_path}/files/pinverse_1000pc.pt").bfloat16().to(device)
48
 
 
49
  unet, vae, text_encoder, tokenizer, noise_scheduler = load_models(device)
50
 
51
-
52
  def sample_model():
53
  global unet
54
  del unet
@@ -57,9 +58,7 @@ def sample_model():
57
  unet, _, _, _, _ = load_models(device)
58
  network = sample_weights(unet, proj, mean, std, v[:, :1000], device, factor = 1.00)
59
 
60
-
61
 
62
-
63
  @torch.no_grad()
64
  def inference(prompt, negative_prompt, guidance_scale, ddim_steps, seed):
65
  global device
@@ -111,7 +110,7 @@ def inference(prompt, negative_prompt, guidance_scale, ddim_steps, seed):
111
 
112
 
113
 
114
-
115
  @torch.no_grad()
116
  def edit_inference(input_image, prompt, negative_prompt, guidance_scale, ddim_steps, seed, start_noise, a1, a2, a3, a4):
117
 
@@ -198,7 +197,7 @@ def edit_inference(input_image, prompt, negative_prompt, guidance_scale, ddim_st
198
 
199
  return (image, input_image["background"])
200
 
201
-
202
  def sample_then_run():
203
  sample_model()
204
  prompt = "sks person"
@@ -270,7 +269,7 @@ class CustomImageDataset(Dataset):
270
  if self.transform:
271
  image = self.transform(image)
272
  return image
273
-
274
  def invert(image, mask, pcs=10000, epochs=400, weight_decay = 1e-10, lr=1e-1):
275
  global unet
276
  del unet
@@ -343,10 +342,9 @@ def invert(image, mask, pcs=10000, epochs=400, weight_decay = 1e-10, lr=1e-1):
343
  return network
344
 
345
 
346
-
347
  def run_inversion(input_image, pcs, epochs, weight_decay,lr):
348
  global network
349
- print(len(input_image["layers"]))
350
  init_image = input_image["background"].convert("RGB").resize((512, 512))
351
  mask = input_image["layers"][0].convert("RGB").resize((512, 512))
352
  network = invert([init_image], mask, pcs, epochs, weight_decay,lr)
@@ -366,7 +364,7 @@ def run_inversion(input_image, pcs, epochs, weight_decay,lr):
366
 
367
 
368
 
369
-
370
  def file_upload(file):
371
  global unet
372
  del unet
 
33
  device = "cuda:0"
34
  generator = torch.Generator(device=device)
35
  from gradio_imageslider import ImageSlider
36
+ import spaces
37
 
38
 
39
  models_path = snapshot_download(repo_id="Snapchat/w2w")
 
46
  weight_dimensions = torch.load(f"{models_path}/files/weight_dimensions.pt")
47
  pinverse = torch.load(f"{models_path}/files/pinverse_1000pc.pt").bfloat16().to(device)
48
 
49
+ @spaces.GPU()
50
  unet, vae, text_encoder, tokenizer, noise_scheduler = load_models(device)
51
 
52
+ @spaces.GPU()
53
  def sample_model():
54
  global unet
55
  del unet
 
58
  unet, _, _, _, _ = load_models(device)
59
  network = sample_weights(unet, proj, mean, std, v[:, :1000], device, factor = 1.00)
60
 
 
61
 
 
62
  @torch.no_grad()
63
  def inference(prompt, negative_prompt, guidance_scale, ddim_steps, seed):
64
  global device
 
110
 
111
 
112
 
113
+ @spaces.GPU()
114
  @torch.no_grad()
115
  def edit_inference(input_image, prompt, negative_prompt, guidance_scale, ddim_steps, seed, start_noise, a1, a2, a3, a4):
116
 
 
197
 
198
  return (image, input_image["background"])
199
 
200
+ @spaces.GPU()
201
  def sample_then_run():
202
  sample_model()
203
  prompt = "sks person"
 
269
  if self.transform:
270
  image = self.transform(image)
271
  return image
272
+
273
  def invert(image, mask, pcs=10000, epochs=400, weight_decay = 1e-10, lr=1e-1):
274
  global unet
275
  del unet
 
342
  return network
343
 
344
 
345
+ @spaces.GPU(duration=200)
346
  def run_inversion(input_image, pcs, epochs, weight_decay,lr):
347
  global network
 
348
  init_image = input_image["background"].convert("RGB").resize((512, 512))
349
  mask = input_image["layers"][0].convert("RGB").resize((512, 512))
350
  network = invert([init_image], mask, pcs, epochs, weight_decay,lr)
 
364
 
365
 
366
 
367
+ @spaces.GPU()
368
  def file_upload(file):
369
  global unet
370
  del unet