File size: 32,834 Bytes
8ee7393
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, Tuple, Dict, Any
from diffusers.models.attention_processor import Attention, AttnProcessor2_0
from diffusers.utils import logging, is_torch_version, deprecate
from diffusers.utils.torch_utils import fourier_filter
# UNet is a diffusers PeftAdapterMixin instance.
from diffusers.loaders.peft import PeftAdapterMixin
from peft import LoraConfig, get_peft_model
import peft.tuners.lora as peft_lora
from peft.tuners.lora.dora import DoraLinearLayer
from einops import rearrange
import math, re
import numpy as np
from peft.tuners.tuners_utils import BaseTunerLayer


logger = logging.get_logger(__name__)  # pylint: disable=invalid-name

def dummy_func(*args, **kwargs):
    pass

# Revised from RevGrad, by removing the grad negation.
class ScaleGrad(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input_, alpha_, debug=False):
        ctx.save_for_backward(alpha_, debug)
        output = input_
        if debug:
            print(f"input: {input_.abs().mean().item()}")
        return output

    @staticmethod
    def backward(ctx, grad_output):  # pragma: no cover
        # saved_tensors returns a tuple of tensors.
        alpha_, debug = ctx.saved_tensors
        if ctx.needs_input_grad[0]:
            grad_output2 = grad_output * alpha_
            if debug:
                print(f"grad_output2: {grad_output2.abs().mean().item()}")
        else:
            grad_output2 = None
        return grad_output2, None, None

class GradientScaler(nn.Module):
    def __init__(self, alpha=1., debug=False, *args, **kwargs):
        """
        A gradient scaling layer.
        This layer has no parameters, and simply scales the gradient in the backward pass.
        """
        super().__init__(*args, **kwargs)

        self._alpha = torch.tensor(alpha, requires_grad=False)
        self._debug = torch.tensor(debug, requires_grad=False)

    def forward(self, input_):
        _debug = self._debug if hasattr(self, '_debug') else False
        return ScaleGrad.apply(input_, self._alpha.to(input_.device), _debug)

def gen_gradient_scaler(alpha, debug=False):
    if alpha == 1:
        return nn.Identity()
    if alpha > 0:
        return GradientScaler(alpha, debug=debug)
    else:
        assert alpha == 0
        # Don't use lambda function here, otherwise the object can't be pickled.
        return torch.detach

def split_indices_by_instance(indices, as_dict=False):
    indices_B, indices_N = indices
    unique_indices_B = torch.unique(indices_B)
    if not as_dict:
        indices_by_instance = [ (indices_B[indices_B == uib], indices_N[indices_B == uib]) for uib in unique_indices_B ]
    else:
        indices_by_instance = { uib.item(): indices_N[indices_B == uib] for uib in unique_indices_B }
    return indices_by_instance

# If do_sum, returned emb_attns is 3D. Otherwise 4D.
# indices are applied on the first 2 dims of attn_mat.
def sel_emb_attns_by_indices(attn_mat, indices, all_token_weights=None, do_sum=True, do_mean=False):
    indices_by_instance = split_indices_by_instance(indices)
    
    # emb_attns[0]: [1, 9, 8, 64]
    # 8: 8 attention heads. Last dim 64: number of image tokens.
    emb_attns   = [ attn_mat[inst_indices].unsqueeze(0) for inst_indices in indices_by_instance ]
    if all_token_weights is not None:
        # all_token_weights: [4, 77].
        # token_weights_by_instance[0]: [1, 9, 1, 1].
        token_weights = [ all_token_weights[inst_indices].reshape(1, -1, 1, 1) for inst_indices in indices_by_instance ]
    else:
        token_weights = [ 1 ] * len(indices_by_instance)

    # Apply token weights.
    emb_attns = [ emb_attns[i] * token_weights[i] for i in range(len(indices_by_instance)) ]

    # sum among K_subj_i subj embeddings -> [1, 8, 64]
    if do_sum:
        emb_attns   = [ emb_attns[i].sum(dim=1) for i in range(len(indices_by_instance)) ]
    elif do_mean:
        emb_attns   = [ emb_attns[i].mean(dim=1) for i in range(len(indices_by_instance)) ]

    emb_attns = torch.cat(emb_attns, dim=0)
    return emb_attns

