File size: 29,755 Bytes
edcf5ee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
"""
MAE Model    Script  ver: Oct 23rd 15:00

# References:
Based on MAE code.
https://github.com/facebookresearch/mae

timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm
DeiT: https://github.com/facebookresearch/deit


July 16th
Add patchify_decoder to form B,N,D
Add a parameter for MAE to import segmentation network
"""
from functools import partial

import torch
import torch.nn as nn

from timm.models.vision_transformer import PatchEmbed, Block
from Backbone.VPT_structure import VPT_ViT
from SSL_structures.pos_embed import get_2d_sincos_pos_embed


class MaskedAutoencoderViT(VPT_ViT):
    """
    Masked Autoencoder with VisionTransformer backbone
    """

    def __init__(self, img_size=224, patch_size=16, in_chans=3,
                 embed_dim=1024, depth=24, num_heads=16,
                 decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16,
                 mlp_ratio=4., norm_layer=nn.LayerNorm, norm_pix_loss=False,
                 prompt_mode=None, Prompt_Token_num=20, basic_state_dict=None, decoder=None, decoder_rep_dim=None):

        #     model = MaskedAutoencoderViT(
        #         patch_size=16, embed_dim=768, depth=12, num_heads=12,
        #         decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16,
        #         mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)

        if prompt_mode is None:
            super().__init__()
            # MAE encoder specifics (this part just the same as ViT)
            # --------------------------------------------------------------------------
            self.patch_embed = PatchEmbed(img_size, patch_size, in_chans, embed_dim)  # BCHW -> BNC
            num_patches = self.patch_embed.num_patches

            # learnable cls token is still used but on cls head need
            self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
            # set and freeze encoder_pos_embed,  use the fixed sin-cos embedding for tokens + mask_token
            self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim), requires_grad=False)
            # Encoder blocks
            self.blocks = nn.ModuleList([  # qk_scale=None fixme related to timm version
                Block(embed_dim, num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer)
                for i in range(depth)])
            self.norm = norm_layer(embed_dim)

            self.prompt_mode = prompt_mode
            # --------------------------------------------------------------------------

        else:
            super().__init__(img_size=img_size, patch_size=patch_size, in_chans=in_chans,
                             embed_dim=embed_dim, depth=depth, num_heads=num_heads, mlp_ratio=mlp_ratio,
                             norm_layer=norm_layer, Prompt_Token_num=Prompt_Token_num, VPT_type=prompt_mode,
                             basic_state_dict=None)  # Firstly, set then Encoder state_dict to none here.
            num_patches = self.patch_embed.num_patches  # set patch_embed of VPT
            # set and freeze encoder_pos_embed,  use the fixed sin-cos embedding for tokens + mask_token
            self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim), requires_grad=False)

            self.prompt_mode = prompt_mode
            # Freeze Encoder parameters except of the Prompt Tokens
            self.Freeze()

        # MAE decoder specifics
        # --------------------------------------------------------------------------
        # if the feature dimension of encoder and decoder are different, use decoder_embed to align them
        if embed_dim != decoder_embed_dim:
            self.decoder_embed = nn.Linear(embed_dim, decoder_embed_dim, bias=True)
        else:
            self.decoder_embed = nn.Identity()

        if decoder is not None:
            self.decoder = decoder
            # set mask_token (learnable mask token for reconstruction)
            self.mask_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
            # Decoder use a FC to reconstruct image, unlike the Encoder which use a CNN to split patch
            self.decoder_pred = nn.Linear(decoder_rep_dim, patch_size ** 2 * in_chans, bias=True)  # decoder to patch

        else:
            self.decoder = None  # 未传入decoder则与encoder流程一致,但是更改了通道数量,构建block(原版MAE)
            # set mask_token (learnable mask token for reconstruction)
            self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim))

            # set and freeze decoder_pos_embed,  use the fixed sin-cos embedding for tokens + mask_token
            self.decoder_pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, decoder_embed_dim),
                                                  requires_grad=False)
            self.decoder_blocks = nn.ModuleList([Block(decoder_embed_dim, decoder_num_heads, mlp_ratio,
                                                       qkv_bias=True, norm_layer=norm_layer)
                                                 for i in range(decoder_depth)])
            # qk_scale=None fixme related to timm version
            self.decoder_norm = norm_layer(decoder_embed_dim)

            # Decoder use a FC to reconstruct image, unlike the Encoder which use a CNN to split patch
            self.decoder_pred = nn.Linear(decoder_embed_dim, patch_size ** 2 * in_chans, bias=True)  # decoder to patch

        # --------------------------------------------------------------------------
        # wether or not to use norm_pix_loss
        self.norm_pix_loss = norm_pix_loss
        # parameter initialization
        self.initialize_weights()

        # load basic state_dict of backbone for Transfer-learning-based tuning
        if basic_state_dict is not None:
            self.load_state_dict(basic_state_dict, False)

    def initialize_weights(self):
        # initialization
        # initialize a 2d positional encoding of (embed_dim, grid) by sin-cos embedding
        pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1],
                                            int(self.patch_embed.num_patches ** .5),
                                            cls_token=True)
        # return: pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
        self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))

        if self.decoder is None:
            # initialize a 2d positional encoding of (embed_dim, grid) by sin-cos embedding
            decoder_pos_embed = get_2d_sincos_pos_embed(self.decoder_pos_embed.shape[-1],
                                                        int(self.patch_embed.num_patches ** .5),
                                                        cls_token=True)
            self.decoder_pos_embed.data.copy_(torch.from_numpy(decoder_pos_embed).float().unsqueeze(0))

        # initialize patch_embed like nn.Linear (instead of nn.Conv2d)
        w = self.patch_embed.proj.weight.data
        torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1]))  # xavier_uniform,让输入输出的方差相同,包括前后向传播

        # timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.)
        torch.nn.init.normal_(self.cls_token, std=.02)
        torch.nn.init.normal_(self.mask_token, std=.02)

        # initialize nn.Linear and nn.LayerNorm
        self.apply(self._init_weights)

    def _init_weights(self, m):
        # initialize nn.Linear and nn.LayerNorm
        if isinstance(m, nn.Linear):
            # we use xavier_uniform following official JAX ViT:
            torch.nn.init.xavier_uniform_(m.weight)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

    def patchify(self, imgs):
        """
        Encode image to patch tokens

        input:
        imgs: (B, 3, H, W)

        output:
        x: (B, num_patches, patch_size**2 *3) AKA [B, num_patches, flatten_dim]
        """
        # patch_size
        p = self.patch_embed.patch_size[0]
        # assert H == W and image shape is dividedable by patch
        assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % p == 0
        # patch num in rol or column
        h = w = imgs.shape[2] // p

        # use reshape to split patch [B, C, H, W] -> [B, C, h_p, p, w_p, p]
        x = imgs.reshape(shape=(imgs.shape[0], 3, h, p, w, p))
        # ReArrange dimensions [B, C, h_p, p, w_p, p] -> [B, h_p, w_p, p, p, C]
        x = torch.einsum('nchpwq->nhwpqc', x)
        # ReArrange dimensions [B, h_p, w_p, p, p, C] -> [B, num_patches, flatten_dim]
        x = x.reshape(shape=(imgs.shape[0], h * w, p ** 2 * 3))
        return x

    def patchify_decoder(self, imgs, patch_size=None):  # TODO 这里目的很大,需要实现预训练!
        """
        Break image to patch tokens

        fixme,注意,这里patch_size应该是按照decoder的网络设置来作为default

        input:
        imgs: (B, CLS, H, W)

        output:
        x: (B, num_patches, -1) AKA [B, num_patches, -1]
        """
        # patch_size
        patch_size = self.patch_embed.patch_size[0] if patch_size is None else patch_size

        # assert H == W and image shape is divided-able by patch
        assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % patch_size == 0
        # patch num in rol or column
        h = w = imgs.shape[2] // patch_size

        # use reshape to split patch [B, C, H, W] -> [B, C, h_p, patch_size, w_p, patch_size]
        x = imgs.reshape(shape=(imgs.shape[0], -1, h, patch_size, w, patch_size))

        # ReArrange dimensions [B, C, h_p, patch_size, w_p, patch_size] -> [B, h_p, w_p, patch_size, patch_size, C]
        x = torch.einsum('nchpwq->nhwpqc', x)
        # ReArrange dimensions [B, h_p, w_p, patch_size, patch_size, C] -> [B, num_patches, flatten_dim]
        x = x.reshape(shape=(imgs.shape[0], h * w, -1))
        return x

    def unpatchify(self, x, patch_size=None):
        """
        Decoding encoded patch tokens

        input:
        x: (B, num_patches, patch_size**2 *3) AKA [B, num_patches, flatten_dim]

        output:
        imgs: (B, 3, H, W)
        """
        # patch_size
        p = self.patch_embed.patch_size[0] if patch_size is None else patch_size

        # squre root of num_patches(without CLS token required)
        h = w = int(x.shape[1] ** .5)
        # assert num_patches is without CLS token
        assert h * w == x.shape[1]

        # ReArrange dimensions [B, num_patches, flatten_dim] -> [B, h_p, w_p, p, p, C]
        x = x.reshape(shape=(x.shape[0], h, w, p, p, 3))
        # ReArrange dimensions [B, h_p, w_p, p, p, C] -> [B, C, h_p, p, w_p, p]
        x = torch.einsum('nhwpqc->nchpwq', x)
        # use reshape to compose patch [B, C, h_p, p, w_p, p] -> [B, C, H, W]
        imgs = x.reshape(shape=(x.shape[0], 3, h * p, h * p))
        return imgs

    def random_masking(self, x, mask_ratio):
        """
        Perform per-sample random masking by per-sample shuffling.
        Per-sample shuffling is done by argsort random noise.

        注意torch.argsort返回的是:
        在每个指定dim,按原tensor每个位置数值大小升序排列后,的原本位置的idx组成的矩阵

        input:
        x: [B, num_patches, D], sequence of Tokens

        output: x_remained, mask, ids_restore
        x_remained: [B, num_patches * (1-mask_ratio), D], sequence of Tokens
        mask: [B, num_patches], binary mask
        ids_restore: [B, num_patches], idx of restoring all position
        """
        B, num_patches, D = x.shape  # batch, length, dim
        # 计算需要保留的位置的个数
        len_keep = int(num_patches * (1 - mask_ratio))
        # 做一个随机序列[B,num_patches],用于做位置标号
        noise = torch.rand(B, num_patches, device=x.device)  # noise in [0, 1]

        # 在Batch里面每个序列上获得noise tensor经过升序排列后原本位置的idx矩阵  在batch内进行升序排列
        ids_shuffle = torch.argsort(noise, dim=1)  # ascend: small is keep, large is remove
        # 再对idx矩阵继续升序排列可获得:原始noise tensor的每个位置的排序顺位
        ids_restore = torch.argsort(ids_shuffle, dim=1)

        # keep the first subset
        ids_keep = ids_shuffle[:, :len_keep]

        # 设置需要的patch的索引
        # ids_keep.unsqueeze(-1).repeat(1, 1, D):
        # [B,num_patches] -> [B,keep_patches] -> [B,keep_patches,1] 每个位置数字为idx of ori patch -> [B,keep_patches,D]

        # torch.gather 按照索引取值构建新tensor: x_remained [B,keep_patches,D] 表示被标记需要保留的位置, 原文是x_masked
        x_remained = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))

        # generate the binary mask: 0 is keep, 1 is remove
        mask = torch.ones([B, num_patches], device=x.device)
        mask[:, :len_keep] = 0  # 设置mask矩阵,前len_keep个为0,后面为1

        # 按照noise tensor每个位置的大小顺序,来设置mask符号为0的位置,获得mask矩阵
        mask = torch.gather(mask, dim=1, index=ids_restore)

        return x_remained, mask, ids_restore  # x_remained原文是x_masked

    def forward_encoder(self, imgs, mask_ratio):
        """
        :param imgs: [B, C, H, W], sequence of imgs
        :param mask_ratio: mask_ratio

        :return: Encoder output: encoded tokens, mask position, restore idxs
        x: [B, 1 + num_patches * (1-mask_ratio), D], sequence of Tokens (including the cls token)
        mask: [B, num_patches], binary mask
        ids_restore: [B, num_patches], idx of restoring all position
        """
        if self.prompt_mode is None:  # ViT
            # embed patches
            x = self.patch_embed(imgs)  # BCHW -> BNC

            # add pos embed w/o cls token
            x = x + self.pos_embed[:, 1:, :]  # add pos embed before concatenate the cls token

            # masking: length -> length * (1-mask_ratio)
            # x_remained: [B, num_patches * (1-mask_ratio), D], sequence of Tokens
            x, mask, ids_restore = self.random_masking(x, mask_ratio)

            # append cls token
            cls_token = self.cls_token + self.pos_embed[:, :1, :]
            cls_tokens = cls_token.expand(x.shape[0], -1, -1)  # batch fix  调整batch
            x = torch.cat((cls_tokens, x), dim=1)

            # apply Transformer Encoders
            for blk in self.blocks:
                x = blk(x)

        else:  # VPT
            x = self.patch_embed(imgs)
            # add pos embed before concatenate the cls token
            x = x + self.pos_embed[:, 1:, :]
            # masking: length -> length * (1-mask_ratio)
            # x_remained: [B, num_patches * (1-mask_ratio), D], sequence of Tokens
            x, mask, ids_restore = self.random_masking(x, mask_ratio)

            # append cls token
            cls_token = self.cls_token + self.pos_embed[:, :1, :]
            cls_tokens = cls_token.expand(x.shape[0], -1, -1)  # batch fix  调整batch
            x = torch.cat((cls_tokens, x), dim=1)

            if self.VPT_type == "Deep":
                Prompt_Token_num = self.Prompt_Tokens.shape[1]
                for i in range(len(self.blocks)):
                    # concatenate Prompt_Tokens
                    Prompt_Tokens = self.Prompt_Tokens[i].unsqueeze(0)
                    # firstly concatenate
                    x = torch.cat((x, Prompt_Tokens.expand(x.shape[0], -1, -1)), dim=1)
                    num_tokens = x.shape[1]
                    # lastly remove, a good trick
                    x = self.blocks[i](x)[:, :num_tokens - Prompt_Token_num]

            else:  # self.VPT_type == "Shallow"
                Prompt_Token_num = self.Prompt_Tokens.shape[1]
                # concatenate Prompt_Tokens
                Prompt_Tokens = self.Prompt_Tokens.expand(x.shape[0], -1, -1)
                x = torch.cat((x, Prompt_Tokens), dim=1)
                num_tokens = x.shape[1]
                # A whole sequential process
                x = self.blocks(x)[:, :num_tokens - Prompt_Token_num]

        # last norm of Transformer
        x = self.norm(x)

        # Encoder output: encoded tokens, mask position, restore idxs
        return x, mask, ids_restore

    def forward_decoder(self, x, ids_restore):
        """
        :param x: [B, 1 + num_patches * (1-mask_ratio), D], sequence of Tokens (including the cls token)
        :param ids_restore: restore idxs for torch.gather(mask, dim=1, index=ids_restore)

        :return: Decoder output: reconstracted tokens
        x: [B, num_patches * (1-mask_ratio), D], sequence of Tokens
        """
        if self.decoder is None:
            # embed tokens: [B, num_encoded_tokens, embed_dim] -> [B, num_encoded_tokens, D_Decoder]
            x = self.decoder_embed(x)  # 更改适合的通道数

            # append mask tokens to sequence as place holder: [B, num_patches + 1 - num_encoded_tokens, D_Decoder]
            # number of mask token need is the requirement to fill the num_patches
            mask_tokens = self.mask_token.repeat(x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1)
            # 这里ids_restore.shape[1] + 1 - x.shape[1] 其实意思是ids_restore.shape[1] - (x.shape[1]-1), 因为不要CLS token

            # -> [B, num_patches, D_Decoder]
            x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1)  # stripe the cls token in Decoder for restore position

            # unshuffle to restore the position of tokens
            x_ = torch.gather(x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2]))
            # torch.gather 按照索引取值构建新tensor: x_ [B,num_patches,D_Decoder] 表示位置还原之后的图,此时数值还不对

            # append back the cls token at the first -> [B,1+num_patches,D_Decoder]
            x = torch.cat([x[:, :1, :], x_], dim=1)

            # add pos embed
            x = x + self.decoder_pos_embed

            # apply Transformer blocks
            for blk in self.decoder_blocks:
                x = blk(x)
            x = self.decoder_norm(x)

            # Reconstruction projection [B, num_patches, D_Decoder] -> [B, num_patches, p*p*3]
            x = self.decoder_pred(x)

            # remove cls token
            x = x[:, 1:, :]

        else:
            # append mask tokens to sequence as place holder: [B, num_patches + 1 - num_encoded_tokens, D]
            # number of mask token need is the requirement to fill the num_patches
            mask_tokens = self.mask_token.repeat(x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1)
            # 这里ids_restore.shape[1] + 1 - x.shape[1] 其实意思是ids_restore.shape[1] - (x.shape[1]-1), 因为不要CLS token

            # -> [B, num_patches, D]
            x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1)  # stripe the cls token in Decoder for restore position

            # unshuffle to restore the position of tokens
            x_ = torch.gather(x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2]))
            # torch.gather 按照索引取值构建新tensor: x_ [B,num_patches,D] 表示位置还原之后的图,此时数值还不对

            # embed tokens: [B, num_encoded_tokens, D_Encoder] -> [B, num_encoded_tokens, D_Decoder]
            x_ = self.decoder_embed(x_)

            # unpatchify to make image form [B, N, Enc] to [B,H,W,C]
            x = self.unpatchify(x_)  # restore image by Encoder

            # apply decoder module to segment the output of encoder
            x = self.decoder(x)  # [B, CLS, H, W]
            # the output of segmentation is transformed to  [B, N, Dec]
            x = self.patchify_decoder(x)  # TODO 做一个有意义的设计

            # Convert the number of channels to match image for loss function
            x = self.decoder_pred(x)  # [B, N, Dec] -> [B, N, p*p*3]

        return x

    def forward_loss(self, imgs, pred, mask):  # 通过把loss放到model里面,把model变成了一个训练框架
        """
        MSE loss for all patches towards the ori image

        Input:
        imgs: [B, 3, H, W], Encoder input image
        pred: [B, num_patches, p*p*3], Decoder reconstructed image
        mask: [B, num_patches], 0 is keep, 1 is remove,

        """
        target = self.patchify(imgs)

        if self.norm_pix_loss:  # 把target image patches 标准化
            mean = target.mean(dim=-1, keepdim=True)
            var = target.var(dim=-1, keepdim=True)
            target = (target - mean) / (var + 1.e-6) ** .5

        # MSE loss
        loss = (pred - target) ** 2
        loss = loss.mean(dim=-1)  # [N, L], mean loss per patch

        # binary mask, 1 for removed patches
        loss = (loss * mask).sum() / mask.sum()  # mean loss on removed patches
        return loss

    def forward(self, imgs, mask_ratio=0.75):
        # Encoder to obtain latent tokens
        latent, mask, ids_restore = self.forward_encoder(imgs, mask_ratio)
        # Decoder to obtain Reconstructed image patches
        pred = self.forward_decoder(latent, ids_restore)  # [N, L, p*p*3]
        # MSE loss for all patches towards the ori image
        loss = self.forward_loss(imgs, pred, mask)
        # print(loss)  # todo 这里原文是为了关注loss爆炸, 可能有坑
        return loss, pred, mask


