amildravid4292 commited on
Commit
216fb80
·
verified ·
1 Parent(s): 01aa716

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -25
app.py CHANGED
@@ -12,15 +12,14 @@ from utils import load_models, save_model_w2w, save_model_for_diffusers
12
  from sampling import sample_weights
13
  from huggingface_hub import snapshot_download
14
 
15
- global device
16
- global generator
17
- global unet
18
- global vae
19
- global text_encoder
20
- global tokenizer
21
- global noise_scheduler
22
  device = "cuda:0"
23
- generator = torch.Generator(device=device)
24
 
25
  models_path = snapshot_download(repo_id="Snapchat/w2w")
26
 
@@ -32,24 +31,23 @@ df = torch.load(f"{models_path}/identity_df.pt")
32
  weight_dimensions = torch.load(f"{models_path}/weight_dimensions.pt")
33
 
34
  unet, vae, text_encoder, tokenizer, noise_scheduler = load_models(device)
35
- global network
 
 
 
 
 
 
 
36
 
37
- def sample_model():
38
- global unet
39
- del unet
40
- global network
41
- unet, _, _, _, _ = load_models(device)
42
-
43
- @torch.no_grad()
44
  def inference(prompt, negative_prompt, guidance_scale, ddim_steps, seed):
45
- global network
46
- global device
47
- global generator
48
- global unet
49
- global vae
50
- global text_encoder
51
- global tokenizer
52
- global noise_scheduler
53
  generator = torch.Generator(device=device).manual_seed(seed)
54
  latents = torch.randn(
55
  (1, unet.in_channels, 512 // 8, 512 // 8),
@@ -118,7 +116,7 @@ with gr.Blocks() as demo:
118
  with gr.Column():
119
  gallery = gr.Gallery(label="Generated Images")
120
 
121
- sample.click(fn=sample_model)
122
 
123
  submit.click(fn=inference,
124
  inputs=[prompt, negative_prompt, cfg, steps, seed],
 
12
  from sampling import sample_weights
13
  from huggingface_hub import snapshot_download
14
 
15
+ #global device
16
+ #global generator
17
+ #global unet
18
+ #global vae
19
+ #global text_encoder
20
+ #global tokenizer
21
+ #global noise_scheduler
22
  device = "cuda:0"
 
23
 
24
  models_path = snapshot_download(repo_id="Snapchat/w2w")
25
 
 
31
  weight_dimensions = torch.load(f"{models_path}/weight_dimensions.pt")
32
 
33
  unet, vae, text_encoder, tokenizer, noise_scheduler = load_models(device)
34
+ network = sample_weights(unet, proj, mean, std, v[:, :1000], device, factor = 1.00)
35
+ #global network
36
+
37
+ #def sample_model():
38
+ # global unet
39
+ # del unet
40
+ # global network
41
+ # unet, _, _, _, _ = load_models(device)
42
 
 
 
 
 
 
 
 
43
  def inference(prompt, negative_prompt, guidance_scale, ddim_steps, seed):
44
+ #global device
45
+ #global generator
46
+ #global unet
47
+ #global vae
48
+ #global text_encoder
49
+ #global tokenizer
50
+ #global noise_scheduler
 
51
  generator = torch.Generator(device=device).manual_seed(seed)
52
  latents = torch.randn(
53
  (1, unet.in_channels, 512 // 8, 512 // 8),
 
116
  with gr.Column():
117
  gallery = gr.Gallery(label="Generated Images")
118
 
119
+ #sample.click(fn=sample_model)
120
 
121
  submit.click(fn=inference,
122
  inputs=[prompt, negative_prompt, cfg, steps, seed],