# Slow implementation equivalent to F.scaled_dot_product_attention.
def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0,
                                 shrink_cross_attn=False, cross_attn_shrink_factor=0.5, 
                                 is_causal=False, scale=None, enable_gqa=False) -> torch.Tensor:
    B, L, S = query.size(0), query.size(-2), key.size(-2)
    scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale
    # 1: head (to be broadcasted). L: query length. S: key length.
    attn_bias = torch.zeros(B, 1, L, S, device=query.device, dtype=query.dtype)
    if is_causal:
        assert attn_mask is None
        temp_mask = torch.ones(B, 1, L, S, device=query.device, dtype=torch.bool).tril(diagonal=0)
        attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
        attn_bias.to(query.dtype)

    if attn_mask is not None:
        if attn_mask.dtype == torch.bool:
            attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
        else:
            attn_bias += attn_mask

    if enable_gqa:
        key = key.repeat_interleave(query.size(-3)//key.size(-3), -3)
        value = value.repeat_interleave(query.size(-3)//value.size(-3), -3)

    attn_weight = query @ key.transpose(-2, -1) * scale_factor

    if shrink_cross_attn:
        cross_attn_scale = cross_attn_shrink_factor
    else:
        cross_attn_scale = 1

    # attn_bias: [1, 1, 4096, 77], the same size as a single-head attn_weight.
    attn_weight += attn_bias
    attn_score = attn_weight
    attn_weight = torch.softmax(attn_weight, dim=-1)
    # NOTE: After scaling, the "probabilities" of the subject embeddings will sum to < 1.
    # But this is intended, as we want to scale down the impact of the subject embeddings 
    # in the computed attention output tensors.
    attn_weight = attn_weight * cross_attn_scale
    attn_weight = torch.dropout(attn_weight, dropout_p, train=True)
    output = attn_weight @ value
    return output, attn_score, attn_weight

# All layers share the same attention processor instance.
class AttnProcessor_LoRA_Capture(nn.Module):
    r"""
    Revised from AttnProcessor2_0
    """
    # lora_proj_layers is a dict of lora_layer_name -> lora_proj_layer.
    def __init__(self, capture_ca_activations: bool = False, enable_lora: bool = False, 
                 lora_uses_dora=True, lora_proj_layers=None, 
                 lora_rank: int = 192, lora_alpha: float = 16,
                 cross_attn_shrink_factor: float = 0.5,
                 q_lora_updates_query=False, attn_proc_idx=-1):
        super().__init__()

        self.global_enable_lora = enable_lora
        self.attn_proc_idx = attn_proc_idx
        # reset_attn_cache_and_flags() sets the local (call-specific) self.enable_lora flag.
        # By default, shrink_cross_attn is False. Later in layers 22, 23, 24 it will be set to True.
        self.reset_attn_cache_and_flags(capture_ca_activations, False, enable_lora)
        self.lora_rank = lora_rank
        self.lora_alpha = lora_alpha
        self.lora_scale = self.lora_alpha / self.lora_rank
        self.cross_attn_shrink_factor = cross_attn_shrink_factor
        self.q_lora_updates_query = q_lora_updates_query

        self.to_q_lora = self.to_k_lora = self.to_v_lora = self.to_out_lora = None
        if self.global_enable_lora:
            for lora_layer_name, lora_proj_layer in lora_proj_layers.items():
                if lora_layer_name == 'q':
                    self.to_q_lora   = peft_lora.Linear(lora_proj_layer, 'default', r=lora_rank, lora_alpha=lora_alpha, 
                                                        use_dora=lora_uses_dora, lora_dropout=0.1)
                elif lora_layer_name == 'k':
                    self.to_k_lora   = peft_lora.Linear(lora_proj_layer, 'default', r=lora_rank, lora_alpha=lora_alpha, 
                                                        use_dora=lora_uses_dora, lora_dropout=0.1)
                elif lora_layer_name == 'v':
                    self.to_v_lora   = peft_lora.Linear(lora_proj_layer, 'default', r=lora_rank, lora_alpha=lora_alpha, 
                                                        use_dora=lora_uses_dora, lora_dropout=0.1)
                elif lora_layer_name == 'out':
                    self.to_out_lora = peft_lora.Linear(lora_proj_layer, 'default', r=lora_rank, lora_alpha=lora_alpha, 
                                                        use_dora=lora_uses_dora, lora_dropout=0.1)

    # LoRA layers can be enabled/disabled dynamically.
    def reset_attn_cache_and_flags(self, capture_ca_activations, shrink_cross_attn, enable_lora):
        self.capture_ca_activations = capture_ca_activations
        self.shrink_cross_attn      = shrink_cross_attn
        self.cached_activations     = {}
        # Only enable LoRA for the next call(s) if global_enable_lora is set to True.
        self.enable_lora = enable_lora and self.global_enable_lora

    def __call__(
        self,
        attn: Attention,
        hidden_states: torch.Tensor,
        encoder_hidden_states: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        temb: Optional[torch.Tensor] = None,
        img_mask: Optional[torch.Tensor] = None,
        subj_indices: Optional[Tuple[torch.IntTensor, torch.IntTensor]] = None,
        debug: bool = False,
        *args,
        **kwargs,
    ) -> torch.Tensor:
        if len(args) > 0 or kwargs.get("scale", None) is not None:
            deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
            deprecate("scale", "1.0.0", deprecation_message)
        
        # hidden_states: [1, 4096, 320]
        residual = hidden_states
        # attn.spatial_norm is None.
        if attn.spatial_norm is not None:
            hidden_states = attn.spatial_norm(hidden_states, temb)

        input_ndim = hidden_states.ndim

        if input_ndim == 4:
            batch_size, channel, height, width = hidden_states.shape
            # Collapse the spatial dimensions to a single token dimension.
            hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)

        batch_size, sequence_length, _ = (
            hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
        )

        if attention_mask is not None:
            attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
            # scaled_dot_product_attention expects attention_mask shape to be
            # (batch, heads, source_length, target_length)
            attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])

        if attn.group_norm is not None:
            hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)

        query = attn.to_q(hidden_states)
        # NOTE: there's a inconsistency between q lora and k, v loras. 
        # k, v loras are directly applied to key and value (currently k, v loras are never enabled), 
        # while q lora is applied to query2, and we keep the query unchanged. 
        if self.enable_lora and self.to_q_lora is not None:
            # query2 will be used in ldm/util.py:calc_elastic_matching_loss() to get more accurate
            # cross attention scores between the latent images of the sc and mc instances.
            query2 = self.to_q_lora(hidden_states)
            # If not q_lora_updates_query, only query2 will be impacted by the LoRA layer.
            # The query, and thus the attention score and attn_out, will be the same 
            # as the original ones. 
            if self.q_lora_updates_query:
                query = query2
        else:
            query2 = query

        scale = 1 / math.sqrt(query.size(-1))

        is_cross_attn = (encoder_hidden_states is not None)
        if (not is_cross_attn) and (img_mask is not None):
            # NOTE: we assume the image is square. But this will fail if the image is not square.
            # hidden_states: [BS, 4096, 320]. img_mask: [BS, 1, 64, 64]
            # Scale the mask to the same size as hidden_states.
            mask_size = int(math.sqrt(hidden_states.shape[-2]))
            img_mask = F.interpolate(img_mask, size=(mask_size, mask_size), mode='nearest')
            if (img_mask.sum(dim=(2, 3)) == 0).any():
                img_mask = None
            else:
                # img_mask: [2, 1, 64, 64] -> [2, 4096]
                img_mask = rearrange(img_mask, 'b ... -> b (...)').contiguous()
                # max_neg_value = -torch.finfo(hidden_states.dtype).max
                # img_mask: [2, 4096] -> [2, 1, 1, 4096]
                img_mask = rearrange(img_mask.bool(), 'b j -> b () () j')
                # attn_score: [16, 4096, 4096]. img_mask will be broadcasted to [16, 4096, 4096].
                # So some rows in dim 1 (e.g. [0, :, 4095]) of attn_score will be masked out (all elements in [0, :, 4095] is -inf).
                # But not all elements in [0, 4095, :] is -inf. Since the softmax is done along dim 2, this is fine.
                # attn_score.masked_fill_(~img_mask, max_neg_value)
                # NOTE: If there's an attention mask, it will be replaced by img_mask.
                attention_mask = img_mask

        if encoder_hidden_states is None:
            encoder_hidden_states = hidden_states
        elif attn.norm_cross:
            encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)

        if self.enable_lora and self.to_k_lora is not None:
            key = self.to_k_lora(encoder_hidden_states)
        else:
            key = attn.to_k(encoder_hidden_states)

        if self.enable_lora and self.to_v_lora is not None:
            value = self.to_v_lora(encoder_hidden_states)
        else:
            value = attn.to_v(encoder_hidden_states)

        if attn.norm_q is not None:
            query = attn.norm_q(query)
            query2 = attn.norm_q(query2)
        if attn.norm_k is not None:
            key = attn.norm_k(key)
        
        inner_dim = key.shape[-1]
        head_dim = inner_dim // attn.heads

        query  = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
        query2 = query2.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)

        key   = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
        value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)

        if debug and self.attn_proc_idx >= 0:
            breakpoint()

        # the output of sdp = (batch, num_heads, seq_len, head_dim)
        if is_cross_attn and (self.capture_ca_activations or self.shrink_cross_attn):
            hidden_states, attn_score, attn_prob = \
                scaled_dot_product_attention(query, key, value, attn_mask=attention_mask, 
                                             dropout_p=0.0, shrink_cross_attn=self.shrink_cross_attn,
                                             cross_attn_shrink_factor=self.cross_attn_shrink_factor)
        else:
            # Use the faster implementation of scaled_dot_product_attention 
            # when not capturing the activations or suppressing the subject attention.
            hidden_states = \
                F.scaled_dot_product_attention(query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False)
            attn_prob = attn_score = None

        hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
        hidden_states = hidden_states.to(query.dtype)

        # linear proj
        if self.enable_lora and self.to_out_lora is not None:
            hidden_states = self.to_out_lora(hidden_states)
        else:
            hidden_states = attn.to_out[0](hidden_states)

        # dropout
        hidden_states = attn.to_out[1](hidden_states)

        if input_ndim == 4:
            hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)

        if attn.residual_connection:
            hidden_states = hidden_states + residual

        hidden_states = hidden_states / attn.rescale_output_factor

        if is_cross_attn and self.capture_ca_activations:
            # cached q will be used in ddpm.py:calc_comp_fg_bg_preserve_loss(), in which two qs will multiply each other.
            # So sqrt(scale) will scale the product of two qs by scale.
            # ANCHOR[id=attention_caching]
            # query: [2, 8, 4096, 40] -> [2, 320, 4096]
            self.cached_activations['q'] = \
                rearrange(query, 'b h n d -> b (h d) n').contiguous() * math.sqrt(scale)
            self.cached_activations['q2'] = \
                rearrange(query2, 'b h n d -> b (h d) n').contiguous() * math.sqrt(scale)
            self.cached_activations['k'] = \
                rearrange(key, 'b h n d -> b (h d) n').contiguous() * math.sqrt(scale)
            self.cached_activations['v'] = \
                rearrange(value, 'b h n d -> b (h d) n').contiguous() * math.sqrt(scale)
            # attn_prob, attn_score: [2, 8, 4096, 77]
            self.cached_activations['attn'] = attn_prob
            self.cached_activations['attnscore'] = attn_score
            # attn_out: [b, n, h * d] -> [b, h * d, n]
            # [2, 4096, 320] -> [2, 320, 4096].
            self.cached_activations['attn_out'] = hidden_states.permute(0, 2, 1).contiguous()

        return hidden_states