def mae_vit_base_patch16_dec512d8b(dec_idx=None, **kwargs):
    print("Decoder:", dec_idx)

    model = MaskedAutoencoderViT(
        patch_size=16, embed_dim=768, depth=12, num_heads=12,
        decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16,
        mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
    return model


def mae_vit_large_patch16_dec512d8b(dec_idx=None, **kwargs):
    print("Decoder:", dec_idx)

    model = MaskedAutoencoderViT(
        patch_size=16, embed_dim=1024, depth=24, num_heads=16,
        decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16,
        mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
    return model


def mae_vit_huge_patch14_dec512d8b(dec_idx=None, **kwargs):
    print("Decoder:", dec_idx)

    model = MaskedAutoencoderViT(
        patch_size=14, embed_dim=1280, depth=32, num_heads=16,
        decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16,
        mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
    return model


def mae_vit_base_patch16_decoder(dec_idx=None, num_classes=3, img_size=224, **kwargs):
    # num_classes做的是one-hot seg但是不是做还原,我们得设计一下如何去做这个还原才能实现预训练

    if dec_idx == 'swin_unet':
        decoder_embed_dim = 768
        decoder_rep_dim = 16 * 16 * 3

        from SSL_structures.Swin_Unet_main.networks.vision_transformer import SwinUnet as ViT_seg
        decoder = ViT_seg(num_classes=num_classes, **kwargs)

    elif dec_idx == 'transunet':
        decoder_embed_dim = 768
        decoder_rep_dim = 16 * 16 * 3

        transunet_name = 'R50-ViT-B_16'
        transunet_patches_size = 16
        from SSL_structures.TransUNet_main.networks.vit_seg_modeling import CONFIGS as CONFIGS_Transunet_seg
        from SSL_structures.TransUNet_main.networks.vit_seg_modeling import VisionTransformer as Transunet_seg

        config_vit = CONFIGS_Transunet_seg[transunet_name]
        config_vit.n_classes = num_classes
        config_vit.n_skip = 3

        if transunet_name.find('R50') != -1:
            config_vit.patches.grid = (
                int(img_size / transunet_patches_size), int(img_size / transunet_patches_size))
        decoder = Transunet_seg(config_vit, num_classes=config_vit.n_classes)

    elif dec_idx == 'UTNetV2':
        decoder_embed_dim = 768
        decoder_rep_dim = 16 * 16 * 3

        from SSL_structures.UtnetV2.utnetv2 import UTNetV2 as UTNetV2_seg
        decoder = UTNetV2_seg(in_chan=3, num_classes=num_classes)

    else:
        print('no effective decoder!')
        return -1

    print('dec_idx: ', dec_idx)

    model = MaskedAutoencoderViT(
        patch_size=16, embed_dim=768, depth=12, num_heads=12,
        decoder_embed_dim=decoder_embed_dim, decoder_depth=8, decoder_num_heads=16,
        mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), decoder_rep_dim=decoder_rep_dim, decoder=decoder,
        **kwargs)
    return model


def mae_vit_large_patch16_decoder(dec_idx=None, num_classes=3, img_size=224, **kwargs):
    # num_classes做的是one-hot seg但是不是做还原,我们得设计一下如何去做这个还原才能实现预训练

    if dec_idx == 'swin_unet':
        decoder_embed_dim = 768
        decoder_rep_dim = 16 * 16 * 3

        from SSL_structures.Swin_Unet_main.networks.vision_transformer import SwinUnet as ViT_seg
        decoder = ViT_seg(num_classes=num_classes, **kwargs)

    elif dec_idx == 'transunet':
        decoder_embed_dim = 768
        decoder_rep_dim = 16 * 16 * 3

        transunet_name = 'R50-ViT-B_16'
        transunet_patches_size = 16
        from SSL_structures.TransUNet_main.networks.vit_seg_modeling import CONFIGS as CONFIGS_Transunet_seg
        from SSL_structures.TransUNet_main.networks.vit_seg_modeling import VisionTransformer as Transunet_seg

        config_vit = CONFIGS_Transunet_seg[transunet_name]
        config_vit.n_classes = num_classes
        config_vit.n_skip = 3

        if transunet_name.find('R50') != -1:
            config_vit.patches.grid = (
                int(img_size / transunet_patches_size), int(img_size / transunet_patches_size))
        decoder = Transunet_seg(config_vit, num_classes=config_vit.n_classes)

    elif dec_idx == 'UTNetV2':
        decoder_embed_dim = 768
        decoder_rep_dim = 16 * 16 * 3

        from SSL_structures.UtnetV2.utnetv2 import UTNetV2 as UTNetV2_seg
        decoder = UTNetV2_seg(in_chan=3, num_classes=num_classes)

    else:
        print('no effective decoder!')
        return -1

    print('dec_idx: ', dec_idx)

    model = MaskedAutoencoderViT(
        patch_size=16, embed_dim=1024, depth=24, num_heads=16,
        decoder_embed_dim=decoder_embed_dim, decoder_depth=8, decoder_num_heads=16,
        mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), decoder_rep_dim=decoder_rep_dim, decoder=decoder,
        **kwargs)
    return model


def mae_vit_huge_patch14_decoder(dec_idx=None, num_classes=3, img_size=224, **kwargs):
    # num_classes做的是one-hot seg但是不是做还原,我们得设计一下如何去做这个还原才能实现预训练

    if dec_idx == 'swin_unet':
        decoder_embed_dim = 588  # 1280  14*14*3
        decoder_rep_dim = 14 * 14 * 3

        from SSL_structures.Swin_Unet_main.networks.vision_transformer import SwinUnet as ViT_seg
        decoder = ViT_seg(num_classes=num_classes, **kwargs)

    elif dec_idx == 'transunet':
        decoder_embed_dim = 768
        decoder_rep_dim = 16 * 16 * 3

        transunet_name = 'R50-ViT-B_16'
        transunet_patches_size = 16
        from SSL_structures.TransUNet_main.networks.vit_seg_modeling import CONFIGS as CONFIGS_Transunet_seg
        from SSL_structures.TransUNet_main.networks.vit_seg_modeling import VisionTransformer as Transunet_seg

        config_vit = CONFIGS_Transunet_seg[transunet_name]
        config_vit.n_classes = num_classes
        config_vit.n_skip = 3

        if transunet_name.find('R50') != -1:
            config_vit.patches.grid = (
                int(img_size / transunet_patches_size), int(img_size / transunet_patches_size))
        decoder = Transunet_seg(config_vit, num_classes=config_vit.n_classes)

    elif dec_idx == 'UTNetV2':
        decoder_embed_dim = 768
        decoder_rep_dim = 14 * 14 * 3

        from SSL_structures.UtnetV2.utnetv2 import UTNetV2 as UTNetV2_seg
        decoder = UTNetV2_seg(in_chan=3, num_classes=num_classes)

    else:
        print('no effective decoder!')
        return -1

    print('dec_idx: ', dec_idx)

    model = MaskedAutoencoderViT(
        patch_size=14, embed_dim=1280, depth=32, num_heads=16,
        decoder_embed_dim=decoder_embed_dim, decoder_depth=8, decoder_num_heads=16,
        mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), decoder_rep_dim=decoder_rep_dim, decoder=decoder,
        **kwargs)
    return model


