amildravid4292 commited on
Commit
0d54165
·
verified ·
1 Parent(s): 31081a5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -5
app.py CHANGED
@@ -32,13 +32,13 @@ device = gr.State("cuda")
32
 
33
  models_path = snapshot_download(repo_id="Snapchat/w2w")
34
 
35
- mean = torch.load(f"{models_path}/files/mean.pt", map_location=torch.device('cpu')).bfloat16().to(device)
36
- std = torch.load(f"{models_path}/files/std.pt", map_location=torch.device('cpu')).bfloat16().to(device)
37
- v = torch.load(f"{models_path}/files/V.pt", map_location=torch.device('cpu')).bfloat16().to(device)
38
- proj = torch.load(f"{models_path}/files/proj_1000pc.pt", map_location=torch.device('cpu')).bfloat16().to(device)
39
  df = torch.load(f"{models_path}/files/identity_df.pt")
40
  weight_dimensions = torch.load(f"{models_path}/files/weight_dimensions.pt")
41
- pinverse = torch.load(f"{models_path}/files/pinverse_1000pc.pt", map_location=torch.device('cpu')).bfloat16().to(device)
42
 
43
  unet, vae, text_encoder, tokenizer, noise_scheduler = load_models(device)
44
 
 
32
 
33
  models_path = snapshot_download(repo_id="Snapchat/w2w")
34
 
35
+ mean = torch.load(f"{models_path}/files/mean.pt", map_location=torch.device('cpu')).bfloat16().to(device.value)
36
+ std = torch.load(f"{models_path}/files/std.pt", map_location=torch.device('cpu')).bfloat16().to(device.value)
37
+ v = torch.load(f"{models_path}/files/V.pt", map_location=torch.device('cpu')).bfloat16().to(device.value)
38
+ proj = torch.load(f"{models_path}/files/proj_1000pc.pt", map_location=torch.device('cpu')).bfloat16().to(device.value)
39
  df = torch.load(f"{models_path}/files/identity_df.pt")
40
  weight_dimensions = torch.load(f"{models_path}/files/weight_dimensions.pt")
41
+ pinverse = torch.load(f"{models_path}/files/pinverse_1000pc.pt", map_location=torch.device('cpu')).bfloat16().to(device.value)
42
 
43
  unet, vae, text_encoder, tokenizer, noise_scheduler = load_models(device)
44