def CrossAttnUpBlock2D_forward_capture(
    self,
    hidden_states: torch.Tensor,
    res_hidden_states_tuple: Tuple[torch.Tensor, ...],
    temb: Optional[torch.Tensor] = None,
    encoder_hidden_states: Optional[torch.Tensor] = None,
    cross_attention_kwargs: Optional[Dict[str, Any]] = None,
    upsample_size: Optional[int] = None,
    attention_mask: Optional[torch.Tensor] = None,
    encoder_attention_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
    if cross_attention_kwargs is not None:
        if cross_attention_kwargs.get("scale", None) is not None:
            logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")

    self.cached_outfeats = {}
    res_hidden_states_gradscale = getattr(self, "res_hidden_states_gradscale", 1)
    capture_outfeats            = getattr(self, "capture_outfeats", False)
    layer_idx = 0
    res_grad_scaler = gen_gradient_scaler(res_hidden_states_gradscale)

    for resnet, attn in zip(self.resnets, self.attentions):
        # pop res hidden states
        res_hidden_states = res_hidden_states_tuple[-1]
        res_hidden_states_tuple = res_hidden_states_tuple[:-1]

        # Scale down the magnitudes of gradients to res_hidden_states 
        # by res_hidden_states_gradscale=0.2, to match the scale of the cross-attn layer outputs.
        res_hidden_states = res_grad_scaler(res_hidden_states)

        hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)

        if self.training and self.gradient_checkpointing:
            def create_custom_forward(module, return_dict=None):
                def custom_forward(*inputs):
                    if return_dict is not None:
                        return module(*inputs, return_dict=return_dict)
                    else:
                        return module(*inputs)

                return custom_forward

            ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
            hidden_states = torch.utils.checkpoint.checkpoint(
                create_custom_forward(resnet),
                hidden_states,
                temb,
                **ckpt_kwargs,
            )
            hidden_states = attn(
                hidden_states,
                encoder_hidden_states=encoder_hidden_states,
                cross_attention_kwargs=cross_attention_kwargs,
                attention_mask=attention_mask,
                encoder_attention_mask=encoder_attention_mask,
                return_dict=False,
            )[0]
        else:
            # resnet: ResnetBlock2D instance.
            #LINK diffusers.models.resnet.ResnetBlock2D
            # up_blocks.3.resnets.2.conv_shortcut is a module within ResnetBlock2D,
            # it's not transforming the UNet shortcut features.
            hidden_states = resnet(hidden_states, temb)
            hidden_states = attn(
                hidden_states,
                encoder_hidden_states=encoder_hidden_states,
                cross_attention_kwargs=cross_attention_kwargs,
                attention_mask=attention_mask,
                encoder_attention_mask=encoder_attention_mask,
                return_dict=False,
            )[0]

        if capture_outfeats:
            self.cached_outfeats[layer_idx] = hidden_states
            layer_idx += 1

    if self.upsamplers is not None:
        for upsampler in self.upsamplers:
            hidden_states = upsampler(hidden_states, upsample_size)

    return hidden_states


