File size: 28,551 Bytes
1f39cf9
 
89f6983
1f39cf9
 
 
 
 
61ac46b
89f6983
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1f39cf9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89f6983
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1f39cf9
ec7f11c
 
1f39cf9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61ac46b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1f39cf9
 
 
 
 
89f6983
d871568
1f39cf9
 
89f6983
 
 
1f39cf9
ec7f11c
61ac46b
 
1f39cf9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d871568
 
1f39cf9
89f6983
 
 
1f39cf9
 
 
 
61ac46b
 
 
1f39cf9
61ac46b
1f39cf9
89f6983
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1f39cf9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89f6983
 
 
1f39cf9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89f6983
d871568
 
1f39cf9
 
 
 
 
 
 
89f6983
d871568
1f39cf9
 
 
 
 
89f6983
 
 
 
1f39cf9
 
 
 
 
 
 
 
61ac46b
1f39cf9
 
 
 
 
 
 
89f6983
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
import torch
from tqdm import tqdm
from utils import guidance, schedule, boxdiff
import utils
from PIL import Image
import gc
import numpy as np
from .attention import GatedSelfAttentionDense
from .models import process_input_embeddings, torch_device
import warnings

# All keys: [('down', 0, 0, 0), ('down', 0, 1, 0), ('down', 1, 0, 0), ('down', 1, 1, 0), ('down', 2, 0, 0), ('down', 2, 1, 0), ('mid', 0, 0, 0), ('up', 1, 0, 0), ('up', 1, 1, 0), ('up', 1, 2, 0), ('up', 2, 0, 0), ('up', 2, 1, 0), ('up', 2, 2, 0), ('up', 3, 0, 0), ('up', 3, 1, 0), ('up', 3, 2, 0)]
# Note that the first up block is `UpBlock2D` rather than `CrossAttnUpBlock2D` and does not have attention. The last index is always 0 in our case since we have one `BasicTransformerBlock` in each `Transformer2DModel`.
DEFAULT_GUIDANCE_ATTN_KEYS = [("mid", 0, 0, 0), ("up", 1, 0, 0), ("up", 1, 1, 0), ("up", 1, 2, 0)]

def latent_backward_guidance(scheduler, unet, cond_embeddings, index, bboxes, object_positions, t, latents, loss, loss_scale = 30, loss_threshold = 0.2, max_iter = 5, max_index_step = 10, cross_attention_kwargs=None, ref_ca_saved_attns=None, guidance_attn_keys=None, verbose=False, clear_cache=False, **kwargs):

    iteration = 0
    
    if index < max_index_step:
        if isinstance(max_iter, list):
            if len(max_iter) > index:
                max_iter = max_iter[index]
            else:
                max_iter = max_iter[-1]
        
        if verbose:
            print(f"time index {index}, loss: {loss.item()/loss_scale:.3f} (de-scaled with scale {loss_scale:.1f}), loss threshold: {loss_threshold:.3f}")
        
        while (loss.item() / loss_scale > loss_threshold and iteration < max_iter and index < max_index_step):
            saved_attn = {}
            full_cross_attention_kwargs = {
                'save_attn_to_dict': saved_attn,
                'save_keys': guidance_attn_keys,
            }
            
            if cross_attention_kwargs is not None:
                full_cross_attention_kwargs.update(cross_attention_kwargs)
            
            latents.requires_grad_(True)
            latent_model_input = latents
            latent_model_input = scheduler.scale_model_input(latent_model_input, t)
            
            unet(latent_model_input, t, encoder_hidden_states=cond_embeddings, return_cross_attention_probs=False, cross_attention_kwargs=full_cross_attention_kwargs)

            # TODO: could return the attention maps for the required blocks only and not necessarily the final output
            # update latents with guidance
            loss = guidance.compute_ca_lossv3(saved_attn=saved_attn, bboxes=bboxes, object_positions=object_positions, guidance_attn_keys=guidance_attn_keys, ref_ca_saved_attns=ref_ca_saved_attns, index=index, verbose=verbose, **kwargs) * loss_scale

            if torch.isnan(loss):
                print("**Loss is NaN**")

            del full_cross_attention_kwargs, saved_attn
            # call gc.collect() here may release some memory

            grad_cond = torch.autograd.grad(loss.requires_grad_(True), [latents])[0]

            latents.requires_grad_(False)
            
            if hasattr(scheduler, 'sigmas'):
                latents = latents - grad_cond * scheduler.sigmas[index] ** 2
            elif hasattr(scheduler, 'alphas_cumprod'):
                warnings.warn("Using guidance scaled with alphas_cumprod")
                # Scaling with classifier guidance
                alpha_prod_t = scheduler.alphas_cumprod[t]
                # Classifier guidance: https://arxiv.org/pdf/2105.05233.pdf
                # DDIM: https://arxiv.org/pdf/2010.02502.pdf
                scale = (1 - alpha_prod_t) ** (0.5)
                latents = latents - scale * grad_cond
            else:
                # NOTE: no scaling is performed
                warnings.warn("No scaling in guidance is performed")
                latents = latents - grad_cond
            iteration += 1
            
            if clear_cache:
                utils.free_memory()
            
            if verbose:
                print(f"time index {index}, loss: {loss.item()/loss_scale:.3f}, loss threshold: {loss_threshold:.3f}, iteration: {iteration}")
            
    return latents, loss

