multimodalart HF Staff commited on
Commit
98f7800
·
verified ·
1 Parent(s): 05fbd23

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -8
app.py CHANGED
@@ -21,8 +21,6 @@ from lora_w2w import LoRAw2w
21
  from huggingface_hub import snapshot_download
22
  import spaces
23
 
24
-
25
-
26
  global device
27
  global generator
28
  global unet
@@ -36,13 +34,13 @@ device = "cuda"
36
 
37
  models_path = snapshot_download(repo_id="Snapchat/w2w")
38
 
39
- mean = torch.load(f"{models_path}/files/mean.pt").bfloat16().to(device)
40
- std = torch.load(f"{models_path}/files/std.pt").bfloat16().to(device)
41
- v = torch.load(f"{models_path}/files/V.pt").bfloat16().to(device)
42
- proj = torch.load(f"{models_path}/files/proj_1000pc.pt").bfloat16().to(device)
43
  df = torch.load(f"{models_path}/files/identity_df.pt")
44
  weight_dimensions = torch.load(f"{models_path}/files/weight_dimensions.pt")
45
- pinverse = torch.load(f"{models_path}/files/pinverse_1000pc.pt").bfloat16().to(device)
46
 
47
 
48
  unet, vae, text_encoder, tokenizer, noise_scheduler = load_models(device)
@@ -51,7 +49,10 @@ def sample_model():
51
  global unet
52
  del unet
53
  global network
54
-
 
 
 
55
  unet, _, _, _, _ = load_models(device)
56
  network = sample_weights(unet, proj, mean, std, v[:, :1000], device, factor = 1.00)
57
 
 
21
  from huggingface_hub import snapshot_download
22
  import spaces
23
 
 
 
24
  global device
25
  global generator
26
  global unet
 
34
 
35
  models_path = snapshot_download(repo_id="Snapchat/w2w")
36
 
37
+ mean = torch.load(f"{models_path}/files/mean.pt").bfloat16()#.to(device)
38
+ std = torch.load(f"{models_path}/files/std.pt").bfloat16()#.to(device)
39
+ v = torch.load(f"{models_path}/files/V.pt").bfloat16()#.to(device)
40
+ proj = torch.load(f"{models_path}/files/proj_1000pc.pt").bfloat16()#.to(device)
41
  df = torch.load(f"{models_path}/files/identity_df.pt")
42
  weight_dimensions = torch.load(f"{models_path}/files/weight_dimensions.pt")
43
+ pinverse = torch.load(f"{models_path}/files/pinverse_1000pc.pt")#.bfloat16()#.to(device)
44
 
45
 
46
  unet, vae, text_encoder, tokenizer, noise_scheduler = load_models(device)
 
49
  global unet
50
  del unet
51
  global network
52
+ mean.to(device)
53
+ std.to(device)
54
+ v.to(device)
55
+ proj.to(device)
56
  unet, _, _, _, _ = load_models(device)
57
  network = sample_weights(unet, proj, mean, std, v[:, :1000], device, factor = 1.00)
58