# set recommended archs
mae_vit_base_patch16 = mae_vit_base_patch16_dec512d8b  # decoder: 512 dim, 8 blocks
mae_vit_large_patch16 = mae_vit_large_patch16_dec512d8b  # decoder: 512 dim, 8 blocks
mae_vit_huge_patch14 = mae_vit_huge_patch14_dec512d8b  # decoder: 512 dim, 8 blocks

# Equiped with decoders
mae_vit_base_patch16_decoder = mae_vit_base_patch16_decoder  # decoder: 768 dim, HYF
mae_vit_large_patch16_decoder = mae_vit_large_patch16_decoder  # decoder: 768 dim, HYF
mae_vit_huge_patch14_decoder = mae_vit_huge_patch14_decoder  # decoder: 768 dim, HYF


if __name__ == '__main__':
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    img_size = 224
    num_classes = 3
    x = torch.rand(8, 3, img_size, img_size, device=device)

    # model = mae_vit_base_patch16(img_size=224, decoder=None)  # decoder_embed_dim=512
    model = mae_vit_base_patch16_decoder(prompt_mode='Deep', Prompt_Token_num=20, basic_state_dict=None,
                                         dec_idx='UTNetV2', img_size=img_size)

    model.to(device)

    loss, pred, mask_patch_indicators = model(x)

    print(loss, '\n')

    print(loss.shape, '\n')

    print(pred.shape, '\n')

    print(mask_patch_indicators.shape, '\n')