@torch.no_grad()
def encode(model_dict, image, generator):
    """
    image should be a PIL object or numpy array with range 0 to 255
    """
    
    vae, dtype = model_dict.vae, model_dict.dtype
    
    if isinstance(image, Image.Image):
        w, h = image.size
        assert w % 8 == 0 and h % 8 == 0, f"h ({h}) and w ({w}) should be a multiple of 8"
        # w, h = (x - x % 8 for x in (w, h))  # resize to integer multiple of 8
        # image = np.array(image.resize((w, h), resample=Image.Resampling.LANCZOS))[None, :]
        image = np.array(image)
    
    if isinstance(image, np.ndarray):
        assert image.dtype == np.uint8, f"Should have dtype uint8 (dtype: {image.dtype})"
        image = image.astype(np.float32) / 255.0
        image = image[None, ...]
        image = image.transpose(0, 3, 1, 2)
        image = 2.0 * image - 1.0
        image = torch.from_numpy(image)
    
    assert isinstance(image, torch.Tensor), f"type of image: {type(image)}"
    
    image = image.to(device=torch_device, dtype=dtype)
    latents = vae.encode(image).latent_dist.sample(generator)
    
    latents = vae.config.scaling_factor * latents

    return latents

@torch.no_grad()
def decode(vae, latents):
    # scale and decode the image latents with vae
    scaled_latents = 1 / 0.18215 * latents
    with torch.no_grad():
        image = vae.decode(scaled_latents).sample
        
    image = (image / 2 + 0.5).clamp(0, 1)
    image = image.detach().cpu().permute(0, 2, 3, 1).numpy()
    images = (image * 255).round().astype("uint8")
    
    return images