# Adapted from ConsistentIDPipeline:set_ip_adapter().
# attn_lora_layer_names: candidates are subsets of ['q', 'k', 'v', 'out'].
def set_up_attn_processors(unet, use_attn_lora, attn_lora_layer_names=['q', 'k', 'v', 'out'], 
                           lora_rank=192, lora_scale_down=8, cross_attn_shrink_factor=0.5,
                           q_lora_updates_query=False):
    attn_procs = {}
    attn_capture_procs = {}
    unet_modules = dict(unet.named_modules())
    attn_opt_modules = {}

    attn_proc_idx = 0

    for name, attn_proc in unet.attn_processors.items():
        # Only capture the activations of the last 3 CA layers.
        if not name.startswith("up_blocks.3"):
            # Not the last 3 CA layers. Don't enable LoRA or capture activations. 
            # Then the layer falls back to the original attention mechanism.
            # We still use AttnProcessor_LoRA_Capture, as it can handle img_mask.
            attn_procs[name] = AttnProcessor_LoRA_Capture(
                capture_ca_activations=False, enable_lora=False, attn_proc_idx=-1)
            continue
        # cross_attention_dim: 768.
        cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
        if cross_attention_dim is None:
            # Self attention. Don't enable LoRA or capture activations.
            # We replace the default attn_proc with AttnProcessor_LoRA_Capture, 
            # so that it can incorporate img_mask into self-attention.
            attn_procs[name] = AttnProcessor_LoRA_Capture(
                capture_ca_activations=False, enable_lora=False, attn_proc_idx=-1)            
            continue

        # block_id = 3
        # hidden_size: 320
        # hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
        # 'up_blocks.3.attentions.1.transformer_blocks.0.attn2.processor' ->
        # 'up_blocks.3.attentions.1.transformer_blocks.0.attn2.to_q'
        lora_layer_dict = {}
        lora_layer_dict['q']    = unet_modules[name[:-9] + "to_q"]
        lora_layer_dict['k']    = unet_modules[name[:-9] + "to_k"]
        lora_layer_dict['v']    = unet_modules[name[:-9] + "to_v"]
        # to_out is a ModuleList(Linear, Dropout).
        lora_layer_dict['out']  = unet_modules[name[:-9] + "to_out"][0]

        lora_proj_layers = {}
        # Only apply LoRA to the specified layers.
        for lora_layer_name in attn_lora_layer_names:
            lora_proj_layers[lora_layer_name] = lora_layer_dict[lora_layer_name]

        attn_capture_proc = AttnProcessor_LoRA_Capture(
            capture_ca_activations=True, enable_lora=use_attn_lora, 
            lora_uses_dora=True, lora_proj_layers=lora_proj_layers,
            # LoRA up is initialized to 0. So no need to worry that the LoRA output may be too large.
            lora_rank=lora_rank, lora_alpha=lora_rank // lora_scale_down,
            cross_attn_shrink_factor=cross_attn_shrink_factor,
            q_lora_updates_query=q_lora_updates_query, attn_proc_idx=attn_proc_idx)
        
        attn_proc_idx += 1
        # attn_procs has to use the original names.
        attn_procs[name] = attn_capture_proc
        # ModuleDict doesn't allow "." in the key.
        name = name.replace(".", "_")
        attn_capture_procs[name] = attn_capture_proc
    
        if use_attn_lora:
            for subname, module in attn_capture_proc.named_modules():
                if isinstance(module, peft_lora.LoraLayer):
                    # ModuleDict doesn't allow "." in the key.
                    lora_path = name + "_" + subname.replace(".", "_")
                    attn_opt_modules[lora_path + "_lora_A"] = module.lora_A
                    attn_opt_modules[lora_path + "_lora_B"] = module.lora_B
                    # lora_uses_dora is always True, so we don't check it here.
                    attn_opt_modules[lora_path + "_lora_magnitude_vector"] = module.lora_magnitude_vector
                    # We will manage attn adapters directly. By default, LoraLayer is an instance of BaseTunerLayer,
                    # so according to the code logic in diffusers/loaders/peft.py,
                    # they will be managed by the diffusers PeftAdapterMixin instance, through the
                    # enable_adapters(), and set_adapter() methods.
                    # Therefore, we disable these calls on module. 
                    # disable_adapters() is a property and changing it will cause exceptions.
                    module.enable_adapters  = dummy_func
                    module.set_adapter      = dummy_func

    unet.set_attn_processor(attn_procs)

    print(f"Set up {len(attn_capture_procs)} CrossAttn processors on {attn_capture_procs.keys()}.")
    print(f"Set up {len(attn_opt_modules)} attn LoRA params: {attn_opt_modules.keys()}.")
    return attn_capture_procs, attn_opt_modules

