multimodalart HF staff commited on
Commit
3b7acd7
·
verified ·
1 Parent(s): 604bbd1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -9
app.py CHANGED
@@ -19,7 +19,7 @@ from editing import get_direction, debias
19
  from sampling import sample_weights
20
  from lora_w2w import LoRAw2w
21
  from huggingface_hub import snapshot_download
22
-
23
 
24
 
25
 
@@ -32,9 +32,7 @@ global tokenizer
32
  global noise_scheduler
33
  global network
34
  device = "cuda:0"
35
- generator = torch.Generator(device=device)
36
-
37
-
38
 
39
  models_path = snapshot_download(repo_id="Snapchat/w2w")
40
 
@@ -49,7 +47,6 @@ pinverse = torch.load(f"{models_path}/files/pinverse_1000pc.pt").bfloat16().to(d
49
 
50
  unet, vae, text_encoder, tokenizer, noise_scheduler = load_models(device)
51
 
52
-
53
  def sample_model():
54
  global unet
55
  del unet
@@ -59,6 +56,7 @@ def sample_model():
59
  network = sample_weights(unet, proj, mean, std, v[:, :1000], device, factor = 1.00)
60
 
61
  @torch.no_grad()
 
62
  def inference( prompt, negative_prompt, guidance_scale, ddim_steps, seed):
63
  global device
64
  global generator
@@ -109,6 +107,7 @@ def inference( prompt, negative_prompt, guidance_scale, ddim_steps, seed):
109
 
110
 
111
  @torch.no_grad()
 
112
  def edit_inference(prompt, negative_prompt, guidance_scale, ddim_steps, seed, start_noise, a1, a2, a3, a4):
113
 
114
  global device
@@ -193,7 +192,8 @@ def edit_inference(prompt, negative_prompt, guidance_scale, ddim_steps, seed, st
193
  network.reset()
194
 
195
  return image
196
-
 
197
  def sample_then_run():
198
  sample_model()
199
  prompt = "sks person"
@@ -267,7 +267,8 @@ class CustomImageDataset(Dataset):
267
  if self.transform:
268
  image = self.transform(image)
269
  return image
270
-
 
271
  def invert(image, mask, pcs=10000, epochs=400, weight_decay = 1e-10, lr=1e-1):
272
  global unet
273
  del unet
@@ -332,7 +333,7 @@ def invert(image, mask, pcs=10000, epochs=400, weight_decay = 1e-10, lr=1e-1):
332
  return network
333
 
334
 
335
-
336
  def run_inversion(dict, pcs, epochs, weight_decay,lr):
337
  global network
338
  init_image = dict["image"].convert("RGB").resize((512, 512))
@@ -351,7 +352,7 @@ def run_inversion(dict, pcs, epochs, weight_decay,lr):
351
  return image, "model.pt"
352
 
353
 
354
-
355
  def file_upload(file):
356
  global unet
357
  del unet
 
19
  from sampling import sample_weights
20
  from lora_w2w import LoRAw2w
21
  from huggingface_hub import snapshot_download
22
+ import spaces
23
 
24
 
25
 
 
32
  global noise_scheduler
33
  global network
34
  device = "cuda:0"
35
+ #generator = torch.Generator(device=device)
 
 
36
 
37
  models_path = snapshot_download(repo_id="Snapchat/w2w")
38
 
 
47
 
48
  unet, vae, text_encoder, tokenizer, noise_scheduler = load_models(device)
49
 
 
50
  def sample_model():
51
  global unet
52
  del unet
 
56
  network = sample_weights(unet, proj, mean, std, v[:, :1000], device, factor = 1.00)
57
 
58
  @torch.no_grad()
59
+ @spaces.GPU
60
  def inference( prompt, negative_prompt, guidance_scale, ddim_steps, seed):
61
  global device
62
  global generator
 
107
 
108
 
109
  @torch.no_grad()
110
+ @spaces.GPU
111
  def edit_inference(prompt, negative_prompt, guidance_scale, ddim_steps, seed, start_noise, a1, a2, a3, a4):
112
 
113
  global device
 
192
  network.reset()
193
 
194
  return image
195
+
196
+ @spaces.GPU
197
  def sample_then_run():
198
  sample_model()
199
  prompt = "sks person"
 
267
  if self.transform:
268
  image = self.transform(image)
269
  return image
270
+
271
+ @spaces.GPU
272
  def invert(image, mask, pcs=10000, epochs=400, weight_decay = 1e-10, lr=1e-1):
273
  global unet
274
  del unet
 
333
  return network
334
 
335
 
336
+ @spaces.GPU
337
  def run_inversion(dict, pcs, epochs, weight_decay,lr):
338
  global network
339
  init_image = dict["image"].convert("RGB").resize((512, 512))
 
352
  return image, "model.pt"
353
 
354
 
355
+ @spaces.GPU
356
  def file_upload(file):
357
  global unet
358
  del unet