def generate_semantic_guidance(model_dict, latents, input_embeddings, num_inference_steps, bboxes, phrases, object_positions, guidance_scale = 7.5, semantic_guidance_kwargs=None, 
                           return_cross_attn=False, return_saved_cross_attn=False, saved_cross_attn_keys=None, return_cond_ca_only=False, return_token_ca_only=None, offload_guidance_cross_attn_to_cpu=False,
                           offload_cross_attn_to_cpu=False, offload_latents_to_cpu=True, return_box_vis=False, show_progress=True, save_all_latents=False, 
                           dynamic_num_inference_steps=False, fast_after_steps=None, fast_rate=2, use_boxdiff=False):
    """
    object_positions: object indices in text tokens
    return_cross_attn: should be deprecated. Use `return_saved_cross_attn` and the new format.
    """
    vae, tokenizer, text_encoder, unet, scheduler, dtype = model_dict.vae, model_dict.tokenizer, model_dict.text_encoder, model_dict.unet, model_dict.scheduler, model_dict.dtype
    text_embeddings, uncond_embeddings, cond_embeddings = input_embeddings
    
    # Just in case that we have in-place ops
    latents = latents.clone()
    
    if save_all_latents:
        # offload to cpu to save space
        if offload_latents_to_cpu:
            latents_all = [latents.cpu()]
        else:
            latents_all = [latents]
    
    scheduler.set_timesteps(num_inference_steps)
    if fast_after_steps is not None:
        scheduler.timesteps = schedule.get_fast_schedule(scheduler.timesteps, fast_after_steps, fast_rate)
    
    if dynamic_num_inference_steps:
        original_num_inference_steps = scheduler.num_inference_steps

    cross_attention_probs_down = []
    cross_attention_probs_mid = []
    cross_attention_probs_up = []

    loss = torch.tensor(10000.)

    # TODO: we can also save necessary tokens only to save memory.
    # offload_guidance_cross_attn_to_cpu does not save too much since we only store attention map for each timestep.
    guidance_cross_attention_kwargs = {
        'offload_cross_attn_to_cpu': offload_guidance_cross_attn_to_cpu,
        'enable_flash_attn': False
    }
    
    if return_saved_cross_attn:
        saved_attns = []
    
    main_cross_attention_kwargs = {
        'offload_cross_attn_to_cpu': offload_cross_attn_to_cpu,
        'return_cond_ca_only': return_cond_ca_only,
        'return_token_ca_only': return_token_ca_only,
        'save_keys': saved_cross_attn_keys,
    }
    
    # Repeating keys leads to different weights for each key.
    # assert len(set(semantic_guidance_kwargs['guidance_attn_keys'])) == len(semantic_guidance_kwargs['guidance_attn_keys']), f"guidance_attn_keys not unique: {semantic_guidance_kwargs['guidance_attn_keys']}"

    for index, t in enumerate(tqdm(scheduler.timesteps, disable=not show_progress)):
        # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
        
        if bboxes:
            if use_boxdiff:
                latents, loss = boxdiff.latent_backward_guidance_boxdiff(scheduler, unet, cond_embeddings, index, bboxes, object_positions, t, latents, loss, cross_attention_kwargs=guidance_cross_attention_kwargs, **semantic_guidance_kwargs)
            else:
                # If encountered None in `guidance_attn_keys`, please be sure to check whether `guidance_attn_keys` is added in `semantic_guidance_kwargs`. Default value has been removed.
                latents, loss = latent_backward_guidance(scheduler, unet, cond_embeddings, index, bboxes, object_positions, t, latents, loss, cross_attention_kwargs=guidance_cross_attention_kwargs, **semantic_guidance_kwargs)
        
        # predict the noise residual
        with torch.no_grad():
            latent_model_input = torch.cat([latents] * 2)
            latent_model_input = scheduler.scale_model_input(latent_model_input, timestep=t)
            
            main_cross_attention_kwargs['save_attn_to_dict'] = {}
            
            unet_output = unet(latent_model_input, t, encoder_hidden_states=text_embeddings, return_cross_attention_probs=return_cross_attn, cross_attention_kwargs=main_cross_attention_kwargs)
            noise_pred = unet_output.sample
            
            if return_cross_attn:
                cross_attention_probs_down.append(unet_output.cross_attention_probs_down)
                cross_attention_probs_mid.append(unet_output.cross_attention_probs_mid)
                cross_attention_probs_up.append(unet_output.cross_attention_probs_up)
                
            if return_saved_cross_attn:
                saved_attns.append(main_cross_attention_kwargs['save_attn_to_dict'])
                
                del main_cross_attention_kwargs['save_attn_to_dict']

        # perform guidance
        noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
        noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)

        if dynamic_num_inference_steps:
            schedule.dynamically_adjust_inference_steps(scheduler, index, t)

        # compute the previous noisy sample x_t -> x_t-1
        latents = scheduler.step(noise_pred, t, latents).prev_sample
        
        if save_all_latents:
            if offload_latents_to_cpu:
                latents_all.append(latents.cpu())
            else:
                latents_all.append(latents)
    
    if dynamic_num_inference_steps:
        # Restore num_inference_steps to avoid confusion in the next generation if it is not dynamic
        scheduler.num_inference_steps = original_num_inference_steps
    
    images = decode(vae, latents)
    
    ret = [latents, images]
    
    if return_cross_attn:
        ret.append((cross_attention_probs_down, cross_attention_probs_mid, cross_attention_probs_up))
    if return_saved_cross_attn:
        ret.append(saved_attns)
    if return_box_vis:
        pil_images = [utils.draw_box(Image.fromarray(image), bboxes, phrases) for image in images]
        ret.append(pil_images)
    if save_all_latents:
        latents_all = torch.stack(latents_all, dim=0)
        ret.append(latents_all)
    return tuple(ret)