# NOTE: cross-attn layers are included in the returned lora_modules.
def set_up_ffn_loras(unet, target_modules_pat, lora_uses_dora=False, lora_rank=192, lora_alpha=16):
    # target_modules_pat = 'up_blocks.3.resnets.[12].conv[a-z0-9_]+'
    # up_blocks.3.resnets.[1~2].conv1, conv2, conv_shortcut
    # Cannot set to conv.+ as it will match added adapter module names, including
    # up_blocks.3.resnets.1.conv1.base_layer, up_blocks.3.resnets.1.conv1.lora_dropout
    if target_modules_pat is not None:
        peft_config = LoraConfig(use_dora=lora_uses_dora, inference_mode=False, r=lora_rank, 
                                 lora_alpha=lora_alpha, lora_dropout=0.1,
                                 target_modules=target_modules_pat)

        # UNet is a diffusers PeftAdapterMixin instance. Using get_peft_model on it will
        # cause weird errors. Instead, we directly use diffusers peft adapter methods.
        unet.add_adapter(peft_config, "recon_loss")
        unet.add_adapter(peft_config, "unet_distill")
        unet.add_adapter(peft_config, "comp_distill")
        unet.enable_adapters()

    # lora_layers contain both the LoRA A and B matrices, as well as the original layers.
    # lora_layers are used to set the flag, not used for optimization.
    # lora_modules contain only the LoRA A and B matrices, so they are used for optimization.
    # NOTE: lora_modules contain both ffn and cross-attn lora modules.
    ffn_lora_layers = {}
    ffn_opt_modules = {}
    for name, module in unet.named_modules():
        if isinstance(module, peft_lora.LoraLayer):
            # We don't want to include cross-attn layers in ffn_lora_layers.
            if target_modules_pat is not None and re.search(target_modules_pat, name):
                ffn_lora_layers[name] = module
                # ModuleDict doesn't allow "." in the key.
                name = name.replace(".", "_")
                # Since ModuleDict doesn't allow "." in the key, we manually collect
                # the LoRA matrices in each module.
                # NOTE: We cannot put every sub-module of module into lora_modules,
                # as base_layer is also a sub-module of module, which we shouldn't optimize.
                # Each value in ffn_opt_modules is a ModuleDict:
                '''
                    (Pdb) ffn_opt_modules['up_blocks_3_resnets_1_conv1_lora_A']
                    ModuleDict(
                    (unet_distill): Conv2d(640, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
                    (recon_loss): Conv2d(640, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
                    )                
                '''
                ffn_opt_modules[name + "_lora_A"] = module.lora_A
                ffn_opt_modules[name + "_lora_B"] = module.lora_B
                if lora_uses_dora:
                    ffn_opt_modules[name + "_lora_magnitude_vector"] = module.lora_magnitude_vector

    print(f"Set up {len(ffn_lora_layers)} FFN LoRA layers: {ffn_lora_layers.keys()}.")
    print(f"Set up {len(ffn_opt_modules)} FFN LoRA params: {ffn_opt_modules.keys()}.")

    return ffn_lora_layers, ffn_opt_modules

