amildravid4292 commited on
Commit
c75298d
·
verified ·
1 Parent(s): ac24ff3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -11
app.py CHANGED
@@ -20,15 +20,19 @@ from huggingface_hub import snapshot_download
20
  import spaces
21
 
22
 
23
- gr.State(generator)
24
- gr.State(unet)
25
- gr.State(vae)
26
- gr.State(text_encoder)
27
- gr.State(tokenizer)
28
- gr.State(noise_scheduler)
29
- gr.State(network)
30
  device = gr.State("cuda")
31
  #generator = torch.Generator(device=device)
 
 
 
 
32
 
33
  models_path = snapshot_download(repo_id="Snapchat/w2w")
34
 
@@ -43,10 +47,7 @@ pinverse = torch.load(f"{models_path}/files/pinverse_1000pc.pt", map_location=to
43
  unet, vae, text_encoder, tokenizer, noise_scheduler = load_models(device)
44
 
45
 
46
- gr.State(young)
47
- gr.State(pointy)
48
- gr.State(wavy)
49
- gr.State(thick)
50
 
51
  young = get_direction(df, "Young", pinverse, 1000, device)
52
  young = debias(young, "Male", df, pinverse, device)
 
20
  import spaces
21
 
22
 
23
+ generator = gr.State()
24
+ unet = gr.State()
25
+ vae = gr.State()
26
+ text_encoder = gr.State()
27
+ tokenizer = gr.State()
28
+ noise_scheduler = gr.State()
29
+ network = gr.State()
30
  device = gr.State("cuda")
31
  #generator = torch.Generator(device=device)
32
+ young = gr.State()
33
+ pointy = gr.State()
34
+ wavy = gr.State()
35
+ thick = gr.State()
36
 
37
  models_path = snapshot_download(repo_id="Snapchat/w2w")
38
 
 
47
  unet, vae, text_encoder, tokenizer, noise_scheduler = load_models(device)
48
 
49
 
50
+
 
 
 
51
 
52
  young = get_direction(df, "Young", pinverse, 1000, device)
53
  young = debias(young, "Male", df, pinverse, device)