multimodalart HF Staff commited on
Commit
dc32163
·
verified ·
1 Parent(s): 436eef9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -24
app.py CHANGED
@@ -12,13 +12,13 @@ from utils import load_models, save_model_w2w, save_model_for_diffusers
12
  from sampling import sample_weights
13
  from huggingface_hub import snapshot_download
14
 
15
- global device
16
- global generator
17
- global unet
18
- global vae
19
- global text_encoder
20
- global tokenizer
21
- global noise_scheduler
22
  device = "cuda:0"
23
  generator = torch.Generator(device=device)
24
 
@@ -32,24 +32,23 @@ df = torch.load(f"{models_path}/identity_df.pt")
32
  weight_dimensions = torch.load(f"{models_path}/weight_dimensions.pt")
33
 
34
  unet, vae, text_encoder, tokenizer, noise_scheduler = load_models(device)
 
 
35
 
36
- global network
37
-
38
- def sample_model():
39
- global unet
40
- del unet
41
- global network
42
- unet, _, _, _, _ = load_models(device)
43
- network = sample_weights(unet, proj, mean, std, v[:, :1000], device, factor = 1.00)
44
 
45
  def inference( prompt, negative_prompt, guidance_scale, ddim_steps, seed):
46
- global device
47
- global generator
48
- global unet
49
- global vae
50
- global text_encoder
51
- global tokenizer
52
- global noise_scheduler
53
  generator = generator.manual_seed(seed)
54
  latents = torch.randn(
55
  (1, unet.in_channels, 512 // 8, 512 // 8),
@@ -57,7 +56,6 @@ def inference( prompt, negative_prompt, guidance_scale, ddim_steps, seed):
57
  device = device
58
  ).bfloat16()
59
 
60
-
61
  text_input = tokenizer(prompt, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt")
62
 
63
  text_embeddings = text_encoder(text_input.input_ids.to(device))[0]
@@ -119,7 +117,7 @@ with gr.Blocks() as demo:
119
  with gr.Column():
120
  gallery = gr.Gallery(label="Generated Images")
121
 
122
- sample.click(fn=sample_model)
123
 
124
  submit.click(fn=inference,
125
  inputs=[prompt, negative_prompt, cfg, steps, seed],
 
12
  from sampling import sample_weights
13
  from huggingface_hub import snapshot_download
14
 
15
+ #global device
16
+ #global generator
17
+ #global unet
18
+ #global vae
19
+ #global text_encoder
20
+ #global tokenizer
21
+ #global noise_scheduler
22
  device = "cuda:0"
23
  generator = torch.Generator(device=device)
24
 
 
32
  weight_dimensions = torch.load(f"{models_path}/weight_dimensions.pt")
33
 
34
  unet, vae, text_encoder, tokenizer, noise_scheduler = load_models(device)
35
+ network = sample_weights(unet, proj, mean, std, v[:, :1000], device, factor = 1.00)
36
+ #global network
37
 
38
+ #def sample_model():
39
+ # global unet
40
+ # del unet
41
+ # global network
42
+ # unet, _, _, _, _ = load_models(device)
 
 
 
43
 
44
  def inference( prompt, negative_prompt, guidance_scale, ddim_steps, seed):
45
+ #global device
46
+ #global generator
47
+ #global unet
48
+ #global vae
49
+ #global text_encoder
50
+ #global tokenizer
51
+ #global noise_scheduler
52
  generator = generator.manual_seed(seed)
53
  latents = torch.randn(
54
  (1, unet.in_channels, 512 // 8, 512 // 8),
 
56
  device = device
57
  ).bfloat16()
58
 
 
59
  text_input = tokenizer(prompt, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt")
60
 
61
  text_embeddings = text_encoder(text_input.input_ids.to(device))[0]
 
117
  with gr.Column():
118
  gallery = gr.Gallery(label="Generated Images")
119
 
120
+ #sample.click(fn=sample_model)
121
 
122
  submit.click(fn=inference,
123
  inputs=[prompt, negative_prompt, cfg, steps, seed],