def set_lora_and_capture_flags(unet, unet_lora_modules, attn_capture_procs, 
                               outfeat_capture_blocks, res_hidden_states_gradscale_blocks,
                               use_attn_lora, use_ffn_lora, ffn_lora_adapter_name, capture_ca_activations, 
                               shrink_cross_attn, res_hidden_states_gradscale):
    # For attn capture procs, capture_ca_activations and use_attn_lora are set in reset_attn_cache_and_flags().
    for attn_capture_proc in attn_capture_procs:
        attn_capture_proc.reset_attn_cache_and_flags(capture_ca_activations, shrink_cross_attn, enable_lora=use_attn_lora)
    # outfeat_capture_blocks only contains the last up block, up_blocks[3].
    # It contains 3 FFN layers. We want to capture their output features.
    for block in outfeat_capture_blocks:
        block.capture_outfeats           = capture_ca_activations

    for block in res_hidden_states_gradscale_blocks:
        block.res_hidden_states_gradscale = res_hidden_states_gradscale

    if not use_ffn_lora:
        unet.disable_adapters()
    else:
        # ffn_lora_adapter_name: 'recon_loss', 'unet_distill', 'comp_distill'.
        if ffn_lora_adapter_name is not None:
            unet.set_adapter(ffn_lora_adapter_name)
            # NOTE: Don't forget to enable_adapters(). 
            # The adapters are not enabled by default after set_adapter().
            unet.enable_adapters()
        else:
            breakpoint()

    # During training, disable_adapters() and set_adapter() will set all/inactive adapters with requires_grad=False, 
    # which might cause issues during DDP training.
    # So we restore them to requires_grad=True.
    # During test, unet_lora_modules will be passed as None, so this block will be skipped.
    if unet_lora_modules is not None:
        for param in unet_lora_modules.parameters():
            param.requires_grad = True

