amildravid4292 commited on
Commit
c1cc7f8
·
verified ·
1 Parent(s): 41f3674

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -19
app.py CHANGED
@@ -12,13 +12,13 @@ from utils import load_models, save_model_w2w, save_model_for_diffusers
12
  from sampling import sample_weights
13
  from huggingface_hub import snapshot_download
14
 
15
- #global device
16
- #global generator
17
- #global unet
18
- #global vae
19
- #global text_encoder
20
- #global tokenizer
21
- #global noise_scheduler
22
  device = "cuda:0"
23
 
24
  models_path = snapshot_download(repo_id="Snapchat/w2w")
@@ -34,20 +34,20 @@ unet, vae, text_encoder, tokenizer, noise_scheduler = load_models(device)
34
  network = sample_weights(unet, proj, mean, std, v[:, :1000], device, factor = 1.00)
35
  #global network
36
 
37
- #def sample_model():
38
- # global unet
39
- # del unet
40
- # global network
41
- # unet, _, _, _, _ = load_models(device)
42
 
43
  def inference(prompt, negative_prompt, guidance_scale, ddim_steps, seed):
44
- #global device
45
- #global generator
46
- #global unet
47
- #global vae
48
- #global text_encoder
49
- #global tokenizer
50
- #global noise_scheduler
51
  generator = torch.Generator(device=device).manual_seed(seed)
52
  latents = torch.randn(
53
  (1, unet.in_channels, 512 // 8, 512 // 8),
 
12
  from sampling import sample_weights
13
  from huggingface_hub import snapshot_download
14
 
15
+ global device
16
+ global generator
17
+ global unet
18
+ global vae
19
+ global text_encoder
20
+ global tokenizer
21
+ global noise_scheduler
22
  device = "cuda:0"
23
 
24
  models_path = snapshot_download(repo_id="Snapchat/w2w")
 
34
  network = sample_weights(unet, proj, mean, std, v[:, :1000], device, factor = 1.00)
35
  #global network
36
 
37
+ def sample_model():
38
+ global unet
39
+ del unet
40
+ global network
41
+ unet, _, _, _, _ = load_models(device)
42
 
43
  def inference(prompt, negative_prompt, guidance_scale, ddim_steps, seed):
44
+ global device
45
+ global generator
46
+ global unet
47
+ global vae
48
+ global text_encoder
49
+ global tokenizer
50
+ global noise_scheduler
51
  generator = torch.Generator(device=device).manual_seed(seed)
52
  latents = torch.randn(
53
  (1, unet.in_channels, 512 // 8, 512 // 8),