@torch.no_grad()
def generate(model_dict, latents, input_embeddings, num_inference_steps, guidance_scale = 7.5, no_set_timesteps=False, scheduler_key='dpm_scheduler'):
    vae, tokenizer, text_encoder, unet, scheduler, dtype = model_dict.vae, model_dict.tokenizer, model_dict.text_encoder, model_dict.unet, model_dict[scheduler_key], model_dict.dtype
    text_embeddings, uncond_embeddings, cond_embeddings = input_embeddings
    
    if not no_set_timesteps:
        scheduler.set_timesteps(num_inference_steps)

    for t in tqdm(scheduler.timesteps):
        # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
        latent_model_input = torch.cat([latents] * 2)

        latent_model_input = scheduler.scale_model_input(latent_model_input, timestep=t)

        # predict the noise residual
        with torch.no_grad():
            noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample

        # perform guidance
        noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
        noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)

        # compute the previous noisy sample x_t -> x_t-1
        latents = scheduler.step(noise_pred, t, latents).prev_sample

    images = decode(vae, latents)
    
    ret = [latents, images]

    return tuple(ret)

def gligen_enable_fuser(unet, enabled=True):
    for module in unet.modules():
        if isinstance(module, GatedSelfAttentionDense):
            module.enabled = enabled