def get_captured_activations(capture_ca_activations, attn_capture_procs, outfeat_capture_blocks, 
                             captured_layer_indices=[22, 23, 24], out_dtype=torch.float32):
    captured_activations = { k: {} for k in ('outfeat', 'attn', 'attnscore', 
                                             'q', 'q2', 'k', 'v', 'attn_out') }

    if not capture_ca_activations:
        return captured_activations
    
    all_cached_outfeats = []
    for block in outfeat_capture_blocks:
        all_cached_outfeats.append(block.cached_outfeats)
        # Clear the capture flag and cached outfeats.
        block.cached_outfeats = {}
        block.capture_outfeats = False

    for layer_idx in captured_layer_indices:
        # Subtract 22 to ca_layer_idx to match the layer index in up_blocks[3].cached_outfeats.
        # 23, 24 -> 1, 2 (!! not 0, 1 !!)
        internal_idx = layer_idx - 22
        for k in captured_activations.keys():
            if k == 'outfeat':
                # Currently we only capture one block, up_blocks.3. So we hard-code the index 0.
                captured_activations['outfeat'][layer_idx] = all_cached_outfeats[0][internal_idx].to(out_dtype)
            else:
                # internal_idx is the index of layers in up_blocks.3. 
                # Layers 22, 23 and 24 map to 0, 1 and 2.
                cached_activations = attn_capture_procs[internal_idx].cached_activations
                captured_activations[k][layer_idx] = cached_activations[k].to(out_dtype)

    return captured_activations