File size: 20,538 Bytes
231edce
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
import sys
from functools import partial
import torch
from torch import nn
from torch.autograd import Function as Function

from .attention import MultiScaleAttention, attention_pool
from .common import Mlp, TwoStreamFusion, drop_path
from .utils import round_width


class ReversibleMViT(nn.Module):
    """
    Reversible model builder. This builds the reversible transformer encoder
    and allows reversible training.

    Karttikeya Mangalam, Haoqi Fan, Yanghao Li, Chao-Yuan Wu, Bo Xiong,
    Christoph Feichtenhofer, Jitendra Malik
    "Reversible Vision Transformers"

    https://openaccess.thecvf.com/content/CVPR2022/papers/Mangalam_Reversible_Vision_Transformers_CVPR_2022_paper.pdf
    """

    def __init__(self, config, model):
        """
        The `__init__` method of any subclass should also contain these
            arguments.
        Args:
            cfg (CfgNode): model building configs, details are in the
                comments of the config file.
            model (nn.Module): parent MViT module this module forms
                a reversible encoder in.
        """

        super().__init__()
        self.cfg = config

        embed_dim = self.cfg.MVIT.EMBED_DIM
        depth = self.cfg.MVIT.DEPTH
        num_heads = self.cfg.MVIT.NUM_HEADS
        mlp_ratio = self.cfg.MVIT.MLP_RATIO
        qkv_bias = self.cfg.MVIT.QKV_BIAS

        drop_path_rate = self.cfg.MVIT.DROPPATH_RATE
        self.dropout = config.MVIT.DROPOUT_RATE
        self.pre_q_fusion = self.cfg.MVIT.REV.PRE_Q_FUSION
        dpr = [
            x.item() for x in torch.linspace(0, drop_path_rate, depth)
        ]  # stochastic depth decay rule

        input_size = model.patch_dims

        self.layers = nn.ModuleList([])
        self.no_custom_backward = False

        if self.cfg.MVIT.NORM == "layernorm":
            norm_layer = partial(nn.LayerNorm, eps=1e-6)
        else:
            raise NotImplementedError("Only supports layernorm.")

        dim_mul, head_mul = torch.ones(depth + 1), torch.ones(depth + 1)
        for i in range(len(self.cfg.MVIT.DIM_MUL)):
            dim_mul[self.cfg.MVIT.DIM_MUL[i][0]] = self.cfg.MVIT.DIM_MUL[i][1]
        for i in range(len(self.cfg.MVIT.HEAD_MUL)):
            head_mul[self.cfg.MVIT.HEAD_MUL[i][0]] = self.cfg.MVIT.HEAD_MUL[i][
                1
            ]

        pool_q = model.pool_q
        pool_kv = model.pool_kv
        stride_q = model.stride_q
        stride_kv = model.stride_kv

        for i in range(depth):

            num_heads = round_width(num_heads, head_mul[i])

            # Upsampling inside the MHPA, input to the Q-pooling block is lower C dimension
            # This localizes the feature changes in a single block, making more computation reversible.
            embed_dim = round_width(
                embed_dim, dim_mul[i - 1] if i > 0 else 1.0, divisor=num_heads
            )
            dim_out = round_width(
                embed_dim,
                dim_mul[i],
                divisor=round_width(num_heads, head_mul[i + 1]),
            )

            if i in self.cfg.MVIT.REV.BUFFER_LAYERS:
                layer_type = StageTransitionBlock
                input_mult = 2 if "concat" in self.pre_q_fusion else 1
            else:
                layer_type = ReversibleBlock
                input_mult = 1

            dimout_correction = (
                2 if (input_mult == 2 and "concat" in self.pre_q_fusion) else 1
            )

            self.layers.append(
                layer_type(
                    dim=embed_dim
                    * input_mult,  # added only for concat fusion before Qpooling layers
                    input_size=input_size,
                    dim_out=dim_out * input_mult // dimout_correction,
                    num_heads=num_heads,
                    cfg=self.cfg,
                    mlp_ratio=mlp_ratio,
                    qkv_bias=qkv_bias,
                    drop_path=dpr[i],
                    norm_layer=norm_layer,
                    kernel_q=pool_q[i] if len(pool_q) > i else [],
                    kernel_kv=pool_kv[i] if len(pool_kv) > i else [],
                    stride_q=stride_q[i] if len(stride_q) > i else [],
                    stride_kv=stride_kv[i] if len(stride_kv) > i else [],
                    layer_id=i,
                    pre_q_fusion=self.pre_q_fusion,
                )
            )
            # F is the attention block
            self.layers[-1].F.thw = input_size

            if len(stride_q[i]) > 0:
                input_size = [
                    size // stride
                    for size, stride in zip(input_size, stride_q[i])
                ]

        embed_dim = dim_out

    @staticmethod
    def vanilla_backward(h, layers, buffer):
        """
        Using rev layers without rev backpropagation. Debugging purposes only.
        Activated with self.no_custom_backward.
        """

        # split into hidden states (h) and attention_output (a)
        h, a = torch.chunk(h, 2, dim=-1)
        for _, layer in enumerate(layers):
            a, h = layer(a, h)

        return torch.cat([a, h], dim=-1)

    def forward(self, x):

        # process the layers in a reversible stack and an irreversible stack.
        stack = []
        for l_i in range(len(self.layers)):
            if isinstance(self.layers[l_i], StageTransitionBlock):
                stack.append(("StageTransition", l_i))
            else:
                if len(stack) == 0 or stack[-1][0] == "StageTransition":
                    stack.append(("Reversible", []))
                stack[-1][1].append(l_i)

        for layer_seq in stack:

            if layer_seq[0] == "StageTransition":
                x = self.layers[layer_seq[1]](x)

            else:
                x = torch.cat([x, x], dim=-1)

                # no need for custom backprop in eval/model stat log
                if not self.training or self.no_custom_backward:
                    executing_fn = ReversibleMViT.vanilla_backward
                else:
                    executing_fn = RevBackProp.apply

                x = executing_fn(
                    x,
                    self.layers[layer_seq[1][0] : layer_seq[1][-1] + 1],
                    [],  # buffer activations
                )

        # Apply dropout
        x = nn.functional.dropout(x, p=self.dropout, training=self.training)

        return x