def prepare_gligen_condition(bboxes, phrases, dtype, tokenizer, text_encoder, num_images_per_prompt):
    batch_size = len(bboxes)
    
    assert len(phrases) == len(bboxes)
    max_objs = 30
    
    n_objs = min(max([len(bboxes_item) for bboxes_item in bboxes]), max_objs)
    boxes = torch.zeros((batch_size, max_objs, 4), device=torch_device, dtype=dtype)
    phrase_embeddings = torch.zeros((batch_size, max_objs, 768), device=torch_device, dtype=dtype)
    # masks is a 1D tensor deciding which of the enteries to be enabled
    masks = torch.zeros((batch_size, max_objs), device=torch_device, dtype=dtype)
    
    if n_objs > 0:
        for idx, (bboxes_item, phrases_item) in enumerate(zip(bboxes, phrases)):
            # the length of `bboxes_item` could be smaller than `n_objs` because n_objs takes the max of item length
            bboxes_item = torch.tensor(bboxes_item[:n_objs])
            boxes[idx, :bboxes_item.shape[0]] = bboxes_item

            tokenizer_inputs = tokenizer(phrases_item[:n_objs], padding=True, return_tensors="pt").to(torch_device)
            _phrase_embeddings = text_encoder(**tokenizer_inputs).pooler_output
            phrase_embeddings[idx, :_phrase_embeddings.shape[0]] = _phrase_embeddings
            assert bboxes_item.shape[0] == _phrase_embeddings.shape[0], f"{bboxes_item.shape[0]} != {_phrase_embeddings.shape[0]}"
            
            masks[idx, :bboxes_item.shape[0]] = 1

    # Classifier-free guidance
    repeat_times = num_images_per_prompt * 2
    condition_len = batch_size * repeat_times

    boxes = boxes.repeat(repeat_times, 1, 1)
    phrase_embeddings = phrase_embeddings.repeat(repeat_times, 1, 1)
    masks = masks.repeat(repeat_times, 1)
    masks[:condition_len // 2] = 0
    
    # print("shapes:", boxes.shape, phrase_embeddings.shape, masks.shape)
    
    return boxes, phrase_embeddings, masks, condition_len

@torch.no_grad()
def generate_gligen(model_dict, latents, input_embeddings, num_inference_steps, bboxes, phrases, num_images_per_prompt=1, gligen_scheduled_sampling_beta: float = 0.3, guidance_scale=7.5, 
    frozen_steps=20, frozen_mask=None,
    return_saved_cross_attn=False, saved_cross_attn_keys=None, return_cond_ca_only=False, return_token_ca_only=None, 
    offload_cross_attn_to_cpu=False, offload_latents_to_cpu=True,
    semantic_guidance=False, semantic_guidance_bboxes=None, semantic_guidance_object_positions=None, semantic_guidance_kwargs=None, 
    return_box_vis=False, show_progress=True, save_all_latents=False, scheduler_key='dpm_scheduler', batched_condition=False, dynamic_num_inference_steps=False, fast_after_steps=None, fast_rate=2):
    """
    The `bboxes` should be a list, rather than a list of lists (one box per phrase, we can have multiple duplicated phrases).
    batched: 
        Enabled: bboxes and phrases should be a list (batch dimension) of items (specify the bboxes/phrases of each image in the batch).
        Disabled: bboxes and phrases should be a list of bboxes and phrases specifying the bboxes/phrases of one image (no batch dimension).
    """
    vae, tokenizer, text_encoder, unet, scheduler, dtype = model_dict.vae, model_dict.tokenizer, model_dict.text_encoder, model_dict.unet, model_dict[scheduler_key], model_dict.dtype
    
    text_embeddings, _, cond_embeddings = process_input_embeddings(input_embeddings)
    
    if latents.dim() == 5:
        # latents_all from the input side, different from the latents_all to be saved
        latents_all_input = latents
        latents = latents[0]
    else:
        latents_all_input = None
    
    # Just in case that we have in-place ops
    latents = latents.clone()
    
    if save_all_latents:
        # offload to cpu to save space
        if offload_latents_to_cpu:
            latents_all = [latents.cpu()]
        else:
            latents_all = [latents]
    
    scheduler.set_timesteps(num_inference_steps)
    if fast_after_steps is not None:
        scheduler.timesteps = schedule.get_fast_schedule(scheduler.timesteps, fast_after_steps, fast_rate)
    
    if dynamic_num_inference_steps:
        original_num_inference_steps = scheduler.num_inference_steps
    
    if frozen_mask is not None:
        frozen_mask = frozen_mask.to(dtype=dtype).clamp(0., 1.)

    # 5.1 Prepare GLIGEN variables
    if not batched_condition:
        # Add batch dimension to bboxes and phrases
        bboxes, phrases = [bboxes], [phrases]
    
    boxes, phrase_embeddings, masks, condition_len = prepare_gligen_condition(bboxes, phrases, dtype, tokenizer, text_encoder, num_images_per_prompt)
    
    if semantic_guidance_bboxes and semantic_guidance:
        loss = torch.tensor(10000.)
        # TODO: we can also save necessary tokens only to save memory.
        # offload_guidance_cross_attn_to_cpu does not save too much since we only store attention map for each timestep.
        guidance_cross_attention_kwargs = {
            'offload_cross_attn_to_cpu': False,
            'enable_flash_attn': False,
            'gligen': {
                'boxes': boxes[:condition_len // 2],
                'positive_embeddings': phrase_embeddings[:condition_len // 2],
                'masks': masks[:condition_len // 2],
                'fuser_attn_kwargs': {
                    'enable_flash_attn': False,
                }
            }
        }
    
    if return_saved_cross_attn:
        saved_attns = []
    
    main_cross_attention_kwargs = {
        'offload_cross_attn_to_cpu': offload_cross_attn_to_cpu,
        'return_cond_ca_only': return_cond_ca_only,
        'return_token_ca_only': return_token_ca_only,
        'save_keys': saved_cross_attn_keys,
        'gligen': {
            'boxes': boxes,
            'positive_embeddings': phrase_embeddings,
            'masks': masks
        }
    }
    
    timesteps = scheduler.timesteps

    num_grounding_steps = int(gligen_scheduled_sampling_beta * len(timesteps))
    gligen_enable_fuser(unet, True)

    for index, t in enumerate(tqdm(timesteps, disable=not show_progress)):
        # Scheduled sampling
        if index == num_grounding_steps:
            gligen_enable_fuser(unet, False)
        
        if semantic_guidance_bboxes and semantic_guidance:
            with torch.enable_grad():
                latents, loss = latent_backward_guidance(scheduler, unet, cond_embeddings, index, semantic_guidance_bboxes, semantic_guidance_object_positions, t, latents, loss, cross_attention_kwargs=guidance_cross_attention_kwargs, **semantic_guidance_kwargs)
        # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
        latent_model_input = torch.cat([latents] * 2)

        latent_model_input = scheduler.scale_model_input(latent_model_input, timestep=t)

        main_cross_attention_kwargs['save_attn_to_dict'] = {}

        # predict the noise residual
        noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings, 
                            cross_attention_kwargs=main_cross_attention_kwargs).sample
        
        if return_saved_cross_attn:
            saved_attns.append(main_cross_attention_kwargs['save_attn_to_dict'])
            
            del main_cross_attention_kwargs['save_attn_to_dict']

        # perform guidance
        noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
        noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)

        if dynamic_num_inference_steps:
            schedule.dynamically_adjust_inference_steps(scheduler, index, t)

        # compute the previous noisy sample x_t -> x_t-1
        latents = scheduler.step(noise_pred, t, latents).prev_sample
        
        if frozen_mask is not None and index < frozen_steps:
            latents = latents_all_input[index+1] * frozen_mask + latents * (1. - frozen_mask)
        
        # Do not save the latents in the fast steps
        if save_all_latents and (fast_after_steps is None or index < fast_after_steps):
            if offload_latents_to_cpu:
                latents_all.append(latents.cpu())
            else:
                latents_all.append(latents)

    if dynamic_num_inference_steps:
        # Restore num_inference_steps to avoid confusion in the next generation if it is not dynamic
        scheduler.num_inference_steps = original_num_inference_steps

    # Turn off fuser for typical SD
    gligen_enable_fuser(unet, False)
    images = decode(vae, latents)
    
    ret = [latents, images]
    if return_saved_cross_attn:
        ret.append(saved_attns)
    if return_box_vis:
        pil_images = [utils.draw_box(Image.fromarray(image), bboxes_item, phrases_item) for image, bboxes_item, phrases_item in zip(images, bboxes, phrases)]
        ret.append(pil_images)
    if save_all_latents:
        latents_all = torch.stack(latents_all, dim=0)
        ret.append(latents_all)
    
    return tuple(ret)


def get_inverse_timesteps(inverse_scheduler, num_inference_steps, strength):
    # get the original timestep using init_timestep
    init_timestep = min(int(num_inference_steps * strength), num_inference_steps)

    t_start = max(num_inference_steps - init_timestep, 0)

    # safety for t_start overflow to prevent empty timsteps slice
    if t_start == 0:
        return inverse_scheduler.timesteps, num_inference_steps
    timesteps = inverse_scheduler.timesteps[:-t_start]

    return timesteps, num_inference_steps - t_start

@torch.no_grad()
def invert(model_dict, latents, input_embeddings, num_inference_steps, guidance_scale = 7.5):
    """
    latents: encoded from the image, should not have noise (t = 0)
    
    returns inverted_latents for all time steps
    """
    vae, tokenizer, text_encoder, unet, scheduler, inverse_scheduler, dtype = model_dict.vae, model_dict.tokenizer, model_dict.text_encoder, model_dict.unet, model_dict.scheduler, model_dict.inverse_scheduler, model_dict.dtype
    text_embeddings, uncond_embeddings, cond_embeddings = input_embeddings
    
    inverse_scheduler.set_timesteps(num_inference_steps, device=latents.device)
    # We need to invert all steps because we need them to generate the background.
    timesteps, num_inference_steps = get_inverse_timesteps(inverse_scheduler, num_inference_steps, strength=1.0)

    inverted_latents = [latents.cpu()]
    for t in tqdm(timesteps[:-1]):
        # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
        if guidance_scale > 0.:
            latent_model_input = torch.cat([latents] * 2)

            latent_model_input = inverse_scheduler.scale_model_input(latent_model_input, timestep=t)

            # predict the noise residual
            with torch.no_grad():
                noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample

            # perform guidance
            noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
            noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
        else:
            latent_model_input = latents

            latent_model_input = inverse_scheduler.scale_model_input(latent_model_input, timestep=t)

            # predict the noise residual
            with torch.no_grad():
                noise_pred_uncond = unet(latent_model_input, t, encoder_hidden_states=uncond_embeddings).sample

            # perform guidance
            noise_pred = noise_pred_uncond

        # compute the previous noisy sample x_t -> x_t-1
        latents = inverse_scheduler.step(noise_pred, t, latents).prev_sample
        
        inverted_latents.append(latents.cpu())
    
    assert len(inverted_latents) == len(timesteps)
    # timestep is the first dimension
    inverted_latents = torch.stack(list(reversed(inverted_latents)), dim=0)
    
    return inverted_latents

def generate_partial_frozen(model_dict, latents_all, frozen_mask, input_embeddings, num_inference_steps, frozen_steps, guidance_scale = 7.5, bboxes=None, phrases=None, object_positions=None, semantic_guidance_kwargs=None, offload_guidance_cross_attn_to_cpu=False, use_boxdiff=False):
    vae, tokenizer, text_encoder, unet, scheduler, dtype = model_dict.vae, model_dict.tokenizer, model_dict.text_encoder, model_dict.unet, model_dict.scheduler, model_dict.dtype
    text_embeddings, uncond_embeddings, cond_embeddings = input_embeddings
    
    scheduler.set_timesteps(num_inference_steps)
    frozen_mask = frozen_mask.to(dtype=dtype).clamp(0., 1.)
    
    latents = latents_all[0]
    
    if bboxes:
        # With semantic guidance
        loss = torch.tensor(10000.)

        # offload_guidance_cross_attn_to_cpu does not save too much since we only store attention map for each timestep.
        guidance_cross_attention_kwargs = {
            'offload_cross_attn_to_cpu': offload_guidance_cross_attn_to_cpu,
            # Getting invalid argument on backward, probably due to insufficient shared memory
            'enable_flash_attn': False
        }

    for index, t in enumerate(tqdm(scheduler.timesteps)):
        if bboxes:
            # With semantic guidance, `guidance_attn_keys` should be in `semantic_guidance_kwargs`
            if use_boxdiff:
                latents, loss = boxdiff.latent_backward_guidance_boxdiff(scheduler, unet, cond_embeddings, index, bboxes, object_positions, t, latents, loss, cross_attention_kwargs=guidance_cross_attention_kwargs, **semantic_guidance_kwargs)
            else:
                latents, loss = latent_backward_guidance(scheduler, unet, cond_embeddings, index, bboxes, object_positions, t, latents, loss, cross_attention_kwargs=guidance_cross_attention_kwargs, **semantic_guidance_kwargs)
            
        with torch.no_grad():
            # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
            latent_model_input = torch.cat([latents] * 2)

            latent_model_input = scheduler.scale_model_input(latent_model_input, timestep=t)

            # predict the noise residual
            noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample

            # perform guidance
            noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
            noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)

            # compute the previous noisy sample x_t -> x_t-1
            latents = scheduler.step(noise_pred, t, latents).prev_sample
            
            if index < frozen_steps:
                latents = latents_all[index+1] * frozen_mask + latents * (1. - frozen_mask)

    # scale and decode the image latents with vae
    scaled_latents = 1 / 0.18215 * latents
    with torch.no_grad():
        image = vae.decode(scaled_latents).sample
        
    image = (image / 2 + 0.5).clamp(0, 1)
    image = image.detach().cpu().permute(0, 2, 3, 1).numpy()
    images = (image * 255).round().astype("uint8")
    
    ret = [latents, images]

    return tuple(ret)