amildravid4292 commited on
Commit
116818c
·
verified ·
1 Parent(s): cfbf1bf

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +2 -3
app.py CHANGED
@@ -12,7 +12,6 @@ import warnings
12
  warnings.filterwarnings("ignore")
13
  from PIL import Image
14
  import numpy as np
15
- from utils import load_models
16
  from editing import get_direction, debias
17
  from lora_w2w import LoRAw2w
18
  from huggingface_hub import snapshot_download
@@ -56,7 +55,7 @@ df = torch.load(f"{models_path}/files/identity_df.pt")
56
  weight_dimensions = torch.load(f"{models_path}/files/weight_dimensions.pt")
57
  pinverse = torch.load(f"{models_path}/files/pinverse_1000pc.pt", map_location=torch.device('cpu')).bfloat16().to(device.value)
58
 
59
- unet.value, vae.value, text_encoder.value, tokenizer.value, noise_scheduler.value = load_models(device.value)
60
 
61
  young = gr.State()
62
  young.value = get_direction(df, "Young", pinverse, 1000, device.value)
@@ -168,7 +167,7 @@ def sample_weights(unet, proj, mean, std, v, device, factor = 1.0):
168
  @torch.no_grad()
169
  @spaces.GPU
170
  def sample_model():
171
- unet.value, _, _, _, _ = load_models(device.value)
172
  network.value = sample_weights(unet.value, proj, mean, std, v[:, :1000], device.value, factor = 1.00)
173
 
174
  @torch.no_grad()
 
12
  warnings.filterwarnings("ignore")
13
  from PIL import Image
14
  import numpy as np
 
15
  from editing import get_direction, debias
16
  from lora_w2w import LoRAw2w
17
  from huggingface_hub import snapshot_download
 
55
  weight_dimensions = torch.load(f"{models_path}/files/weight_dimensions.pt")
56
  pinverse = torch.load(f"{models_path}/files/pinverse_1000pc.pt", map_location=torch.device('cpu')).bfloat16().to(device.value)
57
 
58
+ unet.value, vae.value, text_encoder.value, tokenizer.value, noise_scheduler.value = load_models()
59
 
60
  young = gr.State()
61
  young.value = get_direction(df, "Young", pinverse, 1000, device.value)
 
167
  @torch.no_grad()
168
  @spaces.GPU
169
  def sample_model():
170
+ unet.value, _, _, _, _ = load_models()
171
  network.value = sample_weights(unet.value, proj, mean, std, v[:, :1000], device.value, factor = 1.00)
172
 
173
  @torch.no_grad()