class RevBackProp(Function):
    """
    Custom Backpropagation function to allow (A) flusing memory in foward
    and (B) activation recomputation reversibly in backward for gradient calculation.

    Inspired by https://github.com/RobinBruegger/RevTorch/blob/master/revtorch/revtorch.py
    """

    @staticmethod
    def forward(
        ctx,
        x,
        layers,
        buffer_layers,  # List of layer ids for int activation to buffer
    ):
        """
        Reversible Forward pass. Any intermediate activations from `buffer_layers` are
        cached in ctx for forward pass. This is not necessary for standard usecases.
        Each reversible layer implements its own forward pass logic.
        """
        buffer_layers.sort()

        X_1, X_2 = torch.chunk(x, 2, dim=-1)

        intermediate = []

        for layer in layers:

            X_1, X_2 = layer(X_1, X_2)

            if layer.layer_id in buffer_layers:
                intermediate.extend([X_1.detach(), X_2.detach()])

        if len(buffer_layers) == 0:
            all_tensors = [X_1.detach(), X_2.detach()]
        else:
            intermediate = [torch.LongTensor(buffer_layers), *intermediate]
            all_tensors = [X_1.detach(), X_2.detach(), *intermediate]

        ctx.save_for_backward(*all_tensors)
        ctx.layers = layers

        return torch.cat([X_1, X_2], dim=-1)

    @staticmethod
    def backward(ctx, dx):
        """
        Reversible Backward pass. Any intermediate activations from `buffer_layers` are
        recovered from ctx. Each layer implements its own loic for backward pass (both
        activation recomputation and grad calculation).
        """
        dX_1, dX_2 = torch.chunk(dx, 2, dim=-1)

        # retrieve params from ctx for backward
        X_1, X_2, *int_tensors = ctx.saved_tensors

        # no buffering
        if len(int_tensors) != 0:
            buffer_layers = int_tensors[0].tolist()

        else:
            buffer_layers = []

        layers = ctx.layers

        for _, layer in enumerate(layers[::-1]):

            if layer.layer_id in buffer_layers:

                X_1, X_2, dX_1, dX_2 = layer.backward_pass(
                    Y_1=int_tensors[
                        buffer_layers.index(layer.layer_id) * 2 + 1
                    ],
                    Y_2=int_tensors[
                        buffer_layers.index(layer.layer_id) * 2 + 2
                    ],
                    dY_1=dX_1,
                    dY_2=dX_2,
                )

            else:

                X_1, X_2, dX_1, dX_2 = layer.backward_pass(
                    Y_1=X_1,
                    Y_2=X_2,
                    dY_1=dX_1,
                    dY_2=dX_2,
                )

        dx = torch.cat([dX_1, dX_2], dim=-1)

        del int_tensors
        del dX_1, dX_2, X_1, X_2

        return dx, None, None


