amildravid4292 commited on
Commit
4e3e10a
·
verified ·
1 Parent(s): 9f713c2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +73 -77
app.py CHANGED
@@ -21,15 +21,8 @@ import spaces
21
  import uuid
22
 
23
  global device
24
- global generator
25
- global unet
26
- global vae
27
- global text_encoder
28
- global tokenizer
29
- global noise_scheduler
30
- global network
31
  device = "cuda"
32
- #generator = torch.Generator(device=device)
33
 
34
  models_path = snapshot_download(repo_id="Snapchat/w2w")
35
 
@@ -41,29 +34,76 @@ df = torch.load(f"{models_path}/files/identity_df.pt")
41
  weight_dimensions = torch.load(f"{models_path}/files/weight_dimensions.pt")
42
  pinverse = torch.load(f"{models_path}/files/pinverse_1000pc.pt", map_location=torch.device('cpu')).bfloat16().to(device)
43
 
44
- unet, vae, text_encoder, tokenizer, noise_scheduler = load_models(device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
 
46
- def sample_model():
47
- global unet
48
- del unet
49
- global network
 
 
50
  mean.to(device)
51
  std.to(device)
52
  v.to(device)
53
  proj.to(device)
54
  unet, _, _, _, _ = load_models(device)
55
  network = sample_weights(unet, proj, mean, std, v[:, :1000], device, factor = 1.00)
56
-
 
 
 
 
 
 
 
 
 
 
 
 
57
  @torch.no_grad()
58
  @spaces.GPU
59
- def inference( prompt, negative_prompt, guidance_scale, ddim_steps, seed):
60
  global device
61
- #global generator
62
- global unet
63
- global vae
64
- global text_encoder
65
- global tokenizer
66
- global noise_scheduler
67
  generator = torch.Generator(device=device).manual_seed(seed)
68
  latents = torch.randn(
69
  (1, unet.in_channels, 512 // 8, 512 // 8),
@@ -192,61 +232,10 @@ def edit_inference(prompt, negative_prompt, guidance_scale, ddim_steps, seed, st
192
 
193
  return image
194
 
195
- @spaces.GPU
196
- def sample_then_run():
197
- sample_model()
198
- prompt = "sks person"
199
- negative_prompt = "low quality, blurry, unfinished, nudity, weapon"
200
- seed = 5
201
- cfg = 3.0
202
- steps = 25
203
- image = inference( prompt, negative_prompt, cfg, steps, seed)
204
- torch.save(network.proj, "model.pt" )
205
- return image, "model.pt"
206
 
207
- #@spaces.GPU
208
- def start_items():
209
- print("Starting items")
210
- global young
211
- global pointy
212
- global wavy
213
- global thick
214
- young = get_direction(df, "Young", pinverse, 1000, device)
215
- young = debias(young, "Male", df, pinverse, device)
216
- young = debias(young, "Pointy_Nose", df, pinverse, device)
217
- young = debias(young, "Wavy_Hair", df, pinverse, device)
218
- young = debias(young, "Chubby", df, pinverse, device)
219
- young = debias(young, "No_Beard", df, pinverse, device)
220
- young = debias(young, "Mustache", df, pinverse, device)
221
-
222
- pointy = get_direction(df, "Pointy_Nose", pinverse, 1000, device)
223
- pointy = debias(pointy, "Young", df, pinverse, device)
224
- pointy = debias(pointy, "Male", df, pinverse, device)
225
- pointy = debias(pointy, "Wavy_Hair", df, pinverse, device)
226
- pointy = debias(pointy, "Chubby", df, pinverse, device)
227
- pointy = debias(pointy, "Heavy_Makeup", df, pinverse, device)
228
-
229
- wavy = get_direction(df, "Wavy_Hair", pinverse, 1000, device)
230
- wavy = debias(wavy, "Young", df, pinverse, device)
231
- wavy = debias(wavy, "Male", df, pinverse, device)
232
- wavy = debias(wavy, "Pointy_Nose", df, pinverse, device)
233
- wavy = debias(wavy, "Chubby", df, pinverse, device)
234
- wavy = debias(wavy, "Heavy_Makeup", df, pinverse, device)
235
-
236
- thick = get_direction(df, "Bushy_Eyebrows", pinverse, 1000, device)
237
- thick = debias(thick, "Male", df, pinverse, device)
238
- thick = debias(thick, "Young", df, pinverse, device)
239
- thick = debias(thick, "Pointy_Nose", df, pinverse, device)
240
- thick = debias(thick, "Wavy_Hair", df, pinverse, device)
241
- thick = debias(thick, "Mustache", df, pinverse, device)
242
- thick = debias(thick, "No_Beard", df, pinverse, device)
243
- thick = debias(thick, "Sideburns", df, pinverse, device)
244
- thick = debias(thick, "Big_Nose", df, pinverse, device)
245
- thick = debias(thick, "Big_Lips", df, pinverse, device)
246
- thick = debias(thick, "Black_Hair", df, pinverse, device)
247
- thick = debias(thick, "Brown_Hair", df, pinverse, device)
248
- thick = debias(thick, "Pale_Skin", df, pinverse, device)
249
- thick = debias(thick, "Heavy_Makeup", df, pinverse, device)
250
 
251
  class CustomImageDataset(Dataset):
252
  def __init__(self, images, transform=None):
@@ -403,8 +392,15 @@ intro = """
403
 
404
 
405
  with gr.Blocks(css="style.css") as demo:
406
- gr.HTML(intro)
 
 
 
 
 
 
407
 
 
408
  gr.Markdown("""<div style="text-align: justify;"> In this demo, you can get an identity-encoding model by sampling or inverting. To use a model previously downloaded from this demo see \"Uploading a model\" in the Advanced Options. Next, you can generate new images from it, or edit the identity encoded in the model and generate images from the edited model. We provide detailed instructions and tips at the bottom of the page.""")
409
  with gr.Column():
410
  with gr.Row():
@@ -479,7 +475,7 @@ with gr.Blocks(css="style.css") as demo:
479
  outputs = [input_image, file_output])
480
 
481
 
482
- sample.click(fn=sample_then_run, outputs=[input_image, file_output])
483
 
484
  submit.click(
485
  fn=edit_inference, inputs=[prompt, negative_prompt, cfg, steps, seed, injection_step, a1, a2, a3, a4], outputs=[gallery]
 
21
  import uuid
22
 
23
  global device
 
 
 
 
 
 
 
24
  device = "cuda"
25
+
26
 
27
  models_path = snapshot_download(repo_id="Snapchat/w2w")
28
 
 
34
  weight_dimensions = torch.load(f"{models_path}/files/weight_dimensions.pt")
35
  pinverse = torch.load(f"{models_path}/files/pinverse_1000pc.pt", map_location=torch.device('cpu')).bfloat16().to(device)
36
 
37
+ global young
38
+ global pointy
39
+ global wavy
40
+ global thick
41
+ young = get_direction(df, "Young", pinverse, 1000, device)
42
+ young = debias(young, "Male", df, pinverse, device)
43
+ young = debias(young, "Pointy_Nose", df, pinverse, device)
44
+ young = debias(young, "Wavy_Hair", df, pinverse, device)
45
+ young = debias(young, "Chubby", df, pinverse, device)
46
+ young = debias(young, "No_Beard", df, pinverse, device)
47
+ young = debias(young, "Mustache", df, pinverse, device)
48
+
49
+ pointy = get_direction(df, "Pointy_Nose", pinverse, 1000, device)
50
+ pointy = debias(pointy, "Young", df, pinverse, device)
51
+ pointy = debias(pointy, "Male", df, pinverse, device)
52
+ pointy = debias(pointy, "Wavy_Hair", df, pinverse, device)
53
+ pointy = debias(pointy, "Chubby", df, pinverse, device)
54
+ pointy = debias(pointy, "Heavy_Makeup", df, pinverse, device)
55
+
56
+ wavy = get_direction(df, "Wavy_Hair", pinverse, 1000, device)
57
+ wavy = debias(wavy, "Young", df, pinverse, device)
58
+ wavy = debias(wavy, "Male", df, pinverse, device)
59
+ wavy = debias(wavy, "Pointy_Nose", df, pinverse, device)
60
+ wavy = debias(wavy, "Chubby", df, pinverse, device)
61
+ wavy = debias(wavy, "Heavy_Makeup", df, pinverse, device)
62
+
63
+ thick = get_direction(df, "Bushy_Eyebrows", pinverse, 1000, device)
64
+ thick = debias(thick, "Male", df, pinverse, device)
65
+ thick = debias(thick, "Young", df, pinverse, device)
66
+ thick = debias(thick, "Pointy_Nose", df, pinverse, device)
67
+ thick = debias(thick, "Wavy_Hair", df, pinverse, device)
68
+ thick = debias(thick, "Mustache", df, pinverse, device)
69
+ thick = debias(thick, "No_Beard", df, pinverse, device)
70
+ thick = debias(thick, "Sideburns", df, pinverse, device)
71
+ thick = debias(thick, "Big_Nose", df, pinverse, device)
72
+ thick = debias(thick, "Big_Lips", df, pinverse, device)
73
+ thick = debias(thick, "Black_Hair", df, pinverse, device)
74
+ thick = debias(thick, "Brown_Hair", df, pinverse, device)
75
+ thick = debias(thick, "Pale_Skin", df, pinverse, device)
76
+ thick = debias(thick, "Heavy_Makeup", df, pinverse, device)
77
 
78
+
79
+
80
+ @torch.no_grad()
81
+ @spaces.GPU
82
+ def sample_then_run(network, unet):
83
+ #load models
84
  mean.to(device)
85
  std.to(device)
86
  v.to(device)
87
  proj.to(device)
88
  unet, _, _, _, _ = load_models(device)
89
  network = sample_weights(unet, proj, mean, std, v[:, :1000], device, factor = 1.00)
90
+ #inference
91
+ prompt = "sks person"
92
+ negative_prompt = "low quality, blurry, unfinished, nudity, weapon"
93
+ seed = 5
94
+ cfg = 3.0
95
+ steps = 25
96
+ image = inference( prompt, negative_prompt, cfg, steps, seed)
97
+ torch.save(network.proj, "model.pt" )
98
+ #return
99
+ return image, "model.pt", unet, network
100
+
101
+
102
+
103
  @torch.no_grad()
104
  @spaces.GPU
105
+ def inference(network, unet, prompt, negative_prompt, guidance_scale, ddim_steps, seed):
106
  global device
 
 
 
 
 
 
107
  generator = torch.Generator(device=device).manual_seed(seed)
108
  latents = torch.randn(
109
  (1, unet.in_channels, 512 // 8, 512 // 8),
 
232
 
233
  return image
234
 
 
 
 
 
 
 
 
 
 
 
 
235
 
236
+
237
+
238
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
239
 
240
  class CustomImageDataset(Dataset):
241
  def __init__(self, images, transform=None):
 
392
 
393
 
394
  with gr.Blocks(css="style.css") as demo:
395
+ network = gr.State()
396
+ unet = gr.State()
397
+ vae = gr.State()
398
+ text_encoder = gr.State()
399
+ tokenizer = gr.State()
400
+ noise_scheduler = gr.State()
401
+ _, vae, text_encoder, tokenizer, noise_scheduler = load_models(device)
402
 
403
+ gr.HTML(intro)
404
  gr.Markdown("""<div style="text-align: justify;"> In this demo, you can get an identity-encoding model by sampling or inverting. To use a model previously downloaded from this demo see \"Uploading a model\" in the Advanced Options. Next, you can generate new images from it, or edit the identity encoded in the model and generate images from the edited model. We provide detailed instructions and tips at the bottom of the page.""")
405
  with gr.Column():
406
  with gr.Row():
 
475
  outputs = [input_image, file_output])
476
 
477
 
478
+ sample.click(fn=sample_then_run, outputs=[network, unet, input_image, file_output])
479
 
480
  submit.click(
481
  fn=edit_inference, inputs=[prompt, negative_prompt, cfg, steps, seed, injection_step, a1, a2, a3, a4], outputs=[gallery]