amildravid4292 commited on
Commit
f15739b
·
verified ·
1 Parent(s): 8ba180b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +428 -445
app.py CHANGED
@@ -20,482 +20,465 @@ from huggingface_hub import snapshot_download
20
  import spaces
21
 
22
 
23
-
24
-
25
- models_path = snapshot_download(repo_id="Snapchat/w2w")
26
-
27
- device = "cuda"
28
- mean = torch.load(f"{models_path}/files/mean.pt", map_location=torch.device('cpu')).bfloat16().to(device)
29
- std = torch.load(f"{models_path}/files/std.pt", map_location=torch.device('cpu')).bfloat16().to(device)
30
- v = torch.load(f"{models_path}/files/V.pt", map_location=torch.device('cpu')).bfloat16().to(device)
31
- proj = torch.load(f"{models_path}/files/proj_1000pc.pt", map_location=torch.device('cpu')).bfloat16().to(device)
32
- df = torch.load(f"{models_path}/files/identity_df.pt")
33
- weight_dimensions = torch.load(f"{models_path}/files/weight_dimensions.pt")
34
- pinverse = torch.load(f"{models_path}/files/pinverse_1000pc.pt", map_location=torch.device('cpu')).bfloat16().to(device)
35
-
36
-
37
- unet, vae, text_encoder, tokenizer, noise_scheduler = load_models(device)
38
-
39
-
40
- young = get_direction(df, "Young", pinverse, 1000, device)
41
- young = debias(young, "Male", df, pinverse, device)
42
- young = debias(young, "Pointy_Nose", df, pinverse, device)
43
- young = debias(young, "Wavy_Hair", df, pinverse, device)
44
- young = debias(young, "Chubby", df, pinverse, device)
45
- young = debias(young, "No_Beard", df, pinverse, device)
46
- young = debias(young, "Mustache", df, pinverse, device)
47
-
48
-
49
- pointy = get_direction(df, "Pointy_Nose", pinverse, 1000, device)
50
- pointy = debias(pointy, "Young", df, pinverse, device)
51
- pointy = debias(pointy, "Male", df, pinverse, device)
52
- pointy = debias(pointy, "Wavy_Hair", df, pinverse, device)
53
- pointy = debias(pointy, "Chubby", df, pinverse, device)
54
- pointy = debias(pointy, "Heavy_Makeup", df, pinverse, device)
55
-
56
-
57
-
58
- wavy = get_direction(df, "Wavy_Hair", pinverse, 1000, device)
59
- wavy = debias(wavy, "Young", df, pinverse, device)
60
- wavy = debias(wavy, "Male", df, pinverse, device)
61
- wavy = debias(wavy, "Pointy_Nose", df, pinverse, device)
62
- wavy = debias(wavy, "Chubby", df, pinverse, device)
63
- wavy = debias(wavy, "Heavy_Makeup", df, pinverse, device)
64
-
65
-
66
- thick = get_direction(df, "Bushy_Eyebrows", pinverse, 1000, device)
67
- thick = debias(thick, "Male", df, pinverse, device)
68
- thick = debias(thick, "Young", df, pinverse, device)
69
- thick = debias(thick, "Pointy_Nose", df, pinverse, device)
70
- thick = debias(thick, "Wavy_Hair", df, pinverse, device)
71
- thick = debias(thick, "Mustache", df, pinverse, device)
72
- thick = debias(thick, "No_Beard", df, pinverse, device)
73
- thick = debias(thick, "Sideburns", df, pinverse, device)
74
- thick = debias(thick, "Big_Nose", df, pinverse, device)
75
- thick = debias(thick, "Big_Lips", df, pinverse, device)
76
- thick = debias(thick, "Black_Hair", df, pinverse, device)
77
- thick = debias(thick, "Brown_Hair", df, pinverse, device)
78
- thick = debias(thick, "Pale_Skin", df, pinverse, device)
79
- thick = debias(thick, "Heavy_Makeup", df, pinverse, device)
80
-
81
- def sample_model(unet, network):
82
- del unet
83
- del network
84
- mean.to(device)
85
- std.to(device)
86
- v.to(device)
87
- proj.to(device)
88
- unet, _, _, _, _ = load_models(device)
89
- network = sample_weights(unet, proj, mean, std, v[:, :1000], device, factor = 1.00)
90
-
91
- @torch.no_grad()
92
- @spaces.GPU
93
- def inference( prompt, negative_prompt, guidance_scale, ddim_steps, seed):
94
- generator = torch.Generator(device=device).manual_seed(seed)
95
- latents = torch.randn(
96
- (1, unet.in_channels, 512 // 8, 512 // 8),
97
- generator = generator,
98
- device = device
99
- ).bfloat16()
100
-
101
-
102
- text_input = tokenizer(prompt, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt")
103
-
104
- text_embeddings = text_encoder(text_input.input_ids.to(device))[0]
105
-
106
- max_length = text_input.input_ids.shape[-1]
107
- uncond_input = tokenizer(
108
- [negative_prompt], padding="max_length", max_length=max_length, return_tensors="pt"
109
- )
110
- uncond_embeddings = text_encoder(uncond_input.input_ids.to(device))[0]
111
- text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
112
- noise_scheduler.set_timesteps(ddim_steps)
113
- latents = latents * noise_scheduler.init_noise_sigma
114
-
115
- for i,t in enumerate(tqdm.tqdm(noise_scheduler.timesteps)):
116
- latent_model_input = torch.cat([latents] * 2)
117
- latent_model_input = noise_scheduler.scale_model_input(latent_model_input, timestep=t)
118
- with network:
119
- noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings, timestep_cond= None).sample
120
- #guidance
121
- noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
122
- noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
123
- latents = noise_scheduler.step(noise_pred, t, latents).prev_sample
124
-
125
- latents = 1 / 0.18215 * latents
126
- image = vae.decode(latents).sample
127
- image = (image / 2 + 0.5).clamp(0, 1)
128
- image = image.detach().cpu().float().permute(0, 2, 3, 1).numpy()[0]
129
-
130
- image = Image.fromarray((image * 255).round().astype("uint8"))
131
-
132
- return image
133
-
134
-
135
- @torch.no_grad()
136
- @spaces.GPU
137
- def edit_inference(prompt, negative_prompt, guidance_scale, ddim_steps, seed, start_noise, a1, a2, a3, a4):
138
- global device
139
- #global generator
140
- global unet
141
- global vae
142
- global text_encoder
143
- global tokenizer
144
- global noise_scheduler
145
- global young
146
- global pointy
147
- global wavy
148
- global thick
149
-
150
- original_weights = network.proj.clone()
151
-
152
- #pad to same number of PCs
153
- pcs_original = original_weights.shape[1]
154
- pcs_edits = young.shape[1]
155
- padding = torch.zeros((1,pcs_original-pcs_edits)).to(device)
156
- young_pad = torch.cat((young, padding), 1)
157
- pointy_pad = torch.cat((pointy, padding), 1)
158
- wavy_pad = torch.cat((wavy, padding), 1)
159
- thick_pad = torch.cat((thick, padding), 1)
160
 
161
-
162
- edited_weights = original_weights+a1*1e6*young_pad+a2*1e6*pointy_pad+a3*1e6*wavy_pad+a4*2e6*thick_pad
163
-
164
- generator = torch.Generator(device=device).manual_seed(seed)
165
- latents = torch.randn(
166
- (1, unet.in_channels, 512 // 8, 512 // 8),
167
- generator = generator,
168
- device = device
169
- ).bfloat16()
170
-
171
-
172
- text_input = tokenizer(prompt, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt")
173
-
174
- text_embeddings = text_encoder(text_input.input_ids.to(device))[0]
175
-
176
- max_length = text_input.input_ids.shape[-1]
177
- uncond_input = tokenizer(
178
- [negative_prompt], padding="max_length", max_length=max_length, return_tensors="pt"
179
- )
180
- uncond_embeddings = text_encoder(uncond_input.input_ids.to(device))[0]
181
- text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
182
- noise_scheduler.set_timesteps(ddim_steps)
183
- latents = latents * noise_scheduler.init_noise_sigma
184
 
185
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
186
 
187
- for i,t in enumerate(tqdm.tqdm(noise_scheduler.timesteps)):
188
- latent_model_input = torch.cat([latents] * 2)
189
- latent_model_input = noise_scheduler.scale_model_input(latent_model_input, timestep=t)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
190
 
191
- if t>start_noise:
192
- pass
193
- elif t<=start_noise:
194
- network.proj = torch.nn.Parameter(edited_weights)
195
- network.reset()
196
-
197
-
198
- with network:
199
- noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings, timestep_cond= None).sample
200
-
201
 
202
- #guidance
203
- noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
204
- noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
205
- latents = noise_scheduler.step(noise_pred, t, latents).prev_sample
206
-
207
- latents = 1 / 0.18215 * latents
208
- image = vae.decode(latents).sample
209
- image = (image / 2 + 0.5).clamp(0, 1)
210
-
211
- image = image.detach().cpu().float().permute(0, 2, 3, 1).numpy()[0]
212
-
213
- image = Image.fromarray((image * 255).round().astype("uint8"))
214
-
215
- #reset weights back to original
216
- network.proj = torch.nn.Parameter(original_weights)
217
- network.reset()
218
-
219
- return image
220
-
221
- @spaces.GPU
222
- def sample_then_run():
223
- sample_model()
224
- prompt = "sks person"
225
- negative_prompt = "low quality, blurry, unfinished, nudity, weapon"
226
- seed = 5
227
- cfg = 3.0
228
- steps = 25
229
- image = inference( prompt, negative_prompt, cfg, steps, seed)
230
- torch.save(network.proj, "model.pt" )
231
- return image, "model.pt"
232
-
233
-
234
-
235
- class CustomImageDataset(Dataset):
236
- def __init__(self, images, transform=None):
237
- self.images = images
238
- self.transform = transform
239
-
240
- def __len__(self):
241
- return len(self.images)
242
-
243
- def __getitem__(self, idx):
244
- image = self.images[idx]
245
- if self.transform:
246
- image = self.transform(image)
247
  return image
248
 
249
- @spaces.GPU
250
- def invert(image, mask, pcs=10000, epochs=400, weight_decay = 1e-10, lr=1e-1):
251
- global unet
252
- del unet
253
- global network
254
- unet, _, _, _, _ = load_models(device)
255
-
256
- proj = torch.zeros(1,pcs).bfloat16().to(device)
257
- network = LoRAw2w( proj, mean, std, v[:, :pcs],
258
- unet,
259
- rank=1,
260
- multiplier=1.0,
261
- alpha=27.0,
262
- train_method="xattn-strict"
263
- ).to(device, torch.bfloat16)
264
 
265
- ### load mask
266
- mask = transforms.Resize((64,64), interpolation=transforms.InterpolationMode.BILINEAR)(mask)
267
- mask = torchvision.transforms.functional.pil_to_tensor(mask).unsqueeze(0).to(device).bfloat16()[:,0,:,:].unsqueeze(1)
268
- ### check if an actual mask was draw, otherwise mask is just all ones
269
- if torch.sum(mask) == 0:
270
- mask = torch.ones((1,1,64,64)).to(device).bfloat16()
271
 
272
- ### single image dataset
273
- image_transforms = transforms.Compose([transforms.Resize(512, interpolation=transforms.InterpolationMode.BILINEAR),
274
- transforms.RandomCrop(512),
275
- transforms.ToTensor(),
276
- transforms.Normalize([0.5], [0.5])])
277
-
278
-
279
- train_dataset = CustomImageDataset(image, transform=image_transforms)
280
- train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=1, shuffle=True)
281
-
282
- ### optimizer
283
- optim = torch.optim.Adam(network.parameters(), lr=lr, weight_decay=weight_decay)
284
-
285
- ### training loop
286
- unet.train()
287
- for epoch in tqdm.tqdm(range(epochs)):
288
- for batch in train_dataloader:
289
- ### prepare inputs
290
- batch = batch.to(device).bfloat16()
291
- latents = vae.encode(batch).latent_dist.sample()
292
- latents = latents*0.18215
293
- noise = torch.randn_like(latents)
294
- bsz = latents.shape[0]
295
-
296
- timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
297
- timesteps = timesteps.long()
298
- noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
299
- text_input = tokenizer("sks person", padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt")
300
- text_embeddings = text_encoder(text_input.input_ids.to(device))[0]
301
-
302
- ### loss + sgd step
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
303
  with network:
304
- model_pred = unet(noisy_latents, timesteps, text_embeddings).sample
305
- loss = torch.nn.functional.mse_loss(mask*model_pred.float(), mask*noise.float(), reduction="mean")
306
- optim.zero_grad()
307
- loss.backward()
308
- optim.step()
309
-
310
- ### return optimized network
311
- return network
312
-
313
-
314
- @spaces.GPU
315
- def run_inversion(dict, pcs, epochs, weight_decay,lr):
316
- global network
317
- init_image = dict["image"].convert("RGB").resize((512, 512))
318
- mask = dict["mask"].convert("RGB").resize((512, 512))
319
- network = invert([init_image], mask, pcs, epochs, weight_decay,lr)
320
-
321
-
322
- #sample an image
323
- prompt = "sks person"
324
- negative_prompt = "low quality, blurry, unfinished, nudity"
325
- seed = 5
326
- cfg = 3.0
327
- steps = 25
328
- image = inference( prompt, negative_prompt, cfg, steps, seed)
329
- torch.save(network.proj, "model.pt" )
330
- return image, "model.pt"
331
-
332
-
333
- @spaces.GPU
334
- def file_upload(file):
335
- global unet
336
- del unet
337
- global network
338
- global device
339
-
340
 
 
341
 
342
- proj = torch.load(file.name).to(device)
 
 
 
 
 
 
343
 
344
- #pad to 10000 Principal components to keep everything consistent
345
- pcs = proj.shape[1]
346
- padding = torch.zeros((1,10000-pcs)).to(device)
347
- proj = torch.cat((proj, padding), 1)
 
 
 
 
 
 
 
348
 
349
- unet, _, _, _, _ = load_models(device)
350
 
351
 
352
- network = LoRAw2w( proj, mean, std, v[:, :10000],
353
- unet,
354
- rank=1,
355
- multiplier=1.0,
356
- alpha=27.0,
357
- train_method="xattn-strict"
358
- ).to(device, torch.bfloat16)
 
 
 
 
 
 
 
 
 
359
 
 
 
 
 
 
 
 
 
 
 
 
 
360
 
361
- prompt = "sks person"
362
- negative_prompt = "low quality, blurry, unfinished, nudity"
363
- seed = 5
364
- cfg = 3.0
365
- steps = 25
366
- image = inference( prompt, negative_prompt, cfg, steps, seed)
367
- return image
 
 
 
 
 
368
 
369
 
370
-
 
371
 
372
- intro = """
373
- <div style="display: flex;align-items: center;justify-content: center">
374
- <h1 style="margin-left: 12px;text-align: center;margin-bottom: 7px;display: inline-block"><em>weights2weights</em> Demo</h1>
375
- <h3 style="display: inline-block;margin-left: 10px;margin-top: 6px;font-weight: 500">Interpreting the Weight Space of Customized Diffusion Models</h3>
376
- </div>
377
- <p style="font-size: 0.95rem;margin: 0rem;line-height: 1.2em;margin-top:1em;display: inline-block">
378
- <a href="https://snap-research.github.io/weights2weights/" target="_blank">Project Page</a> | <a href="https://arxiv.org/abs/2406.09413" target="_blank">Paper</a>
379
- | <a href="https://github.com/snap-research/weights2weights" target="_blank">Code</a> |
380
- <a href="https://huggingface.co/spaces/Snapchat/w2w-demo?duplicate=true" target="_blank" style="
381
- display: inline-block;
382
- ">
383
- <img style="margin-top: -1em;margin-bottom: 0em;position: absolute;" src="https://bit.ly/3CWLGkA" alt="Duplicate Space"></a>
384
- </p>
385
- """
386
-
387
-
388
-
389
- with gr.Blocks(css="style.css") as demo:
390
-
391
- unet, vae, text_encoder, tokenizer, noise_scheduler = load_models(device)
392
- network = None
393
 
394
- gr.HTML(intro)
395
-
396
- gr.Markdown("""<div style="text-align: justify;"> In this demo, you can get an identity-encoding model by sampling or inverting. To use a model previously downloaded from this demo see \"Uploading a model\" in the Advanced Options. Next, you can generate new images from it, or edit the identity encoded in the model and generate images from the edited model. We provide detailed instructions and tips at the bottom of the page.""")
397
- with gr.Column():
398
- with gr.Row():
399
- with gr.Column():
400
- gr.Markdown("""1) Either sample a new model, or upload an image (optionally draw a mask over the head) and click `invert`.""")
401
- sample = gr.Button("🎲 Sample New Model")
402
- input_image = gr.ImageEditor(elem_id="image_upload", type='pil', label="Reference Identity",
403
- width=512, height=512)
404
-
405
- with gr.Row():
406
- invert_button = gr.Button("⬆️ Invert")
407
-
 
 
 
 
 
 
 
 
 
 
 
 
 
408
 
409
 
410
- with gr.Column():
411
- gr.Markdown("""2) Generate images of the sampled/inverted identity or edit the identity with the sliders and generate new images with various prompts and seeds.""")
412
- gallery = gr.Image(label="Generated Image",height=512, width=512, interactive=False)
413
- submit = gr.Button("Generate")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
414
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
415
 
416
- prompt = gr.Textbox(label="Prompt",
417
- info="Make sure to include 'sks person'" ,
418
- placeholder="sks person",
419
- value="sks person")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
420
 
421
- seed = gr.Number(value=5, label="Seed", precision=0, interactive=True)
422
 
423
- # Editing
424
- with gr.Column():
425
- with gr.Row():
426
- a1 = gr.Slider(label="- Young +", value=0, step=0.001, minimum=-1, maximum=1, interactive=True)
427
- a2 = gr.Slider(label="- Pointy Nose +", value=0, step=0.001, minimum=-1, maximum=1, interactive=True)
428
- with gr.Row():
429
- a3 = gr.Slider(label="- Curly Hair +", value=0, step=0.001, minimum=-1, maximum=1, interactive=True)
430
- a4 = gr.Slider(label="- Thick Eyebrows +", value=0, step=0.001, minimum=-1, maximum=1, interactive=True)
431
-
432
 
433
- with gr.Accordion("Advanced Options", open=False):
434
- with gr.Tab("Inversion"):
435
- with gr.Row():
436
- lr = gr.Number(value=1e-1, label="Learning Rate", interactive=True)
437
- pcs = gr.Slider(label="# Principal Components", value=10000, step=1, minimum=1, maximum=10000, interactive=True)
438
- with gr.Row():
439
- epochs = gr.Slider(label="Epochs", value=800, step=1, minimum=1, maximum=2000, interactive=True)
440
- weight_decay = gr.Number(value=1e-10, label="Weight Decay", interactive=True)
441
- with gr.Tab("Sampling"):
442
- with gr.Row():
443
- cfg= gr.Slider(label="CFG", value=3.0, step=0.1, minimum=0, maximum=10, interactive=True)
444
- steps = gr.Slider(label="Inference Steps", value=25, step=1, minimum=0, maximum=100, interactive=True)
445
- with gr.Row():
446
- negative_prompt = gr.Textbox(label="Negative Prompt", placeholder="low quality, blurry, unfinished, nudity, weapon", value="low quality, blurry, unfinished, nudity, weapon")
447
- injection_step = gr.Slider(label="Injection Step", value=800, step=1, minimum=0, maximum=1000, interactive=True)
448
-
449
- with gr.Tab("Uploading a model"):
450
- gr.Markdown("""<div style="text-align: justify;">Upload a model below downloaded from this demo.""")
451
-
452
- file_input = gr.File(label="Upload Model", container=True)
453
-
454
-
455
-
456
-
457
- gr.Markdown("""<div style="text-align: justify;"> After sampling a new model or inverting, you can download the model below.""")
458
-
459
- with gr.Row():
460
- file_output = gr.File(label="Download Sampled/Inverted Model", container=True, interactive=False)
461
-
462
-
463
-
464
-
465
- invert_button.click(fn=run_inversion,
466
- inputs=[input_image, pcs, epochs, weight_decay,lr],
467
- outputs = [input_image, file_output])
468
 
469
 
470
- sample.click(fn=sample_then_run, inputs=[unet, network], outputs=[input_image, file_output])
 
471
 
472
- submit.click(
473
- fn=edit_inference, inputs=[prompt, negative_prompt, cfg, steps, seed, injection_step, a1, a2, a3, a4], outputs=[gallery]
474
- )
475
- file_input.change(fn=file_upload, inputs=file_input, outputs = gallery)
476
-
477
 
478
 
479
- help_text1 = """
480
- <b>Instructions</b>:
481
- 1. To get results faster without waiting in queue, you can duplicate into a private space with an A100 GPU.
482
- 2. To begin, you will have to get an identity-encoding model. You can either sample one from *weights2weights* space by clicking `Sample New Model` or by uploading an image and clicking `invert` to invert the identity into a model. You can optionally draw over the head to define a mask in the image for better results. Sampling a model takes around 10 seconds and inversion takes around 2 minutes. After this is done, you can optionally download this model for later use. A model can be uploaded in the \"Uploading a model\" tab in the `Advanced Options`.
483
- 3. After getting a model, an image of the identity will be displayed on the right. You can sample from the model by changing seeds as well as prompts and then clicking `Generate`. Make sure to include \"sks person\" in your prompt to keep the same identity.
484
- 4. The identity in the model can be edited by changing the sliders for various attributes. After clicking `Generate`, you can see how the identity has changed and the effects are maintained across different seeds and prompts.
485
- """
486
- help_text2 = """<b>Tips</b>:
487
- 1. Editing and Identity Generation
488
- * If you are interested in preserving more of the image during identity-editing (i.e., where the same seed and prompt results in the same image with only the identity changed), you can play with the "Injection Step" parameter in the \"Sampling\" tab in the `Advanced Options`. During the first *n* timesteps, the original model's weights will be used, and then the edited weights will be set during the remaining steps. Values closer to 1000 will set the edited weights early, having a more pronounced effect, which may disrupt some semantics and structure of the generated image. Lower values will set the edited weights later, better preserving image context. We notice that around 600-800 tends to produce the best results. Larger values in the range (700-1000) are helpful for more global attribute changes, while smaller (400-700) can be used for more finegrained edits. Although it is not always needed.
489
- * You can play around with negative prompts, number of inference steps, and CFG in the \"Sampling\" tab in the `Advanced Options` to affect the ultimate image quality.
490
- * Sometimes the identity will not be perfectly consistent (e.g., there might be small variations of the face) when you use some seeds or prompts. This is a limitation of our method as well as an open-problem in personalized models.
491
- 2. Inversion
492
- * To obtain the best results for inversion, upload a high resolution photo of the face with minimal occlusion. It is recommended to draw over the face and hair to define a mask. But inversion should still work generally for non-closeup face shots.
493
- * For inverting a realistic photo of an identity, typically 800 epochs with lr=1e-1 and 10,000 principal components (PCs) works well. If the resulting generations have artifacted and unrealstic textures, there is probably overfitting and you may want to reduce the number of epochs or learning rate, or play with weight decay. If the generations do not look like the input photo, then you may want to increase the number of epochs.
494
- * For inverting out-of-distribution identities, such as artistic renditions of people or non-humans (e.g. the ones shown in the paper), it is recommended to use 1000 PCs, lr=1, and train for 800 epochs.
495
- * Note that if you change the number of PCs, you will probably need to change the learning rate. For less PCs, higher learning rates are typically required."""
496
-
497
-
498
- gr.Markdown(help_text1)
499
- gr.Markdown(help_text2)
500
- #demo.load(fn=start_items)
501
- demo.queue().launch()
 
20
  import spaces
21
 
22
 
23
+ def main():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
+ models_path = snapshot_download(repo_id="Snapchat/w2w")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
+ device = "cuda"
28
+ mean = torch.load(f"{models_path}/files/mean.pt", map_location=torch.device('cpu')).bfloat16().to(device)
29
+ std = torch.load(f"{models_path}/files/std.pt", map_location=torch.device('cpu')).bfloat16().to(device)
30
+ v = torch.load(f"{models_path}/files/V.pt", map_location=torch.device('cpu')).bfloat16().to(device)
31
+ proj = torch.load(f"{models_path}/files/proj_1000pc.pt", map_location=torch.device('cpu')).bfloat16().to(device)
32
+ df = torch.load(f"{models_path}/files/identity_df.pt")
33
+ weight_dimensions = torch.load(f"{models_path}/files/weight_dimensions.pt")
34
+ pinverse = torch.load(f"{models_path}/files/pinverse_1000pc.pt", map_location=torch.device('cpu')).bfloat16().to(device)
35
+
36
+
37
+ unet, vae, text_encoder, tokenizer, noise_scheduler = load_models(device)
38
+ network = None
39
+
40
+ young = get_direction(df, "Young", pinverse, 1000, device)
41
+ young = debias(young, "Male", df, pinverse, device)
42
+ young = debias(young, "Pointy_Nose", df, pinverse, device)
43
+ young = debias(young, "Wavy_Hair", df, pinverse, device)
44
+ young = debias(young, "Chubby", df, pinverse, device)
45
+ young = debias(young, "No_Beard", df, pinverse, device)
46
+ young = debias(young, "Mustache", df, pinverse, device)
47
+
48
+ pointy = get_direction(df, "Pointy_Nose", pinverse, 1000, device)
49
+ pointy = debias(pointy, "Young", df, pinverse, device)
50
+ pointy = debias(pointy, "Male", df, pinverse, device)
51
+ pointy = debias(pointy, "Wavy_Hair", df, pinverse, device)
52
+ pointy = debias(pointy, "Chubby", df, pinverse, device)
53
+ pointy = debias(pointy, "Heavy_Makeup", df, pinverse, device)
54
+
55
+ wavy = get_direction(df, "Wavy_Hair", pinverse, 1000, device)
56
+ wavy = debias(wavy, "Young", df, pinverse, device)
57
+ wavy = debias(wavy, "Male", df, pinverse, device)
58
+ wavy = debias(wavy, "Pointy_Nose", df, pinverse, device)
59
+ wavy = debias(wavy, "Chubby", df, pinverse, device)
60
+ wavy = debias(wavy, "Heavy_Makeup", df, pinverse, device)
61
+
62
+
63
+ thick = get_direction(df, "Bushy_Eyebrows", pinverse, 1000, device)
64
+ thick = debias(thick, "Male", df, pinverse, device)
65
+ thick = debias(thick, "Young", df, pinverse, device)
66
+ thick = debias(thick, "Pointy_Nose", df, pinverse, device)
67
+ thick = debias(thick, "Wavy_Hair", df, pinverse, device)
68
+ thick = debias(thick, "Mustache", df, pinverse, device)
69
+ thick = debias(thick, "No_Beard", df, pinverse, device)
70
+ thick = debias(thick, "Sideburns", df, pinverse, device)
71
+ thick = debias(thick, "Big_Nose", df, pinverse, device)
72
+ thick = debias(thick, "Big_Lips", df, pinverse, device)
73
+ thick = debias(thick, "Black_Hair", df, pinverse, device)
74
+ thick = debias(thick, "Brown_Hair", df, pinverse, device)
75
+ thick = debias(thick, "Pale_Skin", df, pinverse, device)
76
+ thick = debias(thick, "Heavy_Makeup", df, pinverse, device)
77
+
78
+ def sample_model(unet, network):
79
+ del unet
80
+ del network
81
+ mean.to(device)
82
+ std.to(device)
83
+ v.to(device)
84
+ proj.to(device)
85
+ unet, _, _, _, _ = load_models(device)
86
+ network = sample_weights(unet, proj, mean, std, v[:, :1000], device, factor = 1.00)
87
 
88
+ @torch.no_grad()
89
+ @spaces.GPU
90
+ def inference( prompt, negative_prompt, guidance_scale, ddim_steps, seed):
91
+ generator = torch.Generator(device=device).manual_seed(seed)
92
+ latents = torch.randn(
93
+ (1, unet.in_channels, 512 // 8, 512 // 8),
94
+ generator = generator,
95
+ device = device
96
+ ).bfloat16()
97
+
98
+
99
+ text_input = tokenizer(prompt, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt")
100
+
101
+ text_embeddings = text_encoder(text_input.input_ids.to(device))[0]
102
+
103
+ max_length = text_input.input_ids.shape[-1]
104
+ uncond_input = tokenizer(
105
+ [negative_prompt], padding="max_length", max_length=max_length, return_tensors="pt"
106
+ )
107
+ uncond_embeddings = text_encoder(uncond_input.input_ids.to(device))[0]
108
+ text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
109
+ noise_scheduler.set_timesteps(ddim_steps)
110
+ latents = latents * noise_scheduler.init_noise_sigma
111
 
112
+ for i,t in enumerate(tqdm.tqdm(noise_scheduler.timesteps)):
113
+ latent_model_input = torch.cat([latents] * 2)
114
+ latent_model_input = noise_scheduler.scale_model_input(latent_model_input, timestep=t)
115
+ with network:
116
+ noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings, timestep_cond= None).sample
117
+ #guidance
118
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
119
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
120
+ latents = noise_scheduler.step(noise_pred, t, latents).prev_sample
 
121
 
122
+ latents = 1 / 0.18215 * latents
123
+ image = vae.decode(latents).sample
124
+ image = (image / 2 + 0.5).clamp(0, 1)
125
+ image = image.detach().cpu().float().permute(0, 2, 3, 1).numpy()[0]
126
+
127
+ image = Image.fromarray((image * 255).round().astype("uint8"))
128
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
129
  return image
130
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
131
 
132
+ @torch.no_grad()
133
+ @spaces.GPU
134
+ def edit_inference(prompt, negative_prompt, guidance_scale, ddim_steps, seed, start_noise, a1, a2, a3, a4):
135
+ original_weights = network.proj.clone()
 
 
136
 
137
+ #pad to same number of PCs
138
+ pcs_original = original_weights.shape[1]
139
+ pcs_edits = young.shape[1]
140
+ padding = torch.zeros((1,pcs_original-pcs_edits)).to(device)
141
+ young_pad = torch.cat((young, padding), 1)
142
+ pointy_pad = torch.cat((pointy, padding), 1)
143
+ wavy_pad = torch.cat((wavy, padding), 1)
144
+ thick_pad = torch.cat((thick, padding), 1)
145
+
146
+
147
+ edited_weights = original_weights+a1*1e6*young_pad+a2*1e6*pointy_pad+a3*1e6*wavy_pad+a4*2e6*thick_pad
148
+
149
+ generator = torch.Generator(device=device).manual_seed(seed)
150
+ latents = torch.randn(
151
+ (1, unet.in_channels, 512 // 8, 512 // 8),
152
+ generator = generator,
153
+ device = device
154
+ ).bfloat16()
155
+
156
+
157
+ text_input = tokenizer(prompt, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt")
158
+
159
+ text_embeddings = text_encoder(text_input.input_ids.to(device))[0]
160
+
161
+ max_length = text_input.input_ids.shape[-1]
162
+ uncond_input = tokenizer(
163
+ [negative_prompt], padding="max_length", max_length=max_length, return_tensors="pt"
164
+ )
165
+ uncond_embeddings = text_encoder(uncond_input.input_ids.to(device))[0]
166
+ text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
167
+ noise_scheduler.set_timesteps(ddim_steps)
168
+ latents = latents * noise_scheduler.init_noise_sigma
169
+
170
+
171
+
172
+ for i,t in enumerate(tqdm.tqdm(noise_scheduler.timesteps)):
173
+ latent_model_input = torch.cat([latents] * 2)
174
+ latent_model_input = noise_scheduler.scale_model_input(latent_model_input, timestep=t)
175
+
176
+ if t>start_noise:
177
+ pass
178
+ elif t<=start_noise:
179
+ network.proj = torch.nn.Parameter(edited_weights)
180
+ network.reset()
181
+
182
+
183
  with network:
184
+ noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings, timestep_cond= None).sample
185
+
186
+
187
+ #guidance
188
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
189
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
190
+ latents = noise_scheduler.step(noise_pred, t, latents).prev_sample
191
+
192
+ latents = 1 / 0.18215 * latents
193
+ image = vae.decode(latents).sample
194
+ image = (image / 2 + 0.5).clamp(0, 1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
195
 
196
+ image = image.detach().cpu().float().permute(0, 2, 3, 1).numpy()[0]
197
 
198
+ image = Image.fromarray((image * 255).round().astype("uint8"))
199
+
200
+ #reset weights back to original
201
+ network.proj = torch.nn.Parameter(original_weights)
202
+ network.reset()
203
+
204
+ return image
205
 
206
+ @spaces.GPU
207
+ def sample_then_run():
208
+ sample_model()
209
+ prompt = "sks person"
210
+ negative_prompt = "low quality, blurry, unfinished, nudity, weapon"
211
+ seed = 5
212
+ cfg = 3.0
213
+ steps = 25
214
+ image = inference( prompt, negative_prompt, cfg, steps, seed)
215
+ torch.save(network.proj, "model.pt" )
216
+ return image, "model.pt"
217
 
 
218
 
219
 
220
+ class CustomImageDataset(Dataset):
221
+ def __init__(self, images, transform=None):
222
+ self.images = images
223
+ self.transform = transform
224
+
225
+ def __len__(self):
226
+ return len(self.images)
227
+
228
+ def __getitem__(self, idx):
229
+ image = self.images[idx]
230
+ if self.transform:
231
+ image = self.transform(image)
232
+ return image
233
+
234
+ @spaces.GPU
235
+ def invert(image, mask, pcs=10000, epochs=400, weight_decay = 1e-10, lr=1e-1):
236
 
237
+ del unet
238
+ del network
239
+ unet, _, _, _, _ = load_models(device)
240
+
241
+ proj = torch.zeros(1,pcs).bfloat16().to(device)
242
+ network = LoRAw2w( proj, mean, std, v[:, :pcs],
243
+ unet,
244
+ rank=1,
245
+ multiplier=1.0,
246
+ alpha=27.0,
247
+ train_method="xattn-strict"
248
+ ).to(device, torch.bfloat16)
249
 
250
+ ### load mask
251
+ mask = transforms.Resize((64,64), interpolation=transforms.InterpolationMode.BILINEAR)(mask)
252
+ mask = torchvision.transforms.functional.pil_to_tensor(mask).unsqueeze(0).to(device).bfloat16()[:,0,:,:].unsqueeze(1)
253
+ ### check if an actual mask was draw, otherwise mask is just all ones
254
+ if torch.sum(mask) == 0:
255
+ mask = torch.ones((1,1,64,64)).to(device).bfloat16()
256
+
257
+ ### single image dataset
258
+ image_transforms = transforms.Compose([transforms.Resize(512, interpolation=transforms.InterpolationMode.BILINEAR),
259
+ transforms.RandomCrop(512),
260
+ transforms.ToTensor(),
261
+ transforms.Normalize([0.5], [0.5])])
262
 
263
 
264
+ train_dataset = CustomImageDataset(image, transform=image_transforms)
265
+ train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=1, shuffle=True)
266
 
267
+ ### optimizer
268
+ optim = torch.optim.Adam(network.parameters(), lr=lr, weight_decay=weight_decay)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
269
 
270
+ ### training loop
271
+ unet.train()
272
+ for epoch in tqdm.tqdm(range(epochs)):
273
+ for batch in train_dataloader:
274
+ ### prepare inputs
275
+ batch = batch.to(device).bfloat16()
276
+ latents = vae.encode(batch).latent_dist.sample()
277
+ latents = latents*0.18215
278
+ noise = torch.randn_like(latents)
279
+ bsz = latents.shape[0]
280
+
281
+ timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
282
+ timesteps = timesteps.long()
283
+ noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
284
+ text_input = tokenizer("sks person", padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt")
285
+ text_embeddings = text_encoder(text_input.input_ids.to(device))[0]
286
+
287
+ ### loss + sgd step
288
+ with network:
289
+ model_pred = unet(noisy_latents, timesteps, text_embeddings).sample
290
+ loss = torch.nn.functional.mse_loss(mask*model_pred.float(), mask*noise.float(), reduction="mean")
291
+ optim.zero_grad()
292
+ loss.backward()
293
+ optim.step()
294
+
295
+ ### return optimized network
296
+ return network
297
 
298
 
299
+ @spaces.GPU
300
+ def run_inversion(dict, pcs, epochs, weight_decay,lr):
301
+ init_image = dict["image"].convert("RGB").resize((512, 512))
302
+ mask = dict["mask"].convert("RGB").resize((512, 512))
303
+ network = invert([init_image], mask, pcs, epochs, weight_decay,lr)
304
+
305
+
306
+ #sample an image
307
+ prompt = "sks person"
308
+ negative_prompt = "low quality, blurry, unfinished, nudity"
309
+ seed = 5
310
+ cfg = 3.0
311
+ steps = 25
312
+ image = inference( prompt, negative_prompt, cfg, steps, seed)
313
+ torch.save(network.proj, "model.pt" )
314
+ return image, "model.pt"
315
+
316
+
317
+ @spaces.GPU
318
+ def file_upload(file):
319
+ del unet
320
+ del network
321
+
322
+ proj = torch.load(file.name).to(device)
323
+
324
+ #pad to 10000 Principal components to keep everything consistent
325
+ pcs = proj.shape[1]
326
+ padding = torch.zeros((1,10000-pcs)).to(device)
327
+ proj = torch.cat((proj, padding), 1)
328
+
329
+ unet, _, _, _, _ = load_models(device)
330
+
331
+
332
+ network = LoRAw2w( proj, mean, std, v[:, :10000],
333
+ unet,
334
+ rank=1,
335
+ multiplier=1.0,
336
+ alpha=27.0,
337
+ train_method="xattn-strict"
338
+ ).to(device, torch.bfloat16)
339
+
340
+
341
+ prompt = "sks person"
342
+ negative_prompt = "low quality, blurry, unfinished, nudity"
343
+ seed = 5
344
+ cfg = 3.0
345
+ steps = 25
346
+ image = inference( prompt, negative_prompt, cfg, steps, seed)
347
+ return image
348
+
349
+
350
 
351
+
352
+ intro = """
353
+ <div style="display: flex;align-items: center;justify-content: center">
354
+ <h1 style="margin-left: 12px;text-align: center;margin-bottom: 7px;display: inline-block"><em>weights2weights</em> Demo</h1>
355
+ <h3 style="display: inline-block;margin-left: 10px;margin-top: 6px;font-weight: 500">Interpreting the Weight Space of Customized Diffusion Models</h3>
356
+ </div>
357
+ <p style="font-size: 0.95rem;margin: 0rem;line-height: 1.2em;margin-top:1em;display: inline-block">
358
+ <a href="https://snap-research.github.io/weights2weights/" target="_blank">Project Page</a> | <a href="https://arxiv.org/abs/2406.09413" target="_blank">Paper</a>
359
+ | <a href="https://github.com/snap-research/weights2weights" target="_blank">Code</a> |
360
+ <a href="https://huggingface.co/spaces/Snapchat/w2w-demo?duplicate=true" target="_blank" style="
361
+ display: inline-block;
362
+ ">
363
+ <img style="margin-top: -1em;margin-bottom: 0em;position: absolute;" src="https://bit.ly/3CWLGkA" alt="Duplicate Space"></a>
364
+ </p>
365
+ """
366
+
367
+
368
+
369
+ with gr.Blocks(css="style.css") as demo:
370
+
371
+
372
+
373
+ gr.HTML(intro)
374
+
375
+ gr.Markdown("""<div style="text-align: justify;"> In this demo, you can get an identity-encoding model by sampling or inverting. To use a model previously downloaded from this demo see \"Uploading a model\" in the Advanced Options. Next, you can generate new images from it, or edit the identity encoded in the model and generate images from the edited model. We provide detailed instructions and tips at the bottom of the page.""")
376
+ with gr.Column():
377
+ with gr.Row():
378
+ with gr.Column():
379
+ gr.Markdown("""1) Either sample a new model, or upload an image (optionally draw a mask over the head) and click `invert`.""")
380
+ sample = gr.Button("🎲 Sample New Model")
381
+ input_image = gr.ImageEditor(elem_id="image_upload", type='pil', label="Reference Identity",
382
+ width=512, height=512)
383
+
384
+ with gr.Row():
385
+ invert_button = gr.Button("⬆️ Invert")
386
+
387
+
388
+
389
+ with gr.Column():
390
+ gr.Markdown("""2) Generate images of the sampled/inverted identity or edit the identity with the sliders and generate new images with various prompts and seeds.""")
391
+ gallery = gr.Image(label="Generated Image",height=512, width=512, interactive=False)
392
+ submit = gr.Button("Generate")
393
+
394
+
395
+ prompt = gr.Textbox(label="Prompt",
396
+ info="Make sure to include 'sks person'" ,
397
+ placeholder="sks person",
398
+ value="sks person")
399
+
400
+ seed = gr.Number(value=5, label="Seed", precision=0, interactive=True)
401
 
402
+ # Editing
403
+ with gr.Column():
404
+ with gr.Row():
405
+ a1 = gr.Slider(label="- Young +", value=0, step=0.001, minimum=-1, maximum=1, interactive=True)
406
+ a2 = gr.Slider(label="- Pointy Nose +", value=0, step=0.001, minimum=-1, maximum=1, interactive=True)
407
+ with gr.Row():
408
+ a3 = gr.Slider(label="- Curly Hair +", value=0, step=0.001, minimum=-1, maximum=1, interactive=True)
409
+ a4 = gr.Slider(label="- Thick Eyebrows +", value=0, step=0.001, minimum=-1, maximum=1, interactive=True)
410
+
411
+
412
+ with gr.Accordion("Advanced Options", open=False):
413
+ with gr.Tab("Inversion"):
414
+ with gr.Row():
415
+ lr = gr.Number(value=1e-1, label="Learning Rate", interactive=True)
416
+ pcs = gr.Slider(label="# Principal Components", value=10000, step=1, minimum=1, maximum=10000, interactive=True)
417
+ with gr.Row():
418
+ epochs = gr.Slider(label="Epochs", value=800, step=1, minimum=1, maximum=2000, interactive=True)
419
+ weight_decay = gr.Number(value=1e-10, label="Weight Decay", interactive=True)
420
+ with gr.Tab("Sampling"):
421
+ with gr.Row():
422
+ cfg= gr.Slider(label="CFG", value=3.0, step=0.1, minimum=0, maximum=10, interactive=True)
423
+ steps = gr.Slider(label="Inference Steps", value=25, step=1, minimum=0, maximum=100, interactive=True)
424
+ with gr.Row():
425
+ negative_prompt = gr.Textbox(label="Negative Prompt", placeholder="low quality, blurry, unfinished, nudity, weapon", value="low quality, blurry, unfinished, nudity, weapon")
426
+ injection_step = gr.Slider(label="Injection Step", value=800, step=1, minimum=0, maximum=1000, interactive=True)
427
+
428
+ with gr.Tab("Uploading a model"):
429
+ gr.Markdown("""<div style="text-align: justify;">Upload a model below downloaded from this demo.""")
430
+
431
+ file_input = gr.File(label="Upload Model", container=True)
432
+
433
+
434
+
435
+
436
+ gr.Markdown("""<div style="text-align: justify;"> After sampling a new model or inverting, you can download the model below.""")
437
+
438
+ with gr.Row():
439
+ file_output = gr.File(label="Download Sampled/Inverted Model", container=True, interactive=False)
440
+
441
+
442
+
443
+
444
+ invert_button.click(fn=run_inversion,
445
+ inputs=[input_image, pcs, epochs, weight_decay,lr],
446
+ outputs = [input_image, file_output])
447
 
 
448
 
449
+ sample.click(fn=sample_then_run, inputs=[unet, network], outputs=[input_image, file_output])
 
 
 
 
 
 
 
 
450
 
451
+ submit.click(
452
+ fn=edit_inference, inputs=[prompt, negative_prompt, cfg, steps, seed, injection_step, a1, a2, a3, a4], outputs=[gallery]
453
+ )
454
+ file_input.change(fn=file_upload, inputs=file_input, outputs = gallery)
455
+
456
+
457
+
458
+ help_text1 = """
459
+ <b>Instructions</b>:
460
+ 1. To get results faster without waiting in queue, you can duplicate into a private space with an A100 GPU.
461
+ 2. To begin, you will have to get an identity-encoding model. You can either sample one from *weights2weights* space by clicking `Sample New Model` or by uploading an image and clicking `invert` to invert the identity into a model. You can optionally draw over the head to define a mask in the image for better results. Sampling a model takes around 10 seconds and inversion takes around 2 minutes. After this is done, you can optionally download this model for later use. A model can be uploaded in the \"Uploading a model\" tab in the `Advanced Options`.
462
+ 3. After getting a model, an image of the identity will be displayed on the right. You can sample from the model by changing seeds as well as prompts and then clicking `Generate`. Make sure to include \"sks person\" in your prompt to keep the same identity.
463
+ 4. The identity in the model can be edited by changing the sliders for various attributes. After clicking `Generate`, you can see how the identity has changed and the effects are maintained across different seeds and prompts.
464
+ """
465
+ help_text2 = """<b>Tips</b>:
466
+ 1. Editing and Identity Generation
467
+ * If you are interested in preserving more of the image during identity-editing (i.e., where the same seed and prompt results in the same image with only the identity changed), you can play with the "Injection Step" parameter in the \"Sampling\" tab in the `Advanced Options`. During the first *n* timesteps, the original model's weights will be used, and then the edited weights will be set during the remaining steps. Values closer to 1000 will set the edited weights early, having a more pronounced effect, which may disrupt some semantics and structure of the generated image. Lower values will set the edited weights later, better preserving image context. We notice that around 600-800 tends to produce the best results. Larger values in the range (700-1000) are helpful for more global attribute changes, while smaller (400-700) can be used for more finegrained edits. Although it is not always needed.
468
+ * You can play around with negative prompts, number of inference steps, and CFG in the \"Sampling\" tab in the `Advanced Options` to affect the ultimate image quality.
469
+ * Sometimes the identity will not be perfectly consistent (e.g., there might be small variations of the face) when you use some seeds or prompts. This is a limitation of our method as well as an open-problem in personalized models.
470
+ 2. Inversion
471
+ * To obtain the best results for inversion, upload a high resolution photo of the face with minimal occlusion. It is recommended to draw over the face and hair to define a mask. But inversion should still work generally for non-closeup face shots.
472
+ * For inverting a realistic photo of an identity, typically 800 epochs with lr=1e-1 and 10,000 principal components (PCs) works well. If the resulting generations have artifacted and unrealstic textures, there is probably overfitting and you may want to reduce the number of epochs or learning rate, or play with weight decay. If the generations do not look like the input photo, then you may want to increase the number of epochs.
473
+ * For inverting out-of-distribution identities, such as artistic renditions of people or non-humans (e.g. the ones shown in the paper), it is recommended to use 1000 PCs, lr=1, and train for 800 epochs.
474
+ * Note that if you change the number of PCs, you will probably need to change the learning rate. For less PCs, higher learning rates are typically required."""
 
 
 
 
 
 
 
 
 
 
 
475
 
476
 
477
+ gr.Markdown(help_text1)
478
+ gr.Markdown(help_text2)
479
 
480
+ demo.queue().launch()
 
 
 
 
481
 
482
 
483
+ if __name__ == "__main__":
484
+ main()