class StageTransitionBlock(nn.Module):
    """
    Blocks for changing the feature dimensions in MViT (using Q-pooling).
    See Section 3.3.1 in paper for details.
    """

    def __init__(
        self,
        dim,
        input_size,
        dim_out,
        num_heads,
        mlp_ratio,
        qkv_bias,
        drop_path,
        kernel_q,
        kernel_kv,
        stride_q,
        stride_kv,
        cfg,
        norm_layer=nn.LayerNorm,
        pre_q_fusion=None,
        layer_id=0,
    ):
        """
        Uses the same structure of F and G functions as Reversible Block except
        without using reversible forward (and backward) pass.
        """
        super().__init__()

        self.drop_path_rate = drop_path

        embed_dim = dim

        self.F = AttentionSubBlock(
            dim=embed_dim,
            input_size=input_size,
            num_heads=num_heads,
            cfg=cfg,
            dim_out=dim_out,
            kernel_q=kernel_q,
            kernel_kv=kernel_kv,
            stride_q=stride_q,
            stride_kv=stride_kv,
            norm_layer=norm_layer,
        )

        self.G = MLPSubblock(
            dim=dim_out,
            mlp_ratio=mlp_ratio,
            norm_layer=norm_layer,
        )

        self.layer_id = layer_id

        self.is_proj = False
        self.has_cls_embed = cfg.MVIT.CLS_EMBED_ON

        self.is_conv = False
        self.pool_first = cfg.MVIT.POOL_FIRST
        self.mode = cfg.MVIT.MODE
        self.pre_q_fuse = TwoStreamFusion(pre_q_fusion, dim=dim)

        if cfg.MVIT.REV.RES_PATH == "max":
            self.res_conv = False
            self.pool_skip = nn.MaxPool3d(
                # self.attention.attn.pool_q.kernel_size,
                [s + 1 if s > 1 else s for s in self.F.attn.pool_q.stride],
                self.F.attn.pool_q.stride,
                [int(k // 2) for k in self.F.attn.pool_q.stride],
                # self.attention.attn.pool_q.padding,
                ceil_mode=False,
            )

        elif cfg.MVIT.REV.RES_PATH == "conv":
            self.res_conv = True
        else:
            raise NotImplementedError

        # Add a linear projection in residual branch
        if embed_dim != dim_out:
            self.is_proj = True
            self.res_proj = nn.Linear(embed_dim, dim_out, bias=True)

    def forward(
        self,
        x,
    ):
        """
        Forward logic is similar to MultiScaleBlock with Q-pooling.
        """
        x = self.pre_q_fuse(x)

        # fork tensor for residual connections
        x_res = x

        # This uses conv to pool the residual hidden features
        # but done before pooling only if not pool_first
        if self.is_proj and not self.pool_first:
            x_res = self.res_proj(x_res)

        if self.res_conv:

            # Pooling the hidden features with the same conv as Q
            N, L, C = x_res.shape

            # This handling is the same as that of q in MultiScaleAttention
            if self.mode == "conv_unshared":
                fold_dim = 1
            else:
                fold_dim = self.F.attn.num_heads

            # Output is (B, N, L, C)
            x_res = x_res.reshape(N, L, fold_dim, C // fold_dim).permute(
                0, 2, 1, 3
            )

            x_res, _ = attention_pool(
                x_res,
                self.F.attn.pool_q,
                # thw_shape = self.attention.attn.thw,
                thw_shape=self.F.thw,
                has_cls_embed=self.has_cls_embed,
                norm=self.F.attn.norm_q
                if hasattr(self.F.attn, "norm_q")
                else None,
            )
            x_res = x_res.permute(0, 2, 1, 3).reshape(N, x_res.shape[2], C)

        else:
            # Pooling the hidden features with max op
            x_res, _ = attention_pool(
                x_res,
                self.pool_skip,
                thw_shape=self.F.attn.thw,
                has_cls_embed=self.has_cls_embed,
            )

        # If pool_first then project to higher dim now
        if self.is_proj and self.pool_first:
            x_res = self.res_proj(x_res)

        x = self.F(x)
        x = x_res + x
        x = x + self.G(x)

        x = drop_path(x, drop_prob=self.drop_path_rate, training=self.training)

        return x


class ReversibleBlock(nn.Module):
    """
    Reversible Blocks for Reversible Vision Transformer and also
    for state-preserving blocks in Reversible MViT. See Section
    3.3.2 in paper for details.
    """

    def __init__(
        self,
        dim,
        input_size,
        dim_out,
        num_heads,
        mlp_ratio,
        qkv_bias,
        drop_path,
        kernel_q,
        kernel_kv,
        stride_q,
        stride_kv,
        cfg,
        norm_layer=nn.LayerNorm,
        layer_id=0,
        **kwargs
    ):
        """
        Block is composed entirely of function F (Attention
        sub-block) and G (MLP sub-block) including layernorm.
        """
        super().__init__()

        self.drop_path_rate = drop_path

        self.F = AttentionSubBlock(
            dim=dim,
            input_size=input_size,
            num_heads=num_heads,
            cfg=cfg,
            dim_out=dim_out,
            kernel_q=kernel_q,
            kernel_kv=kernel_kv,
            stride_q=stride_q,
            stride_kv=stride_kv,
            norm_layer=norm_layer,
        )

        self.G = MLPSubblock(
            dim=dim,
            mlp_ratio=mlp_ratio,
            norm_layer=norm_layer,
        )

        self.layer_id = layer_id

        self.seeds = {}

    def seed_cuda(self, key):
        """
        Fix seeds to allow for stochastic elements such as
        dropout to be reproduced exactly in activation
        recomputation in the backward pass.
        """

        # randomize seeds
        # use cuda generator if available
        if (
            hasattr(torch.cuda, "default_generators")
            and len(torch.cuda.default_generators) > 0
        ):
            # GPU
            device_idx = torch.cuda.current_device()
            seed = torch.cuda.default_generators[device_idx].seed()
        else:
            # CPU
            seed = int(torch.seed() % sys.maxsize)

        self.seeds[key] = seed
        torch.manual_seed(self.seeds[key])

    def forward(self, X_1, X_2):
        """
        forward pass equations:
        Y_1 = X_1 + Attention(X_2), F = Attention
        Y_2 = X_2 + MLP(Y_1), G = MLP
        """

        self.seed_cuda("attn")
        # Y_1 : attn_output
        f_X_2 = self.F(X_2)

        self.seed_cuda("droppath")
        f_X_2_dropped = drop_path(
            f_X_2, drop_prob=self.drop_path_rate, training=self.training
        )

        # Y_1 = X_1 + f(X_2)
        Y_1 = X_1 + f_X_2_dropped

        # free memory
        del X_1

        self.seed_cuda("FFN")
        g_Y_1 = self.G(Y_1)

        torch.manual_seed(self.seeds["droppath"])
        g_Y_1_dropped = drop_path(
            g_Y_1, drop_prob=self.drop_path_rate, training=self.training
        )

        # Y_2 = X_2 + g(Y_1)
        Y_2 = X_2 + g_Y_1_dropped

        del X_2

        return Y_1, Y_2

    def backward_pass(
        self,
        Y_1,
        Y_2,
        dY_1,
        dY_2,
    ):
        """
        equation for activation recomputation:
        X_2 = Y_2 - G(Y_1), G = MLP
        X_1 = Y_1 - F(X_2), F = Attention
        """

        # temporarily record intermediate activation for G
        # and use them for gradient calculcation of G
        with torch.enable_grad():

            Y_1.requires_grad = True

            torch.manual_seed(self.seeds["FFN"])
            g_Y_1 = self.G(Y_1)

            torch.manual_seed(self.seeds["droppath"])
            g_Y_1 = drop_path(
                g_Y_1, drop_prob=self.drop_path_rate, training=self.training
            )

            g_Y_1.backward(dY_2, retain_graph=True)

        # activation recomputation is by design and not part of
        # the computation graph in forward pass.
        with torch.no_grad():

            X_2 = Y_2 - g_Y_1
            del g_Y_1

            dY_1 = dY_1 + Y_1.grad
            Y_1.grad = None

        # record F activations and calc gradients on F
        with torch.enable_grad():
            X_2.requires_grad = True

            torch.manual_seed(self.seeds["attn"])
            f_X_2 = self.F(X_2)

            torch.manual_seed(self.seeds["droppath"])
            f_X_2 = drop_path(
                f_X_2, drop_prob=self.drop_path_rate, training=self.training
            )

            f_X_2.backward(dY_1, retain_graph=True)

        # propagate reverse computed acitvations at the start of
        # the previou block for backprop.s
        with torch.no_grad():

            X_1 = Y_1 - f_X_2

            del f_X_2, Y_1
            dY_2 = dY_2 + X_2.grad

            X_2.grad = None
            X_2 = X_2.detach()

        return X_1, X_2, dY_1, dY_2


class MLPSubblock(nn.Module):
    """
    This creates the function G such that the entire block can be
    expressed as F(G(X)). Includes pre-LayerNorm.
    """

    def __init__(
        self,
        dim,
        mlp_ratio,
        norm_layer=nn.LayerNorm,
    ):

        super().__init__()
        self.norm = norm_layer(dim, eps=1e-6, elementwise_affine=True)

        mlp_hidden_dim = int(dim * mlp_ratio)

        self.mlp = Mlp(
            in_features=dim,
            hidden_features=mlp_hidden_dim,
            act_layer=nn.GELU,
        )

    def forward(self, x):
        return self.mlp(self.norm(x))


class AttentionSubBlock(nn.Module):
    """
    This creates the function F such that the entire block can be
    expressed as F(G(X)). Includes pre-LayerNorm.
    """

    def __init__(
        self,
        dim,
        input_size,
        num_heads,
        cfg,
        dim_out=None,
        kernel_q=(1, 1, 1),
        kernel_kv=(1, 1, 1),
        stride_q=(1, 1, 1),
        stride_kv=(1, 1, 1),
        norm_layer=nn.LayerNorm,
    ):

        super().__init__()
        self.norm = norm_layer(dim, eps=1e-6, elementwise_affine=True)

        # This will be set externally during init
        self.thw = None

        # the actual attention details are the same as Multiscale
        # attention for MViTv2 (with channel up=projection inside block)
        # can also implement no upprojection attention for vanilla ViT
        self.attn = MultiScaleAttention(
            dim,
            dim_out,
            input_size=input_size,
            num_heads=num_heads,
            kernel_q=kernel_q,
            kernel_kv=kernel_kv,
            stride_q=stride_q,
            stride_kv=stride_kv,
            norm_layer=norm_layer,
            drop_rate=cfg.MVIT.DROPOUT_RATE,
            qkv_bias=cfg.MVIT.QKV_BIAS,
            has_cls_embed=cfg.MVIT.CLS_EMBED_ON,
            mode=cfg.MVIT.MODE,
            pool_first=cfg.MVIT.POOL_FIRST,
            rel_pos_spatial=cfg.MVIT.REL_POS_SPATIAL,
            rel_pos_temporal=cfg.MVIT.REL_POS_TEMPORAL,
            rel_pos_zero_init=cfg.MVIT.REL_POS_ZERO_INIT,
            residual_pooling=cfg.MVIT.RESIDUAL_POOLING,
            separate_qkv=cfg.MVIT.SEPARATE_QKV,
        )

    def forward(self, x):
        out, _ = self.attn(self.norm(x), self.thw)
        return out