Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -12,7 +12,6 @@ import warnings
|
|
12 |
warnings.filterwarnings("ignore")
|
13 |
from PIL import Image
|
14 |
import numpy as np
|
15 |
-
from utils import load_models
|
16 |
from editing import get_direction, debias
|
17 |
from lora_w2w import LoRAw2w
|
18 |
from huggingface_hub import snapshot_download
|
@@ -56,7 +55,7 @@ df = torch.load(f"{models_path}/files/identity_df.pt")
|
|
56 |
weight_dimensions = torch.load(f"{models_path}/files/weight_dimensions.pt")
|
57 |
pinverse = torch.load(f"{models_path}/files/pinverse_1000pc.pt", map_location=torch.device('cpu')).bfloat16().to(device.value)
|
58 |
|
59 |
-
unet.value, vae.value, text_encoder.value, tokenizer.value, noise_scheduler.value = load_models(
|
60 |
|
61 |
young = gr.State()
|
62 |
young.value = get_direction(df, "Young", pinverse, 1000, device.value)
|
@@ -168,7 +167,7 @@ def sample_weights(unet, proj, mean, std, v, device, factor = 1.0):
|
|
168 |
@torch.no_grad()
|
169 |
@spaces.GPU
|
170 |
def sample_model():
|
171 |
-
unet.value, _, _, _, _ = load_models(
|
172 |
network.value = sample_weights(unet.value, proj, mean, std, v[:, :1000], device.value, factor = 1.00)
|
173 |
|
174 |
@torch.no_grad()
|
|
|
12 |
warnings.filterwarnings("ignore")
|
13 |
from PIL import Image
|
14 |
import numpy as np
|
|
|
15 |
from editing import get_direction, debias
|
16 |
from lora_w2w import LoRAw2w
|
17 |
from huggingface_hub import snapshot_download
|
|
|
55 |
weight_dimensions = torch.load(f"{models_path}/files/weight_dimensions.pt")
|
56 |
pinverse = torch.load(f"{models_path}/files/pinverse_1000pc.pt", map_location=torch.device('cpu')).bfloat16().to(device.value)
|
57 |
|
58 |
+
unet.value, vae.value, text_encoder.value, tokenizer.value, noise_scheduler.value = load_models()
|
59 |
|
60 |
young = gr.State()
|
61 |
young.value = get_direction(df, "Young", pinverse, 1000, device.value)
|
|
|
167 |
@torch.no_grad()
|
168 |
@spaces.GPU
|
169 |
def sample_model():
|
170 |
+
unet.value, _, _, _, _ = load_models()
|
171 |
network.value = sample_weights(unet.value, proj, mean, std, v[:, :1000], device.value, factor = 1.00)
|
172 |
|
173 |
@torch.no_grad()
|