Tianyinus commited on
Commit
edcf5ee
·
verified ·
1 Parent(s): af70051

init submit

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +1 -0
  2. PuzzleTuning/Archive/PuzzleTuning_demoset.zip +3 -0
  3. PuzzleTuning/Archive/build_CPIA_dataset.sh +35 -0
  4. PuzzleTuning/Archive/warwick_CLS.zip +3 -0
  5. PuzzleTuning/Backbone/GetPromptModel.py +87 -0
  6. PuzzleTuning/Backbone/ResHybrid.py +481 -0
  7. PuzzleTuning/Backbone/Transformer_blocks.py +1631 -0
  8. PuzzleTuning/Backbone/VPT_structure.py +133 -0
  9. PuzzleTuning/Backbone/attention_modules.py +303 -0
  10. PuzzleTuning/Backbone/counterpart_models/README.md +25 -0
  11. PuzzleTuning/Backbone/counterpart_models/conformer.py +483 -0
  12. PuzzleTuning/Backbone/counterpart_models/crossformer.py +624 -0
  13. PuzzleTuning/Backbone/counterpart_models/crossformer_backbone.py +659 -0
  14. PuzzleTuning/Backbone/getmodel.py +392 -0
  15. PuzzleTuning/Counterpart PreTrain Methods/ReadMe.md +2 -0
  16. PuzzleTuning/Counterpart PreTrain Methods/dino-main/LICENSE +201 -0
  17. PuzzleTuning/Counterpart PreTrain Methods/dino-main/README.md +14 -0
  18. PuzzleTuning/Counterpart PreTrain Methods/dino-main/eval_copy_detection.py +301 -0
  19. PuzzleTuning/Counterpart PreTrain Methods/dino-main/eval_image_retrieval.py +201 -0
  20. PuzzleTuning/Counterpart PreTrain Methods/dino-main/eval_knn.py +242 -0
  21. PuzzleTuning/Counterpart PreTrain Methods/dino-main/eval_linear.py +281 -0
  22. PuzzleTuning/Counterpart PreTrain Methods/dino-main/eval_video_segmentation.py +292 -0
  23. PuzzleTuning/Counterpart PreTrain Methods/dino-main/hubconf.py +151 -0
  24. PuzzleTuning/Counterpart PreTrain Methods/dino-main/main_dino.py +515 -0
  25. PuzzleTuning/Counterpart PreTrain Methods/dino-main/run_with_submitit.py +132 -0
  26. PuzzleTuning/Counterpart PreTrain Methods/dino-main/utils.py +829 -0
  27. PuzzleTuning/Counterpart PreTrain Methods/dino-main/video_generation.py +378 -0
  28. PuzzleTuning/Counterpart PreTrain Methods/dino-main/vision_transformer.py +291 -0
  29. PuzzleTuning/Counterpart PreTrain Methods/dino-main/visualize_attention.py +213 -0
  30. PuzzleTuning/Counterpart PreTrain Methods/gcmae/FINETUNE.md +13 -0
  31. PuzzleTuning/Counterpart PreTrain Methods/gcmae/GCMAE.png +0 -0
  32. PuzzleTuning/Counterpart PreTrain Methods/gcmae/LICENSE +399 -0
  33. PuzzleTuning/Counterpart PreTrain Methods/gcmae/LINPROBE.md +15 -0
  34. PuzzleTuning/Counterpart PreTrain Methods/gcmae/PRETRAIN.md +21 -0
  35. PuzzleTuning/Counterpart PreTrain Methods/gcmae/README.md +15 -0
  36. PuzzleTuning/Counterpart PreTrain Methods/gcmae/README_origin.md +39 -0
  37. PuzzleTuning/Counterpart PreTrain Methods/gcmae/VISUAL.md +10 -0
  38. PuzzleTuning/Counterpart PreTrain Methods/gcmae/engine_finetune.py +183 -0
  39. PuzzleTuning/Counterpart PreTrain Methods/gcmae/engine_pretrain.py +99 -0
  40. PuzzleTuning/Counterpart PreTrain Methods/gcmae/lib/NCEAverage.py +94 -0
  41. PuzzleTuning/Counterpart PreTrain Methods/gcmae/lib/NCECriterion.py +38 -0
  42. PuzzleTuning/Counterpart PreTrain Methods/gcmae/lib/__init__.py +1 -0
  43. PuzzleTuning/Counterpart PreTrain Methods/gcmae/lib/__pycache__/NCEAverage.cpython-38.pyc +0 -0
  44. PuzzleTuning/Counterpart PreTrain Methods/gcmae/lib/__pycache__/NCECriterion.cpython-38.pyc +0 -0
  45. PuzzleTuning/Counterpart PreTrain Methods/gcmae/lib/__pycache__/__init__.cpython-38.pyc +0 -0
  46. PuzzleTuning/Counterpart PreTrain Methods/gcmae/lib/__pycache__/alias_multinomial.cpython-38.pyc +0 -0
  47. PuzzleTuning/Counterpart PreTrain Methods/gcmae/lib/__pycache__/normalize.cpython-38.pyc +0 -0
  48. PuzzleTuning/Counterpart PreTrain Methods/gcmae/lib/__pycache__/utils.cpython-38.pyc +0 -0
  49. PuzzleTuning/Counterpart PreTrain Methods/gcmae/lib/alias_multinomial.py +64 -0
  50. PuzzleTuning/Counterpart PreTrain Methods/gcmae/lib/normalize.py +14 -0
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ PuzzleTuning/Counterpart[[:space:]]PreTrain[[:space:]]Methods/gcmae/nohup.out filter=lfs diff=lfs merge=lfs -text
PuzzleTuning/Archive/PuzzleTuning_demoset.zip ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:363cbdc5f8e944f99542a5727d4b2457dd79bbfca7eb70d76857de4b8be92858
3
+ size 4987321
PuzzleTuning/Archive/build_CPIA_dataset.sh ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/sh
2
+ # go to the dataset location
3
+
4
+ # altering the zip files
5
+
6
+ zip -F L.zip --out L_Scale.zip
7
+ zip -FF L_Scale.zip --out L.zip -fz
8
+ zip -F M.zip --out M_Scale.zip
9
+ zip -FF M_Scale.zip --out M.zip -fz
10
+
11
+ rm -f L_Scale.zip
12
+ rm -f L.z01
13
+ rm -f M_Scale.zip
14
+ rm -f M.z01
15
+ rm -f M.z02
16
+
17
+ # build a directory of datasets
18
+ mkdir datasets
19
+ mv L.zip datasets
20
+ mv M.zip datasets
21
+ mv S.zip datasets
22
+
23
+ cd datasets
24
+ unzip L.zip
25
+ unzip M.zip
26
+ unzip S.zip
27
+
28
+ rm -f L.zip
29
+ rm -f M.zip
30
+ rm -f S.zip
31
+
32
+ mkdir All
33
+ cp -r L/* All/ &
34
+ cp -r M/* All/ &
35
+ cp -r S/* All/
PuzzleTuning/Archive/warwick_CLS.zip ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9813cdc86e2420476e06638965df6040cec0197b148be786f0bf88b020e445f6
3
+ size 15098031
PuzzleTuning/Backbone/GetPromptModel.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ build_promptmodel Script ver: Oct 17th 14:20
3
+
4
+ """
5
+
6
+ try:
7
+ from Backbone.VPT_structure import *
8
+ except:
9
+ from Backbone.VPT_structure import *
10
+
11
+
12
+ def build_promptmodel(num_classes=1000, edge_size=224, model_idx='ViT', patch_size=16,
13
+ Prompt_Token_num=20, VPT_type="Deep", prompt_state_dict=None, base_state_dict='timm'):
14
+ """
15
+ following the https://github.com/sagizty/VPT
16
+ this build the VPT (prompt version of ViT), with additional prompt tokens,
17
+ each layer the information become [B, N_patch + N_prompt, Dim]
18
+
19
+ During training only the prompt tokens and the head layer are
20
+ set to be learnable while the rest of Transformer layers are frozen
21
+
22
+ # VPT_type = "Shallow" / "Deep"
23
+ - Shallow: concatenate N_prompt of prompt tokens before the first Transformer Encoder block,
24
+ each layer the information become [B, N_patch + N_prompt, Dim]
25
+ - Deep: concatenate N_prompt of prompt tokens to each Transformer Encoder block,
26
+ this will replace the output prompt tokens learnt form previous encoder.
27
+ """
28
+
29
+ if model_idx[0:3] == 'ViT':
30
+
31
+ if base_state_dict is None:
32
+ basic_state_dict = None
33
+
34
+ elif type(base_state_dict) == str:
35
+ if base_state_dict == 'timm':
36
+ # ViT_Prompt
37
+ import timm
38
+ # from pprint import pprint
39
+ # model_names = timm.list_models('*vit*')
40
+ # pprint(model_names)
41
+
42
+ basic_model = timm.create_model('vit_base_patch' + str(patch_size) + '_' + str(edge_size),
43
+ pretrained=True)
44
+ basic_state_dict = basic_model.state_dict()
45
+ print('in prompt model building, timm ViT loaded for base_state_dict')
46
+
47
+ else:
48
+ basic_state_dict = None
49
+ print('in prompt model building, no vaild str for base_state_dict')
50
+
51
+ else: # state dict: collections.OrderedDict
52
+ basic_state_dict = base_state_dict
53
+ print('in prompt model building, a .pth base_state_dict loaded')
54
+
55
+ model = VPT_ViT(img_size=edge_size, patch_size=patch_size, Prompt_Token_num=Prompt_Token_num,
56
+ VPT_type=VPT_type, basic_state_dict=basic_state_dict)
57
+
58
+ model.New_CLS_head(num_classes)
59
+
60
+ if prompt_state_dict is not None:
61
+ try:
62
+ model.load_prompt(prompt_state_dict)
63
+ except:
64
+ print('erro in .pth prompt_state_dict')
65
+ else:
66
+ print('in prompt model building, a .pth prompt_state_dict loaded')
67
+
68
+ model.Freeze()
69
+ else:
70
+ print("The model is not difined in the Prompt script!!")
71
+ return -1
72
+
73
+ try:
74
+ img = torch.randn(1, 3, edge_size, edge_size)
75
+ preds = model(img) # (1, class_number)
76
+ print('Build VPT model with in/out shape: ', img.shape, ' -> ', preds.shape)
77
+
78
+ except:
79
+ print("Problem exist in the model defining process!!")
80
+ return -1
81
+ else:
82
+ print('model is ready now!')
83
+ return model
84
+
85
+
86
+ if __name__ == '__main__':
87
+ model = build_promptmodel(prompt_state_dict=None, base_state_dict='timm', num_classes=0)
PuzzleTuning/Backbone/ResHybrid.py ADDED
@@ -0,0 +1,481 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ MSHT
3
+
4
+ Models ver: OCT 27th 20:00 official release
5
+
6
+ by the authors, check our github page:
7
+ https://github.com/sagizty/Multi-Stage-Hybrid-Transformer
8
+
9
+
10
+ ResNet stages' feature map
11
+
12
+ # input = 3, 384, 384
13
+ torch.Size([1, 256, 96, 96])
14
+ torch.Size([1, 512, 48, 48])
15
+ torch.Size([1, 1024, 24, 24])
16
+ torch.Size([1, 2048, 12, 12])
17
+ torch.Size([1, 1000])
18
+
19
+ # input = 3, 224, 224
20
+ torch.Size([1, 256, 56, 56])
21
+ torch.Size([1, 512, 28, 28])
22
+ torch.Size([1, 1024, 14, 14])
23
+ torch.Size([1, 2048, 7, 7])
24
+ torch.Size([1, 1000])
25
+
26
+ ref
27
+ https://note.youdao.com/ynoteshare1/index.html?id=5a7dbe1a71713c317062ddeedd97d98e&type=note
28
+ """
29
+ import torch
30
+ from torch import nn
31
+ from functools import partial
32
+ from torchsummary import summary
33
+ import os
34
+ from Backbone import Transformer_blocks
35
+
36
+
37
+ # ResNet Bottleneck_block_constructor
38
+ class Bottleneck_block_constructor(nn.Module):
39
+
40
+ extention = 4
41
+
42
+ # 定义初始化的网络和参数
43
+ def __init__(self, inplane, midplane, stride, downsample=None):
44
+ super(Bottleneck_block_constructor, self).__init__()
45
+
46
+ outplane = midplane * self.extention
47
+
48
+ self.conv1 = nn.Conv2d(inplane, midplane, kernel_size=1, stride=stride, bias=False)
49
+ self.bn1 = nn.BatchNorm2d(midplane)
50
+
51
+ self.conv2 = nn.Conv2d(midplane, midplane, kernel_size=3, stride=1, padding=1, bias=False)
52
+ self.bn2 = nn.BatchNorm2d(midplane)
53
+
54
+ self.conv3 = nn.Conv2d(midplane, outplane, kernel_size=1, stride=1, bias=False)
55
+ self.bn3 = nn.BatchNorm2d(midplane * self.extention)
56
+
57
+ self.relu = nn.ReLU(inplace=False)
58
+
59
+ self.downsample = downsample
60
+ self.stride = stride
61
+
62
+ def forward(self, x):
63
+
64
+ out = self.relu(self.bn1(self.conv1(x)))
65
+ out = self.relu(self.bn2(self.conv2(out)))
66
+ out = self.relu(self.bn3(self.conv3(out)))
67
+
68
+ if self.downsample is not None:
69
+ residual = self.downsample(x)
70
+ else:
71
+ residual = x
72
+
73
+ out += residual
74
+
75
+ out = self.relu(out)
76
+
77
+ return out
78
+
79
+
80
+ # Hybrid_backbone of ResNets
81
+ class Hybrid_backbone_4(nn.Module):
82
+
83
+ def __init__(self, block_constructor, bottleneck_channels_setting=None, identity_layers_setting=None,
84
+ stage_stride_setting=None, fc_num_classes=None, feature_idx=None):
85
+
86
+ if bottleneck_channels_setting is None:
87
+ bottleneck_channels_setting = [64, 128, 256, 512]
88
+ if identity_layers_setting is None:
89
+ identity_layers_setting = [3, 4, 6, 3]
90
+ if stage_stride_setting is None:
91
+ stage_stride_setting = [1, 2, 2, 2]
92
+
93
+ self.inplane = 64
94
+ self.fc_num_classes = fc_num_classes
95
+ self.feature_idx = feature_idx
96
+
97
+ super(Hybrid_backbone_4, self).__init__()
98
+
99
+ self.block_constructor = block_constructor # Bottleneck_block_constructor
100
+ self.bcs = bottleneck_channels_setting # [64, 128, 256, 512]
101
+ self.ils = identity_layers_setting # [3, 4, 6, 3]
102
+ self.sss = stage_stride_setting # [1, 2, 2, 2]
103
+
104
+ # stem
105
+ # alter the RGB pic chanel to match inplane
106
+ self.conv1 = nn.Conv2d(3, self.inplane, kernel_size=7, stride=2, padding=3, bias=False)
107
+ self.bn1 = nn.BatchNorm2d(self.inplane)
108
+ self.relu = nn.ReLU()
109
+ self.maxpool = nn.MaxPool2d(kernel_size=3, padding=1, stride=2)
110
+
111
+ # ResNet stages
112
+ self.layer1 = self.make_stage_layer(self.block_constructor, self.bcs[0], self.ils[0], self.sss[0])
113
+ self.layer2 = self.make_stage_layer(self.block_constructor, self.bcs[1], self.ils[1], self.sss[1])
114
+ self.layer3 = self.make_stage_layer(self.block_constructor, self.bcs[2], self.ils[2], self.sss[2])
115
+ self.layer4 = self.make_stage_layer(self.block_constructor, self.bcs[3], self.ils[3], self.sss[3])
116
+
117
+ # cls head
118
+ if self.fc_num_classes is not None:
119
+ self.avgpool = nn.AvgPool2d(7)
120
+ self.fc = nn.Linear(512 * self.block_constructor.extention, fc_num_classes)
121
+
122
+ def forward(self, x):
123
+
124
+ # stem
125
+ out = self.conv1(x)
126
+ out = self.bn1(out)
127
+ out = self.relu(out)
128
+ stem_out = self.maxpool(out)
129
+
130
+ # Resnet block of 4 stages
131
+ stage1_out = self.layer1(stem_out)
132
+ stage2_out = self.layer2(stage1_out)
133
+ stage3_out = self.layer3(stage2_out)
134
+ stage4_out = self.layer4(stage3_out)
135
+
136
+ if self.fc_num_classes is not None:
137
+ # connect to cls head mlp if asked
138
+ fc_out = self.avgpool(stage4_out)
139
+ fc_out = torch.flatten(fc_out, 1)
140
+ fc_out = self.fc(fc_out)
141
+
142
+ # get what we need for different usage
143
+ if self.feature_idx == 'stages':
144
+ if self.fc_num_classes is not None:
145
+ return stage1_out, stage2_out, stage3_out, stage4_out, fc_out
146
+ else:
147
+ return stage1_out, stage2_out, stage3_out, stage4_out
148
+ elif self.feature_idx == 'features':
149
+ if self.fc_num_classes is not None:
150
+ return stem_out, stage1_out, stage2_out, stage3_out, stage4_out, fc_out
151
+ else:
152
+ return stem_out, stage1_out, stage2_out, stage3_out, stage4_out
153
+ else: # self.feature_idx is None
154
+ if self.fc_num_classes is not None:
155
+ return fc_out
156
+ else:
157
+ return stage4_out
158
+
159
+ def make_stage_layer(self, block_constractor, midplane, block_num, stride=1):
160
+ """
161
+ block:
162
+ midplane:usually = output chanel/4
163
+ block_num:
164
+ stride:stride of the ResNet Conv Block
165
+ """
166
+
167
+ block_list = []
168
+
169
+ outplane = midplane * block_constractor.extention # extention
170
+
171
+ if stride != 1 or self.inplane != outplane:
172
+ downsample = nn.Sequential(
173
+ nn.Conv2d(self.inplane, outplane, stride=stride, kernel_size=1, bias=False),
174
+ nn.BatchNorm2d(midplane * block_constractor.extention)
175
+ )
176
+ else:
177
+ downsample = None
178
+
179
+ # Conv Block
180
+ conv_block = block_constractor(self.inplane, midplane, stride=stride, downsample=downsample)
181
+ block_list.append(conv_block)
182
+
183
+ self.inplane = outplane # update inplane for the next stage
184
+
185
+ # Identity Block
186
+ for i in range(1, block_num):
187
+ block_list.append(block_constractor(self.inplane, midplane, stride=1, downsample=None))
188
+
189
+ return nn.Sequential(*block_list) # stack blocks
190
+
191
+
192
+ class Hybrid_backbone_3(nn.Module): # 3 stages version
193
+
194
+ def __init__(self, block_constructor, bottleneck_channels_setting=None, identity_layers_setting=None,
195
+ stage_stride_setting=None, fc_num_classes=None, feature_idx=None):
196
+
197
+ if bottleneck_channels_setting is None:
198
+ bottleneck_channels_setting = [64, 128, 256]
199
+ if identity_layers_setting is None:
200
+ identity_layers_setting = [3, 4, 6]
201
+ if stage_stride_setting is None:
202
+ stage_stride_setting = [1, 2, 2]
203
+
204
+ self.inplane = 64
205
+ self.fc_num_classes = fc_num_classes
206
+ self.feature_idx = feature_idx
207
+
208
+ super(Hybrid_backbone_3, self).__init__()
209
+
210
+ self.block_constructor = block_constructor # Bottleneck_block_constructor
211
+ self.bcs = bottleneck_channels_setting # [64, 128, 256]
212
+ self.ils = identity_layers_setting # [3, 4, 6]
213
+ self.sss = stage_stride_setting # [1, 2, 2]
214
+
215
+ # stem
216
+ self.conv1 = nn.Conv2d(3, self.inplane, kernel_size=7, stride=2, padding=3, bias=False)
217
+ self.bn1 = nn.BatchNorm2d(self.inplane)
218
+ self.relu = nn.ReLU()
219
+ self.maxpool = nn.MaxPool2d(kernel_size=3, padding=1, stride=2)
220
+
221
+ # ResNet 3 stages
222
+ self.layer1 = self.make_stage_layer(self.block_constructor, self.bcs[0], self.ils[0], self.sss[0])
223
+ self.layer2 = self.make_stage_layer(self.block_constructor, self.bcs[1], self.ils[1], self.sss[1])
224
+ self.layer3 = self.make_stage_layer(self.block_constructor, self.bcs[2], self.ils[2], self.sss[2])
225
+
226
+ if self.fc_num_classes is not None:
227
+ self.avgpool = nn.AvgPool2d(24) # 224-14 384-24
228
+ self.fc = nn.Linear(self.bcs[-1] * self.block_constructor.extention, fc_num_classes)
229
+
230
+ def forward(self, x):
231
+ # stem:conv+bn+relu+maxpool
232
+ out = self.conv1(x)
233
+ out = self.bn1(out)
234
+ out = self.relu(out)
235
+ stem_out = self.maxpool(out)
236
+
237
+ # Resnet 3 stages
238
+ stage1_out = self.layer1(stem_out)
239
+ stage2_out = self.layer2(stage1_out)
240
+ stage3_out = self.layer3(stage2_out)
241
+
242
+ if self.fc_num_classes is not None:
243
+ fc_out = self.avgpool(stage3_out)
244
+ fc_out = torch.flatten(fc_out, 1)
245
+ fc_out = self.fc(fc_out)
246
+
247
+ if self.feature_idx == 'stages':
248
+ if self.fc_num_classes is not None:
249
+ return stage1_out, stage2_out, stage3_out, fc_out
250
+ else:
251
+ return stage1_out, stage2_out, stage3_out
252
+ elif self.feature_idx == 'features':
253
+ if self.fc_num_classes is not None:
254
+ return stem_out, stage1_out, stage2_out, stage3_out, fc_out
255
+ else:
256
+ return stem_out, stage1_out, stage2_out, stage3_out
257
+ else: # self.feature_idx is None
258
+ if self.fc_num_classes is not None:
259
+ return fc_out
260
+ else:
261
+ return stage3_out
262
+
263
+ def make_stage_layer(self, block_constractor, midplane, block_num, stride=1):
264
+ """
265
+ block:
266
+ midplane:
267
+ block_num:
268
+ stride:
269
+ """
270
+
271
+ block_list = []
272
+
273
+ outplane = midplane * block_constractor.extention # extention
274
+
275
+ if stride != 1 or self.inplane != outplane:
276
+ downsample = nn.Sequential(
277
+ nn.Conv2d(self.inplane, outplane, stride=stride, kernel_size=1, bias=False),
278
+ nn.BatchNorm2d(midplane * block_constractor.extention)
279
+ )
280
+ else:
281
+ downsample = None
282
+
283
+ # Conv Block
284
+ conv_block = block_constractor(self.inplane, midplane, stride=stride, downsample=downsample)
285
+ block_list.append(conv_block)
286
+
287
+ self.inplane = outplane
288
+
289
+ # Identity Block
290
+ for i in range(1, block_num):
291
+ block_list.append(block_constractor(self.inplane, midplane, stride=1, downsample=None))
292
+
293
+ return nn.Sequential(*block_list)
294
+
295
+
296
+ def Hybrid_a(backbone, img_size=224, patch_size=1, in_chans=3, num_classes=1000, embed_dim=768, depth=8,
297
+ num_heads=12, mlp_ratio=4., qkv_bias=True, representation_size=None, drop_rate=0., attn_drop_rate=0.,
298
+ drop_path_rate=0., norm_layer=None, act_layer=None):
299
+ # directly stack CNNs and Transformer blocks
300
+ embed_layer = partial(Transformer_blocks.Hybrid_feature_map_Embed, backbone=backbone)
301
+
302
+ Hybrid_model = Transformer_blocks.VisionTransformer(img_size, patch_size, in_chans, num_classes, embed_dim, depth,
303
+ num_heads, mlp_ratio, qkv_bias, representation_size,
304
+ drop_rate, attn_drop_rate, drop_path_rate, embed_layer,
305
+ norm_layer, act_layer)
306
+
307
+ return Hybrid_model
308
+
309
+
310
+ def create_model(model_idx, edge_size, pretrained=True, num_classes=2, drop_rate=0., attn_drop_rate=0.,
311
+ drop_path_rate=0., use_cls_token=True, use_pos_embedding=True, use_att_module='SimAM'):
312
+ """
313
+ get one of MSHT models
314
+
315
+ :param model_idx: the model we are going to use. by the format of Model_size_other_info
316
+ :param edge_size: the input edge size of the dataloder
317
+ :param pretrained: The backbone CNN is initiate randomly or by its official Pretrained models
318
+ :param num_classes: classification required number of your dataset
319
+
320
+ :param drop_rate: The dropout layer's probility of proposed models
321
+ :param attn_drop_rate: The dropout layer(right after the MHSA block or MHGA block)'s probility of proposed models
322
+ :param drop_path_rate: The probility of stochastic depth
323
+
324
+ :param use_cls_token: To use the class token
325
+ :param use_pos_embedding: To use the positional enbedding
326
+ :param use_att_module: To use which attention module in the FGD Focus block
327
+ # use_att_module in ['SimAM', 'CBAM', 'SE'] different attention module we applied in the ablation study
328
+
329
+ :return: prepared model
330
+ """
331
+
332
+ if pretrained:
333
+ from torchvision import models
334
+ backbone_weights = models.resnet50(pretrained=True).state_dict()
335
+ # True for pretrained Resnet50 model, False will randomly initiate
336
+ else:
337
+ backbone_weights = None
338
+
339
+ if model_idx[0:11] == 'Hybrid1_224' and edge_size == 224: # ablation study: no focus depth=8 edge_size == 224
340
+ backbone = Hybrid_backbone_4(block_constructor=Bottleneck_block_constructor,
341
+ bottleneck_channels_setting=[64, 128, 256, 512],
342
+ identity_layers_setting=[3, 4, 6, 3],
343
+ stage_stride_setting=[1, 2, 2, 2],
344
+ fc_num_classes=None,
345
+ feature_idx=None)
346
+
347
+ if pretrained:
348
+ try:
349
+ backbone.load_state_dict(backbone_weights, False)
350
+ except:
351
+ print("backbone not loaded")
352
+ else:
353
+ print("backbone loaded")
354
+
355
+ model = Hybrid_a(backbone, img_size=edge_size, patch_size=1, in_chans=3, num_classes=num_classes, embed_dim=768,
356
+ depth=8, num_heads=12, mlp_ratio=4., qkv_bias=True, representation_size=None,
357
+ drop_rate=drop_rate, attn_drop_rate=attn_drop_rate, drop_path_rate=drop_path_rate,
358
+ norm_layer=None, act_layer=None)
359
+
360
+ elif model_idx[0:11] == 'Hybrid1_384' and edge_size == 384: # ablation study: no focus depth=8 edge_size == 384
361
+ backbone = Hybrid_backbone_4(block_constructor=Bottleneck_block_constructor,
362
+ bottleneck_channels_setting=[64, 128, 256, 512],
363
+ identity_layers_setting=[3, 4, 6, 3],
364
+ stage_stride_setting=[1, 2, 2, 2],
365
+ fc_num_classes=None,
366
+ feature_idx=None)
367
+
368
+ if pretrained:
369
+ try:
370
+ backbone.load_state_dict(backbone_weights, False)
371
+ except:
372
+ print("backbone not loaded")
373
+ else:
374
+ print("backbone loaded")
375
+
376
+ model = Hybrid_a(backbone, img_size=edge_size, patch_size=1, in_chans=3, num_classes=num_classes, embed_dim=768,
377
+ depth=8, num_heads=12, mlp_ratio=4., qkv_bias=True, representation_size=None,
378
+ drop_rate=drop_rate, attn_drop_rate=attn_drop_rate, drop_path_rate=drop_path_rate,
379
+ norm_layer=None, act_layer=None)
380
+
381
+ elif model_idx[0:11] == 'Hybrid2_224' and edge_size == 224: # Proposed model ablation study: edge_size==224
382
+ backbone = Hybrid_backbone_4(block_constructor=Bottleneck_block_constructor,
383
+ bottleneck_channels_setting=[64, 128, 256, 512],
384
+ identity_layers_setting=[3, 4, 6, 3],
385
+ stage_stride_setting=[1, 2, 2, 2],
386
+ fc_num_classes=None,
387
+ feature_idx='stages')
388
+ if pretrained:
389
+ try:
390
+ backbone.load_state_dict(backbone_weights, False)
391
+ except:
392
+ print("backbone not loaded")
393
+ else:
394
+ print("backbone loaded")
395
+
396
+ model = Transformer_blocks.Stage_wise_hybrid_Transformer(backbone, num_classes=num_classes,
397
+ drop_rate=drop_rate, attn_drop_rate=attn_drop_rate,
398
+ drop_path_rate=drop_path_rate,
399
+ use_cls_token=use_cls_token,
400
+ use_pos_embedding=use_pos_embedding,
401
+ use_att_module=use_att_module,
402
+ stage_size=(56, 28, 14, 7),
403
+ stage_dim=[256, 512, 1024, 2048])
404
+
405
+ elif model_idx[0:11] == 'Hybrid2_384' and edge_size == 384: # Proposed model 384 !!!
406
+ backbone = Hybrid_backbone_4(block_constructor=Bottleneck_block_constructor,
407
+ bottleneck_channels_setting=[64, 128, 256, 512],
408
+ identity_layers_setting=[3, 4, 6, 3],
409
+ stage_stride_setting=[1, 2, 2, 2],
410
+ fc_num_classes=None,
411
+ feature_idx='stages')
412
+ if pretrained:
413
+ try:
414
+ backbone.load_state_dict(backbone_weights, False)
415
+ except:
416
+ print("backbone not loaded")
417
+ else:
418
+ print("backbone loaded")
419
+
420
+ model = Transformer_blocks.Stage_wise_hybrid_Transformer(backbone, num_classes=num_classes,
421
+ drop_rate=drop_rate, attn_drop_rate=attn_drop_rate,
422
+ drop_path_rate=drop_path_rate,
423
+ use_cls_token=use_cls_token,
424
+ use_pos_embedding=use_pos_embedding,
425
+ use_att_module=use_att_module,
426
+ stage_size=(96, 48, 24, 12),
427
+ stage_dim=[256, 512, 1024, 2048])
428
+
429
+ elif model_idx[0:11] == 'Hybrid3_224' and edge_size == 224: # Proposed model ablation study: edge_size==224
430
+ backbone = Hybrid_backbone_3(block_constructor=Bottleneck_block_constructor,
431
+ bottleneck_channels_setting=[64, 128, 256],
432
+ identity_layers_setting=[3, 4, 6],
433
+ stage_stride_setting=[1, 2, 2],
434
+ fc_num_classes=None,
435
+ feature_idx='stages')
436
+ if pretrained:
437
+ try:
438
+ backbone.load_state_dict(backbone_weights, False)
439
+ except:
440
+ print("backbone not loaded")
441
+ else:
442
+ print("backbone loaded")
443
+
444
+ model = Transformer_blocks.Stage_wise_hybrid_Transformer(backbone, num_classes=num_classes,
445
+ drop_rate=drop_rate, attn_drop_rate=attn_drop_rate,
446
+ drop_path_rate=drop_path_rate,
447
+ use_cls_token=use_cls_token,
448
+ use_pos_embedding=use_pos_embedding,
449
+ use_att_module=use_att_module,
450
+ stage_size=(56, 28, 14),
451
+ stage_dim=[256, 512, 1024])
452
+
453
+ elif model_idx[0:11] == 'Hybrid3_384' and edge_size == 384: # Proposed model 384 !!!
454
+ backbone = Hybrid_backbone_3(block_constructor=Bottleneck_block_constructor,
455
+ bottleneck_channels_setting=[64, 128, 256],
456
+ identity_layers_setting=[3, 4, 6],
457
+ stage_stride_setting=[1, 2, 2],
458
+ fc_num_classes=None,
459
+ feature_idx='stages')
460
+ if pretrained:
461
+ try:
462
+ backbone.load_state_dict(backbone_weights, False)
463
+ except:
464
+ print("backbone not loaded")
465
+ else:
466
+ print("backbone loaded")
467
+
468
+ model = Transformer_blocks.Stage_wise_hybrid_Transformer(backbone, num_classes=num_classes,
469
+ drop_rate=drop_rate, attn_drop_rate=attn_drop_rate,
470
+ drop_path_rate=drop_path_rate,
471
+ use_cls_token=use_cls_token,
472
+ use_pos_embedding=use_pos_embedding,
473
+ use_att_module=use_att_module,
474
+ stage_size=(96, 48, 24),
475
+ stage_dim=[256, 512, 1024])
476
+
477
+ else:
478
+ print('not a valid hybrid model')
479
+ return -1
480
+
481
+ return model
PuzzleTuning/Backbone/Transformer_blocks.py ADDED
@@ -0,0 +1,1631 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Transformer blocks script ver: OCT 28th 15:00
3
+
4
+ bug fix: 'Cross-attn' name is used in MHGA for compareability
5
+
6
+ by the authors, check our github page:
7
+ https://github.com/sagizty/Multi-Stage-Hybrid-Transformer
8
+
9
+ based on:timm
10
+ https://www.freeaihub.com/post/94067.html
11
+
12
+ """
13
+
14
+ import math
15
+ import logging
16
+ from functools import partial
17
+ from collections import OrderedDict
18
+
19
+ import torch
20
+ import torch.nn as nn
21
+ import torch.nn.functional as F
22
+
23
+ import numpy as np
24
+
25
+ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
26
+
27
+ from timm.models.layers import StdConv2dSame, DropPath, to_2tuple, trunc_normal_
28
+
29
+ from .attention_modules import simam_module, cbam_module, se_module
30
+
31
+
32
+ class FFN(nn.Module): # Mlp from timm
33
+ """
34
+ FFN (from timm)
35
+
36
+ :param in_features:
37
+ :param hidden_features:
38
+ :param out_features:
39
+ :param act_layer:
40
+ :param drop:
41
+ """
42
+
43
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
44
+ super().__init__()
45
+
46
+ out_features = out_features or in_features
47
+ hidden_features = hidden_features or in_features
48
+
49
+ self.fc1 = nn.Linear(in_features, hidden_features)
50
+ self.act = act_layer()
51
+
52
+ self.fc2 = nn.Linear(hidden_features, out_features)
53
+ self.drop = nn.Dropout(drop)
54
+
55
+ def forward(self, x):
56
+ x = self.fc1(x)
57
+ x = self.act(x)
58
+ x = self.drop(x)
59
+
60
+ x = self.fc2(x)
61
+ x = self.drop(x)
62
+
63
+ return x
64
+
65
+
66
+ class Attention(nn.Module): # qkv Transform + MSA(MHSA) (Attention from timm)
67
+ """
68
+ qkv Transform + MSA(MHSA) (from timm)
69
+
70
+ # input x.shape = batch, patch_number, patch_dim
71
+ # output x.shape = batch, patch_number, patch_dim
72
+
73
+ :param dim: dim=CNN feature dim, because the patch size is 1x1
74
+ :param num_heads:
75
+ :param qkv_bias:
76
+ :param qk_scale: by default head_dim ** -0.5 (squre root)
77
+ :param attn_drop: dropout rate after MHSA
78
+ :param proj_drop:
79
+
80
+ """
81
+
82
+ def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
83
+ super().__init__()
84
+ self.num_heads = num_heads
85
+ head_dim = dim // num_heads
86
+ # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
87
+ self.scale = qk_scale or head_dim ** -0.5
88
+
89
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
90
+
91
+ self.attn_drop = nn.Dropout(attn_drop)
92
+ self.proj = nn.Linear(dim, dim)
93
+ self.proj_drop = nn.Dropout(proj_drop)
94
+
95
+ def forward(self, x):
96
+ # input x.shape = batch, patch_number, patch_dim
97
+ batch, patch_number, patch_dim = x.shape
98
+
99
+ # mlp transform + head split [N, P, D] -> [N, P, 3D] -> [N, P, 3, H, D/H] -> [3, N, H, P, D/H]
100
+ qkv = self.qkv(x).reshape(batch, patch_number, 3, self.num_heads, patch_dim //
101
+ self.num_heads).permute(2, 0, 3, 1, 4)
102
+ # 3 [N, H, P, D/H]
103
+ q, k, v = qkv[0], qkv[1], qkv[2]
104
+
105
+ # [N, H, P, D/H] -> [N, H, P, D/H]
106
+ attn = (q @ k.transpose(-2, -1)) * self.scale
107
+ attn = attn.softmax(dim=-1)
108
+
109
+ attn = self.attn_drop(attn) # Dropout
110
+
111
+ # head fusion [N, H, P, D/H] -> [N, P, H, D/H] -> [N, P, D]
112
+ x = (attn @ v).transpose(1, 2).reshape(batch, patch_number, patch_dim)
113
+
114
+ x = self.proj(x)
115
+ x = self.proj_drop(x) # mlp
116
+
117
+ # output x.shape = batch, patch_number, patch_dim
118
+ return x
119
+
120
+
121
+ class Encoder_Block(nn.Module): # teansformer Block from timm
122
+
123
+ def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
124
+ drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
125
+ """
126
+ # input x.shape = batch, patch_number, patch_dim
127
+ # output x.shape = batch, patch_number, patch_dim
128
+
129
+ :param dim: dim
130
+ :param num_heads:
131
+ :param mlp_ratio: FFN
132
+ :param qkv_bias:
133
+ :param qk_scale: by default head_dim ** -0.5 (squre root)
134
+ :param drop:
135
+ :param attn_drop: dropout rate after Attention
136
+ :param drop_path: dropout rate after sd
137
+ :param act_layer: FFN act
138
+ :param norm_layer: Pre Norm
139
+ """
140
+ super().__init__()
141
+ # Pre Norm
142
+ self.norm1 = norm_layer(dim) # Transformer used the nn.LayerNorm
143
+ self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop,
144
+ proj_drop=drop)
145
+ # NOTE from timm: drop path for stochastic depth, we shall see if this is better than dropout here
146
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() # stochastic depth
147
+
148
+ # Add & Norm
149
+ self.norm2 = norm_layer(dim)
150
+
151
+ # FFN
152
+ mlp_hidden_dim = int(dim * mlp_ratio)
153
+ self.mlp = FFN(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
154
+
155
+ def forward(self, x):
156
+ x = x + self.drop_path(self.attn(self.norm1(x)))
157
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
158
+ return x
159
+
160
+
161
+ class Guided_Attention(nn.Module): # q1 k1 v0 Transform + MSA(MHSA) (based on timm Attention)
162
+ """
163
+ notice the q abd k is guided information from Focus module
164
+ qkv Transform + MSA(MHSA) (from timm)
165
+
166
+ # 3 input of x.shape = batch, patch_number, patch_dim
167
+ # 1 output of x.shape = batch, patch_number, patch_dim
168
+
169
+ :param dim: dim = CNN feature dim, because the patch size is 1x1
170
+ :param num_heads:
171
+ :param qkv_bias:
172
+ :param qk_scale: by default head_dim ** -0.5 (squre root)
173
+ :param attn_drop:
174
+ :param proj_drop:
175
+
176
+ """
177
+
178
+ def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
179
+ super().__init__()
180
+ self.num_heads = num_heads
181
+ head_dim = dim // num_heads
182
+ self.scale = qk_scale or head_dim ** -0.5
183
+
184
+ self.qT = nn.Linear(dim, dim, bias=qkv_bias)
185
+ self.kT = nn.Linear(dim, dim, bias=qkv_bias)
186
+ self.vT = nn.Linear(dim, dim, bias=qkv_bias)
187
+
188
+ self.attn_drop = nn.Dropout(attn_drop)
189
+
190
+ self.proj = nn.Linear(dim, dim)
191
+ self.proj_drop = nn.Dropout(proj_drop)
192
+
193
+ def forward(self, q_encoder, k_encoder, v_input):
194
+ # 3 input of x.shape = batch, patch_number, patch_dim
195
+ batch, patch_number, patch_dim = v_input.shape
196
+
197
+ q = self.qT(q_encoder).reshape(batch, patch_number, 1, self.num_heads,
198
+ patch_dim // self.num_heads).permute(2, 0, 3, 1, 4)
199
+ k = self.kT(k_encoder).reshape(batch, patch_number, 1, self.num_heads,
200
+ patch_dim // self.num_heads).permute(2, 0, 3, 1, 4)
201
+ v = self.vT(v_input).reshape(batch, patch_number, 1, self.num_heads,
202
+ patch_dim // self.num_heads).permute(2, 0, 3, 1, 4)
203
+ q = q[0]
204
+ k = k[0]
205
+ v = v[0]
206
+
207
+ attn = (q @ k.transpose(-2, -1)) * self.scale
208
+ attn = attn.softmax(dim=-1)
209
+
210
+ attn = self.attn_drop(attn) # Dropout
211
+
212
+ x = (attn @ v).transpose(1, 2).reshape(batch, patch_number, patch_dim)
213
+
214
+ x = self.proj(x)
215
+ x = self.proj_drop(x) # mlp Dropout
216
+
217
+ # output of x.shape = batch, patch_number, patch_dim
218
+ return x
219
+
220
+
221
+ class Decoder_Block(nn.Module):
222
+ # FGD Decoder (Transformer encoder + Guided Attention block block)
223
+ def __init__(self, dim, num_heads=8, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
224
+ drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
225
+ """
226
+ # input x.shape = batch, patch_number, patch_dim
227
+ # output x.shape = batch, patch_number, patch_dim
228
+
229
+ :param dim: dim=CNN feature dim, because the patch size is 1x1
230
+ :param num_heads: multi-head
231
+ :param mlp_ratio: FFN expand ratio
232
+ :param qkv_bias: qkv MLP bias
233
+ :param qk_scale: by default head_dim ** -0.5 (squre root)
234
+ :param drop: the MLP after MHSA equipt a dropout rate
235
+ :param attn_drop: dropout rate after attention block
236
+ :param drop_path: dropout rate for stochastic depth
237
+ :param act_layer: FFN act
238
+ :param norm_layer: Pre Norm strategy with norm layer
239
+ """
240
+ super().__init__()
241
+ # Pre Norm
242
+ self.norm0 = norm_layer(dim) # nn.LayerNorm
243
+ self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop,
244
+ proj_drop=drop)
245
+ # stochastic depth
246
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
247
+
248
+ # Pre Norm
249
+ self.norm1 = norm_layer(dim)
250
+
251
+ # FFN1
252
+ mlp_hidden_dim = int(dim * mlp_ratio)
253
+ self.FFN1 = FFN(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
254
+
255
+ # Guided_Attention
256
+ self.Cross_attn = Guided_Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
257
+ attn_drop=attn_drop, proj_drop=drop)
258
+
259
+ # Add & Norm
260
+ self.norm2 = norm_layer(dim)
261
+ # FFN2
262
+ self.FFN2 = FFN(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
263
+
264
+ # Add & Norm
265
+ self.norm3 = norm_layer(dim)
266
+
267
+ def forward(self, q_encoder, k_encoder, v_input):
268
+ v_self = v_input + self.drop_path(self.attn(self.norm0(v_input)))
269
+
270
+ v_self = v_self + self.drop_path(self.FFN1(self.norm1(v_self)))
271
+
272
+ # norm layer for v only, the normalization of q and k is inside FGD Focus block
273
+ v_self = v_self + self.drop_path(self.Cross_attn(q_encoder, k_encoder, self.norm2(v_self)))
274
+
275
+ v_self = v_self + self.drop_path(self.FFN2(self.norm3(v_self)))
276
+
277
+ return v_self
278
+
279
+
280
+ '''
281
+ # testing example
282
+
283
+ model=Decoder_Block(dim=768)
284
+ k = torch.randn(7, 49, 768)
285
+ q = torch.randn(7, 49, 768)
286
+ v = torch.randn(7, 49, 768)
287
+ x = model(k,q,v)
288
+ print(x.shape)
289
+ '''
290
+
291
+
292
+ # MViT modules
293
+ # from https://github.com/facebookresearch/SlowFast/slowfast/models/attention.py
294
+ def attention_pool(tensor, pool, thw_shape, has_cls_embed=True, norm=None):
295
+ """
296
+ attention pooling constructor
297
+
298
+ input:
299
+ tensor of (B, Head, N, C) or (B, N, C)
300
+ thw_shape: T, H, W 对应CNN的特征图形状(2D形状)T is video frams
301
+
302
+ numpy.prob(T, H, W) == N(Num_patches) - 1 (cls token if it is there)
303
+
304
+ output:
305
+ tensor of (B, Head, N_O, C) or (B, N_O, C)
306
+ thw_shape: T_O, H_O, W_O
307
+
308
+ :param tensor: input feature patches
309
+ :param pool: pooling/conv layer
310
+ :param thw_shape: reconstruction feature map shape
311
+ :param has_cls_embed: if cls token is used
312
+ :param norm: norm layer
313
+
314
+ """
315
+ if pool is None: # no pool
316
+ return tensor, thw_shape
317
+
318
+ tensor_dim = tensor.ndim
319
+
320
+ # fix dim: [B, Head, N, C]
321
+ # N is Num_patches in Transformer modeling
322
+
323
+ if tensor_dim == 4:
324
+ pass
325
+ elif tensor_dim == 3: # [B, N, C] -> [B, Head(1), N, C]
326
+ tensor = tensor.unsqueeze(1)
327
+ else:
328
+ raise NotImplementedError(f"Unsupported input dimension {tensor.shape}")
329
+
330
+ if has_cls_embed:
331
+ cls_tok, tensor = tensor[:, :, :1, :], tensor[:, :, 1:, :]
332
+
333
+ B, Head, N, C = tensor.shape
334
+ T, H, W = thw_shape # numpy.prob(T, H, W) == N(Num_patches) - 1 (cls token if it is there)
335
+
336
+ # [B, Head, N, C] -> [B * Head, T, H, W, C] -> [B * Head, C, T, H, W]
337
+ tensor = (tensor.reshape(B * Head, T, H, W, C).permute(0, 4, 1, 2, 3).contiguous())
338
+ # use tensor.contiguous() to matain its memory location
339
+
340
+ # [B * Head, C, T, H, W] -> [B * Head, C, T_O, H_O, W_O]
341
+ tensor = pool(tensor) # 3D Pooling/ 3D Conv
342
+
343
+ # output T, H, W
344
+ thw_shape = [tensor.shape[2], tensor.shape[3], tensor.shape[4]]
345
+ # output Num_patches: numpy.prob(T, H, W)
346
+ N_pooled = tensor.shape[2] * tensor.shape[3] * tensor.shape[4]
347
+
348
+ # [B * Head, C, T_O, H_O, W_O] -> [B, Head, C, N_O(T_O*H_O*W_O)] -> [B, Head, N_O, C]
349
+ tensor = tensor.reshape(B, Head, C, N_pooled).transpose(2, 3)
350
+
351
+ if has_cls_embed:
352
+ # [B, Head, N_O, C] -> [B, Head, N_O+1(cls token), C]
353
+ tensor = torch.cat((cls_tok, tensor), dim=2)
354
+
355
+ # norm
356
+ if norm is not None:
357
+ tensor = norm(tensor)
358
+
359
+ # Assert tensor_dim in [3, 4]
360
+ if tensor_dim == 4: # [B, Head, N_O, C] multi-head
361
+ pass
362
+ else: # tensor_dim == 3: this is a single Head
363
+ tensor = tensor.squeeze(1) # [B, N_O, C]
364
+
365
+ return tensor, thw_shape
366
+
367
+
368
+ '''
369
+ # case 1 single-head no pooling scale
370
+ x = torch.randn(1, 197, 768)
371
+ thw_shape = [1, 14, 14]
372
+ pool = nn.MaxPool3d((1, 1, 1), (1, 1, 1), (0, 0, 0), ceil_mode=False)
373
+ y, thw = attention_pool(x, pool, thw_shape)
374
+
375
+ print(y.shape) # torch.Size([1, 197, 768])
376
+ print(thw) # [1, 14, 14]
377
+
378
+
379
+ # case 2 multi-head no pooling scale
380
+ x = torch.randn(1, 8, 197, 96) # [B, Head, N_O, C] multi-head
381
+ thw_shape = [1, 14, 14]
382
+ pool = nn.MaxPool3d((1, 1, 1), (1, 1, 1), (0, 0, 0), ceil_mode=False)
383
+ y, thw = attention_pool(x, pool, thw_shape)
384
+
385
+ print(y.shape) # torch.Size([1, 8, 197, 96])
386
+ print(thw) # [1, 14, 14]
387
+
388
+
389
+ # case 3 pooling scale
390
+ x = torch.randn(1, 197, 768)
391
+ thw_shape = [1, 14, 14]
392
+ pool = nn.MaxPool3d((1, 2, 2), (1, 2, 2), (0, 0, 0), ceil_mode=False)
393
+ y, thw = attention_pool(x, pool, thw_shape)
394
+
395
+ print(y.shape) # torch.Size([1, 50, 768])
396
+ print(thw) # [1, 7, 7]
397
+
398
+
399
+ # case 4 multi-head pooling scale
400
+ x = torch.randn(1, 8, 197, 96) # [B, Head, N_O, C] multi-head
401
+ thw_shape = [1, 14, 14]
402
+ pool = nn.MaxPool3d((1, 2, 2), (1, 2, 2), (0, 0, 0), ceil_mode=False)
403
+ y, thw = attention_pool(x, pool, thw_shape)
404
+
405
+ print(y.shape) # torch.Size([1, 8, 50, 96])
406
+ print(thw) # [1, 7, 7]
407
+ '''
408
+
409
+
410
+ class MultiScaleAttention(nn.Module): # Attention module
411
+ """
412
+ Attention module constructor
413
+
414
+ input:
415
+ tensor of (B, N, C)
416
+ thw_shape: T, H, W 对应CNN的特征图形状(2D形状)T is video frams
417
+
418
+ numpy.prob(T, H, W) == N(Num_patches) - 1 (cls token if it is there)
419
+
420
+ output:
421
+ tensor of (B, N_O, C)
422
+ thw_shape: T_O, H_O, W_O
423
+
424
+ :param dim: Transformer feature dim
425
+ :param num_heads: Transformer heads
426
+ :param qkv_bias: projecting bias
427
+ :param drop_rate: dropout rate after attention calculation and mlp
428
+
429
+ :param kernel_q: pooling kernal size for q
430
+ :param kernel_kv: pooling kernal size for k and v
431
+ :param stride_q: pooling kernal stride for q
432
+ :param stride_kv: pooling kernal stride for k and v
433
+
434
+ :param norm_layer: norm layer
435
+ :param has_cls_embed: if cls token is used
436
+ :param mode: mode for attention pooling(downsampling) Options include `conv`, `avg`, and `max`.
437
+ :param pool_first: process pooling(downsampling) before liner projecting
438
+
439
+ """
440
+
441
+ def __init__(
442
+ self,
443
+ dim,
444
+ num_heads=8,
445
+ qkv_bias=False,
446
+ drop_rate=0.0,
447
+ kernel_q=(1, 1, 1),
448
+ kernel_kv=(1, 1, 1),
449
+ stride_q=(1, 1, 1),
450
+ stride_kv=(1, 1, 1),
451
+ norm_layer=nn.LayerNorm,
452
+ has_cls_embed=True,
453
+ # Options include `conv`, `avg`, and `max`.
454
+ mode="conv",
455
+ # If True, perform pool before projection.
456
+ pool_first=False,
457
+ ):
458
+ super().__init__()
459
+
460
+ self.pool_first = pool_first
461
+ self.drop_rate = drop_rate
462
+ self.num_heads = num_heads
463
+ head_dim = dim // num_heads
464
+ self.scale = head_dim ** -0.5 # squre root
465
+ self.has_cls_embed = has_cls_embed
466
+
467
+ padding_q = [int(q // 2) for q in kernel_q] # 以半个kernal size进行padding,向下取整
468
+ padding_kv = [int(kv // 2) for kv in kernel_kv]
469
+
470
+ # projecting mlp
471
+ self.q = nn.Linear(dim, dim, bias=qkv_bias)
472
+ self.k = nn.Linear(dim, dim, bias=qkv_bias)
473
+ self.v = nn.Linear(dim, dim, bias=qkv_bias)
474
+ self.proj = nn.Linear(dim, dim)
475
+
476
+ if drop_rate > 0.0:
477
+ self.proj_drop = nn.Dropout(drop_rate)
478
+
479
+ # Skip pooling with kernel and stride size of (1, 1, 1).
480
+ if np.prod(kernel_q) == 1 and np.prod(stride_q) == 1:
481
+ kernel_q = () # clear
482
+ if np.prod(kernel_kv) == 1 and np.prod(stride_kv) == 1:
483
+ kernel_kv = ()
484
+
485
+ if mode in ("avg", "max"): # use nn.MaxPool3d or nn.AvgPool3d
486
+ pool_op = nn.MaxPool3d if mode == "max" else nn.AvgPool3d
487
+ self.pool_q = (
488
+ pool_op(kernel_q, stride_q, padding_q, ceil_mode=False)
489
+ if len(kernel_q) > 0
490
+ else None # Skip pooling if kernel is cleared
491
+ )
492
+ self.pool_k = (
493
+ pool_op(kernel_kv, stride_kv, padding_kv, ceil_mode=False)
494
+ if len(kernel_kv) > 0
495
+ else None
496
+ )
497
+ self.pool_v = (
498
+ pool_op(kernel_kv, stride_kv, padding_kv, ceil_mode=False)
499
+ if len(kernel_kv) > 0
500
+ else None
501
+ )
502
+
503
+ elif mode == "conv": # use nn.Conv3d with depth wise conv and fixed channel setting
504
+ self.pool_q = (
505
+ nn.Conv3d(
506
+ head_dim,
507
+ head_dim,
508
+ kernel_q,
509
+ stride=stride_q,
510
+ padding=padding_q,
511
+ groups=head_dim,
512
+ bias=False,
513
+ )
514
+ if len(kernel_q) > 0
515
+ else None
516
+ )
517
+ self.norm_q = norm_layer(head_dim) if len(kernel_q) > 0 else None
518
+
519
+ self.pool_k = (
520
+ nn.Conv3d(
521
+ head_dim,
522
+ head_dim,
523
+ kernel_kv,
524
+ stride=stride_kv,
525
+ padding=padding_kv,
526
+ groups=head_dim,
527
+ bias=False,
528
+ )
529
+ if len(kernel_kv) > 0
530
+ else None
531
+ )
532
+ self.norm_k = norm_layer(head_dim) if len(kernel_kv) > 0 else None
533
+
534
+ self.pool_v = (
535
+ nn.Conv3d(
536
+ head_dim,
537
+ head_dim,
538
+ kernel_kv,
539
+ stride=stride_kv,
540
+ padding=padding_kv,
541
+ groups=head_dim,
542
+ bias=False,
543
+ )
544
+ if len(kernel_kv) > 0
545
+ else None
546
+ )
547
+ self.norm_v = norm_layer(head_dim) if len(kernel_kv) > 0 else None
548
+ else:
549
+ raise NotImplementedError(f"Unsupported model {mode}")
550
+
551
+ def forward(self, x, thw_shape):
552
+ """
553
+ x: Transformer feature patches
554
+ thw_shape: reconstruction feature map shape
555
+ """
556
+
557
+ B, N, C = x.shape
558
+
559
+ # step 1: duplicate projecting + head split: [B, N, C] -> [B, H, N, C/H]
560
+
561
+ if self.pool_first: # step a.1 embedding
562
+ # head split [B, N, C] -> [B, N, H, C/H] -> [B, H, N, C/H]
563
+ x = x.reshape(B, N, self.num_heads, C // self.num_heads).permute(
564
+ 0, 2, 1, 3
565
+ )
566
+ q = k = v = x
567
+
568
+ else: # step b.1 projecting first
569
+ # mlp transform + head split: [B, N, C] -> [B, N, H, C/H] -> [B, H, N, C/H]
570
+ # todo 这里我觉得可能共享mlp映射更好,能有更好的交互,但是分离mlp更节约计算量
571
+ q = k = v = x
572
+ q = (
573
+ self.q(q)
574
+ .reshape(B, N, self.num_heads, C // self.num_heads)
575
+ .permute(0, 2, 1, 3)
576
+ )
577
+ k = (
578
+ self.k(k)
579
+ .reshape(B, N, self.num_heads, C // self.num_heads)
580
+ .permute(0, 2, 1, 3)
581
+ )
582
+ v = (
583
+ self.v(v)
584
+ .reshape(B, N, self.num_heads, C // self.num_heads)
585
+ .permute(0, 2, 1, 3)
586
+ )
587
+
588
+ # step 2: calculate attention_pool feature sequence and its shape
589
+ # [B, H, N0, C/H] -> [B, H, N1, C/H]
590
+ q, q_shape = attention_pool(
591
+ q,
592
+ self.pool_q,
593
+ thw_shape,
594
+ has_cls_embed=self.has_cls_embed,
595
+ norm=self.norm_q if hasattr(self, "norm_q") else None,
596
+ )
597
+ k, k_shape = attention_pool(
598
+ k,
599
+ self.pool_k,
600
+ thw_shape,
601
+ has_cls_embed=self.has_cls_embed,
602
+ norm=self.norm_k if hasattr(self, "norm_k") else None,
603
+ )
604
+ v, v_shape = attention_pool(
605
+ v,
606
+ self.pool_v,
607
+ thw_shape,
608
+ has_cls_embed=self.has_cls_embed,
609
+ norm=self.norm_v if hasattr(self, "norm_v") else None,
610
+ )
611
+
612
+ if self.pool_first: # step a.3 MLP projecting
613
+ # calculate patch number, q_N, k_N, v_N
614
+ q_N = (
615
+ np.prod(q_shape) + 1
616
+ if self.has_cls_embed
617
+ else np.prod(q_shape)
618
+ )
619
+ k_N = (
620
+ np.prod(k_shape) + 1
621
+ if self.has_cls_embed
622
+ else np.prod(k_shape)
623
+ )
624
+ v_N = (
625
+ np.prod(v_shape) + 1
626
+ if self.has_cls_embed
627
+ else np.prod(v_shape)
628
+ )
629
+
630
+ # [B, H, N1, C/H] -> [B, N1, H, C/H] -> [B, N1, C] -> MLP
631
+ # -> [B, N1, C] -> [B, N1, H, C/H] -> [B, H, N1, C/H]
632
+ q = q.permute(0, 2, 1, 3).reshape(B, q_N, C)
633
+ q = (
634
+ self.q(q)
635
+ .reshape(B, q_N, self.num_heads, C // self.num_heads)
636
+ .permute(0, 2, 1, 3)
637
+ )
638
+
639
+ v = v.permute(0, 2, 1, 3).reshape(B, v_N, C)
640
+ v = (
641
+ self.v(v)
642
+ .reshape(B, v_N, self.num_heads, C // self.num_heads)
643
+ .permute(0, 2, 1, 3)
644
+ )
645
+
646
+ k = k.permute(0, 2, 1, 3).reshape(B, k_N, C)
647
+ k = (
648
+ self.k(k)
649
+ .reshape(B, k_N, self.num_heads, C // self.num_heads)
650
+ .permute(0, 2, 1, 3)
651
+ )
652
+
653
+ # step 3: attention calculation
654
+ # multi-head self attention [B, H, N1, C/H] -> [B, H, N1, C/H]
655
+ attn = (q @ k.transpose(-2, -1)) * self.scale
656
+ attn = attn.softmax(dim=-1)
657
+
658
+ # head squeeze [B, H, N1, C/H] -> [B, N1, H, C/H] -> [B, N1, C]
659
+ N = q.shape[2]
660
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
661
+
662
+ # step 4: mlp stablization and dropout [B, N1, C] -> [B, N1, C]
663
+ x = self.proj(x)
664
+ if self.drop_rate > 0.0:
665
+ x = self.proj_drop(x)
666
+
667
+ return x, q_shape
668
+
669
+
670
+ '''
671
+ # case 1
672
+ model = MultiScaleAttention(768)
673
+ x = torch.randn(1, 197, 768)
674
+ y, thw = model(x, [1, 14, 14])
675
+ print(y.shape)
676
+
677
+
678
+ # case 2
679
+ kernel_q = (1, 2, 2)
680
+ kernel_kv = (1, 2, 2)
681
+ stride_q = (1, 2, 2)
682
+ stride_kv = (1, 2, 2)
683
+ # MultiScaleAttention 中设计以半个kernal size进行padding,向下取整
684
+
685
+ model = MultiScaleAttention(768, kernel_q=kernel_q, kernel_kv=kernel_kv, stride_q=stride_q, stride_kv=stride_kv)
686
+ x = torch.randn(1, 197, 768)
687
+ y, thw = model(x, [1, 14, 14])
688
+
689
+ print(y.shape) # 输出torch.Size([1, 65, 768]):不padding是7*7 由于padding变成8*8, 之后加上cls token
690
+ '''
691
+
692
+
693
+ class MultiScaleBlock(nn.Module): # MViT Encoder
694
+ """
695
+ Attention module constructor
696
+
697
+ input:
698
+ tensor of (B, N, C)
699
+ thw_shape: T, H, W 对应CNN的特征图形状(2D形状)T is video frams
700
+
701
+ numpy.prob(T, H, W) == N(Num_patches) - 1 (cls token if it is there)
702
+
703
+ output:
704
+ tensor of (B, N_O, C)
705
+ thw_shape: T_O, H_O, W_O
706
+
707
+ :param dim: Transformer feature dim
708
+ :param dim_out:
709
+
710
+ :param num_heads: Transformer heads
711
+ :param mlp_ratio: FFN hidden expansion
712
+ :param qkv_bias: projecting bias
713
+ :param drop_rate: dropout rate after attention calculation and mlp
714
+ :param drop_path: dropout rate for SD
715
+ :param act_layer: FFN act
716
+ :param norm_layer: Pre Norm
717
+
718
+ :param up_rate:
719
+ :param kernel_q: pooling kernal size for q
720
+ :param kernel_kv: pooling kernal size for k and v
721
+ :param stride_q: pooling kernal stride for q
722
+ :param stride_kv: pooling kernal stride for k and v
723
+
724
+ :param has_cls_embed: if cls token is used
725
+ :param mode: mode for attention pooling(downsampling) Options include `conv`, `avg`, and `max`.
726
+ :param pool_first: process pooling(downsampling) before liner projecting
727
+
728
+ """
729
+
730
+ def __init__(
731
+ self,
732
+ dim,
733
+ dim_out,
734
+ num_heads=8,
735
+ mlp_ratio=4.0,
736
+ qkv_bias=False,
737
+ drop_rate=0.0,
738
+ drop_path=0.0,
739
+ act_layer=nn.GELU,
740
+ norm_layer=nn.LayerNorm,
741
+ up_rate=None,
742
+ kernel_q=(1, 1, 1),
743
+ kernel_kv=(1, 1, 1),
744
+ stride_q=(1, 1, 1),
745
+ stride_kv=(1, 1, 1),
746
+ has_cls_embed=True,
747
+ mode="conv",
748
+ pool_first=False,
749
+ ):
750
+ super().__init__()
751
+
752
+ self.has_cls_embed = has_cls_embed
753
+
754
+ # step 1: Attention projecting
755
+ self.dim = dim
756
+ self.dim_out = dim_out
757
+ self.norm1 = norm_layer(dim) # pre-norm
758
+
759
+ self.attn = MultiScaleAttention(
760
+ dim,
761
+ num_heads=num_heads,
762
+ qkv_bias=qkv_bias,
763
+ drop_rate=drop_rate,
764
+ kernel_q=kernel_q,
765
+ kernel_kv=kernel_kv,
766
+ stride_q=stride_q,
767
+ stride_kv=stride_kv,
768
+ norm_layer=nn.LayerNorm,
769
+ has_cls_embed=self.has_cls_embed,
770
+ mode=mode,
771
+ pool_first=pool_first,
772
+ )
773
+
774
+ self.drop_path = (DropPath(drop_path) if drop_path > 0.0 else nn.Identity())
775
+
776
+ # residual connection for Attention projecting
777
+ kernel_skip = kernel_q # fixme ori: [s + 1 if s > 1 else s for s in stride_q]
778
+ stride_skip = stride_q
779
+ padding_skip = [int(skip // 2) for skip in kernel_skip] # 以半个kernal size进行padding,向下取整
780
+
781
+ self.pool_skip = (
782
+ nn.MaxPool3d(kernel_skip, stride_skip, padding_skip, ceil_mode=False)
783
+ if len(kernel_skip) > 0
784
+ else None)
785
+
786
+ self.norm2 = norm_layer(dim) # pre-norm
787
+
788
+ # step 2: FFN projecting
789
+ mlp_hidden_dim = int(dim * mlp_ratio)
790
+
791
+ # here use FFN to encode feature into abstractive information in the dimension
792
+ # TODO: check the use case for up_rate, and merge the following lines
793
+ if up_rate is not None and up_rate > 1:
794
+ mlp_dim_out = dim * up_rate
795
+ else:
796
+ mlp_dim_out = dim_out
797
+
798
+ self.mlp = FFN(
799
+ in_features=dim,
800
+ hidden_features=mlp_hidden_dim,
801
+ out_features=mlp_dim_out,
802
+ act_layer=act_layer,
803
+ drop=drop_rate,
804
+ )
805
+
806
+ # residual connection for FFN projecting
807
+ if dim != dim_out:
808
+ self.proj = nn.Linear(dim, dim_out)
809
+
810
+ def forward(self, x, thw_shape):
811
+ # step 1: Attention projecting
812
+ x_block, thw_shape_new = self.attn(self.norm1(x), thw_shape)
813
+ # residual connection for Attention projecting
814
+ x_res, _ = attention_pool(x, self.pool_skip, thw_shape, has_cls_embed=self.has_cls_embed)
815
+ x = x_res + self.drop_path(x_block)
816
+
817
+ # step 2: FFN projecting
818
+ x_norm = self.norm2(x)
819
+ x_mlp = self.mlp(x_norm)
820
+ # residual connection for FFN projecting
821
+ if self.dim != self.dim_out:
822
+ x = self.proj(x_norm)
823
+ x = x + self.drop_path(x_mlp)
824
+
825
+ return x, thw_shape_new
826
+
827
+
828
+ '''
829
+ # case 1
830
+ model = MultiScaleBlock(768,1024)
831
+ x = torch.randn(1, 197, 768)
832
+ y, thw = model(x, [1, 14, 14])
833
+ print(y.shape) # torch.Size([1, 197, 1024])
834
+
835
+
836
+ # case 2
837
+ kernel_q = (1, 2, 2)
838
+ kernel_kv = (1, 2, 2)
839
+ stride_q = (1, 2, 2)
840
+ stride_kv = (1, 2, 2)
841
+ # MultiScaleAttention 中设计以半个kernal size进行padding,向下取整
842
+
843
+ model = MultiScaleBlock(768, 1024, kernel_q=kernel_q, kernel_kv=kernel_kv, stride_q=stride_q, stride_kv=stride_kv)
844
+ x = torch.randn(1, 197, 768)
845
+ y, thw = model(x, [1, 14, 14])
846
+
847
+ print(y.shape) # 输出torch.Size([1, 65, 1024]):不padding是7*7 由于padding变成8*8, 之后加上cls token
848
+ '''
849
+
850
+
851
+ class PatchEmbed(nn.Module): # PatchEmbed from timm
852
+ """
853
+ Image to Patch Embedding
854
+ """
855
+
856
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
857
+ super().__init__()
858
+ img_size = to_2tuple(img_size)
859
+ patch_size = to_2tuple(patch_size)
860
+ num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
861
+ self.img_size = img_size
862
+ self.patch_size = patch_size
863
+ self.num_patches = num_patches
864
+
865
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
866
+
867
+ def forward(self, x):
868
+ B, C, H, W = x.shape
869
+ # FIXME look at relaxing size constraints
870
+ assert H == self.img_size[0] and W == self.img_size[1], \
871
+ f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
872
+ x = self.proj(x).flatten(2).transpose(1, 2)
873
+
874
+ # x: (B, 14*14, 768)
875
+ return x
876
+
877
+
878
+ class Hybrid_feature_map_Embed(nn.Module): # HybridEmbed from timm
879
+ """
880
+ CNN Feature Map Embedding, required backbone which is just for referance here
881
+ Extract feature map from CNN, flatten, project to embedding dim.
882
+
883
+ # input x.shape = batch, feature_dim, feature_size[0], feature_size[1]
884
+ # output x.shape = batch, patch_number, patch_dim
885
+ """
886
+
887
+ def __init__(self, backbone, img_size=224, patch_size=1, feature_size=None, feature_dim=None,
888
+ in_chans=3, embed_dim=768):
889
+ super().__init__()
890
+
891
+ assert isinstance(backbone, nn.Module)
892
+
893
+ img_size = to_2tuple(img_size)
894
+ patch_size = to_2tuple(patch_size)
895
+ self.img_size = img_size
896
+ self.patch_size = patch_size
897
+ self.backbone = backbone
898
+
899
+ if feature_size is None or feature_dim is None: # backbone output feature_size
900
+ with torch.no_grad():
901
+ # NOTE Most reliable way of determining output dims is to run forward pass
902
+ training = backbone.training
903
+ if training:
904
+ backbone.eval()
905
+ o = self.backbone(torch.zeros(1, in_chans, img_size[0], img_size[1]))
906
+ if isinstance(o, (list, tuple)):
907
+ o = o[-1] # last feature if backbone outputs list/tuple of features
908
+ feature_size = o.shape[-2:]
909
+ feature_dim = o.shape[1]
910
+ backbone.train(training)
911
+ else:
912
+ feature_size = to_2tuple(feature_size)
913
+ '''
914
+ if hasattr(self.backbone, 'feature_info'):
915
+ feature_dim = self.backbone.feature_info.channels()[-1]
916
+ else:
917
+ feature_dim = self.backbone.num_features
918
+ '''
919
+
920
+ assert feature_size[0] % patch_size[0] == 0 and feature_size[1] % patch_size[1] == 0
921
+
922
+ self.grid_size = (feature_size[0] // patch_size[0], feature_size[1] // patch_size[1]) # patchlize
923
+
924
+ self.num_patches = self.grid_size[0] * self.grid_size[1]
925
+
926
+ self.proj = nn.Conv2d(in_channels=feature_dim, out_channels=embed_dim,
927
+ kernel_size=patch_size, stride=patch_size)
928
+
929
+ def forward(self, x):
930
+ x = self.backbone(x)
931
+ if isinstance(x, (list, tuple)):
932
+ x = x[-1] # last feature if backbone outputs list/tuple of features
933
+
934
+ x = self.proj(x).flatten(2).transpose(1, 2) # shape = ( )
935
+ """
936
+ x.shape: batch, feature_dim, feature_size[0], feature_size[1]
937
+ proj(x).shape: batch, embed_dim, patch_height_num, patch_width_num
938
+ flatten(2).shape: batch, embed_dim, patch_num
939
+ .transpose(1, 2).shape: batch feature_patch_number feature_patch_dim
940
+ """
941
+ # output: x.shape = batch, patch_number, patch_dim
942
+ return x
943
+
944
+
945
+ class Last_feature_map_Embed(nn.Module):
946
+ """
947
+ use this block to connect last CNN stage to the first Transformer block
948
+ Extract feature map from CNN, flatten, project to embedding dim.
949
+
950
+ # input x.shape = batch, feature_dim, feature_size[0], feature_size[1]
951
+ # output x.shape = batch, patch_number, patch_dim
952
+ """
953
+
954
+ def __init__(self, patch_size=1, feature_size=(7, 7), feature_dim=2048, embed_dim=768,
955
+ Attention_module=None):
956
+ super().__init__()
957
+
958
+ # Attention module
959
+ if Attention_module is not None:
960
+ if Attention_module == 'SimAM':
961
+ self.Attention_module = simam_module(e_lambda=1e-4)
962
+ elif Attention_module == 'CBAM':
963
+ self.Attention_module = cbam_module(gate_channels=feature_dim)
964
+ elif Attention_module == 'SE':
965
+ self.Attention_module = se_module(channel=feature_dim)
966
+ else:
967
+ self.Attention_module = None
968
+
969
+ patch_size = to_2tuple(patch_size)
970
+ self.patch_size = patch_size
971
+
972
+ feature_size = to_2tuple(feature_size)
973
+
974
+ # feature map should be matching the size
975
+ assert feature_size[0] % self.patch_size[0] == 0 and feature_size[1] % self.patch_size[1] == 0
976
+
977
+ self.grid_size = (feature_size[0] // self.patch_size[0], feature_size[1] // self.patch_size[1]) # patch
978
+
979
+ self.num_patches = self.grid_size[0] * self.grid_size[1]
980
+
981
+ # use the conv to split the patch by the following design:
982
+ self.proj = nn.Conv2d(in_channels=feature_dim, out_channels=embed_dim,
983
+ kernel_size=self.patch_size, stride=self.patch_size)
984
+
985
+ def forward(self, x):
986
+ if self.Attention_module is not None:
987
+ x = self.Attention_module(x)
988
+
989
+ if isinstance(x, (list, tuple)):
990
+ x = x[-1] # last feature if backbone outputs list/tuple of features
991
+
992
+ x = self.proj(x).flatten(2).transpose(1, 2)
993
+ """
994
+ x.shape: batch, feature_dim, feature_size[0], feature_size[1]
995
+ proj(x).shape: batch, embed_dim, patch_height_num, patch_width_num
996
+ flatten(2).shape: batch, embed_dim, patch_num
997
+ .transpose(1, 2).shape: batch feature_patch_number feature_patch_dim
998
+ """
999
+ # output 格式 x.shape = batch, patch_number, patch_dim
1000
+ return x
1001
+
1002
+
1003
+ class Focus_Embed(nn.Module): # Attention guided module for hybridzing the early stages CNN feature
1004
+ """
1005
+ FGD Focus module
1006
+ Extract feature map from CNN, flatten, project to embedding dim. and use them as attention guidance
1007
+
1008
+ input: x.shape = batch, feature_dim, feature_size[0], feature_size[1]
1009
+
1010
+ Firstly, an attention block will be used to stable the feature projecting process
1011
+
1012
+ Secondly, for each feature map,the focus will be 2 path: gaze and glance
1013
+ in gaze path Max pool will be applied to get prominent information
1014
+ in glance path Avg pool will be applied to get general information
1015
+
1016
+ after the dual pooling path 2 seperate CNNs will be used to project the dimension
1017
+ Finally, flattern and transpose will be applied
1018
+
1019
+ output 2 attention guidance: gaze, glance
1020
+ x.shape = batch, patch_number, patch_dim
1021
+
1022
+
1023
+ ref:
1024
+ ResNet50's feature map from different stages (edge size of 224)
1025
+ stage 1 output feature map: torch.Size([b, 256, 56, 56])
1026
+ stage 2 output feature map: torch.Size([b, 512, 28, 28])
1027
+ stage 3 output feature map: torch.Size([b, 1024, 14, 14])
1028
+ stage 4 output feature map: torch.Size([b, 2048, 7, 7])
1029
+ """
1030
+
1031
+ def __init__(self, patch_size=1, target_feature_size=(7, 7), feature_size=(56, 56), feature_dim=256, embed_dim=768,
1032
+ Attention_module=None, norm_layer=nn.LayerNorm):
1033
+ super().__init__()
1034
+ patch_size = to_2tuple(patch_size)
1035
+ feature_size = to_2tuple(feature_size) # patch size of the current feature map
1036
+
1037
+ target_feature_size = to_2tuple(target_feature_size) # patch size of the last feature map
1038
+
1039
+ # cheak feature map can be patchlize to target_feature_size
1040
+ assert feature_size[0] % target_feature_size[0] == 0 and feature_size[1] % target_feature_size[1] == 0
1041
+
1042
+ # cheak target_feature map can be patchlize to patch
1043
+ assert target_feature_size[0] % patch_size[0] == 0 and target_feature_size[1] % patch_size[1] == 0
1044
+
1045
+ # Attention block
1046
+ if Attention_module is not None:
1047
+ if Attention_module == 'SimAM':
1048
+ self.Attention_module = simam_module(e_lambda=1e-4)
1049
+ elif Attention_module == 'CBAM':
1050
+ self.Attention_module = cbam_module(gate_channels=feature_dim)
1051
+ elif Attention_module == 'SE':
1052
+ self.Attention_module = se_module(channel=feature_dim)
1053
+ else:
1054
+ self.Attention_module = None
1055
+
1056
+ # split focus ROI
1057
+ self.focus_size = (feature_size[0] // target_feature_size[0], feature_size[1] // target_feature_size[1])
1058
+ self.num_focus = self.focus_size[0] * self.focus_size[1]
1059
+ # by kernel_size=focus_size, stride=focus_size design
1060
+ # output_size=target_feature_size=7x7 so as to match the minist feature map
1061
+
1062
+ self.gaze = nn.MaxPool2d(self.focus_size, stride=self.focus_size)
1063
+ self.glance = nn.AvgPool2d(self.focus_size, stride=self.focus_size)
1064
+ # x.shape: batch, feature_dim, target_feature_size[0], target_feature_size[1]
1065
+
1066
+ # split patch
1067
+ self.grid_size = (target_feature_size[0] // patch_size[0], target_feature_size[1] // patch_size[1])
1068
+ self.num_patches = self.grid_size[0] * self.grid_size[1]
1069
+
1070
+ # use CNN to project dim to patch_dim
1071
+ self.gaze_proj = nn.Conv2d(in_channels=feature_dim, out_channels=embed_dim,
1072
+ kernel_size=patch_size, stride=patch_size)
1073
+ self.glance_proj = nn.Conv2d(in_channels=feature_dim, out_channels=embed_dim,
1074
+ kernel_size=patch_size, stride=patch_size)
1075
+
1076
+ self.norm_q = norm_layer(embed_dim) # Transformer nn.LayerNorm
1077
+ self.norm_k = norm_layer(embed_dim) # Transformer nn.LayerNorm
1078
+
1079
+ def forward(self, x):
1080
+ if self.Attention_module is not None:
1081
+ x = self.Attention_module(x)
1082
+
1083
+ if isinstance(x, (list, tuple)):
1084
+ x = x[-1] # last feature if backbone outputs list/tuple of features
1085
+
1086
+ q = self.norm_q(self.gaze_proj(self.gaze(x)).flatten(2).transpose(1, 2))
1087
+ k = self.norm_k(self.glance_proj(self.glance(x)).flatten(2).transpose(1, 2))
1088
+ """
1089
+ x.shape: batch, feature_dim, feature_size[0], feature_size[1]
1090
+ gaze/glance(x).shape: batch, feature_dim, target_feature_size[0], target_feature_size[1]
1091
+ proj(x).shape: batch, embed_dim, patch_height_num, patch_width_num
1092
+ flatten(2).shape: batch, embed_dim, patch_num
1093
+ .transpose(1, 2).shape: batch feature_patch_number feature_patch_dim
1094
+ """
1095
+ # output x.shape = batch, patch_number, patch_dim
1096
+ return q, k
1097
+
1098
+
1099
+ '''
1100
+ # test sample
1101
+ model = Focus_Embed()
1102
+ x = torch.randn(4, 256, 56, 56)
1103
+ y1,y2 = model(x)
1104
+ print(y1.shape)
1105
+ print(y2.shape)
1106
+ '''
1107
+
1108
+
1109
+ class Focus_SEmbed(nn.Module): # Attention guided module for hybridzing the early stages CNN feature
1110
+ """
1111
+
1112
+ self focus (q=k) based on FGD Focus block
1113
+
1114
+ Extract feature map from CNN, flatten, project to embedding dim. and use them as attention guidance
1115
+
1116
+ input: x.shape = batch, feature_dim, feature_size[0], feature_size[1]
1117
+
1118
+ Firstly, an attention block will be used to stable the feature projecting process
1119
+
1120
+ Secondly, for each feature map,the focus will be 1 path: glance
1121
+ in glance path Avg pool will be applied to get general information
1122
+
1123
+ after the pooling process 1 CNN will be used to project the dimension
1124
+ Finally, flattern and transpose will be applied
1125
+
1126
+ output 2 attention guidance: glance, glance
1127
+ x.shape = batch, patch_number, patch_dim
1128
+ """
1129
+
1130
+ def __init__(self, patch_size=1, target_feature_size=(7, 7), feature_size=(56, 56), feature_dim=256, embed_dim=768,
1131
+ Attention_module=None, norm_layer=nn.LayerNorm):
1132
+ super().__init__()
1133
+ patch_size = to_2tuple(patch_size)
1134
+ feature_size = to_2tuple(feature_size)
1135
+
1136
+ target_feature_size = to_2tuple(target_feature_size)
1137
+
1138
+ assert feature_size[0] % target_feature_size[0] == 0 and feature_size[1] % target_feature_size[1] == 0
1139
+
1140
+ assert target_feature_size[0] % patch_size[0] == 0 and target_feature_size[1] % patch_size[1] == 0
1141
+
1142
+ if Attention_module is not None:
1143
+ if Attention_module == 'SimAM':
1144
+ self.Attention_module = simam_module(e_lambda=1e-4)
1145
+ elif Attention_module == 'CBAM':
1146
+ self.Attention_module = cbam_module(gate_channels=feature_dim)
1147
+ elif Attention_module == 'SE':
1148
+ self.Attention_module = se_module(channel=feature_dim)
1149
+ else:
1150
+ self.Attention_module = None
1151
+
1152
+ self.focus_size = (feature_size[0] // target_feature_size[0], feature_size[1] // target_feature_size[1])
1153
+ self.num_focus = self.focus_size[0] * self.focus_size[1]
1154
+
1155
+ self.gaze = nn.MaxPool2d(self.focus_size, stride=self.focus_size)
1156
+
1157
+ self.grid_size = (target_feature_size[0] // patch_size[0], target_feature_size[1] // patch_size[1])
1158
+ self.num_patches = self.grid_size[0] * self.grid_size[1]
1159
+
1160
+ self.proj = nn.Conv2d(in_channels=feature_dim, out_channels=embed_dim, kernel_size=patch_size,
1161
+ stride=patch_size)
1162
+
1163
+ self.norm_f = norm_layer(embed_dim)
1164
+
1165
+ def forward(self, x):
1166
+ if self.Attention_module is not None:
1167
+ x = self.Attention_module(x)
1168
+
1169
+ if isinstance(x, (list, tuple)):
1170
+ x = x[-1] # last feature if backbone outputs list/tuple of features
1171
+
1172
+ q = self.norm_f(self.proj(self.gaze(x)).flatten(2).transpose(1, 2))
1173
+ k = q
1174
+ """
1175
+ x.shape: batch, feature_dim, feature_size[0], feature_size[1]
1176
+ gaze/glance(x).shape: batch, feature_dim, target_feature_size[0], target_feature_size[1]
1177
+ proj(x).shape: batch, embed_dim, patch_height_num, patch_width_num
1178
+ flatten(2).shape: batch, embed_dim, patch_num
1179
+ .transpose(1, 2).shape: batch feature_patch_number feature_patch_dim
1180
+ """
1181
+ # output x.shape = batch, patch_number, patch_dim
1182
+ return q, k
1183
+
1184
+
1185
+ class Focus_Aggressive(nn.Module): # Attention guided module for hybridzing the early stages CNN feature
1186
+ """
1187
+ Aggressive CNN Focus based on FGD Focus block
1188
+
1189
+ Extract feature map from CNN, flatten, project to embedding dim. and use them as attention guidance
1190
+
1191
+ input: x.shape = batch, feature_dim, feature_size[0], feature_size[1]
1192
+
1193
+ Firstly, an attention block will be used to stable the feature projecting process
1194
+
1195
+ Secondly, 2 CNNs will be used to project the dimension
1196
+
1197
+ Finally, flattern and transpose will be applied
1198
+
1199
+ output 2 attention guidance: gaze, glance
1200
+ x.shape = batch, patch_number, patch_dim
1201
+
1202
+ """
1203
+
1204
+ def __init__(self, patch_size=1, target_feature_size=(7, 7), feature_size=(56, 56), feature_dim=256, embed_dim=768,
1205
+ Attention_module=None, norm_layer=nn.LayerNorm):
1206
+ super().__init__()
1207
+ patch_size = to_2tuple(patch_size) # patch size of the last feature map
1208
+ feature_size = to_2tuple(feature_size)
1209
+
1210
+ target_feature_size = to_2tuple(target_feature_size)
1211
+
1212
+ assert feature_size[0] % target_feature_size[0] == 0 and feature_size[1] % target_feature_size[1] == 0
1213
+
1214
+ assert target_feature_size[0] % patch_size[0] == 0 and target_feature_size[1] % patch_size[1] == 0
1215
+
1216
+ if Attention_module is not None:
1217
+ if Attention_module == 'SimAM':
1218
+ self.Attention_module = simam_module(e_lambda=1e-4)
1219
+ elif Attention_module == 'CBAM':
1220
+ self.Attention_module = cbam_module(gate_channels=feature_dim)
1221
+ elif Attention_module == 'SE':
1222
+ self.Attention_module = se_module(channel=feature_dim)
1223
+ else:
1224
+ self.Attention_module = None
1225
+
1226
+ self.focus_size = (feature_size[0] // target_feature_size[0], feature_size[1] // target_feature_size[1])
1227
+
1228
+ self.grid_size = (self.focus_size[0] * patch_size[0], self.focus_size[1] * patch_size[1])
1229
+ self.num_patches = (feature_size[0] // self.grid_size[0]) * (feature_size[1] // self.grid_size[1])
1230
+
1231
+ self.gaze_proj = nn.Conv2d(in_channels=feature_dim, out_channels=embed_dim,
1232
+ kernel_size=self.grid_size, stride=self.grid_size)
1233
+ self.glance_proj = nn.Conv2d(in_channels=feature_dim, out_channels=embed_dim,
1234
+ kernel_size=self.grid_size, stride=self.grid_size)
1235
+
1236
+ self.norm_q = norm_layer(embed_dim)
1237
+ self.norm_k = norm_layer(embed_dim)
1238
+
1239
+ def forward(self, x):
1240
+ if self.Attention_module is not None:
1241
+ x = self.Attention_module(x)
1242
+
1243
+ if isinstance(x, (list, tuple)):
1244
+ x = x[-1] # last feature if backbone outputs list/tuple of features
1245
+
1246
+ q = self.norm_q(self.gaze_proj(x).flatten(2).transpose(1, 2))
1247
+ k = self.norm_k(self.glance_proj(x).flatten(2).transpose(1, 2))
1248
+ """
1249
+ x.shape: batch, feature_dim, feature_size[0], feature_size[1]
1250
+ proj(x).shape: batch, embed_dim, patch_height_num, patch_width_num
1251
+ flatten(2).shape: batch, embed_dim, patch_num
1252
+ .transpose(1, 2).shape: batch feature_patch_number feature_patch_dim
1253
+ """
1254
+ # output x.shape = batch, patch_number, patch_dim
1255
+ return q, k
1256
+
1257
+
1258
+ class Focus_SAggressive(nn.Module): # Attention guided module for hybridzing the early stages CNN feature
1259
+ """
1260
+ Aggressive CNN self Focus
1261
+ Extract feature map from CNN, flatten, project to embedding dim. and use them as attention guidance
1262
+
1263
+ input: x.shape = batch, feature_dim, feature_size[0], feature_size[1]
1264
+
1265
+ Firstly, an attention block will be used to stable the feature projecting process
1266
+
1267
+ Secondly, 1 CNN will be used to project the dimension
1268
+
1269
+ Finally, flattern and transpose will be applied
1270
+
1271
+ output 2 attention guidance: glance, glance
1272
+ x.shape = batch, patch_number, patch_dim
1273
+ """
1274
+
1275
+ def __init__(self, patch_size=1, target_feature_size=(7, 7), feature_size=(56, 56), feature_dim=256, embed_dim=768,
1276
+ Attention_module=None, norm_layer=nn.LayerNorm):
1277
+ super().__init__()
1278
+ patch_size = to_2tuple(patch_size)
1279
+ feature_size = to_2tuple(feature_size)
1280
+
1281
+ target_feature_size = to_2tuple(target_feature_size)
1282
+
1283
+ assert feature_size[0] % target_feature_size[0] == 0 and feature_size[1] % target_feature_size[1] == 0
1284
+
1285
+ assert target_feature_size[0] % patch_size[0] == 0 and target_feature_size[1] % patch_size[1] == 0
1286
+
1287
+ if Attention_module is not None:
1288
+ if Attention_module == 'SimAM':
1289
+ self.Attention_module = simam_module(e_lambda=1e-4)
1290
+ elif Attention_module == 'CBAM':
1291
+ self.Attention_module = cbam_module(gate_channels=feature_dim)
1292
+ elif Attention_module == 'SE':
1293
+ self.Attention_module = se_module(channel=feature_dim)
1294
+ else:
1295
+ self.Attention_module = None
1296
+
1297
+ self.focus_size = (feature_size[0] // target_feature_size[0], feature_size[1] // target_feature_size[1])
1298
+
1299
+ self.grid_size = (self.focus_size[0] * patch_size[0], self.focus_size[1] * patch_size[1])
1300
+ self.num_patches = (feature_size[0] // self.grid_size[0]) * (feature_size[1] // self.grid_size[1])
1301
+
1302
+ self.proj = nn.Conv2d(in_channels=feature_dim, out_channels=embed_dim,
1303
+ kernel_size=self.grid_size, stride=self.grid_size)
1304
+
1305
+ self.norm_f = norm_layer(embed_dim)
1306
+
1307
+ def forward(self, x):
1308
+ if self.Attention_module is not None:
1309
+ x = self.Attention_module(x)
1310
+
1311
+ if isinstance(x, (list, tuple)):
1312
+ x = x[-1] # last feature if backbone outputs list/tuple of features
1313
+
1314
+ q = self.norm_f(self.proj(x).flatten(2).transpose(1, 2))
1315
+ k = q
1316
+ """
1317
+ x.shape: batch, feature_dim, feature_size[0], feature_size[1]
1318
+ proj(x).shape: batch, embed_dim, patch_height_num, patch_width_num
1319
+ flatten(2).shape: batch, embed_dim, patch_num
1320
+ .transpose(1, 2).shape: batch feature_patch_number feature_patch_dim
1321
+ """
1322
+ # output x.shape = batch, patch_number, patch_dim
1323
+ return q, k
1324
+
1325
+
1326
+ class VisionTransformer(nn.Module): # From timm to review the ViT and ViT_resn5
1327
+ """
1328
+ Vision Transformer
1329
+ A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale`
1330
+ - https://arxiv.org/abs/2010.11929
1331
+ Includes distillation token & head support for `DeiT: Data-efficient Image Transformers`
1332
+ - https://arxiv.org/abs/2012.12877
1333
+ """
1334
+
1335
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
1336
+ num_heads=12, mlp_ratio=4., qkv_bias=True, representation_size=None, drop_rate=0., attn_drop_rate=0.,
1337
+ drop_path_rate=0., embed_layer=PatchEmbed, norm_layer=None, act_layer=None):
1338
+ """
1339
+ Args:
1340
+ img_size (int, tuple): input image size
1341
+ patch_size (int, tuple): patch size
1342
+ in_chans (int): number of input channels
1343
+ num_classes (int): number of classes for classification head
1344
+ embed_dim (int): embedding dimension
1345
+ depth (int): depth of transformer
1346
+ num_heads (int): number of attention heads
1347
+ mlp_ratio (int): ratio of mlp hidden dim to embedding dim
1348
+ qkv_bias (bool): enable bias for qkv if True
1349
+ representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set
1350
+ drop_rate (float): dropout rate
1351
+ attn_drop_rate (float): attention dropout rate
1352
+ drop_path_rate (float): stochastic depth rate
1353
+ embed_layer (nn.Module): patch embedding layer
1354
+ norm_layer: (nn.Module): normalization layer
1355
+ """
1356
+ super().__init__()
1357
+ self.num_classes = num_classes
1358
+ self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
1359
+ self.num_tokens = 1
1360
+ norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
1361
+ act_layer = act_layer or nn.GELU
1362
+
1363
+ self.patch_embed = embed_layer(
1364
+ img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
1365
+ num_patches = self.patch_embed.num_patches
1366
+
1367
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
1368
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))
1369
+ self.pos_drop = nn.Dropout(p=drop_rate)
1370
+
1371
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
1372
+
1373
+ self.blocks = nn.Sequential(*[
1374
+ Encoder_Block(dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop=drop_rate,
1375
+ attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, act_layer=act_layer)
1376
+ for i in range(depth)])
1377
+
1378
+ self.norm = norm_layer(embed_dim)
1379
+
1380
+ # Representation layer
1381
+ if representation_size:
1382
+ self.num_features = representation_size
1383
+ self.pre_logits = nn.Sequential(OrderedDict([
1384
+ ('fc', nn.Linear(embed_dim, representation_size)),
1385
+ ('act', nn.Tanh())
1386
+ ]))
1387
+ else:
1388
+ self.pre_logits = nn.Identity()
1389
+
1390
+ # Classifier head(s)
1391
+ self.head = nn.Linear(self.num_features, self.num_classes) if self.num_classes > 0 else nn.Identity()
1392
+ self.head_dist = None
1393
+
1394
+ def forward_features(self, x):
1395
+ x = self.patch_embed(x)
1396
+ # print(x.shape,self.pos_embed.shape)
1397
+ cls_token = self.cls_token.expand(x.shape[0], -1, -1) # stole cls_tokens impl from Phil Wang, thanks
1398
+ x = torch.cat((cls_token, x), dim=1)
1399
+ x = self.pos_drop(x + self.pos_embed)
1400
+
1401
+ x = self.blocks(x)
1402
+ x = self.norm(x)
1403
+ return self.pre_logits(x[:, 0]) # use cls token for cls head
1404
+
1405
+ def forward(self, x):
1406
+ x = self.forward_features(x)
1407
+ x = self.head(x)
1408
+ return x
1409
+
1410
+
1411
+ class Stage_wise_hybrid_Transformer(nn.Module):
1412
+ """
1413
+ MSHT: Multi Stage Backbone Transformer
1414
+ Stem + 4 ResNet stages(Backbone)is used as backbone
1415
+ then, last feature map patch embedding is used to connect the CNN output to the decoder1 input
1416
+
1417
+ horizonally, 4 ResNet Stage has its feature map connecting to the Focus module
1418
+ which we be use as attention guidance into the FGD decoder
1419
+ """
1420
+
1421
+ def __init__(self, backbone, num_classes=1000, patch_size=1, embed_dim=768, depth=4, num_heads=8, mlp_ratio=4.,
1422
+ qkv_bias=True, representation_size=None, drop_rate=0., attn_drop_rate=0., drop_path_rate=0.,
1423
+ use_cls_token=True, use_pos_embedding=True, use_att_module='SimAM', stage_size=(56, 28, 14, 7),
1424
+ stage_dim=(256, 512, 1024, 2048), norm_layer=None, act_layer=None):
1425
+ """
1426
+ Args:
1427
+ backbone (nn.Module): input backbone = stem + 4 ResNet stages
1428
+ num_classes (int): number of classes for classification head
1429
+ patch_size (int, tuple): patch size
1430
+ embed_dim (int): embedding dimension
1431
+ depth (int): depth of transformer
1432
+ num_heads (int): number of attention heads
1433
+ mlp_ratio (int): ratio of mlp hidden dim to embedding dim
1434
+ qkv_bias (bool): enable bias for qkv if True
1435
+ representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set
1436
+ drop_rate (float): dropout rate
1437
+ attn_drop_rate (float): attention dropout rate
1438
+ drop_path_rate (float): stochastic depth rate
1439
+
1440
+ use_cls_token(bool): classification token
1441
+ use_pos_embedding(bool): use positional embedding
1442
+ use_att_module(str or None): use which attention module in embedding
1443
+
1444
+ stage_size (int, tuple): the stage feature map size of ResNet stages
1445
+ stage_dim (int, tuple): the stage feature map dimension of ResNet stages
1446
+ norm_layer: (nn.Module): normalization layer
1447
+ """
1448
+ super().__init__()
1449
+ self.num_classes = num_classes
1450
+ if len(stage_dim) != len(stage_size):
1451
+ raise TypeError('stage_dim and stage_size mismatch!')
1452
+ else:
1453
+ self.stage_num = len(stage_dim)
1454
+
1455
+ self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
1456
+
1457
+ self.cls_token_num = 1 if use_cls_token else 0
1458
+ self.use_pos_embedding = use_pos_embedding
1459
+
1460
+ norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
1461
+ act_layer = act_layer or nn.GELU
1462
+
1463
+ # backbone CNN
1464
+ self.backbone = backbone
1465
+
1466
+ # Attention module
1467
+ if use_att_module is not None:
1468
+ if use_att_module in ['SimAM', 'CBAM', 'SE']:
1469
+ Attention_module = use_att_module
1470
+ else:
1471
+ Attention_module = None
1472
+ else:
1473
+ Attention_module = None
1474
+
1475
+ self.patch_embed = Last_feature_map_Embed(patch_size=patch_size, feature_size=stage_size[-1],
1476
+ feature_dim=stage_dim[-1], embed_dim=self.embed_dim,
1477
+ Attention_module=Attention_module)
1478
+ num_patches = self.patch_embed.num_patches
1479
+
1480
+ # global sharing cls token and positional embedding
1481
+ self.cls_token_0 = nn.Parameter(torch.zeros(1, 1, embed_dim)) # like message token
1482
+ if self.use_pos_embedding:
1483
+ self.pos_embed_0 = nn.Parameter(torch.zeros(1, num_patches + self.cls_token_num, embed_dim))
1484
+
1485
+ '''
1486
+ self.cls_token_1 = nn.Parameter(torch.zeros(1, 1, embed_dim))
1487
+ self.pos_embed_1 = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))
1488
+
1489
+ self.cls_token_2 = nn.Parameter(torch.zeros(1, 1, embed_dim))
1490
+ self.pos_embed_2 = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))
1491
+
1492
+ self.cls_token_3 = nn.Parameter(torch.zeros(1, 1, embed_dim))
1493
+ self.pos_embed_3 = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))
1494
+
1495
+ self.cls_token_4 = nn.Parameter(torch.zeros(1, 1, embed_dim))
1496
+ self.pos_embed_4 = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))
1497
+ '''
1498
+
1499
+ self.pos_drop = nn.Dropout(p=drop_rate)
1500
+ # stochastic depth
1501
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
1502
+
1503
+ self.dec1 = Decoder_Block(dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias,
1504
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[0], norm_layer=norm_layer,
1505
+ act_layer=act_layer)
1506
+ self.Fo1 = Focus_Embed(patch_size=patch_size, target_feature_size=stage_size[-1], feature_size=stage_size[0],
1507
+ feature_dim=stage_dim[0], embed_dim=embed_dim, Attention_module=Attention_module,
1508
+ norm_layer=norm_layer)
1509
+
1510
+ self.dec2 = Decoder_Block(dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias,
1511
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[1], norm_layer=norm_layer,
1512
+ act_layer=act_layer)
1513
+ self.Fo2 = Focus_Embed(patch_size=patch_size, target_feature_size=stage_size[-1], feature_size=stage_size[1],
1514
+ feature_dim=stage_dim[1], embed_dim=embed_dim, Attention_module=Attention_module,
1515
+ norm_layer=norm_layer)
1516
+
1517
+ self.dec3 = Decoder_Block(dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias,
1518
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[2], norm_layer=norm_layer,
1519
+ act_layer=act_layer)
1520
+ self.Fo3 = Focus_Embed(patch_size=patch_size, target_feature_size=stage_size[-1], feature_size=stage_size[2],
1521
+ feature_dim=stage_dim[2], embed_dim=embed_dim, Attention_module=Attention_module,
1522
+ norm_layer=norm_layer)
1523
+
1524
+ if self.stage_num == 4:
1525
+ self.dec4 = Decoder_Block(dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias,
1526
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[3], norm_layer=norm_layer,
1527
+ act_layer=act_layer)
1528
+ self.Fo4 = Focus_Embed(patch_size=patch_size, target_feature_size=stage_size[-1],
1529
+ feature_size=stage_size[-1],
1530
+ feature_dim=stage_dim[-1], embed_dim=embed_dim, Attention_module=Attention_module,
1531
+ norm_layer=norm_layer)
1532
+
1533
+ self.norm = norm_layer(embed_dim)
1534
+
1535
+ # Representation layer
1536
+ if representation_size:
1537
+ self.num_features = representation_size
1538
+ self.pre_logits = nn.Sequential(OrderedDict([
1539
+ ('fc', nn.Linear(embed_dim, representation_size)),
1540
+ ('act', nn.Tanh())
1541
+ ]))
1542
+ else:
1543
+ self.pre_logits = nn.Identity()
1544
+
1545
+ # Classifier head(s)
1546
+ self.head = nn.Linear(self.num_features, self.num_classes) if self.num_classes > 0 else nn.Identity()
1547
+ self.head_dist = None
1548
+
1549
+ def forward_features(self, x):
1550
+ if self.stage_num == 3:
1551
+ stage1_out, stage2_out, stage3_out = self.backbone(x)
1552
+ # embedding the last feature map
1553
+ x = self.patch_embed(stage3_out)
1554
+
1555
+ elif self.stage_num == 4:
1556
+ stage1_out, stage2_out, stage3_out, stage4_out = self.backbone(x)
1557
+ # embedding the last feature map
1558
+ x = self.patch_embed(stage4_out)
1559
+ else:
1560
+ raise TypeError('stage_dim is not legal !')
1561
+
1562
+ # get guidance info
1563
+ s1_q, s1_k = self.Fo1(stage1_out)
1564
+ s2_q, s2_k = self.Fo2(stage2_out)
1565
+ s3_q, s3_k = self.Fo3(stage3_out)
1566
+ if self.stage_num == 4:
1567
+ s4_q, s4_k = self.Fo4(stage4_out)
1568
+
1569
+ if self.cls_token_num != 0: # concat cls token
1570
+ # process the(cls token / message token)
1571
+ cls_token_0 = self.cls_token_0.expand(x.shape[0], -1, -1) # stole cls_tokens impl from Phil Wang, thanks
1572
+ x = torch.cat((cls_token_0, x), dim=1) # 增加classification head patch
1573
+
1574
+ s1_q = torch.cat((cls_token_0, s1_q), dim=1)
1575
+ s1_k = torch.cat((cls_token_0, s1_k), dim=1)
1576
+ s2_q = torch.cat((cls_token_0, s2_q), dim=1)
1577
+ s2_k = torch.cat((cls_token_0, s2_k), dim=1)
1578
+ s3_q = torch.cat((cls_token_0, s3_q), dim=1)
1579
+ s3_k = torch.cat((cls_token_0, s3_k), dim=1)
1580
+ if self.stage_num == 4:
1581
+ s4_q = torch.cat((cls_token_0, s4_q), dim=1)
1582
+ s4_k = torch.cat((cls_token_0, s4_k), dim=1)
1583
+
1584
+ if self.use_pos_embedding:
1585
+
1586
+ s1_q = self.pos_drop(s1_q + self.pos_embed_0)
1587
+ s1_k = self.pos_drop(s1_k + self.pos_embed_0)
1588
+ s2_q = self.pos_drop(s2_q + self.pos_embed_0)
1589
+ s2_k = self.pos_drop(s2_k + self.pos_embed_0)
1590
+ s3_q = self.pos_drop(s3_q + self.pos_embed_0)
1591
+ s3_k = self.pos_drop(s3_k + self.pos_embed_0)
1592
+ if self.stage_num == 4:
1593
+ s4_q = self.pos_drop(s4_q + self.pos_embed_0)
1594
+ s4_k = self.pos_drop(s4_k + self.pos_embed_0)
1595
+
1596
+ # plus to encoding positional infor
1597
+ x = self.pos_drop(x + self.pos_embed_0)
1598
+
1599
+ else:
1600
+
1601
+ s1_q = self.pos_drop(s1_q)
1602
+ s1_k = self.pos_drop(s1_k)
1603
+ s2_q = self.pos_drop(s2_q)
1604
+ s2_k = self.pos_drop(s2_k)
1605
+ s3_q = self.pos_drop(s3_q)
1606
+ s3_k = self.pos_drop(s3_k)
1607
+ if self.stage_num == 4:
1608
+ s4_q = self.pos_drop(s4_q)
1609
+ s4_k = self.pos_drop(s4_k)
1610
+
1611
+ # stem's feature map
1612
+ x = self.pos_drop(x)
1613
+
1614
+ # Decoder module use the guidance to help global modeling process
1615
+
1616
+ x = self.dec1(s1_q, s1_k, x)
1617
+
1618
+ x = self.dec2(s2_q, s2_k, x)
1619
+
1620
+ x = self.dec3(s3_q, s3_k, x)
1621
+
1622
+ if self.stage_num == 4:
1623
+ x = self.dec4(s4_q, s4_k, x)
1624
+
1625
+ x = self.norm(x)
1626
+ return self.pre_logits(x[:, 0]) # take the first cls token
1627
+
1628
+ def forward(self, x):
1629
+ x = self.forward_features(x) # connect the cls token to the cls head
1630
+ x = self.head(x)
1631
+ return x
PuzzleTuning/Backbone/VPT_structure.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ VPT Script ver: Oct 17th 14:30
3
+
4
+ based on
5
+ timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm
6
+ """
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+
11
+ from timm.models.vision_transformer import VisionTransformer, PatchEmbed
12
+
13
+
14
+ class VPT_ViT(VisionTransformer):
15
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
16
+ num_heads=12, mlp_ratio=4., qkv_bias=True, drop_rate=0., attn_drop_rate=0., drop_path_rate=0.,
17
+ embed_layer=PatchEmbed, norm_layer=None, act_layer=None, Prompt_Token_num=1,
18
+ VPT_type="Shallow", basic_state_dict=None):
19
+
20
+ # Recreate ViT
21
+ super().__init__(img_size=img_size, patch_size=patch_size, in_chans=in_chans, num_classes=num_classes,
22
+ embed_dim=embed_dim, depth=depth, num_heads=num_heads, mlp_ratio=mlp_ratio,
23
+ qkv_bias=qkv_bias, drop_rate=drop_rate, attn_drop_rate=attn_drop_rate,
24
+ drop_path_rate=drop_path_rate, embed_layer=embed_layer,
25
+ norm_layer=norm_layer, act_layer=act_layer)
26
+
27
+ # load basic state_dict
28
+ if basic_state_dict is not None:
29
+ self.load_state_dict(basic_state_dict, False)
30
+
31
+ self.VPT_type = VPT_type
32
+ if VPT_type == "Deep":
33
+ self.Prompt_Tokens = nn.Parameter(torch.zeros(depth, Prompt_Token_num, embed_dim))
34
+ else: # "Shallow"
35
+ self.Prompt_Tokens = nn.Parameter(torch.zeros(1, Prompt_Token_num, embed_dim))
36
+
37
+ def New_CLS_head(self, new_classes=15):
38
+ if new_classes != 0:
39
+ self.head = nn.Linear(self.embed_dim, new_classes)
40
+ else:
41
+ self.head = nn.Identity()
42
+
43
+ def Freeze(self):
44
+ for param in self.parameters():
45
+ param.requires_grad = False
46
+
47
+ self.Prompt_Tokens.requires_grad = True
48
+ try:
49
+ for param in self.head.parameters():
50
+ param.requires_grad = True
51
+ except:
52
+ pass
53
+
54
+ def UnFreeze(self):
55
+ for param in self.parameters():
56
+ param.requires_grad = True
57
+
58
+ def obtain_prompt(self):
59
+ prompt_state_dict = {'head': self.head.state_dict(),
60
+ 'Prompt_Tokens': self.Prompt_Tokens}
61
+ # print(prompt_state_dict)
62
+ return prompt_state_dict
63
+
64
+ def load_prompt(self, prompt_state_dict):
65
+ try:
66
+ self.head.load_state_dict(prompt_state_dict['head'], False)
67
+ except:
68
+ print('head not match, so skip head')
69
+ else:
70
+ print('prompt head match')
71
+
72
+ if self.Prompt_Tokens.shape == prompt_state_dict['Prompt_Tokens'].shape:
73
+
74
+ # device check
75
+ Prompt_Tokens = nn.Parameter(prompt_state_dict['Prompt_Tokens'].cpu())
76
+ Prompt_Tokens.to(torch.device(self.Prompt_Tokens.device))
77
+
78
+ self.Prompt_Tokens = Prompt_Tokens
79
+
80
+ else:
81
+ print('\n !!! cannot load prompt')
82
+ print('shape of model req prompt', self.Prompt_Tokens.shape)
83
+ print('shape of model given prompt', prompt_state_dict['Prompt_Tokens'].shape)
84
+ print('')
85
+
86
+ def forward_features(self, x):
87
+ x = self.patch_embed(x)
88
+ # print(x.shape,self.pos_embed.shape)
89
+ cls_token = self.cls_token.expand(x.shape[0], -1, -1)
90
+
91
+ # concatenate CLS token
92
+ x = torch.cat((cls_token, x), dim=1)
93
+ x = self.pos_drop(x + self.pos_embed)
94
+
95
+ if self.VPT_type == "Deep":
96
+
97
+ Prompt_Token_num = self.Prompt_Tokens.shape[1]
98
+
99
+ for i in range(len(self.blocks)):
100
+ # concatenate Prompt_Tokens
101
+ Prompt_Tokens = self.Prompt_Tokens[i].unsqueeze(0)
102
+ # firstly concatenate
103
+ x = torch.cat((x, Prompt_Tokens.expand(x.shape[0], -1, -1)), dim=1)
104
+ num_tokens = x.shape[1]
105
+ # lastly remove, a genius trick
106
+ x = self.blocks[i](x)[:, :num_tokens - Prompt_Token_num]
107
+
108
+ else: # self.VPT_type == "Shallow"
109
+ Prompt_Token_num = self.Prompt_Tokens.shape[1]
110
+
111
+ # concatenate Prompt_Tokens
112
+ Prompt_Tokens = self.Prompt_Tokens.expand(x.shape[0], -1, -1)
113
+ x = torch.cat((x, Prompt_Tokens), dim=1)
114
+ num_tokens = x.shape[1]
115
+ # Sequntially procees
116
+ x = self.blocks(x)[:, :num_tokens - Prompt_Token_num]
117
+
118
+ x = self.norm(x)
119
+ return x
120
+
121
+ def forward(self, x):
122
+
123
+ x = self.forward_features(x)
124
+
125
+ # use cls token for cls head
126
+ try:
127
+ x = self.pre_logits(x[:, 0, :])
128
+ except:
129
+ x = self.fc_norm(x[:, 0, :])
130
+ else:
131
+ pass
132
+ x = self.head(x)
133
+ return x
PuzzleTuning/Backbone/attention_modules.py ADDED
@@ -0,0 +1,303 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ attention modules in ['SimAM', 'CBAM', 'SE', 'GAM'] were applied in the ablation study
3
+
4
+ ver: Dec 24th 15:00
5
+
6
+
7
+ ref:
8
+ https://github.com/xmu-xiaoma666/External-Attention-pytorch
9
+ """
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+ import math
14
+ import torch.nn.functional as F
15
+ from torch.nn import init
16
+
17
+
18
+ # help func
19
+ class BasicConv(nn.Module):
20
+ def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1, groups=1, relu=True,
21
+ bn=True, bias=False):
22
+ super(BasicConv, self).__init__()
23
+ self.out_channels = out_planes
24
+ self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding,
25
+ dilation=dilation, groups=groups, bias=bias)
26
+ self.bn = nn.BatchNorm2d(out_planes, eps=1e-5, momentum=0.01, affine=True) if bn else None
27
+ self.relu = nn.ReLU() if relu else None
28
+
29
+ def forward(self, x):
30
+ x = self.conv(x)
31
+ if self.bn is not None:
32
+ x = self.bn(x)
33
+ if self.relu is not None:
34
+ x = self.relu(x)
35
+ return x
36
+
37
+
38
+ class Flatten(nn.Module):
39
+ def forward(self, x):
40
+ return x.view(x.size(0), -1)
41
+
42
+
43
+ class ChannelGate(nn.Module):
44
+ def __init__(self, gate_channels, reduction_ratio=16, pool_types=['avg', 'max']):
45
+ super(ChannelGate, self).__init__()
46
+ self.gate_channels = gate_channels
47
+ self.mlp = nn.Sequential(
48
+ Flatten(),
49
+ nn.Linear(gate_channels, int(gate_channels // reduction_ratio)),
50
+ nn.ReLU(),
51
+ nn.Linear(int(gate_channels // reduction_ratio), gate_channels)
52
+ )
53
+ self.pool_types = pool_types
54
+
55
+ def forward(self, x):
56
+ channel_att_sum = None
57
+ for pool_type in self.pool_types:
58
+ if pool_type == 'avg':
59
+ avg_pool = F.avg_pool2d(x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))
60
+ channel_att_raw = self.mlp(avg_pool)
61
+ elif pool_type == 'max':
62
+ max_pool = F.max_pool2d(x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))
63
+ channel_att_raw = self.mlp(max_pool)
64
+ elif pool_type == 'lp':
65
+ lp_pool = F.lp_pool2d(x, 2, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))
66
+ channel_att_raw = self.mlp(lp_pool)
67
+ elif pool_type == 'lse':
68
+ # LSE pool only
69
+ lse_pool = logsumexp_2d(x)
70
+ channel_att_raw = self.mlp(lse_pool)
71
+
72
+ if channel_att_sum is None:
73
+ channel_att_sum = channel_att_raw
74
+ else:
75
+ channel_att_sum = channel_att_sum + channel_att_raw
76
+
77
+ scale = F.sigmoid(channel_att_sum).unsqueeze(2).unsqueeze(3).expand_as(x)
78
+ return x * scale
79
+
80
+
81
+ def logsumexp_2d(tensor):
82
+ tensor_flatten = tensor.view(tensor.size(0), tensor.size(1), -1)
83
+ s, _ = torch.max(tensor_flatten, dim=2, keepdim=True)
84
+ outputs = s + (tensor_flatten - s).exp().sum(dim=2, keepdim=True).log()
85
+ return outputs
86
+
87
+
88
+ class ChannelPool(nn.Module):
89
+ def forward(self, x):
90
+ return torch.cat((torch.max(x, 1)[0].unsqueeze(1), torch.mean(x, 1).unsqueeze(1)), dim=1)
91
+
92
+
93
+ class SpatialGate(nn.Module):
94
+ def __init__(self):
95
+ super(SpatialGate, self).__init__()
96
+ kernel_size = 7
97
+ self.compress = ChannelPool()
98
+ self.spatial = BasicConv(2, 1, kernel_size, stride=1, padding=int((kernel_size - 1) // 2), relu=False)
99
+
100
+ def forward(self, x):
101
+ x_compress = self.compress(x)
102
+ x_out = self.spatial(x_compress)
103
+ scale = F.sigmoid(x_out) # broadcasting
104
+ return x * scale
105
+
106
+
107
+ # attention modules:
108
+ class cbam_module(nn.Module):
109
+ """
110
+ module:CBAM
111
+
112
+ input、output= b, c, h, w
113
+
114
+ paper:
115
+ https://arxiv.org/abs/1807.06521
116
+ code:
117
+ https://github.com/ZjjConan/SimAM/blob/master/networks/attentions
118
+ """
119
+
120
+ def __init__(self, gate_channels, reduction=16, pool_types=['avg', 'max'], no_spatial=False):
121
+ super(cbam_module, self).__init__()
122
+ self.ChannelGate = ChannelGate(gate_channels, reduction, pool_types)
123
+ self.no_spatial = no_spatial
124
+ if not no_spatial:
125
+ self.SpatialGate = SpatialGate()
126
+
127
+ @staticmethod
128
+ def get_module_name():
129
+ return "cbam"
130
+
131
+ def forward(self, x):
132
+ x_out = self.ChannelGate(x)
133
+ if not self.no_spatial:
134
+ x_out = self.SpatialGate(x_out)
135
+ return x_out
136
+
137
+
138
+ class se_module(nn.Module):
139
+ """
140
+ module: SE
141
+
142
+ input、output= b, c, h, w
143
+
144
+ from paper Squeeze-and-Excitation Networks
145
+ SE-Net https://arxiv.org/abs/1709.01507
146
+ code:
147
+ https://github.com/ZjjConan/SimAM/blob/master/networks/attentions
148
+ """
149
+
150
+ def __init__(self, channel, reduction=16):
151
+ super(se_module, self).__init__()
152
+ self.avg_pool = nn.AdaptiveAvgPool2d(1)
153
+ self.fc = nn.Sequential(
154
+ nn.Linear(channel, int(channel // reduction), bias=False),
155
+ nn.ReLU(inplace=True),
156
+ nn.Linear(int(channel // reduction), channel, bias=False),
157
+ nn.Sigmoid()
158
+ )
159
+
160
+ @staticmethod
161
+ def get_module_name():
162
+ return "se"
163
+
164
+ def forward(self, x):
165
+ b, c, _, _ = x.size()
166
+ y = self.avg_pool(x).view(b, c)
167
+ y = self.fc(y).view(b, c, 1, 1)
168
+ return x * y
169
+
170
+
171
+ class simam_module(torch.nn.Module):
172
+ """
173
+ module:SimAM
174
+
175
+ input、output= b, c, h, w
176
+
177
+ paper:(ICML)
178
+ SimAM: A Simple, Parameter-Free Attention Module for Convolutional Neural Networks
179
+ code:
180
+ https://github.com/ZjjConan/SimAM/blob/master/networks/attentions/simam_module.py
181
+ """
182
+
183
+ def __init__(self, channels=None, e_lambda=1e-4):
184
+ super(simam_module, self).__init__()
185
+
186
+ self.activaton = nn.Sigmoid()
187
+ self.e_lambda = e_lambda
188
+
189
+ def __repr__(self):
190
+ s = self.__class__.__name__ + '('
191
+ s += ('lambda=%f)' % self.e_lambda)
192
+ return s
193
+
194
+ @staticmethod
195
+ def get_module_name():
196
+ return "simam"
197
+
198
+ def forward(self, x):
199
+ b, c, h, w = x.size()
200
+
201
+ n = w * h - 1
202
+
203
+ x_minus_mu_square = (x - x.mean(dim=[2, 3], keepdim=True)).pow(2)
204
+ y = x_minus_mu_square / (4 * (x_minus_mu_square.sum(dim=[2, 3], keepdim=True) / n + self.e_lambda)) + 0.5
205
+
206
+ return x * self.activaton(y)
207
+
208
+
209
+ class ResidualAttention(nn.Module):
210
+ """
211
+ module: ResidualAttention
212
+
213
+ input、output= b, c, h, w
214
+
215
+ Paper:ICCV 2021 Residual Attention: A Simple but Effective Method for Multi-Label Recognition
216
+ code:https://github.com/xmu-xiaoma666/External-Attention-pytorch/blob/master/attention/ResidualAttention.py
217
+ """
218
+
219
+ def __init__(self, channel=512, num_class=1000, la=0.2):
220
+ super().__init__()
221
+ self.la = la
222
+ self.fc = nn.Conv2d(in_channels=channel, out_channels=num_class, kernel_size=1, stride=1, bias=False)
223
+
224
+ def forward(self, x):
225
+ b, c, h, w = x.shape
226
+ y_raw = self.fc(x).flatten(2) # b,num_class,hxw
227
+ y_avg = torch.mean(y_raw, dim=2) # b,num_class
228
+ y_max = torch.max(y_raw, dim=2)[0] # b,num_class
229
+ score = y_avg + self.la * y_max
230
+ return score
231
+
232
+
233
+ class eca_module(nn.Module):
234
+ """Constructs a ECA module.
235
+
236
+ Args:
237
+ channel: Number of channels of the input feature map
238
+ k_size: Adaptive selection of kernel size
239
+ """
240
+ def __init__(self, channel, k_size=3):
241
+ super(eca_module, self).__init__()
242
+ self.avg_pool = nn.AdaptiveAvgPool2d(1)
243
+ self.conv = nn.Conv1d(1, 1, kernel_size=k_size, padding=(k_size - 1) // 2, bias=False)
244
+ self.sigmoid = nn.Sigmoid()
245
+
246
+ def forward(self, x):
247
+ # x: input features with shape [b, c, h, w]
248
+ b, c, h, w = x.size()
249
+
250
+ # feature descriptor on the global spatial information
251
+ y = self.avg_pool(x)
252
+
253
+ # Two different branches of ECA module
254
+ y = self.conv(y.squeeze(-1).transpose(-1, -2)).transpose(-1, -2).unsqueeze(-1)
255
+
256
+ # Multi-scale information fusion
257
+ y = self.sigmoid(y)
258
+
259
+ return x * y.expand_as(x)
260
+
261
+
262
+ class GAM_Attention(nn.Module):
263
+ """
264
+ module:GAM
265
+
266
+ input= b, in_channels, h, w
267
+ output= b, out_channels, h, w
268
+
269
+ paper:
270
+ Global Attention Mechanism: Retain Information to Enhance Channel-Spatial Interactions
271
+ https://arxiv.org/abs/2112.05561
272
+ code:
273
+ https://mp.weixin.qq.com/s/VL6rXjyUDmHToYTqM32hUg
274
+ """
275
+ def __init__(self, in_channels, out_channels, rate=4):
276
+ super(GAM_Attention, self).__init__()
277
+
278
+ self.channel_attention = nn.Sequential(
279
+ nn.Linear(in_channels, int(in_channels / rate)),
280
+ nn.ReLU(inplace=True),
281
+ nn.Linear(int(in_channels / rate), in_channels)
282
+ )
283
+
284
+ self.spatial_attention = nn.Sequential(
285
+ nn.Conv2d(in_channels, int(in_channels / rate), kernel_size=7, padding=3),
286
+ nn.BatchNorm2d(int(in_channels / rate)),
287
+ nn.ReLU(inplace=True),
288
+ nn.Conv2d(int(in_channels / rate), out_channels, kernel_size=7, padding=3),
289
+ nn.BatchNorm2d(out_channels)
290
+ )
291
+
292
+ def forward(self, x):
293
+ b, c, h, w = x.shape
294
+ x_permute = x.permute(0, 2, 3, 1).view(b, -1, c)
295
+ x_att_permute = self.channel_attention(x_permute).view(b, h, w, c)
296
+ x_channel_att = x_att_permute.permute(0, 3, 1, 2)
297
+
298
+ x = x * x_channel_att
299
+
300
+ x_spatial_att = self.spatial_attention(x).sigmoid()
301
+ out = x * x_spatial_att
302
+
303
+ return out
PuzzleTuning/Backbone/counterpart_models/README.md ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Recent SOTA works in fine-grained Tasks
2
+
3
+
4
+ CrossFormer
5
+
6
+ Paper:
7
+ https://arxiv.org/pdf/2108.00154.pdf
8
+
9
+
10
+ Code from:
11
+ https://github.com/cheerss/CrossFormer
12
+
13
+
14
+
15
+ Conformer
16
+ Paper:
17
+ https://arxiv.org/pdf/2105.03889.pdf
18
+
19
+
20
+ Code from:
21
+ https://github.com/pengzhiliang/Conformer/blob/main/conformer.py
22
+
23
+
24
+ both work will be compared with official pretrained backbone
25
+ and a new MLP head (classification head).
PuzzleTuning/Backbone/counterpart_models/conformer.py ADDED
@@ -0,0 +1,483 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ From Conformer with alter: conv and trans cls head was changed to volting together
3
+ ver: DEC 1st 16:00 official release
4
+
5
+ ref: https://github.com/pengzhiliang/Conformer/blob/main/conformer.py
6
+ """
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+ from functools import partial
12
+
13
+ from timm.models.layers import DropPath, trunc_normal_
14
+
15
+
16
+ class Mlp(nn.Module): # FFN
17
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
18
+ super().__init__()
19
+ out_features = out_features or in_features
20
+ hidden_features = hidden_features or in_features
21
+ self.fc1 = nn.Linear(in_features, hidden_features)
22
+ self.act = act_layer()
23
+ self.fc2 = nn.Linear(hidden_features, out_features)
24
+ self.drop = nn.Dropout(drop)
25
+
26
+ def forward(self, x):
27
+ x = self.fc1(x)
28
+ x = self.act(x)
29
+ x = self.drop(x)
30
+ x = self.fc2(x)
31
+ x = self.drop(x)
32
+ return x
33
+
34
+
35
+ class Attention(nn.Module): # MHSA
36
+ def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
37
+ super().__init__()
38
+ self.num_heads = num_heads
39
+ head_dim = dim // num_heads
40
+ # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
41
+ self.scale = qk_scale or head_dim ** -0.5
42
+
43
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
44
+ self.attn_drop = nn.Dropout(attn_drop)
45
+ self.proj = nn.Linear(dim, dim)
46
+ self.proj_drop = nn.Dropout(proj_drop)
47
+
48
+ def forward(self, x):
49
+ B, N, C = x.shape # N is patch number, C is patch dimension
50
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) # re arrange
51
+ q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
52
+
53
+ attn = (q @ k.transpose(-2, -1)) * self.scale
54
+ attn = attn.softmax(dim=-1)
55
+ attn = self.attn_drop(attn)
56
+
57
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
58
+ x = self.proj(x)
59
+ x = self.proj_drop(x)
60
+ return x
61
+
62
+
63
+ class Block(nn.Module): # Encoder from ViT
64
+
65
+ def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
66
+ drop_path=0., act_layer=nn.GELU, norm_layer=partial(nn.LayerNorm, eps=1e-6)):
67
+ super().__init__()
68
+ # pre norm 1
69
+ self.norm1 = norm_layer(dim)
70
+ # MHSA
71
+ self.attn = Attention(
72
+ dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
73
+
74
+ # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
75
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
76
+
77
+ # pre norm 2
78
+ self.norm2 = norm_layer(dim)
79
+
80
+ # FFN(MLP)
81
+ mlp_hidden_dim = int(dim * mlp_ratio)
82
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
83
+
84
+ def forward(self, x):
85
+ x = x + self.drop_path(self.attn(self.norm1(x)))
86
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
87
+ return x
88
+
89
+
90
+ class ConvBlock(nn.Module): # ResNet bottleneck Convblock actually
91
+
92
+ def __init__(self, inplanes, outplanes, stride=1, res_conv=False, act_layer=nn.ReLU, groups=1,
93
+ norm_layer=partial(nn.BatchNorm2d, eps=1e-6), drop_block=None, drop_path=None):
94
+ super(ConvBlock, self).__init__()
95
+
96
+ expansion = 4
97
+ med_planes = outplanes // expansion
98
+
99
+ self.conv1 = nn.Conv2d(inplanes, med_planes, kernel_size=1, stride=1, padding=0, bias=False)
100
+ self.bn1 = norm_layer(med_planes)
101
+ self.act1 = act_layer(inplace=True)
102
+
103
+ self.conv2 = nn.Conv2d(med_planes, med_planes, kernel_size=3, stride=stride, groups=groups, padding=1,
104
+ bias=False)
105
+ self.bn2 = norm_layer(med_planes)
106
+ self.act2 = act_layer(inplace=True)
107
+
108
+ self.conv3 = nn.Conv2d(med_planes, outplanes, kernel_size=1, stride=1, padding=0, bias=False)
109
+ self.bn3 = norm_layer(outplanes)
110
+ self.act3 = act_layer(inplace=True)
111
+
112
+ if res_conv:
113
+ self.residual_conv = nn.Conv2d(inplanes, outplanes, kernel_size=1, stride=stride, padding=0, bias=False)
114
+ self.residual_bn = norm_layer(outplanes)
115
+
116
+ self.res_conv = res_conv
117
+ self.drop_block = drop_block
118
+ self.drop_path = drop_path
119
+
120
+ def zero_init_last_bn(self):
121
+ nn.init.zeros_(self.bn3.weight)
122
+
123
+ def forward(self, x, x_t=None, return_x_2=True):
124
+ residual = x
125
+
126
+ x = self.conv1(x)
127
+ x = self.bn1(x)
128
+ if self.drop_block is not None:
129
+ x = self.drop_block(x)
130
+ x = self.act1(x)
131
+
132
+ x = self.conv2(x) if x_t is None else self.conv2(x + x_t)
133
+ x = self.bn2(x)
134
+ if self.drop_block is not None:
135
+ x = self.drop_block(x)
136
+ x2 = self.act2(x)
137
+
138
+ x = self.conv3(x2)
139
+ x = self.bn3(x)
140
+ if self.drop_block is not None:
141
+ x = self.drop_block(x)
142
+
143
+ if self.drop_path is not None:
144
+ x = self.drop_path(x)
145
+
146
+ if self.res_conv:
147
+ residual = self.residual_conv(residual)
148
+ residual = self.residual_bn(residual)
149
+
150
+ x += residual
151
+ x = self.act3(x)
152
+
153
+ if return_x_2:
154
+ return x, x2
155
+ else:
156
+ return x
157
+
158
+
159
+ class FCUDown(nn.Module):
160
+ """ CNN feature maps -> Transformer patch embeddings
161
+ """
162
+
163
+ def __init__(self, inplanes, outplanes, dw_stride, act_layer=nn.GELU,
164
+ norm_layer=partial(nn.LayerNorm, eps=1e-6)):
165
+ super(FCUDown, self).__init__()
166
+ self.dw_stride = dw_stride
167
+
168
+ self.conv_project = nn.Conv2d(inplanes, outplanes, kernel_size=1, stride=1, padding=0) # fix dimension
169
+ self.sample_pooling = nn.AvgPool2d(kernel_size=dw_stride, stride=dw_stride) # fix feature map size
170
+
171
+ self.ln = norm_layer(outplanes)
172
+ self.act = act_layer()
173
+
174
+ def forward(self, x, x_t):
175
+ x = self.conv_project(x) # [N, C, H, W]
176
+
177
+ x = self.sample_pooling(x).flatten(2).transpose(1, 2)
178
+ x = self.ln(x)
179
+ x = self.act(x)
180
+
181
+ x = torch.cat([x_t[:, 0][:, None, :], x], dim=1) # concatenate class token from x_t
182
+
183
+ return x
184
+
185
+
186
+ class FCUUp(nn.Module):
187
+ """ Transformer patch embeddings -> CNN feature maps
188
+ by interpolate operation
189
+ """
190
+
191
+ def __init__(self, inplanes, outplanes, up_stride, act_layer=nn.ReLU,
192
+ norm_layer=partial(nn.BatchNorm2d, eps=1e-6), ):
193
+ super(FCUUp, self).__init__()
194
+
195
+ self.up_stride = up_stride
196
+ self.conv_project = nn.Conv2d(inplanes, outplanes, kernel_size=1, stride=1, padding=0)
197
+ self.bn = norm_layer(outplanes)
198
+ self.act = act_layer()
199
+
200
+ def forward(self, x, H, W): # interpolate to
201
+ B, _, C = x.shape
202
+
203
+ # [N, 197, 384] -> [N, 196, 384] -> [N, 384, 196] -> [N, 384, 14, 14]
204
+ x_r = x[:, 1:].transpose(1, 2).reshape(B, C, H, W) # drop cls token of x_t
205
+
206
+ x_r = self.act(self.bn(self.conv_project(x_r)))
207
+
208
+ return F.interpolate(x_r, size=(H * self.up_stride, W * self.up_stride)) # interpolate operation
209
+
210
+
211
+ class Med_ConvBlock(nn.Module): # ResNet bottleneck indentity actually
212
+ """ special case for Convblock without down sampling,
213
+ """
214
+
215
+ def __init__(self, inplanes, act_layer=nn.ReLU, groups=1, norm_layer=partial(nn.BatchNorm2d, eps=1e-6),
216
+ drop_block=None, drop_path=None):
217
+
218
+ super(Med_ConvBlock, self).__init__()
219
+
220
+ expansion = 4
221
+ med_planes = inplanes // expansion
222
+
223
+ self.conv1 = nn.Conv2d(inplanes, med_planes, kernel_size=1, stride=1, padding=0, bias=False)
224
+ self.bn1 = norm_layer(med_planes)
225
+ self.act1 = act_layer(inplace=True)
226
+
227
+ self.conv2 = nn.Conv2d(med_planes, med_planes, kernel_size=3, stride=1, groups=groups, padding=1, bias=False)
228
+ self.bn2 = norm_layer(med_planes)
229
+ self.act2 = act_layer(inplace=True)
230
+
231
+ self.conv3 = nn.Conv2d(med_planes, inplanes, kernel_size=1, stride=1, padding=0, bias=False)
232
+ self.bn3 = norm_layer(inplanes)
233
+ self.act3 = act_layer(inplace=True)
234
+
235
+ self.drop_block = drop_block
236
+ self.drop_path = drop_path
237
+
238
+ def zero_init_last_bn(self):
239
+ nn.init.zeros_(self.bn3.weight)
240
+
241
+ def forward(self, x):
242
+ residual = x
243
+
244
+ x = self.conv1(x)
245
+ x = self.bn1(x)
246
+ if self.drop_block is not None:
247
+ x = self.drop_block(x)
248
+ x = self.act1(x)
249
+
250
+ x = self.conv2(x)
251
+ x = self.bn2(x)
252
+ if self.drop_block is not None:
253
+ x = self.drop_block(x)
254
+ x = self.act2(x)
255
+
256
+ x = self.conv3(x)
257
+ x = self.bn3(x)
258
+ if self.drop_block is not None:
259
+ x = self.drop_block(x)
260
+
261
+ if self.drop_path is not None:
262
+ x = self.drop_path(x)
263
+
264
+ x += residual
265
+ x = self.act3(x)
266
+
267
+ return x
268
+
269
+
270
+ class ConvTransBlock(nn.Module):
271
+ """
272
+ Basic module for ConvTransformer, keep feature maps for CNN block and patch embeddings for transformer encoder block
273
+ """
274
+
275
+ def __init__(self, inplanes, outplanes, res_conv, stride, dw_stride, embed_dim, num_heads=12, mlp_ratio=4.,
276
+ qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0., drop_path_rate=0.,
277
+ last_fusion=False, num_med_block=0, groups=1):
278
+
279
+ super(ConvTransBlock, self).__init__()
280
+ expansion = 4
281
+ # ConvBlock
282
+ self.cnn_block = ConvBlock(inplanes=inplanes, outplanes=outplanes, res_conv=res_conv, stride=stride,
283
+ groups=groups)
284
+
285
+ if last_fusion:
286
+ self.fusion_block = ConvBlock(inplanes=outplanes, outplanes=outplanes, stride=2, res_conv=True,
287
+ groups=groups)
288
+ else:
289
+ self.fusion_block = ConvBlock(inplanes=outplanes, outplanes=outplanes, groups=groups)
290
+
291
+ # identity block
292
+ if num_med_block > 0:
293
+ self.med_block = []
294
+ for i in range(num_med_block):
295
+ self.med_block.append(Med_ConvBlock(inplanes=outplanes, groups=groups))
296
+
297
+ self.med_block = nn.ModuleList(self.med_block) # nn.ModuleList
298
+
299
+ # FCU
300
+ self.squeeze_block = FCUDown(inplanes=outplanes // expansion, outplanes=embed_dim, dw_stride=dw_stride)
301
+
302
+ self.expand_block = FCUUp(inplanes=embed_dim, outplanes=outplanes // expansion, up_stride=dw_stride)
303
+
304
+ # Transformer Encoder block
305
+ self.trans_block = Block(
306
+ dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
307
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=drop_path_rate)
308
+
309
+ self.dw_stride = dw_stride
310
+ self.embed_dim = embed_dim
311
+ self.num_med_block = num_med_block
312
+ self.last_fusion = last_fusion
313
+
314
+ def forward(self, x, x_t):
315
+ x, x2 = self.cnn_block(x)
316
+
317
+ _, _, H, W = x2.shape
318
+
319
+ x_st = self.squeeze_block(x2, x_t)
320
+
321
+ x_t = self.trans_block(x_st + x_t)
322
+
323
+ if self.num_med_block > 0:
324
+ for m in self.med_block:
325
+ x = m(x)
326
+
327
+ x_t_r = self.expand_block(x_t, H // self.dw_stride, W // self.dw_stride)
328
+ x = self.fusion_block(x, x_t_r, return_x_2=False)
329
+
330
+ return x, x_t
331
+
332
+
333
+ class Conformer(nn.Module):
334
+
335
+ def __init__(self, patch_size=16, in_chans=3, num_classes=1000, base_channel=64, channel_ratio=4, num_med_block=0,
336
+ embed_dim=768, depth=12, num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None,
337
+ drop_rate=0., attn_drop_rate=0., drop_path_rate=0.):
338
+
339
+ # Transformer
340
+ super().__init__()
341
+ self.num_classes = num_classes
342
+ self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
343
+ assert depth % 3 == 0
344
+
345
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
346
+ self.trans_dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
347
+
348
+ # Classifier head
349
+ self.trans_norm = nn.LayerNorm(embed_dim)
350
+ self.trans_cls_head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
351
+ self.pooling = nn.AdaptiveAvgPool2d(1)
352
+ self.conv_cls_head = nn.Linear(int(256 * channel_ratio), num_classes)
353
+ self.cls_head = nn.Linear(int(2 * num_classes), num_classes)
354
+
355
+ # Stem stage: get the feature maps by conv block (copied form resnet.py)
356
+ self.conv1 = nn.Conv2d(in_chans, 64, kernel_size=7, stride=2, padding=3, bias=False) # 1 / 2 [112, 112]
357
+ self.bn1 = nn.BatchNorm2d(64)
358
+ self.act1 = nn.ReLU(inplace=True)
359
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) # 1 / 4 [56, 56]
360
+
361
+ # 1 stage
362
+ stage_1_channel = int(base_channel * channel_ratio)
363
+ trans_dw_stride = patch_size // 4
364
+ self.conv_1 = ConvBlock(inplanes=64, outplanes=stage_1_channel, res_conv=True, stride=1)
365
+ # embedding
366
+ self.trans_patch_conv = nn.Conv2d(64, embed_dim, kernel_size=trans_dw_stride, stride=trans_dw_stride, padding=0)
367
+ self.trans_1 = Block(dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias,
368
+ qk_scale=qk_scale, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=self.trans_dpr[0],
369
+ )
370
+
371
+ # 2~4 stage
372
+ init_stage = 2
373
+ fin_stage = depth // 3 + 1
374
+ for i in range(init_stage, fin_stage):
375
+ self.add_module('conv_trans_' + str(i),
376
+ ConvTransBlock(
377
+ stage_1_channel, stage_1_channel, False, 1, dw_stride=trans_dw_stride,
378
+ embed_dim=embed_dim,
379
+ num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
380
+ drop_rate=drop_rate, attn_drop_rate=attn_drop_rate,
381
+ drop_path_rate=self.trans_dpr[i - 1],
382
+ num_med_block=num_med_block
383
+ )
384
+ )
385
+
386
+ stage_2_channel = int(base_channel * channel_ratio * 2)
387
+ # 5~8 stage
388
+ init_stage = fin_stage # 5
389
+ fin_stage = fin_stage + depth // 3 # 9
390
+ for i in range(init_stage, fin_stage):
391
+ s = 2 if i == init_stage else 1
392
+ in_channel = stage_1_channel if i == init_stage else stage_2_channel
393
+ res_conv = True if i == init_stage else False
394
+ self.add_module('conv_trans_' + str(i),
395
+ ConvTransBlock(
396
+ in_channel, stage_2_channel, res_conv, s, dw_stride=trans_dw_stride // 2,
397
+ embed_dim=embed_dim,
398
+ num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
399
+ drop_rate=drop_rate, attn_drop_rate=attn_drop_rate,
400
+ drop_path_rate=self.trans_dpr[i - 1],
401
+ num_med_block=num_med_block
402
+ )
403
+ )
404
+
405
+ stage_3_channel = int(base_channel * channel_ratio * 2 * 2)
406
+ # 9~12 stage
407
+ init_stage = fin_stage # 9
408
+ fin_stage = fin_stage + depth // 3 # 13
409
+ for i in range(init_stage, fin_stage):
410
+ s = 2 if i == init_stage else 1
411
+ in_channel = stage_2_channel if i == init_stage else stage_3_channel
412
+ res_conv = True if i == init_stage else False
413
+ last_fusion = True if i == depth else False
414
+ self.add_module('conv_trans_' + str(i),
415
+ ConvTransBlock(
416
+ in_channel, stage_3_channel, res_conv, s, dw_stride=trans_dw_stride // 4,
417
+ embed_dim=embed_dim,
418
+ num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
419
+ drop_rate=drop_rate, attn_drop_rate=attn_drop_rate,
420
+ drop_path_rate=self.trans_dpr[i - 1],
421
+ num_med_block=num_med_block, last_fusion=last_fusion
422
+ )
423
+ )
424
+ self.fin_stage = fin_stage
425
+
426
+ trunc_normal_(self.cls_token, std=.02)
427
+
428
+ self.apply(self._init_weights)
429
+
430
+ def _init_weights(self, m):
431
+ if isinstance(m, nn.Linear):
432
+ trunc_normal_(m.weight, std=.02)
433
+ if isinstance(m, nn.Linear) and m.bias is not None:
434
+ nn.init.constant_(m.bias, 0)
435
+ elif isinstance(m, nn.LayerNorm):
436
+ nn.init.constant_(m.bias, 0)
437
+ nn.init.constant_(m.weight, 1.0)
438
+ elif isinstance(m, nn.Conv2d):
439
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
440
+ elif isinstance(m, nn.BatchNorm2d):
441
+ nn.init.constant_(m.weight, 1.)
442
+ nn.init.constant_(m.bias, 0.)
443
+ elif isinstance(m, nn.GroupNorm):
444
+ nn.init.constant_(m.weight, 1.)
445
+ nn.init.constant_(m.bias, 0.)
446
+
447
+ @torch.jit.ignore
448
+ def no_weight_decay(self):
449
+ return {'cls_token'}
450
+
451
+ def forward(self, x):
452
+ B = x.shape[0]
453
+ cls_tokens = self.cls_token.expand(B, -1, -1)
454
+
455
+ # pdb.set_trace()
456
+ # stem stage [N, 3, 224, 224] -> [N, 64, 56, 56]
457
+ x_base = self.maxpool(self.act1(self.bn1(self.conv1(x))))
458
+
459
+ # 1 stage
460
+ x = self.conv_1(x_base, return_x_2=False)
461
+ # embedding: [N, 64, 56, 56] -> [N, d, p, p] -> [N, d, p^2] -> [N, p^2, d] -> [N, p^2 + 1, d]
462
+ x_t = self.trans_patch_conv(x_base).flatten(2).transpose(1, 2)
463
+ x_t = torch.cat([cls_tokens, x_t], dim=1)
464
+ x_t = self.trans_1(x_t)
465
+
466
+ # 2 ~ final
467
+ for i in range(2, self.fin_stage):
468
+ x, x_t = eval('self.conv_trans_' + str(i))(x, x_t)
469
+
470
+ # conv classification
471
+ x_p = self.pooling(x).flatten(1)
472
+ conv_cls = self.conv_cls_head(x_p)
473
+
474
+ # trans classification
475
+ x_t = self.trans_norm(x_t)
476
+ tran_cls = self.trans_cls_head(x_t[:, 0])
477
+
478
+ # 加一个类别投票
479
+ cls = torch.cat([conv_cls, tran_cls], dim=1)
480
+ cls = self.cls_head(cls)
481
+ return cls
482
+
483
+ # return [conv_cls, tran_cls]
PuzzleTuning/Backbone/counterpart_models/crossformer.py ADDED
@@ -0,0 +1,624 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.utils.checkpoint as checkpoint
4
+ from timm.models.layers import DropPath, to_2tuple, trunc_normal_
5
+
6
+
7
+ class Mlp(nn.Module):
8
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
9
+ super().__init__()
10
+ out_features = out_features or in_features
11
+ hidden_features = hidden_features or in_features
12
+ self.fc1 = nn.Linear(in_features, hidden_features)
13
+ self.act = act_layer()
14
+ self.fc2 = nn.Linear(hidden_features, out_features)
15
+ self.drop = nn.Dropout(drop)
16
+
17
+ def forward(self, x):
18
+ x = self.fc1(x)
19
+ x = self.act(x)
20
+ x = self.drop(x)
21
+ x = self.fc2(x)
22
+ x = self.drop(x)
23
+ return x
24
+
25
+ class DynamicPosBias(nn.Module):
26
+ def __init__(self, dim, num_heads, residual):
27
+ super().__init__()
28
+ self.residual = residual
29
+ self.num_heads = num_heads
30
+ self.pos_dim = dim // 4
31
+ self.pos_proj = nn.Linear(2, self.pos_dim)
32
+ self.pos1 = nn.Sequential(
33
+ nn.LayerNorm(self.pos_dim),
34
+ nn.ReLU(inplace=True),
35
+ nn.Linear(self.pos_dim, self.pos_dim),
36
+ )
37
+ self.pos2 = nn.Sequential(
38
+ nn.LayerNorm(self.pos_dim),
39
+ nn.ReLU(inplace=True),
40
+ nn.Linear(self.pos_dim, self.pos_dim)
41
+ )
42
+ self.pos3 = nn.Sequential(
43
+ nn.LayerNorm(self.pos_dim),
44
+ nn.ReLU(inplace=True),
45
+ nn.Linear(self.pos_dim, self.num_heads)
46
+ )
47
+ def forward(self, biases):
48
+ if self.residual:
49
+ pos = self.pos_proj(biases) # 2Wh-1 * 2Ww-1, heads
50
+ pos = pos + self.pos1(pos)
51
+ pos = pos + self.pos2(pos)
52
+ pos = self.pos3(pos)
53
+ else:
54
+ pos = self.pos3(self.pos2(self.pos1(self.pos_proj(biases))))
55
+ return pos
56
+
57
+ def flops(self, N):
58
+ flops = N * 2 * self.pos_dim
59
+ flops += N * self.pos_dim * self.pos_dim
60
+ flops += N * self.pos_dim * self.pos_dim
61
+ flops += N * self.pos_dim * self.num_heads
62
+ return flops
63
+
64
+ class Attention(nn.Module):
65
+ r""" Multi-head self attention module with dynamic position bias.
66
+
67
+ Args:
68
+ dim (int): Number of input channels.
69
+ group_size (tuple[int]): The height and width of the group.
70
+ num_heads (int): Number of attention heads.
71
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
72
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
73
+ attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
74
+ proj_drop (float, optional): Dropout ratio of output. Default: 0.0
75
+ """
76
+
77
+ def __init__(self, dim, group_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.,
78
+ position_bias=True):
79
+
80
+ super().__init__()
81
+ self.dim = dim
82
+ self.group_size = group_size # Wh, Ww
83
+ self.num_heads = num_heads
84
+ head_dim = dim // num_heads
85
+ self.scale = qk_scale or head_dim ** -0.5
86
+ self.position_bias = position_bias
87
+
88
+ if position_bias:
89
+ self.pos = DynamicPosBias(self.dim // 4, self.num_heads, residual=False)
90
+
91
+ # generate mother-set
92
+ position_bias_h = torch.arange(1 - self.group_size[0], self.group_size[0])
93
+ position_bias_w = torch.arange(1 - self.group_size[1], self.group_size[1])
94
+ biases = torch.stack(torch.meshgrid([position_bias_h, position_bias_w])) # 2, 2Wh-1, 2W2-1
95
+ biases = biases.flatten(1).transpose(0, 1).float()
96
+ self.register_buffer("biases", biases)
97
+
98
+ # get pair-wise relative position index for each token inside the group
99
+ coords_h = torch.arange(self.group_size[0])
100
+ coords_w = torch.arange(self.group_size[1])
101
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
102
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
103
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
104
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
105
+ relative_coords[:, :, 0] += self.group_size[0] - 1 # shift to start from 0
106
+ relative_coords[:, :, 1] += self.group_size[1] - 1
107
+ relative_coords[:, :, 0] *= 2 * self.group_size[1] - 1
108
+ relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
109
+ self.register_buffer("relative_position_index", relative_position_index)
110
+
111
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
112
+ self.attn_drop = nn.Dropout(attn_drop)
113
+ self.proj = nn.Linear(dim, dim)
114
+ self.proj_drop = nn.Dropout(proj_drop)
115
+
116
+ self.softmax = nn.Softmax(dim=-1)
117
+
118
+ def forward(self, x, mask=None):
119
+ """
120
+ Args:
121
+ x: input features with shape of (num_groups*B, N, C)
122
+ mask: (0/-inf) mask with shape of (num_groups, Wh*Ww, Wh*Ww) or None
123
+ """
124
+ B_, N, C = x.shape
125
+ qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
126
+ q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
127
+
128
+ q = q * self.scale
129
+ attn = (q @ k.transpose(-2, -1))
130
+
131
+ if self.position_bias:
132
+ pos = self.pos(self.biases) # 2Wh-1 * 2Ww-1, heads
133
+ # select position bias
134
+ relative_position_bias = pos[self.relative_position_index.view(-1)].view(
135
+ self.group_size[0] * self.group_size[1], self.group_size[0] * self.group_size[1], -1) # Wh*Ww,Wh*Ww,nH
136
+ relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
137
+ attn = attn + relative_position_bias.unsqueeze(0)
138
+
139
+ if mask is not None:
140
+ nW = mask.shape[0]
141
+ attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
142
+ attn = attn.view(-1, self.num_heads, N, N)
143
+ attn = self.softmax(attn)
144
+ else:
145
+ attn = self.softmax(attn)
146
+
147
+ attn = self.attn_drop(attn)
148
+
149
+ x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
150
+ x = self.proj(x)
151
+ x = self.proj_drop(x)
152
+ return x
153
+
154
+ def extra_repr(self) -> str:
155
+ return f'dim={self.dim}, group_size={self.group_size}, num_heads={self.num_heads}'
156
+
157
+ def flops(self, N):
158
+ # calculate flops for 1 group with token length of N
159
+ flops = 0
160
+ # qkv = self.qkv(x)
161
+ flops += N * self.dim * 3 * self.dim
162
+ # attn = (q @ k.transpose(-2, -1))
163
+ flops += self.num_heads * N * (self.dim // self.num_heads) * N
164
+ # x = (attn @ v)
165
+ flops += self.num_heads * N * N * (self.dim // self.num_heads)
166
+ # x = self.proj(x)
167
+ flops += N * self.dim * self.dim
168
+ if self.position_bias:
169
+ flops += self.pos.flops(N)
170
+ return flops
171
+
172
+
173
+ class CrossFormerBlock(nn.Module):
174
+ r""" CrossFormer Block.
175
+
176
+ Args:
177
+ dim (int): Number of input channels.
178
+ input_resolution (tuple[int]): Input resulotion.
179
+ num_heads (int): Number of attention heads.
180
+ group_size (int): Group size.
181
+ lsda_flag (int): use SDA or LDA, 0 for SDA and 1 for LDA.
182
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
183
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
184
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
185
+ drop (float, optional): Dropout rate. Default: 0.0
186
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
187
+ drop_path (float, optional): Stochastic depth rate. Default: 0.0
188
+ act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
189
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
190
+ """
191
+
192
+ def __init__(self, dim, input_resolution, num_heads, group_size=7, lsda_flag=0,
193
+ mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
194
+ act_layer=nn.GELU, norm_layer=nn.LayerNorm, num_patch_size=1):
195
+ super().__init__()
196
+ self.dim = dim
197
+ self.input_resolution = input_resolution
198
+ self.num_heads = num_heads
199
+ self.group_size = group_size
200
+ self.lsda_flag = lsda_flag
201
+ self.mlp_ratio = mlp_ratio
202
+ self.num_patch_size = num_patch_size
203
+ if min(self.input_resolution) <= self.group_size:
204
+ # if group size is larger than input resolution, we don't partition groups
205
+ self.lsda_flag = 0
206
+ self.group_size = min(self.input_resolution)
207
+
208
+ self.norm1 = norm_layer(dim)
209
+
210
+ self.attn = Attention(
211
+ dim, group_size=to_2tuple(self.group_size), num_heads=num_heads,
212
+ qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop,
213
+ position_bias=True)
214
+
215
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
216
+ self.norm2 = norm_layer(dim)
217
+ mlp_hidden_dim = int(dim * mlp_ratio)
218
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
219
+
220
+ attn_mask = None
221
+ self.register_buffer("attn_mask", attn_mask)
222
+
223
+ def forward(self, x):
224
+ H, W = self.input_resolution
225
+ B, L, C = x.shape
226
+ assert L == H * W, "input feature has wrong size %d, %d, %d" % (L, H, W)
227
+
228
+ shortcut = x
229
+ x = self.norm1(x)
230
+ x = x.view(B, H, W, C)
231
+
232
+ # group embeddings
233
+ G = self.group_size
234
+ if self.lsda_flag == 0: # 0 for SDA
235
+ x = x.reshape(B, H // G, G, W // G, G, C).permute(0, 1, 3, 2, 4, 5)
236
+ else: # 1 for LDA
237
+ x = x.reshape(B, G, H // G, G, W // G, C).permute(0, 2, 4, 1, 3, 5)
238
+ x = x.reshape(B * H * W // G**2, G**2, C)
239
+
240
+ # multi-head self-attention
241
+ x = self.attn(x, mask=self.attn_mask) # nW*B, G*G, C
242
+
243
+ # ungroup embeddings
244
+ x = x.reshape(B, H // G, W // G, G, G, C)
245
+ if self.lsda_flag == 0:
246
+ x = x.permute(0, 1, 3, 2, 4, 5).reshape(B, H, W, C)
247
+ else:
248
+ x = x.permute(0, 3, 1, 4, 2, 5).reshape(B, H, W, C)
249
+ x = x.view(B, H * W, C)
250
+
251
+ # FFN
252
+ x = shortcut + self.drop_path(x)
253
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
254
+
255
+ return x
256
+
257
+ def extra_repr(self) -> str:
258
+ return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \
259
+ f"group_size={self.group_size}, lsda_flag={self.lsda_flag}, mlp_ratio={self.mlp_ratio}"
260
+
261
+ def flops(self):
262
+ flops = 0
263
+ H, W = self.input_resolution
264
+ # norm1
265
+ flops += self.dim * H * W
266
+ # LSDA
267
+ nW = H * W / self.group_size / self.group_size
268
+ flops += nW * self.attn.flops(self.group_size * self.group_size)
269
+ # mlp
270
+ flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio
271
+ # norm2
272
+ flops += self.dim * H * W
273
+ return flops
274
+
275
+ class PatchMerging(nn.Module):
276
+ r""" Patch Merging Layer.
277
+
278
+ Args:
279
+ input_resolution (tuple[int]): Resolution of input feature.
280
+ dim (int): Number of input channels.
281
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
282
+ """
283
+
284
+ def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm, patch_size=[2], num_input_patch_size=1):
285
+ super().__init__()
286
+ self.input_resolution = input_resolution
287
+ self.dim = dim
288
+ self.reductions = nn.ModuleList()
289
+ self.patch_size = patch_size
290
+ self.norm = norm_layer(dim)
291
+
292
+ for i, ps in enumerate(patch_size):
293
+ if i == len(patch_size) - 1:
294
+ out_dim = 2 * dim // 2 ** i
295
+ else:
296
+ out_dim = 2 * dim // 2 ** (i + 1)
297
+ stride = 2
298
+ padding = (ps - stride) // 2
299
+ self.reductions.append(nn.Conv2d(dim, out_dim, kernel_size=ps,
300
+ stride=stride, padding=padding))
301
+
302
+ def forward(self, x):
303
+ """
304
+ x: B, H*W, C
305
+ """
306
+ H, W = self.input_resolution
307
+ B, L, C = x.shape
308
+ assert L == H * W, "input feature has wrong size"
309
+ assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."
310
+
311
+ x = self.norm(x)
312
+ x = x.view(B, H, W, C).permute(0, 3, 1, 2)
313
+
314
+ xs = []
315
+ for i in range(len(self.reductions)):
316
+ tmp_x = self.reductions[i](x).flatten(2).transpose(1, 2)
317
+ xs.append(tmp_x)
318
+ x = torch.cat(xs, dim=2)
319
+ return x
320
+
321
+ def extra_repr(self) -> str:
322
+ return f"input_resolution={self.input_resolution}, dim={self.dim}"
323
+
324
+ def flops(self):
325
+ H, W = self.input_resolution
326
+ flops = H * W * self.dim
327
+ for i, ps in enumerate(self.patch_size):
328
+ if i == len(self.patch_size) - 1:
329
+ out_dim = 2 * self.dim // 2 ** i
330
+ else:
331
+ out_dim = 2 * self.dim // 2 ** (i + 1)
332
+ flops += (H // 2) * (W // 2) * ps * ps * out_dim * self.dim
333
+ return flops
334
+
335
+
336
+ class Stage(nn.Module):
337
+ """ CrossFormer blocks for one stage.
338
+
339
+ Args:
340
+ dim (int): Number of input channels.
341
+ input_resolution (tuple[int]): Input resolution.
342
+ depth (int): Number of blocks.
343
+ num_heads (int): Number of attention heads.
344
+ group_size (int): variable G in the paper, one group has GxG embeddings
345
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
346
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
347
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
348
+ drop (float, optional): Dropout rate. Default: 0.0
349
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
350
+ drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
351
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
352
+ downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
353
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
354
+ """
355
+
356
+ def __init__(self, dim, input_resolution, depth, num_heads, group_size,
357
+ mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
358
+ drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False,
359
+ patch_size_end=[4], num_patch_size=None):
360
+
361
+ super().__init__()
362
+ self.dim = dim
363
+ self.input_resolution = input_resolution
364
+ self.depth = depth
365
+ self.use_checkpoint = use_checkpoint
366
+
367
+ # build blocks
368
+ self.blocks = nn.ModuleList()
369
+ for i in range(depth):
370
+ lsda_flag = 0 if (i % 2 == 0) else 1
371
+ self.blocks.append(CrossFormerBlock(dim=dim, input_resolution=input_resolution,
372
+ num_heads=num_heads, group_size=group_size,
373
+ lsda_flag=lsda_flag,
374
+ mlp_ratio=mlp_ratio,
375
+ qkv_bias=qkv_bias, qk_scale=qk_scale,
376
+ drop=drop, attn_drop=attn_drop,
377
+ drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
378
+ norm_layer=norm_layer,
379
+ num_patch_size=num_patch_size))
380
+
381
+ # patch merging layer
382
+ if downsample is not None:
383
+ self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer,
384
+ patch_size=patch_size_end, num_input_patch_size=num_patch_size)
385
+ else:
386
+ self.downsample = None
387
+
388
+ def forward(self, x):
389
+ for blk in self.blocks:
390
+ if self.use_checkpoint:
391
+ x = checkpoint.checkpoint(blk, x)
392
+ else:
393
+ x = blk(x)
394
+ if self.downsample is not None:
395
+ x = self.downsample(x)
396
+ return x
397
+
398
+ def extra_repr(self) -> str:
399
+ return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"
400
+
401
+ def flops(self):
402
+ flops = 0
403
+ for blk in self.blocks:
404
+ flops += blk.flops()
405
+ if self.downsample is not None:
406
+ flops += self.downsample.flops()
407
+ return flops
408
+
409
+
410
+ class PatchEmbed(nn.Module):
411
+ r""" Image to Patch Embedding
412
+
413
+ Args:
414
+ img_size (int): Image size. Default: 224.
415
+ patch_size (int): Patch token size. Default: [4].
416
+ in_chans (int): Number of input image channels. Default: 3.
417
+ embed_dim (int): Number of linear projection output channels. Default: 96.
418
+ norm_layer (nn.Module, optional): Normalization layer. Default: None
419
+ """
420
+
421
+ def __init__(self, img_size=224, patch_size=[4], in_chans=3, embed_dim=96, norm_layer=None):
422
+ super().__init__()
423
+ img_size = to_2tuple(img_size)
424
+ # patch_size = to_2tuple(patch_size)
425
+ patches_resolution = [img_size[0] // patch_size[0], img_size[0] // patch_size[0]]
426
+ self.img_size = img_size
427
+ self.patch_size = patch_size
428
+ self.patches_resolution = patches_resolution
429
+ self.num_patches = patches_resolution[0] * patches_resolution[1]
430
+
431
+ self.in_chans = in_chans
432
+ self.embed_dim = embed_dim
433
+
434
+ self.projs = nn.ModuleList()
435
+ for i, ps in enumerate(patch_size):
436
+ if i == len(patch_size) - 1:
437
+ dim = embed_dim // 2 ** i
438
+ else:
439
+ dim = embed_dim // 2 ** (i + 1)
440
+ stride = patch_size[0]
441
+ padding = (ps - patch_size[0]) // 2
442
+ self.projs.append(nn.Conv2d(in_chans, dim, kernel_size=ps, stride=stride, padding=padding))
443
+ if norm_layer is not None:
444
+ self.norm = norm_layer(embed_dim)
445
+ else:
446
+ self.norm = None
447
+
448
+ def forward(self, x):
449
+ B, C, H, W = x.shape
450
+ # FIXME look at relaxing size constraints
451
+ assert H == self.img_size[0] and W == self.img_size[1], \
452
+ f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
453
+ xs = []
454
+ for i in range(len(self.projs)):
455
+ tx = self.projs[i](x).flatten(2).transpose(1, 2)
456
+ xs.append(tx) # B Ph*Pw C
457
+ x = torch.cat(xs, dim=2)
458
+ if self.norm is not None:
459
+ x = self.norm(x)
460
+ return x
461
+
462
+ def flops(self):
463
+ Ho, Wo = self.patches_resolution
464
+ flops = 0
465
+ for i, ps in enumerate(self.patch_size):
466
+ if i == len(self.patch_size) - 1:
467
+ dim = self.embed_dim // 2 ** i
468
+ else:
469
+ dim = self.embed_dim // 2 ** (i + 1)
470
+ flops += Ho * Wo * dim * self.in_chans * (self.patch_size[i] * self.patch_size[i])
471
+ if self.norm is not None:
472
+ flops += Ho * Wo * self.embed_dim
473
+ return flops
474
+
475
+
476
+ class CrossFormer(nn.Module):
477
+ r""" CrossFormer
478
+ A PyTorch impl of : `CrossFormer: A Versatile Vision Transformer Based on Cross-scale Attention` -
479
+
480
+ Args:
481
+ img_size (int | tuple(int)): Input image size. Default 224
482
+ patch_size (int | tuple(int)): Patch size. Default: 4
483
+ in_chans (int): Number of input image channels. Default: 3
484
+ num_classes (int): Number of classes for classification head. Default: 1000
485
+ embed_dim (int): Patch embedding dimension. Default: 96
486
+ depths (tuple(int)): Depth of each stage.
487
+ num_heads (tuple(int)): Number of attention heads in different layers.
488
+ group_size (int): Group size. Default: 7
489
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
490
+ qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
491
+ qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None
492
+ drop_rate (float): Dropout rate. Default: 0
493
+ attn_drop_rate (float): Attention dropout rate. Default: 0
494
+ drop_path_rate (float): Stochastic depth rate. Default: 0.1
495
+ norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
496
+ ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
497
+ patch_norm (bool): If True, add normalization after patch embedding. Default: True
498
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
499
+ """
500
+
501
+ def __init__(self, img_size=224, patch_size=[4], in_chans=3, num_classes=1000,
502
+ embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24],
503
+ group_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None,
504
+ drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
505
+ norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
506
+ use_checkpoint=False, merge_size=[[2], [2], [2]], **kwargs):
507
+ super().__init__()
508
+
509
+ self.num_classes = num_classes
510
+ self.num_layers = len(depths)
511
+ self.embed_dim = embed_dim
512
+ self.ape = ape
513
+ self.patch_norm = patch_norm
514
+ self.num_features = int(embed_dim * 2 ** (self.num_layers - 1))
515
+ self.mlp_ratio = mlp_ratio
516
+
517
+ # split image into non-overlapping patches
518
+ self.patch_embed = PatchEmbed(
519
+ img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim,
520
+ norm_layer=norm_layer if self.patch_norm else None)
521
+ num_patches = self.patch_embed.num_patches
522
+ patches_resolution = self.patch_embed.patches_resolution
523
+ self.patches_resolution = patches_resolution
524
+
525
+ # absolute position embedding
526
+ if self.ape:
527
+ self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
528
+ trunc_normal_(self.absolute_pos_embed, std=.02)
529
+
530
+ self.pos_drop = nn.Dropout(p=drop_rate)
531
+
532
+ # stochastic depth
533
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
534
+
535
+ # build layers
536
+ self.layers = nn.ModuleList()
537
+
538
+ num_patch_sizes = [len(patch_size)] + [len(m) for m in merge_size]
539
+ for i_layer in range(self.num_layers):
540
+ patch_size_end = merge_size[i_layer] if i_layer < self.num_layers - 1 else None
541
+ num_patch_size = num_patch_sizes[i_layer]
542
+ layer = Stage(dim=int(embed_dim * 2 ** i_layer),
543
+ input_resolution=(patches_resolution[0] // (2 ** i_layer),
544
+ patches_resolution[1] // (2 ** i_layer)),
545
+ depth=depths[i_layer],
546
+ num_heads=num_heads[i_layer],
547
+ group_size=group_size[i_layer],
548
+ mlp_ratio=self.mlp_ratio,
549
+ qkv_bias=qkv_bias, qk_scale=qk_scale,
550
+ drop=drop_rate, attn_drop=attn_drop_rate,
551
+ drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
552
+ norm_layer=norm_layer,
553
+ downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
554
+ use_checkpoint=use_checkpoint,
555
+ patch_size_end=patch_size_end,
556
+ num_patch_size=num_patch_size)
557
+ self.layers.append(layer)
558
+
559
+ self.norm = norm_layer(self.num_features)
560
+ self.avgpool = nn.AdaptiveAvgPool1d(1)
561
+ self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
562
+
563
+ self.apply(self._init_weights)
564
+
565
+ def _init_weights(self, m):
566
+ if isinstance(m, nn.Linear):
567
+ trunc_normal_(m.weight, std=.02)
568
+ if isinstance(m, nn.Linear) and m.bias is not None:
569
+ nn.init.constant_(m.bias, 0)
570
+ elif isinstance(m, nn.LayerNorm):
571
+ nn.init.constant_(m.bias, 0)
572
+ nn.init.constant_(m.weight, 1.0)
573
+
574
+ @torch.jit.ignore
575
+ def no_weight_decay(self):
576
+ return {'absolute_pos_embed'}
577
+
578
+ @torch.jit.ignore
579
+ def no_weight_decay_keywords(self):
580
+ return {'relative_position_bias_table'}
581
+
582
+ def forward_features(self, x):
583
+ x = self.patch_embed(x)
584
+ if self.ape:
585
+ x = x + self.absolute_pos_embed
586
+ x = self.pos_drop(x)
587
+
588
+ for layer in self.layers:
589
+ x = layer(x)
590
+
591
+ x = self.norm(x) # B L C
592
+ x = self.avgpool(x.transpose(1, 2)) # B C 1
593
+ x = torch.flatten(x, 1)
594
+ return x
595
+
596
+ def forward(self, x):
597
+ x = self.forward_features(x)
598
+ x = self.head(x)
599
+ return x
600
+
601
+ def flops(self):
602
+ flops = 0
603
+ flops += self.patch_embed.flops()
604
+ for i, layer in enumerate(self.layers):
605
+ flops += layer.flops()
606
+ flops += self.num_features * self.patches_resolution[0] * self.patches_resolution[1] // (2 ** self.num_layers)
607
+ flops += self.num_features * self.num_classes
608
+ return flops
609
+
610
+
611
+ class cross_former_cls_head_warp(nn.Module):
612
+ def __init__(self, backbone, num_classes):
613
+ super().__init__()
614
+ embed_dim = 96
615
+ depths = [2, 2, 18, 2]
616
+ num_layers = len(depths)
617
+ num_features = int(embed_dim * 2 ** (num_layers - 1))
618
+ self.backbone = backbone
619
+ self.head = nn.Linear(num_features, num_classes)
620
+
621
+ def forward(self, x):
622
+ x = self.backbone(x)
623
+ x = self.head(x)
624
+ return x
PuzzleTuning/Backbone/counterpart_models/crossformer_backbone.py ADDED
@@ -0,0 +1,659 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ import torch.utils.checkpoint as checkpoint
6
+ from timm.models.layers import DropPath, to_2tuple, trunc_normal_
7
+ from mmdet.utils import get_root_logger
8
+ from mmcv.runner import load_checkpoint
9
+
10
+ NEG_INF = -1000000
11
+
12
+ class Mlp(nn.Module):
13
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
14
+ super().__init__()
15
+ out_features = out_features or in_features
16
+ hidden_features = hidden_features or in_features
17
+ self.fc1 = nn.Linear(in_features, hidden_features)
18
+ self.act = act_layer()
19
+ self.fc2 = nn.Linear(hidden_features, out_features)
20
+ self.drop = nn.Dropout(drop)
21
+
22
+ def forward(self, x):
23
+ x = self.fc1(x)
24
+ x = self.act(x)
25
+ x = self.drop(x)
26
+ x = self.fc2(x)
27
+ x = self.drop(x)
28
+ return x
29
+
30
+
31
+ class DynamicPosBias(nn.Module):
32
+ def __init__(self, dim, num_heads, residual):
33
+ super().__init__()
34
+ self.residual = residual
35
+ self.num_heads = num_heads
36
+ self.pos_dim = dim // 4
37
+ self.pos_proj = nn.Linear(2, self.pos_dim)
38
+ self.pos1 = nn.Sequential(
39
+ nn.LayerNorm(self.pos_dim),
40
+ nn.ReLU(inplace=True),
41
+ nn.Linear(self.pos_dim, self.pos_dim),
42
+ )
43
+ self.pos2 = nn.Sequential(
44
+ nn.LayerNorm(self.pos_dim),
45
+ nn.ReLU(inplace=True),
46
+ nn.Linear(self.pos_dim, self.pos_dim)
47
+ )
48
+ self.pos3 = nn.Sequential(
49
+ nn.LayerNorm(self.pos_dim),
50
+ nn.ReLU(inplace=True),
51
+ nn.Linear(self.pos_dim, self.num_heads)
52
+ )
53
+ def forward(self, biases):
54
+ if self.residual:
55
+ pos = self.pos_proj(biases) # 2Gh-1 * 2Gw-1, heads
56
+ pos = pos + self.pos1(pos)
57
+ pos = pos + self.pos2(pos)
58
+ pos = self.pos3(pos)
59
+ else:
60
+ pos = self.pos3(self.pos2(self.pos1(self.pos_proj(biases))))
61
+ return pos
62
+
63
+ def flops(self, N):
64
+ flops = N * 2 * self.pos_dim
65
+ flops += N * self.pos_dim * self.pos_dim
66
+ flops += N * self.pos_dim * self.pos_dim
67
+ flops += N * self.pos_dim * self.num_heads
68
+ return flops
69
+
70
+ class Attention(nn.Module):
71
+ r""" Multi-head self attention module with relative position bias.
72
+
73
+ Args:
74
+ dim (int): Number of input channels.
75
+ num_heads (int): Number of attention heads.
76
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
77
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
78
+ attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
79
+ proj_drop (float, optional): Dropout ratio of output. Default: 0.0
80
+ """
81
+
82
+ def __init__(self, dim, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.,
83
+ position_bias=True):
84
+
85
+ super().__init__()
86
+ self.dim = dim
87
+ self.num_heads = num_heads
88
+ head_dim = dim // num_heads
89
+ self.scale = qk_scale or head_dim ** -0.5
90
+ self.position_bias = position_bias
91
+ if self.position_bias:
92
+ self.pos = DynamicPosBias(self.dim // 4, self.num_heads, residual=False)
93
+
94
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
95
+ self.attn_drop = nn.Dropout(attn_drop)
96
+ self.proj = nn.Linear(dim, dim)
97
+ self.proj_drop = nn.Dropout(proj_drop)
98
+
99
+ self.softmax = nn.Softmax(dim=-1)
100
+
101
+ def forward(self, x, H, W, mask=None):
102
+ """
103
+ Args:
104
+ x: input features with shape of (num_windows*B, N, C)
105
+ mask: (0/-inf) mask with shape of (num_windows, Gh*Gw, Gh*Gw) or None
106
+ """
107
+ group_size = (H, W)
108
+ B_, N, C = x.shape
109
+ assert H*W == N
110
+ qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4).contiguous()
111
+ q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
112
+
113
+ q = q * self.scale
114
+ attn = (q @ k.transpose(-2, -1)) # (num_windows*B, N, N), N = Gh*Gw
115
+
116
+ if self.position_bias:
117
+ # generate mother-set
118
+ position_bias_h = torch.arange(1 - group_size[0], group_size[0], device=attn.device)
119
+ position_bias_w = torch.arange(1 - group_size[1], group_size[1], device=attn.device)
120
+ biases = torch.stack(torch.meshgrid([position_bias_h, position_bias_w])) # 2, 2Gh-1, 2W2-1
121
+ biases = biases.flatten(1).transpose(0, 1).contiguous().float()
122
+
123
+ # get pair-wise relative position index for each token inside the window
124
+ coords_h = torch.arange(group_size[0], device=attn.device)
125
+ coords_w = torch.arange(group_size[1], device=attn.device)
126
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Gh, Gw
127
+ coords_flatten = torch.flatten(coords, 1) # 2, Gh*Gw
128
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Gh*Gw, Gh*Gw
129
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Gh*Gw, Gh*Gw, 2
130
+ relative_coords[:, :, 0] += group_size[0] - 1 # shift to start from 0
131
+ relative_coords[:, :, 1] += group_size[1] - 1
132
+ relative_coords[:, :, 0] *= 2 * group_size[1] - 1
133
+ relative_position_index = relative_coords.sum(-1) # Gh*Gw, Gh*Gw
134
+
135
+ pos = self.pos(biases) # 2Gh-1 * 2Gw-1, heads
136
+ # select position bias
137
+ relative_position_bias = pos[relative_position_index.view(-1)].view(
138
+ group_size[0] * group_size[1], group_size[0] * group_size[1], -1) # Gh*Gw,Gh*Gw,nH
139
+ relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Gh*Gw, Gh*Gw
140
+ attn = attn + relative_position_bias.unsqueeze(0)
141
+
142
+ if mask is not None:
143
+ nG = mask.shape[0]
144
+ attn = attn.view(B_ // nG, nG, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) # (B, nG, nHead, N, N)
145
+ attn = attn.view(-1, self.num_heads, N, N)
146
+ attn = self.softmax(attn)
147
+ else:
148
+ attn = self.softmax(attn)
149
+
150
+ attn = self.attn_drop(attn)
151
+
152
+ x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
153
+ x = self.proj(x)
154
+ x = self.proj_drop(x)
155
+ return x
156
+
157
+ def extra_repr(self) -> str:
158
+ return f'dim={self.dim}, num_heads={self.num_heads}'
159
+
160
+ def flops(self, N):
161
+ # calculate flops for 1 window with token length of N
162
+ flops = 0
163
+ excluded_flops = 0
164
+ # qkv = self.qkv(x)
165
+ flops += N * self.dim * 3 * self.dim
166
+ # attn = (q @ k.transpose(-2, -1))
167
+ flops += self.num_heads * N * (self.dim // self.num_heads) * N
168
+ excluded_flops += self.num_heads * N * (self.dim // self.num_heads) * N
169
+ # x = (attn @ v)
170
+ flops += self.num_heads * N * N * (self.dim // self.num_heads)
171
+ excluded_flops += self.num_heads * N * N * (self.dim // self.num_heads)
172
+ # x = self.proj(x)
173
+ flops += N * self.dim * self.dim
174
+ if self.position_bias:
175
+ flops += self.pos.flops(N)
176
+ return flops, excluded_flops
177
+
178
+
179
+ class CrossFormerBlock(nn.Module):
180
+ r""" CrossFormer Block.
181
+
182
+ Args:
183
+ dim (int): Number of input channels.
184
+ input_resolution (tuple[int]): Input resulotion.
185
+ num_heads (int): Number of attention heads.
186
+ group_size (int): Window size.
187
+ lsda_flag (int): use SDA or LDA, 0 for SDA and 1 for LDA.
188
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
189
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
190
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
191
+ drop (float, optional): Dropout rate. Default: 0.0
192
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
193
+ drop_path (float, optional): Stochastic depth rate. Default: 0.0
194
+ act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
195
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
196
+ """
197
+
198
+ def __init__(self, dim, input_resolution, num_heads, group_size=7, interval=8, lsda_flag=0,
199
+ mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
200
+ act_layer=nn.GELU, norm_layer=nn.LayerNorm, num_patch_size=1):
201
+ super().__init__()
202
+ self.dim = dim
203
+ self.input_resolution = input_resolution
204
+ self.num_heads = num_heads
205
+ self.group_size = group_size
206
+ self.interval = interval
207
+ self.lsda_flag = lsda_flag
208
+ self.mlp_ratio = mlp_ratio
209
+ self.num_patch_size = num_patch_size
210
+
211
+ self.norm1 = norm_layer(dim)
212
+
213
+ self.attn = Attention(
214
+ dim, num_heads=num_heads,
215
+ qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop,
216
+ position_bias=True)
217
+
218
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
219
+ self.norm2 = norm_layer(dim)
220
+ mlp_hidden_dim = int(dim * mlp_ratio)
221
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
222
+
223
+ def forward(self, x, H, W):
224
+ B, L, C = x.shape
225
+ assert L == H * W, "input feature has wrong size %d, %d, %d" % (L, H, W)
226
+
227
+ if min(H, W) <= self.group_size:
228
+ # if window size is larger than input resolution, we don't partition windows
229
+ self.lsda_flag = 0
230
+ self.group_size = min(H, W)
231
+
232
+ shortcut = x
233
+ x = self.norm1(x)
234
+ x = x.view(B, H, W, C)
235
+
236
+ # padding
237
+ size_div = self.interval if self.lsda_flag == 1 else self.group_size
238
+ pad_l = pad_t = 0
239
+ pad_r = (size_div - W % size_div) % size_div
240
+ pad_b = (size_div - H % size_div) % size_div
241
+ x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
242
+ _, Hp, Wp, _ = x.shape
243
+
244
+ mask = torch.zeros((1, Hp, Wp, 1), device=x.device)
245
+ if pad_b > 0:
246
+ mask[:, -pad_b:, :, :] = -1
247
+ if pad_r > 0:
248
+ mask[:, :, -pad_r:, :] = -1
249
+
250
+ # group embeddings and generate attn_mask
251
+ if self.lsda_flag == 0: # SDA
252
+ G = Gh = Gw = self.group_size
253
+ x = x.reshape(B, Hp // G, G, Wp // G, G, C).permute(0, 1, 3, 2, 4, 5).contiguous()
254
+ x = x.reshape(B * Hp * Wp // G**2, G**2, C)
255
+ nG = Hp * Wp // G**2
256
+ # attn_mask
257
+ if pad_r > 0 or pad_b > 0:
258
+ mask = mask.reshape(1, Hp // G, G, Wp // G, G, 1).permute(0, 1, 3, 2, 4, 5).contiguous()
259
+ mask = mask.reshape(nG, 1, G * G)
260
+ attn_mask = torch.zeros((nG, G * G, G * G), device=x.device)
261
+ attn_mask = attn_mask.masked_fill(mask < 0, NEG_INF)
262
+ else:
263
+ attn_mask = None
264
+ else: # LDA
265
+ I, Gh, Gw = self.interval, Hp // self.interval, Wp // self.interval
266
+ x = x.reshape(B, Gh, I, Gw, I, C).permute(0, 2, 4, 1, 3, 5).contiguous()
267
+ x = x.reshape(B * I * I, Gh * Gw, C)
268
+ nG = I ** 2
269
+ # attn_mask
270
+ if pad_r > 0 or pad_b > 0:
271
+ mask = mask.reshape(1, Gh, I, Gw, I, 1).permute(0, 2, 4, 1, 3, 5).contiguous()
272
+ mask = mask.reshape(nG, 1, Gh * Gw)
273
+ attn_mask = torch.zeros((nG, Gh * Gw, Gh * Gw), device=x.device)
274
+ attn_mask = attn_mask.masked_fill(mask < 0, NEG_INF)
275
+ else:
276
+ attn_mask = None
277
+
278
+ # multi-head self-attention
279
+ x = self.attn(x, Gh, Gw, mask=attn_mask) # nG*B, G*G, C
280
+
281
+ # ungroup embeddings
282
+ if self.lsda_flag == 0:
283
+ x = x.reshape(B, Hp // G, Wp // G, G, G, C).permute(0, 1, 3, 2, 4, 5).contiguous() # B, Hp//G, G, Wp//G, G, C
284
+ else:
285
+ x = x.reshape(B, I, I, Gh, Gw, C).permute(0, 3, 1, 4, 2, 5).contiguous() # B, Gh, I, Gw, I, C
286
+ x = x.reshape(B, Hp, Wp, C)
287
+
288
+ # remove padding
289
+ if pad_r > 0 or pad_b > 0:
290
+ x = x[:, :H, :W, :].contiguous()
291
+ x = x.view(B, H * W, C)
292
+
293
+ # FFN
294
+ x = shortcut + self.drop_path(x)
295
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
296
+
297
+ return x
298
+
299
+ def extra_repr(self) -> str:
300
+ return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \
301
+ f"group_size={self.group_size}, lsda_flag={self.lsda_flag}, mlp_ratio={self.mlp_ratio}"
302
+
303
+ def flops(self):
304
+ flops = 0
305
+ H, W = self.input_resolution
306
+ # norm1
307
+ flops += self.dim * H * W
308
+ # Attention
309
+ size_div = self.interval if self.lsda_flag == 1 else self.group_size
310
+ Hp = math.ceil(H / size_div) * size_div
311
+ Wp = math.ceil(W / size_div) * size_div
312
+ Gh = Hp / size_div if self.lsda_flag == 1 else self.group_size
313
+ Gw = Wp / size_div if self.lsda_flag == 1 else self.group_size
314
+ nG = Hp * Wp / Gh / Gw
315
+ attn_flops, attn_excluded_flops = self.attn.flops(Gh * Gw)
316
+ flops += nG * attn_flops
317
+ excluded_flops = nG * attn_excluded_flops
318
+ # mlp
319
+ flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio
320
+ # norm2
321
+ flops += self.dim * H * W
322
+ return flops, excluded_flops
323
+
324
+ class PatchMerging(nn.Module):
325
+ r""" Patch Merging Layer.
326
+
327
+ Args:
328
+ input_resolution (tuple[int]): Resolution of input feature.
329
+ dim (int): Number of input channels.
330
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
331
+ """
332
+
333
+ def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm, patch_size=[2], num_input_patch_size=1):
334
+ super().__init__()
335
+ self.input_resolution = input_resolution
336
+ self.dim = dim
337
+ self.reductions = nn.ModuleList()
338
+ self.patch_size = patch_size
339
+ self.norm = norm_layer(dim)
340
+
341
+ for i, ps in enumerate(patch_size):
342
+ if i == len(patch_size) - 1:
343
+ out_dim = 2 * dim // 2 ** i
344
+ else:
345
+ out_dim = 2 * dim // 2 ** (i + 1)
346
+ stride = 2
347
+ padding = (ps - stride) // 2
348
+ self.reductions.append(nn.Conv2d(dim, out_dim, kernel_size=ps,
349
+ stride=stride, padding=padding))
350
+
351
+ def forward(self, x, H, W):
352
+ """
353
+ x: B, H*W, C
354
+ """
355
+ B, L, C = x.shape
356
+ assert L == H * W, "input feature has wrong size"
357
+ assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."
358
+
359
+ x = self.norm(x)
360
+ x = x.view(B, H, W, C).permute(0, 3, 1, 2).contiguous()
361
+
362
+ xs = []
363
+ for i in range(len(self.reductions)):
364
+ tmp_x = self.reductions[i](x).flatten(2).transpose(1, 2).contiguous()
365
+ xs.append(tmp_x)
366
+ x = torch.cat(xs, dim=2)
367
+ return x
368
+
369
+ def extra_repr(self) -> str:
370
+ return f"input_resolution={self.input_resolution}, dim={self.dim}"
371
+
372
+ def flops(self):
373
+ H, W = self.input_resolution
374
+ flops = H * W * self.dim
375
+ for i, ps in enumerate(self.patch_size):
376
+ if i == len(self.patch_size) - 1:
377
+ out_dim = 2 * self.dim // 2 ** i
378
+ else:
379
+ out_dim = 2 * self.dim // 2 ** (i + 1)
380
+ flops += (H // 2) * (W // 2) * ps * ps * out_dim * self.dim
381
+ return flops
382
+
383
+
384
+ class Stage(nn.Module):
385
+ """ CrossFormer blocks for one stage.
386
+
387
+ Args:
388
+ dim (int): Number of input channels.
389
+ input_resolution (tuple[int]): Input resolution.
390
+ depth (int): Number of blocks.
391
+ num_heads (int): Number of attention heads.
392
+ group_size (int): Group size.
393
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
394
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
395
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
396
+ drop (float, optional): Dropout rate. Default: 0.0
397
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
398
+ drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
399
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
400
+ downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
401
+ use_checkpoint (bool): Ghether to use checkpointing to save memory. Default: False.
402
+ """
403
+
404
+ def __init__(self, dim, input_resolution, depth, num_heads, group_size, interval,
405
+ mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
406
+ drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False,
407
+ patch_size_end=[4], num_patch_size=None):
408
+
409
+ super().__init__()
410
+ self.dim = dim
411
+ self.depth = depth
412
+ self.use_checkpoint = use_checkpoint
413
+
414
+ # build blocks
415
+ self.blocks = nn.ModuleList()
416
+ for i in range(depth):
417
+ lsda_flag = 0 if (i % 2 == 0) else 1
418
+ self.blocks.append(CrossFormerBlock(dim=dim, input_resolution=input_resolution,
419
+ num_heads=num_heads, group_size=group_size, interval=interval,
420
+ lsda_flag=lsda_flag,
421
+ mlp_ratio=mlp_ratio,
422
+ qkv_bias=qkv_bias, qk_scale=qk_scale,
423
+ drop=drop, attn_drop=attn_drop,
424
+ drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
425
+ norm_layer=norm_layer,
426
+ num_patch_size=num_patch_size))
427
+
428
+ # patch merging layer
429
+ if downsample is not None:
430
+ self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer,
431
+ patch_size=patch_size_end, num_input_patch_size=num_patch_size)
432
+ else:
433
+ self.downsample = None
434
+
435
+ def forward(self, x, H, W):
436
+ for blk in self.blocks:
437
+ if self.use_checkpoint:
438
+ x = checkpoint.checkpoint(blk, x)
439
+ else:
440
+ x = blk(x, H, W)
441
+
442
+ B, _, C = x.shape
443
+ feat = x.view(B, H, W, C).permute(0, 3, 1, 2).contiguous()
444
+ if self.downsample is not None:
445
+ x = self.downsample(x, H, W)
446
+ return feat, x
447
+
448
+ def extra_repr(self) -> str:
449
+ return f"dim={self.dim}, depth={self.depth}"
450
+
451
+ def flops(self):
452
+ flops = 0
453
+ excluded_flops = 0
454
+ for blk in self.blocks:
455
+ blk_flops, blk_excluded_flops = blk.flops()
456
+ flops += blk_flops
457
+ excluded_flops += blk_excluded_flops
458
+ if self.downsample is not None:
459
+ flops += self.downsample.flops()
460
+ return flops, excluded_flops
461
+
462
+
463
+ class PatchEmbed(nn.Module):
464
+ r""" Image to Patch Embedding
465
+
466
+ Args:
467
+ img_size (int): Image size. Default: 224.
468
+ patch_size (int): Patch token size. Default: 4.
469
+ in_chans (int): Number of input image channels. Default: 3.
470
+ embed_dim (int): Number of linear projection output channels. Default: 96.
471
+ norm_layer (nn.Module, optional): Normalization layer. Default: None
472
+ """
473
+
474
+ def __init__(self, img_size=224, patch_size=[4], in_chans=3, embed_dim=96, norm_layer=None):
475
+ super().__init__()
476
+ img_size = to_2tuple(img_size)
477
+ # patch_size = to_2tuple(patch_size)
478
+ patches_resolution = [img_size[0] // 4, img_size[1] // 4] # only for flops calculation
479
+ self.img_size = img_size
480
+ self.patch_size = patch_size
481
+ self.patches_resolution = patches_resolution
482
+
483
+ self.in_chans = in_chans
484
+ self.embed_dim = embed_dim
485
+
486
+ self.projs = nn.ModuleList()
487
+ for i, ps in enumerate(patch_size):
488
+ if i == len(patch_size) - 1:
489
+ dim = embed_dim // 2 ** i
490
+ else:
491
+ dim = embed_dim // 2 ** (i + 1)
492
+ stride = 4
493
+ padding = (ps - 4) // 2
494
+ self.projs.append(nn.Conv2d(in_chans, dim, kernel_size=ps, stride=stride, padding=padding))
495
+ if norm_layer is not None:
496
+ self.norm = norm_layer(embed_dim)
497
+ else:
498
+ self.norm = None
499
+
500
+ def forward(self, x):
501
+ B, C, H, W = x.shape
502
+ xs = []
503
+ for i in range(len(self.projs)):
504
+ tx = self.projs[i](x).flatten(2).transpose(1, 2)
505
+ xs.append(tx) # B Ph*Pw C
506
+ x = torch.cat(xs, dim=2)
507
+ if self.norm is not None:
508
+ x = self.norm(x)
509
+ return x, H, W
510
+
511
+ def flops(self):
512
+ Ho, Wo = self.patches_resolution
513
+ flops = 0
514
+ for i, ps in enumerate(self.patch_size):
515
+ if i == len(self.patch_size) - 1:
516
+ dim = self.embed_dim // 2 ** i
517
+ else:
518
+ dim = self.embed_dim // 2 ** (i + 1)
519
+ flops += Ho * Wo * dim * self.in_chans * (self.patch_size[i] * self.patch_size[i])
520
+ if self.norm is not None:
521
+ flops += Ho * Wo * self.embed_dim
522
+ return flops
523
+
524
+
525
+ class CrossFormer(nn.Module):
526
+ r""" CrossFormer
527
+ A PyTorch impl of : `CrossFormer: A Versatile Vision Transformer Based on Cross-scale Attention` -
528
+
529
+ Args:
530
+ img_size (int | tuple(int)): Input image size. Default 224
531
+ patch_size (int | tuple(int)): Patch size. Default: 4
532
+ in_chans (int): Number of input image channels. Default: 3
533
+ num_classes (int): Number of classes for classification head. Default: 1000
534
+ embed_dim (int): Patch embedding dimension. Default: 96
535
+ depths (tuple(int)): Depth of each stage.
536
+ num_heads (tuple(int)): Number of attention heads in different layers.
537
+ group_size (int): Group size. Default: 7
538
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
539
+ qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
540
+ qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None
541
+ drop_rate (float): Dropout rate. Default: 0
542
+ attn_drop_rate (float): Attention dropout rate. Default: 0
543
+ drop_path_rate (float): Stochastic depth rate. Default: 0.1
544
+ norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
545
+ ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
546
+ patch_norm (bool): If True, add normalization after patch embedding. Default: True
547
+ use_checkpoint (bool): Ghether to use checkpointing to save memory. Default: False
548
+ """
549
+
550
+ def __init__(self, img_size=224, patch_size=[4], in_chans=3, num_classes=1000,
551
+ embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24],
552
+ group_size=7, crs_interval=[8, 4, 2, 1], mlp_ratio=4., qkv_bias=True, qk_scale=None,
553
+ drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
554
+ norm_layer=nn.LayerNorm, patch_norm=True,
555
+ use_checkpoint=False, merge_size=[[2], [2], [2]], **kwargs):
556
+ super().__init__()
557
+
558
+ self.num_classes = num_classes
559
+ self.num_layers = len(depths)
560
+ self.embed_dim = embed_dim
561
+ self.patch_norm = patch_norm
562
+ self.num_features = int(embed_dim * 2 ** (self.num_layers - 1))
563
+ self.mlp_ratio = mlp_ratio
564
+
565
+ # split image into non-overlapping patches
566
+ self.patch_embed = PatchEmbed(
567
+ img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim,
568
+ norm_layer=norm_layer if self.patch_norm else None)
569
+ patches_resolution = self.patch_embed.patches_resolution
570
+ self.patches_resolution = patches_resolution # [H//4, W//4] of original image size
571
+
572
+ self.pos_drop = nn.Dropout(p=drop_rate)
573
+
574
+ # stochastic depth
575
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
576
+
577
+ # build layers
578
+ self.layers = nn.ModuleList()
579
+
580
+ num_patch_sizes = [len(patch_size)] + [len(m) for m in merge_size]
581
+ for i_layer in range(self.num_layers):
582
+ patch_size_end = merge_size[i_layer] if i_layer < self.num_layers - 1 else None
583
+ num_patch_size = num_patch_sizes[i_layer]
584
+ layer = Stage(dim=int(embed_dim * 2 ** i_layer),
585
+ input_resolution=(patches_resolution[0] // (2 ** i_layer),
586
+ patches_resolution[1] // (2 ** i_layer)),
587
+ depth=depths[i_layer],
588
+ num_heads=num_heads[i_layer],
589
+ group_size=group_size[i_layer],
590
+ interval=crs_interval[i_layer],
591
+ mlp_ratio=self.mlp_ratio,
592
+ qkv_bias=qkv_bias, qk_scale=qk_scale,
593
+ drop=drop_rate, attn_drop=attn_drop_rate,
594
+ drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
595
+ norm_layer=norm_layer,
596
+ downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
597
+ use_checkpoint=use_checkpoint,
598
+ patch_size_end=patch_size_end,
599
+ num_patch_size=num_patch_size)
600
+ self.layers.append(layer)
601
+
602
+ # # classification
603
+ # self.norm = norm_layer(self.num_features)
604
+ # self.avgpool = nn.AdaptiveAvgPool1d(1)
605
+ # self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
606
+
607
+ self.apply(self._init_weights)
608
+
609
+ def init_weights(self, pretrained=None):
610
+ if isinstance(pretrained, str):
611
+ logger = get_root_logger()
612
+ load_checkpoint(self, pretrained, map_location='cpu', strict=False, logger=logger)
613
+
614
+ def _init_weights(self, m):
615
+ if isinstance(m, nn.Linear):
616
+ trunc_normal_(m.weight, std=.02)
617
+ if isinstance(m, nn.Linear) and m.bias is not None:
618
+ nn.init.constant_(m.bias, 0)
619
+ elif isinstance(m, nn.LayerNorm):
620
+ nn.init.constant_(m.bias, 0)
621
+ nn.init.constant_(m.weight, 1.0)
622
+
623
+ @torch.jit.ignore
624
+ def no_weight_decay(self):
625
+ return {'absolute_pos_embed'}
626
+
627
+ @torch.jit.ignore
628
+ def no_weight_decay_keywords(self):
629
+ return {'relative_position_bias_table'}
630
+
631
+ def forward(self, x):
632
+ x, H, W = self.patch_embed(x)
633
+ x = self.pos_drop(x)
634
+
635
+ outs = []
636
+ for i, layer in enumerate(self.layers):
637
+ feat, x = layer(x, H //4 //(2 ** i), W //4 //(2 ** i))
638
+ outs.append(feat)
639
+
640
+ # # classification
641
+ # x = self.norm(x) # B L C
642
+ # x = self.avgpool(x.transpose(1, 2)) # B C 1
643
+ # x = torch.flatten(x, 1)
644
+ # x = self.head(x)
645
+ # return x
646
+
647
+ return outs
648
+
649
+ def flops(self):
650
+ flops = 0
651
+ excluded_flops = 0
652
+ flops += self.patch_embed.flops()
653
+ for i, layer in enumerate(self.layers):
654
+ layer_flops, layer_excluded_flops = layer.flops()
655
+ flops += layer_flops
656
+ excluded_flops += layer_excluded_flops
657
+ # flops += self.num_features * self.patches_resolution[0] * self.patches_resolution[1] // (2 ** self.num_layers)
658
+ # flops += self.num_features * self.num_classes
659
+ return flops, excluded_flops
PuzzleTuning/Backbone/getmodel.py ADDED
@@ -0,0 +1,392 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ get model func Script ver: Dec 5th 14:20
3
+ """
4
+ import os
5
+ import sys
6
+ sys.path.append(os.path.realpath('.'))
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ from torchvision import models
11
+ from Backbone import ResHybrid
12
+
13
+
14
+ # get model
15
+ def get_model(num_classes=1000, edge_size=224, model_idx=None, drop_rate=0.0, attn_drop_rate=0.0, drop_path_rate=0.0,
16
+ pretrained_backbone=True, use_cls_token=True, use_pos_embedding=True, use_att_module='SimAM'):
17
+ """
18
+ :param num_classes: classification required number of your dataset
19
+ :param edge_size: the input edge size of the dataloder
20
+ :param model_idx: the model we are going to use. by the format of Model_size_other_info
21
+
22
+ :param drop_rate: The dropout layer's probility of proposed models
23
+ :param attn_drop_rate: The dropout layer(right after the MHSA block or MHGA block)'s probility of proposed models
24
+ :param drop_path_rate: The probility of stochastic depth
25
+
26
+ :param pretrained_backbone: The backbone CNN is initiate randomly or by its official Pretrained models
27
+
28
+ :param use_cls_token: To use the class token
29
+ :param use_pos_embedding: To use the positional enbedding
30
+ :param use_att_module: To use which attention module in the FGD Focus block
31
+
32
+ :return: prepared model
33
+ """
34
+ if model_idx[0:5] == 'ViT_h':
35
+ # Transfer learning for ViT
36
+ import timm
37
+ from pprint import pprint
38
+ model_names = timm.list_models('*vit*')
39
+ pprint(model_names)
40
+ if edge_size == 224:
41
+ model = timm.create_model('vit_huge_patch14_224_in21k', pretrained=pretrained_backbone, num_classes=num_classes)
42
+ else:
43
+ print('not a avaliable image size with', model_idx)
44
+
45
+ elif model_idx[0:5] == 'ViT_l':
46
+ # Transfer learning for ViT
47
+ import timm
48
+ from pprint import pprint
49
+ model_names = timm.list_models('*vit*')
50
+ pprint(model_names)
51
+ if edge_size == 224:
52
+ model = timm.create_model('vit_large_patch16_224', pretrained=pretrained_backbone, num_classes=num_classes)
53
+ elif edge_size == 384:
54
+ model = timm.create_model('vit_large_patch16_384', pretrained=pretrained_backbone, num_classes=num_classes)
55
+ else:
56
+ print('not a avaliable image size with', model_idx)
57
+
58
+ elif model_idx[0:5] == 'ViT_s':
59
+ # Transfer learning for ViT
60
+ import timm
61
+ from pprint import pprint
62
+ model_names = timm.list_models('*vit*')
63
+ pprint(model_names)
64
+ if edge_size == 224:
65
+ model = timm.create_model('vit_small_patch16_224', pretrained=pretrained_backbone, num_classes=num_classes)
66
+ elif edge_size == 384:
67
+ model = timm.create_model('vit_small_patch16_384', pretrained=pretrained_backbone, num_classes=num_classes)
68
+ else:
69
+ print('not a avaliable image size with', model_idx)
70
+
71
+ elif model_idx[0:5] == 'ViT_t':
72
+ # Transfer learning for ViT
73
+ import timm
74
+ from pprint import pprint
75
+ model_names = timm.list_models('*vit*')
76
+ pprint(model_names)
77
+ if edge_size == 224:
78
+ model = timm.create_model('vit_tiny_patch16_224', pretrained=pretrained_backbone, num_classes=num_classes)
79
+ elif edge_size == 384:
80
+ model = timm.create_model('vit_tiny_patch16_384', pretrained=pretrained_backbone, num_classes=num_classes)
81
+ else:
82
+ print('not a avaliable image size with', model_idx)
83
+
84
+ elif model_idx[0:5] == 'ViT_b' or model_idx[0:3] == 'ViT': # vit_base
85
+ # Transfer learning for ViT
86
+ import timm
87
+ from pprint import pprint
88
+ model_names = timm.list_models('*vit*')
89
+ pprint(model_names)
90
+ if edge_size == 224:
91
+ model = timm.create_model('vit_base_patch16_224', pretrained=pretrained_backbone, num_classes=num_classes)
92
+ elif edge_size == 384:
93
+ model = timm.create_model('vit_base_patch16_384', pretrained=pretrained_backbone, num_classes=num_classes)
94
+ else:
95
+ print('not a avaliable image size with', model_idx)
96
+
97
+ elif model_idx[0:3] == 'vgg':
98
+ # Transfer learning for vgg16_bn
99
+ import timm
100
+ from pprint import pprint
101
+ model_names = timm.list_models('*vgg*')
102
+ pprint(model_names)
103
+ if model_idx[0:8] == 'vgg16_bn':
104
+ model = timm.create_model('vgg16_bn', pretrained=pretrained_backbone, num_classes=num_classes)
105
+ elif model_idx[0:5] == 'vgg16':
106
+ model = timm.create_model('vgg16', pretrained=pretrained_backbone, num_classes=num_classes)
107
+ elif model_idx[0:8] == 'vgg19_bn':
108
+ model = timm.create_model('vgg19_bn', pretrained=pretrained_backbone, num_classes=num_classes)
109
+ elif model_idx[0:5] == 'vgg19':
110
+ model = timm.create_model('vgg19', pretrained=pretrained_backbone, num_classes=num_classes)
111
+
112
+ elif model_idx[0:4] == 'deit': # Transfer learning for DeiT
113
+ import timm
114
+ from pprint import pprint
115
+ model_names = timm.list_models('*deit*')
116
+ pprint(model_names)
117
+ if edge_size == 384:
118
+ model = timm.create_model('deit_base_patch16_384', pretrained=pretrained_backbone, num_classes=2)
119
+ elif edge_size == 224:
120
+ model = timm.create_model('deit_base_patch16_224', pretrained=pretrained_backbone, num_classes=2)
121
+ else:
122
+ pass
123
+
124
+ elif model_idx[0:5] == 'twins': # Transfer learning for twins
125
+ import timm
126
+ from pprint import pprint
127
+
128
+ model_names = timm.list_models('*twins*')
129
+ pprint(model_names)
130
+ model = timm.create_model('twins_pcpvt_base', pretrained=pretrained_backbone, num_classes=num_classes)
131
+
132
+ elif model_idx[0:5] == 'pit_b' and edge_size == 224: # Transfer learning for PiT
133
+ import timm
134
+ from pprint import pprint
135
+
136
+ model_names = timm.list_models('*pit*')
137
+ pprint(model_names)
138
+ model = timm.create_model('pit_b_224', pretrained=pretrained_backbone, num_classes=num_classes)
139
+
140
+ elif model_idx[0:5] == 'gcvit' and edge_size == 224: # Transfer learning for gcvit
141
+ import timm
142
+ from pprint import pprint
143
+
144
+ model_names = timm.list_models('*gcvit*')
145
+ pprint(model_names)
146
+ model = timm.create_model('gcvit_base', pretrained=pretrained_backbone, num_classes=num_classes)
147
+
148
+ elif model_idx[0:6] == 'xcit_s': # Transfer learning for XCiT
149
+ import timm
150
+ from pprint import pprint
151
+ model_names = timm.list_models('*xcit*')
152
+ pprint(model_names)
153
+ if edge_size == 384:
154
+ model = timm.create_model('xcit_small_12_p16_384_dist', pretrained=pretrained_backbone,
155
+ num_classes=num_classes)
156
+ elif edge_size == 224:
157
+ model = timm.create_model('xcit_small_12_p16_224_dist', pretrained=pretrained_backbone,
158
+ num_classes=num_classes)
159
+ else:
160
+ pass
161
+
162
+ elif model_idx[0:6] == 'xcit_m': # Transfer learning for XCiT
163
+ import timm
164
+ from pprint import pprint
165
+ model_names = timm.list_models('*xcit*')
166
+ pprint(model_names)
167
+ if edge_size == 384:
168
+ model = timm.create_model('xcit_medium_24_p16_384_dist', pretrained=pretrained_backbone,
169
+ num_classes=num_classes)
170
+ elif edge_size == 224:
171
+ model = timm.create_model('xcit_medium_24_p16_224_dist', pretrained=pretrained_backbone,
172
+ num_classes=num_classes)
173
+ else:
174
+ pass
175
+
176
+ elif model_idx[0:6] == 'mvitv2': # Transfer learning for MViT v2 small fixme bug in model!
177
+ import timm
178
+ from pprint import pprint
179
+ model_names = timm.list_models('*mvitv2*')
180
+ pprint(model_names)
181
+ model = timm.create_model('mvitv2_small_cls', pretrained=pretrained_backbone, num_classes=num_classes)
182
+
183
+ elif model_idx[0:6] == 'convit' and edge_size == 224: # Transfer learning for ConViT fixme bug in model!
184
+ import timm
185
+ from pprint import pprint
186
+
187
+ model_names = timm.list_models('*convit*')
188
+ pprint(model_names)
189
+ model = timm.create_model('convit_base', pretrained=pretrained_backbone, num_classes=num_classes)
190
+
191
+ elif model_idx[0:6] == 'ResNet': # Transfer learning for the ResNets
192
+ if model_idx[0:8] == 'ResNet34':
193
+ model = models.resnet34(pretrained=pretrained_backbone)
194
+ elif model_idx[0:8] == 'ResNet50':
195
+ model = models.resnet50(pretrained=pretrained_backbone)
196
+ elif model_idx[0:9] == 'ResNet101':
197
+ model = models.resnet101(pretrained=pretrained_backbone)
198
+ else:
199
+ print('this model is not defined in get model')
200
+ return -1
201
+ num_ftrs = model.fc.in_features
202
+ model.fc = nn.Linear(num_ftrs, num_classes)
203
+
204
+ elif model_idx[0:6] == 'Backbone': # ours: MSHT
205
+ # NOTICE: HERE 'pretrained' controls only The backbone CNN is initiate randomly
206
+ # or by its official Pretrained models
207
+ model = ResHybrid.create_model(model_idx, edge_size, pretrained=pretrained_backbone, num_classes=num_classes,
208
+ drop_rate=drop_rate, attn_drop_rate=attn_drop_rate,
209
+ drop_path_rate=drop_path_rate, use_cls_token=use_cls_token,
210
+ use_pos_embedding=use_pos_embedding, use_att_module=use_att_module)
211
+
212
+ elif model_idx[0:7] == 'bot_256' and edge_size == 256: # Model: BoT
213
+ import timm
214
+ from pprint import pprint
215
+ model_names = timm.list_models('*bot*')
216
+ pprint(model_names)
217
+ # NOTICE: we find no weight for BoT in timm
218
+ # ['botnet26t_256', 'botnet50ts_256', 'eca_botnext26ts_256']
219
+ model = timm.create_model('botnet26t_256', pretrained=pretrained_backbone, num_classes=num_classes)
220
+
221
+ elif model_idx[0:8] == 'densenet': # Transfer learning for densenet
222
+ import timm
223
+ from pprint import pprint
224
+
225
+ model_names = timm.list_models('*densenet*')
226
+ pprint(model_names)
227
+ model = timm.create_model('densenet121', pretrained=pretrained_backbone, num_classes=num_classes)
228
+
229
+ elif model_idx[0:8] == 'xception': # Transfer learning for Xception
230
+ import timm
231
+ from pprint import pprint
232
+ model_names = timm.list_models('*xception*')
233
+ pprint(model_names)
234
+ model = timm.create_model('xception', pretrained=pretrained_backbone, num_classes=num_classes)
235
+
236
+ elif model_idx[0:9] == 'pvt_v2_b0': # Transfer learning for PVT v2 (todo not okey with torch summary)
237
+ import timm
238
+ from pprint import pprint
239
+ model_names = timm.list_models('*pvt_v2*')
240
+ pprint(model_names)
241
+ model = timm.create_model('pvt_v2_b0', pretrained=pretrained_backbone, num_classes=num_classes)
242
+
243
+ elif model_idx[0:9] == 'visformer' and edge_size == 224: # Transfer learning for Visformer
244
+ import timm
245
+ from pprint import pprint
246
+ model_names = timm.list_models('*visformer*')
247
+ pprint(model_names)
248
+ model = timm.create_model('visformer_small', pretrained=pretrained_backbone, num_classes=num_classes)
249
+
250
+ elif model_idx[0:9] == 'conformer': # Transfer learning for Conformer base
251
+ from Backbone.counterpart_models import conformer
252
+
253
+ embed_dim = 576
254
+ channel_ratio = 6
255
+
256
+ if pretrained_backbone:
257
+ model = conformer.Conformer(num_classes=1000, patch_size=16, channel_ratio=channel_ratio,
258
+ embed_dim=embed_dim, depth=12, num_heads=9, mlp_ratio=4, qkv_bias=True)
259
+ # this is the related path to <code>, not <Backbone>
260
+ save_model_path = '../saved_models/Conformer_base_patch16.pth' # fixme model is downloaded at this path
261
+ # downloaded from official model state at https://github.com/pengzhiliang/Conformer
262
+ model.load_state_dict(torch.load(save_model_path), False)
263
+
264
+ model.trans_cls_head = nn.Linear(embed_dim, num_classes)
265
+ model.conv_cls_head = nn.Linear(int(256 * channel_ratio), num_classes)
266
+ model.cls_head = nn.Linear(int(2 * num_classes), num_classes)
267
+
268
+ else:
269
+ model = conformer.Conformer(num_classes=num_classes, patch_size=16, channel_ratio=channel_ratio,
270
+ embed_dim=embed_dim, depth=12, num_heads=9, mlp_ratio=4, qkv_bias=True)
271
+
272
+ elif model_idx[0:9] == 'coat_mini' and edge_size == 224: # Transfer learning for coat_mini
273
+ import timm
274
+ from pprint import pprint
275
+
276
+ model_names = timm.list_models('*coat*')
277
+ pprint(model_names)
278
+ model = timm.create_model('coat_mini', pretrained=pretrained_backbone, num_classes=num_classes)
279
+
280
+ elif model_idx[0:10] == 'swin_b_384' and edge_size == 384: # Transfer learning for Swin Transformer (swin_b_384)
281
+ import timm
282
+ from pprint import pprint
283
+ model_names = timm.list_models('*swin*')
284
+ pprint(model_names) # swin_base_patch4_window12_384 swin_base_patch4_window12_384_in22k
285
+ model = timm.create_model('swin_base_patch4_window12_384', pretrained=pretrained_backbone,
286
+ num_classes=num_classes)
287
+
288
+ elif model_idx[0:10] == 'swin_b_224' and edge_size == 224: # Transfer learning for Swin Transformer (swin_b_384)
289
+ import timm
290
+ from pprint import pprint
291
+ model_names = timm.list_models('*swin*')
292
+ pprint(model_names) # swin_base_patch4_window7_224 swin_base_patch4_window7_224_in22k
293
+ model = timm.create_model('swin_base_patch4_window7_224', pretrained=pretrained_backbone,
294
+ num_classes=num_classes)
295
+
296
+ elif model_idx[0:11] == 'mobilenetv3': # Transfer learning for mobilenetv3
297
+ import timm
298
+ from pprint import pprint
299
+ model_names = timm.list_models('*mobilenet*')
300
+ pprint(model_names)
301
+ model = timm.create_model('mobilenetv3_large_100', pretrained=pretrained_backbone, num_classes=num_classes)
302
+
303
+ elif model_idx[0:11] == 'mobilevit_s': # Transfer learning for mobilevit_s
304
+ import timm
305
+ from pprint import pprint
306
+ model_names = timm.list_models('*mobilevit*')
307
+ pprint(model_names)
308
+ model = timm.create_model('mobilevit_s', pretrained=pretrained_backbone, num_classes=num_classes)
309
+
310
+ elif model_idx[0:11] == 'inceptionv3': # Transfer learning for Inception v3
311
+ import timm
312
+ from pprint import pprint
313
+ model_names = timm.list_models('*inception*')
314
+ pprint(model_names)
315
+ model = timm.create_model('inception_v3', pretrained=pretrained_backbone, num_classes=num_classes)
316
+
317
+ elif model_idx[0:12] == 'cross_former' and edge_size == 224: # Transfer learning for crossformer base
318
+ from Backbone.counterpart_models import crossformer
319
+ backbone = crossformer.CrossFormer(img_size=edge_size,
320
+ patch_size=[4, 8, 16, 32],
321
+ in_chans=3,
322
+ num_classes=0, # get backbone only
323
+ embed_dim=96,
324
+ depths=[2, 2, 18, 2],
325
+ num_heads=[3, 6, 12, 24],
326
+ group_size=[7, 7, 7, 7],
327
+ mlp_ratio=4.,
328
+ qkv_bias=True,
329
+ qk_scale=None,
330
+ drop_rate=0.0,
331
+ drop_path_rate=0.3,
332
+ ape=False,
333
+ patch_norm=True,
334
+ use_checkpoint=False,
335
+ merge_size=[[2, 4], [2, 4], [2, 4]], )
336
+ if pretrained_backbone:
337
+ save_model_path = '../saved_models/crossformer-b.pth' # fixme model is downloaded at this path
338
+ # downloaded from official model state at https://github.com/cheerss/CrossFormer
339
+ backbone.load_state_dict(torch.load(save_model_path)['model'], False)
340
+ model = crossformer.cross_former_cls_head_warp(backbone, num_classes)
341
+
342
+ elif model_idx[0:13] == 'crossvit_base': # Transfer learning for crossvit_base (todo not okey with torch summary)
343
+ import timm
344
+ from pprint import pprint
345
+ model_names = timm.list_models('*crossvit_base*')
346
+ pprint(model_names)
347
+ model = timm.create_model('crossvit_base_240', pretrained=pretrained_backbone, num_classes=num_classes)
348
+
349
+ elif model_idx[0:14] == 'efficientnet_b': # Transfer learning for efficientnet_b3,4
350
+ import timm
351
+ from pprint import pprint
352
+ model_names = timm.list_models('*efficientnet*')
353
+ pprint(model_names)
354
+ model = timm.create_model(model_idx[0:15], pretrained=pretrained_backbone, num_classes=num_classes)
355
+
356
+ elif model_idx[0:14] == 'ResN50_ViT_384': # ResNet+ViT融合模型384
357
+ import timm
358
+ from pprint import pprint
359
+ model_names = timm.list_models('*vit_base_resnet*')
360
+ pprint(model_names)
361
+ model = timm.create_model('vit_base_resnet50_384', pretrained=pretrained_backbone, num_classes=num_classes)
362
+
363
+ elif model_idx[0:15] == 'coat_lite_small' and edge_size == 224: # Transfer learning for coat_lite_small
364
+ import timm
365
+ from pprint import pprint
366
+
367
+ model_names = timm.list_models('*coat*')
368
+ pprint(model_names)
369
+ model = timm.create_model('coat_lite_small', pretrained=pretrained_backbone, num_classes=num_classes)
370
+
371
+ elif model_idx[0:17] == 'efficientformer_l' and edge_size == 224: # Transfer learning for efficientnet_b3,4
372
+ import timm
373
+ from pprint import pprint
374
+ model_names = timm.list_models('*efficientformer*')
375
+ pprint(model_names)
376
+ model = timm.create_model(model_idx[0:18], pretrained=pretrained_backbone, num_classes=num_classes)
377
+
378
+ else:
379
+ print('\nThe model', model_idx, 'with the edge size of', edge_size)
380
+ print("is not defined in the script!!", '\n')
381
+ return -1
382
+
383
+ try:
384
+ img = torch.randn(1, 3, edge_size, edge_size)
385
+ preds = model(img) # (1, class_number)
386
+ print('test model output:', preds)
387
+ except:
388
+ print("Problem exist in the model defining process!!")
389
+ return -1
390
+ else:
391
+ print('model is ready now!')
392
+ return model
PuzzleTuning/Counterpart PreTrain Methods/ReadMe.md ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ we have used MOCO-V3, Dino, MAE to pretrain the ViT-base-224 model.
2
+ The official codes are implemented here.
PuzzleTuning/Counterpart PreTrain Methods/dino-main/LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [yyyy] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
PuzzleTuning/Counterpart PreTrain Methods/dino-main/README.md ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Self-Supervised Vision Transformers with DINO
2
+
3
+ The original repo of DINO could be found [here](https://github.com/facebookresearch/dino "DINO")
4
+
5
+ Pip requirements: timm == 0.4.9, PyTorch == 1.7.1, Torchvision == 0.8.2, Cuda == 11.0
6
+
7
+ Typical BASH:
8
+ ```console
9
+ python -m torch.distributed.launch \
10
+ --nproc_per_node=2 main_dino.py --arch vit_base --batch_size_per_gpu 128 \
11
+ --lr 1.5e-4 --epochs 100 --data_path /root/autodl-tmp/All \
12
+ --basic_state_dict /root/autodl-tmp/ViT_b16_224_Imagenet.pth \
13
+ --num_workers 32 --output_dir the/path/of/CPIA
14
+ ```
PuzzleTuning/Counterpart PreTrain Methods/dino-main/eval_copy_detection.py ADDED
@@ -0,0 +1,301 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import os
15
+ import sys
16
+ import pickle
17
+ import argparse
18
+
19
+ import torch
20
+ from torch import nn
21
+ import torch.distributed as dist
22
+ import torch.backends.cudnn as cudnn
23
+ from torchvision import models as torchvision_models
24
+ from torchvision import transforms as pth_transforms
25
+ from PIL import Image, ImageFile
26
+ import numpy as np
27
+
28
+ import utils
29
+ import vision_transformer as vits
30
+ from eval_knn import extract_features
31
+
32
+
33
+ class CopydaysDataset():
34
+ def __init__(self, basedir):
35
+ self.basedir = basedir
36
+ self.block_names = (
37
+ ['original', 'strong'] +
38
+ ['jpegqual/%d' % i for i in
39
+ [3, 5, 8, 10, 15, 20, 30, 50, 75]] +
40
+ ['crops/%d' % i for i in
41
+ [10, 15, 20, 30, 40, 50, 60, 70, 80]])
42
+ self.nblocks = len(self.block_names)
43
+
44
+ self.query_blocks = range(self.nblocks)
45
+ self.q_block_sizes = np.ones(self.nblocks, dtype=int) * 157
46
+ self.q_block_sizes[1] = 229
47
+ # search only among originals
48
+ self.database_blocks = [0]
49
+
50
+ def get_block(self, i):
51
+ dirname = self.basedir + '/' + self.block_names[i]
52
+ fnames = [dirname + '/' + fname
53
+ for fname in sorted(os.listdir(dirname))
54
+ if fname.endswith('.jpg')]
55
+ return fnames
56
+
57
+ def get_block_filenames(self, subdir_name):
58
+ dirname = self.basedir + '/' + subdir_name
59
+ return [fname
60
+ for fname in sorted(os.listdir(dirname))
61
+ if fname.endswith('.jpg')]
62
+
63
+ def eval_result(self, ids, distances):
64
+ j0 = 0
65
+ for i in range(self.nblocks):
66
+ j1 = j0 + self.q_block_sizes[i]
67
+ block_name = self.block_names[i]
68
+ I = ids[j0:j1] # block size
69
+ sum_AP = 0
70
+ if block_name != 'strong':
71
+ # 1:1 mapping of files to names
72
+ positives_per_query = [[i] for i in range(j1 - j0)]
73
+ else:
74
+ originals = self.get_block_filenames('original')
75
+ strongs = self.get_block_filenames('strong')
76
+
77
+ # check if prefixes match
78
+ positives_per_query = [
79
+ [j for j, bname in enumerate(originals)
80
+ if bname[:4] == qname[:4]]
81
+ for qname in strongs]
82
+
83
+ for qno, Iline in enumerate(I):
84
+ positives = positives_per_query[qno]
85
+ ranks = []
86
+ for rank, bno in enumerate(Iline):
87
+ if bno in positives:
88
+ ranks.append(rank)
89
+ sum_AP += score_ap_from_ranks_1(ranks, len(positives))
90
+
91
+ print("eval on %s mAP=%.3f" % (
92
+ block_name, sum_AP / (j1 - j0)))
93
+ j0 = j1
94
+
95
+
96
+ # from the Holidays evaluation package
97
+ def score_ap_from_ranks_1(ranks, nres):
98
+ """ Compute the average precision of one search.
99
+ ranks = ordered list of ranks of true positives
100
+ nres = total number of positives in dataset
101
+ """
102
+
103
+ # accumulate trapezoids in PR-plot
104
+ ap = 0.0
105
+
106
+ # All have an x-size of:
107
+ recall_step = 1.0 / nres
108
+
109
+ for ntp, rank in enumerate(ranks):
110
+
111
+ # y-size on left side of trapezoid:
112
+ # ntp = nb of true positives so far
113
+ # rank = nb of retrieved items so far
114
+ if rank == 0:
115
+ precision_0 = 1.0
116
+ else:
117
+ precision_0 = ntp / float(rank)
118
+
119
+ # y-size on right side of trapezoid:
120
+ # ntp and rank are increased by one
121
+ precision_1 = (ntp + 1) / float(rank + 1)
122
+
123
+ ap += (precision_1 + precision_0) * recall_step / 2.0
124
+
125
+ return ap
126
+
127
+
128
+ class ImgListDataset(torch.utils.data.Dataset):
129
+ def __init__(self, img_list, transform=None):
130
+ self.samples = img_list
131
+ self.transform = transform
132
+
133
+ def __getitem__(self, i):
134
+ with open(self.samples[i], 'rb') as f:
135
+ img = Image.open(f)
136
+ img = img.convert('RGB')
137
+ if self.transform is not None:
138
+ img = self.transform(img)
139
+ return img, i
140
+
141
+ def __len__(self):
142
+ return len(self.samples)
143
+
144
+
145
+ def is_image_file(s):
146
+ ext = s.split(".")[-1]
147
+ if ext in ['jpg', 'jpeg', 'png', 'ppm', 'bmp', 'pgm', 'tif', 'tiff', 'webp']:
148
+ return True
149
+ return False
150
+
151
+
152
+ @torch.no_grad()
153
+ def extract_features(image_list, model, args):
154
+ transform = pth_transforms.Compose([
155
+ pth_transforms.Resize((args.imsize, args.imsize), interpolation=3),
156
+ pth_transforms.ToTensor(),
157
+ pth_transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
158
+ ])
159
+ tempdataset = ImgListDataset(image_list, transform=transform)
160
+ data_loader = torch.utils.data.DataLoader(tempdataset, batch_size=args.batch_size_per_gpu,
161
+ num_workers=args.num_workers, drop_last=False,
162
+ sampler=torch.utils.data.DistributedSampler(tempdataset, shuffle=False))
163
+ features = None
164
+ for samples, index in utils.MetricLogger(delimiter=" ").log_every(data_loader, 10):
165
+ samples, index = samples.cuda(non_blocking=True), index.cuda(non_blocking=True)
166
+ feats = model.get_intermediate_layers(samples, n=1)[0].clone()
167
+
168
+ cls_output_token = feats[:, 0, :] # [CLS] token
169
+ # GeM with exponent 4 for output patch tokens
170
+ b, h, w, d = len(samples), int(samples.shape[-2] / model.patch_embed.patch_size), int(samples.shape[-1] / model.patch_embed.patch_size), feats.shape[-1]
171
+ feats = feats[:, 1:, :].reshape(b, h, w, d)
172
+ feats = feats.clamp(min=1e-6).permute(0, 3, 1, 2)
173
+ feats = nn.functional.avg_pool2d(feats.pow(4), (h, w)).pow(1. / 4).reshape(b, -1)
174
+ # concatenate [CLS] token and GeM pooled patch tokens
175
+ feats = torch.cat((cls_output_token, feats), dim=1)
176
+
177
+ # init storage feature matrix
178
+ if dist.get_rank() == 0 and features is None:
179
+ features = torch.zeros(len(data_loader.dataset), feats.shape[-1])
180
+ if args.use_cuda:
181
+ features = features.cuda(non_blocking=True)
182
+
183
+ # get indexes from all processes
184
+ y_all = torch.empty(dist.get_world_size(), index.size(0), dtype=index.dtype, device=index.device)
185
+ y_l = list(y_all.unbind(0))
186
+ y_all_reduce = torch.distributed.all_gather(y_l, index, async_op=True)
187
+ y_all_reduce.wait()
188
+ index_all = torch.cat(y_l)
189
+
190
+ # share features between processes
191
+ feats_all = torch.empty(dist.get_world_size(), feats.size(0), feats.size(1),
192
+ dtype=feats.dtype, device=feats.device)
193
+ output_l = list(feats_all.unbind(0))
194
+ output_all_reduce = torch.distributed.all_gather(output_l, feats, async_op=True)
195
+ output_all_reduce.wait()
196
+
197
+ # update storage feature matrix
198
+ if dist.get_rank() == 0:
199
+ if args.use_cuda:
200
+ features.index_copy_(0, index_all, torch.cat(output_l))
201
+ else:
202
+ features.index_copy_(0, index_all.cpu(), torch.cat(output_l).cpu())
203
+ return features # features is still None for every rank which is not 0 (main)
204
+
205
+
206
+ if __name__ == '__main__':
207
+ parser = argparse.ArgumentParser('Copy detection on Copydays')
208
+ parser.add_argument('--data_path', default='/path/to/copydays/', type=str,
209
+ help="See https://lear.inrialpes.fr/~jegou/data.php#copydays")
210
+ parser.add_argument('--whitening_path', default='/path/to/whitening_data/', type=str,
211
+ help="""Path to directory with images used for computing the whitening operator.
212
+ In our paper, we use 20k random images from YFCC100M.""")
213
+ parser.add_argument('--distractors_path', default='/path/to/distractors/', type=str,
214
+ help="Path to directory with distractors images. In our paper, we use 10k random images from YFCC100M.")
215
+ parser.add_argument('--imsize', default=320, type=int, help='Image size (square image)')
216
+ parser.add_argument('--batch_size_per_gpu', default=16, type=int, help='Per-GPU batch-size')
217
+ parser.add_argument('--pretrained_weights', default='', type=str, help="Path to pretrained weights to evaluate.")
218
+ parser.add_argument('--use_cuda', default=True, type=utils.bool_flag)
219
+ parser.add_argument('--arch', default='vit_base', type=str, help='Architecture')
220
+ parser.add_argument('--patch_size', default=8, type=int, help='Patch resolution of the model.')
221
+ parser.add_argument("--checkpoint_key", default="teacher", type=str,
222
+ help='Key to use in the checkpoint (example: "teacher")')
223
+ parser.add_argument('--num_workers', default=10, type=int, help='Number of data loading workers per GPU.')
224
+ parser.add_argument("--dist_url", default="env://", type=str, help="""url used to set up
225
+ distributed training; see https://pytorch.org/docs/stable/distributed.html""")
226
+ parser.add_argument("--local_rank", default=0, type=int, help="Please ignore and do not set this argument.")
227
+ args = parser.parse_args()
228
+
229
+ utils.init_distributed_mode(args)
230
+ print("git:\n {}\n".format(utils.get_sha()))
231
+ print("\n".join("%s: %s" % (k, str(v)) for k, v in sorted(dict(vars(args)).items())))
232
+ cudnn.benchmark = True
233
+
234
+ # ============ building network ... ============
235
+ if "vit" in args.arch:
236
+ model = vits.__dict__[args.arch](patch_size=args.patch_size, num_classes=0)
237
+ print(f"Model {args.arch} {args.patch_size}x{args.patch_size} built.")
238
+ else:
239
+ print(f"Architecture {args.arch} non supported")
240
+ sys.exit(1)
241
+ if args.use_cuda:
242
+ model.cuda()
243
+ model.eval()
244
+ utils.load_pretrained_weights(model, args.pretrained_weights, args.checkpoint_key, args.arch, args.patch_size)
245
+
246
+ dataset = CopydaysDataset(args.data_path)
247
+
248
+ # ============ Extract features ... ============
249
+ # extract features for queries
250
+ queries = []
251
+ for q in dataset.query_blocks:
252
+ queries.append(extract_features(dataset.get_block(q), model, args))
253
+ if utils.get_rank() == 0:
254
+ queries = torch.cat(queries)
255
+ print(f"Extraction of queries features done. Shape: {queries.shape}")
256
+
257
+ # extract features for database
258
+ database = []
259
+ for b in dataset.database_blocks:
260
+ database.append(extract_features(dataset.get_block(b), model, args))
261
+
262
+ # extract features for distractors
263
+ if os.path.isdir(args.distractors_path):
264
+ print("Using distractors...")
265
+ list_distractors = [os.path.join(args.distractors_path, s) for s in os.listdir(args.distractors_path) if is_image_file(s)]
266
+ database.append(extract_features(list_distractors, model, args))
267
+ if utils.get_rank() == 0:
268
+ database = torch.cat(database)
269
+ print(f"Extraction of database and distractors features done. Shape: {database.shape}")
270
+
271
+ # ============ Whitening ... ============
272
+ if os.path.isdir(args.whitening_path):
273
+ print(f"Extracting features on images from {args.whitening_path} for learning the whitening operator.")
274
+ list_whit = [os.path.join(args.whitening_path, s) for s in os.listdir(args.whitening_path) if is_image_file(s)]
275
+ features_for_whitening = extract_features(list_whit, model, args)
276
+ if utils.get_rank() == 0:
277
+ # center
278
+ mean_feature = torch.mean(features_for_whitening, dim=0)
279
+ database -= mean_feature
280
+ queries -= mean_feature
281
+ pca = utils.PCA(dim=database.shape[-1], whit=0.5)
282
+ # compute covariance
283
+ cov = torch.mm(features_for_whitening.T, features_for_whitening) / features_for_whitening.shape[0]
284
+ pca.train_pca(cov.cpu().numpy())
285
+ database = pca.apply(database)
286
+ queries = pca.apply(queries)
287
+
288
+ # ============ Copy detection ... ============
289
+ if utils.get_rank() == 0:
290
+ # l2 normalize the features
291
+ database = nn.functional.normalize(database, dim=1, p=2)
292
+ queries = nn.functional.normalize(queries, dim=1, p=2)
293
+
294
+ # similarity
295
+ similarity = torch.mm(queries, database.T)
296
+ distances, indices = similarity.topk(20, largest=True, sorted=True)
297
+
298
+ # evaluate
299
+ retrieved = dataset.eval_result(indices, distances)
300
+ dist.barrier()
301
+
PuzzleTuning/Counterpart PreTrain Methods/dino-main/eval_image_retrieval.py ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import os
15
+ import sys
16
+ import pickle
17
+ import argparse
18
+
19
+ import torch
20
+ from torch import nn
21
+ import torch.distributed as dist
22
+ import torch.backends.cudnn as cudnn
23
+ from torchvision import models as torchvision_models
24
+ from torchvision import transforms as pth_transforms
25
+ from PIL import Image, ImageFile
26
+ import numpy as np
27
+
28
+ import utils
29
+ import vision_transformer as vits
30
+ from eval_knn import extract_features
31
+
32
+
33
+ class OxfordParisDataset(torch.utils.data.Dataset):
34
+ def __init__(self, dir_main, dataset, split, transform=None, imsize=None):
35
+ if dataset not in ['roxford5k', 'rparis6k']:
36
+ raise ValueError('Unknown dataset: {}!'.format(dataset))
37
+
38
+ # loading imlist, qimlist, and gnd, in cfg as a dict
39
+ gnd_fname = os.path.join(dir_main, dataset, 'gnd_{}.pkl'.format(dataset))
40
+ with open(gnd_fname, 'rb') as f:
41
+ cfg = pickle.load(f)
42
+ cfg['gnd_fname'] = gnd_fname
43
+ cfg['ext'] = '.jpg'
44
+ cfg['qext'] = '.jpg'
45
+ cfg['dir_data'] = os.path.join(dir_main, dataset)
46
+ cfg['dir_images'] = os.path.join(cfg['dir_data'], 'jpg')
47
+ cfg['n'] = len(cfg['imlist'])
48
+ cfg['nq'] = len(cfg['qimlist'])
49
+ cfg['im_fname'] = config_imname
50
+ cfg['qim_fname'] = config_qimname
51
+ cfg['dataset'] = dataset
52
+ self.cfg = cfg
53
+
54
+ self.samples = cfg["qimlist"] if split == "query" else cfg["imlist"]
55
+ self.transform = transform
56
+ self.imsize = imsize
57
+
58
+ def __len__(self):
59
+ return len(self.samples)
60
+
61
+ def __getitem__(self, index):
62
+ path = os.path.join(self.cfg["dir_images"], self.samples[index] + ".jpg")
63
+ ImageFile.LOAD_TRUNCATED_IMAGES = True
64
+ with open(path, 'rb') as f:
65
+ img = Image.open(f)
66
+ img = img.convert('RGB')
67
+ if self.imsize is not None:
68
+ img.thumbnail((self.imsize, self.imsize), Image.ANTIALIAS)
69
+ if self.transform is not None:
70
+ img = self.transform(img)
71
+ return img, index
72
+
73
+
74
+ def config_imname(cfg, i):
75
+ return os.path.join(cfg['dir_images'], cfg['imlist'][i] + cfg['ext'])
76
+
77
+
78
+ def config_qimname(cfg, i):
79
+ return os.path.join(cfg['dir_images'], cfg['qimlist'][i] + cfg['qext'])
80
+
81
+
82
+ if __name__ == '__main__':
83
+ parser = argparse.ArgumentParser('Image Retrieval on revisited Paris and Oxford')
84
+ parser.add_argument('--data_path', default='/path/to/revisited_paris_oxford/', type=str)
85
+ parser.add_argument('--dataset', default='roxford5k', type=str, choices=['roxford5k', 'rparis6k'])
86
+ parser.add_argument('--multiscale', default=False, type=utils.bool_flag)
87
+ parser.add_argument('--imsize', default=224, type=int, help='Image size')
88
+ parser.add_argument('--pretrained_weights', default='', type=str, help="Path to pretrained weights to evaluate.")
89
+ parser.add_argument('--use_cuda', default=True, type=utils.bool_flag)
90
+ parser.add_argument('--arch', default='vit_small', type=str, help='Architecture')
91
+ parser.add_argument('--patch_size', default=16, type=int, help='Patch resolution of the model.')
92
+ parser.add_argument("--checkpoint_key", default="teacher", type=str,
93
+ help='Key to use in the checkpoint (example: "teacher")')
94
+ parser.add_argument('--num_workers', default=10, type=int, help='Number of data loading workers per GPU.')
95
+ parser.add_argument("--dist_url", default="env://", type=str, help="""url used to set up
96
+ distributed training; see https://pytorch.org/docs/stable/distributed.html""")
97
+ parser.add_argument("--local_rank", default=0, type=int, help="Please ignore and do not set this argument.")
98
+ args = parser.parse_args()
99
+
100
+ utils.init_distributed_mode(args)
101
+ print("git:\n {}\n".format(utils.get_sha()))
102
+ print("\n".join("%s: %s" % (k, str(v)) for k, v in sorted(dict(vars(args)).items())))
103
+ cudnn.benchmark = True
104
+
105
+ # ============ preparing data ... ============
106
+ transform = pth_transforms.Compose([
107
+ pth_transforms.ToTensor(),
108
+ pth_transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
109
+ ])
110
+ dataset_train = OxfordParisDataset(args.data_path, args.dataset, split="train", transform=transform, imsize=args.imsize)
111
+ dataset_query = OxfordParisDataset(args.data_path, args.dataset, split="query", transform=transform, imsize=args.imsize)
112
+ sampler = torch.utils.data.DistributedSampler(dataset_train, shuffle=False)
113
+ data_loader_train = torch.utils.data.DataLoader(
114
+ dataset_train,
115
+ sampler=sampler,
116
+ batch_size=1,
117
+ num_workers=args.num_workers,
118
+ pin_memory=True,
119
+ drop_last=False,
120
+ )
121
+ data_loader_query = torch.utils.data.DataLoader(
122
+ dataset_query,
123
+ batch_size=1,
124
+ num_workers=args.num_workers,
125
+ pin_memory=True,
126
+ drop_last=False,
127
+ )
128
+ print(f"train: {len(dataset_train)} imgs / query: {len(dataset_query)} imgs")
129
+
130
+ # ============ building network ... ============
131
+ if "vit" in args.arch:
132
+ model = vits.__dict__[args.arch](patch_size=args.patch_size, num_classes=0)
133
+ print(f"Model {args.arch} {args.patch_size}x{args.patch_size} built.")
134
+ elif "xcit" in args.arch:
135
+ model = torch.hub.load('facebookresearch/xcit:main', args.arch, num_classes=0)
136
+ elif args.arch in torchvision_models.__dict__.keys():
137
+ model = torchvision_models.__dict__[args.arch](num_classes=0)
138
+ else:
139
+ print(f"Architecture {args.arch} non supported")
140
+ sys.exit(1)
141
+ if args.use_cuda:
142
+ model.cuda()
143
+ model.eval()
144
+
145
+ # load pretrained weights
146
+ if os.path.isfile(args.pretrained_weights):
147
+ state_dict = torch.load(args.pretrained_weights, map_location="cpu")
148
+ if args.checkpoint_key is not None and args.checkpoint_key in state_dict:
149
+ print(f"Take key {args.checkpoint_key} in provided checkpoint dict")
150
+ state_dict = state_dict[args.checkpoint_key]
151
+ # remove `module.` prefix
152
+ state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
153
+ # remove `backbone.` prefix induced by multicrop wrapper
154
+ state_dict = {k.replace("backbone.", ""): v for k, v in state_dict.items()}
155
+ msg = model.load_state_dict(state_dict, strict=False)
156
+ print('Pretrained weights found at {} and loaded with msg: {}'.format(args.pretrained_weights, msg))
157
+ elif args.arch == "vit_small" and args.patch_size == 16:
158
+ print("Since no pretrained weights have been provided, we load pretrained DINO weights on Google Landmark v2.")
159
+ model.load_state_dict(torch.hub.load_state_dict_from_url(url="https://dl.fbaipublicfiles.com/dino/dino_vitsmall16_googlelandmark_pretrain/dino_vitsmall16_googlelandmark_pretrain.pth"))
160
+ else:
161
+ print("Warning: We use random weights.")
162
+
163
+ ############################################################################
164
+ # Step 1: extract features
165
+ train_features = extract_features(model, data_loader_train, args.use_cuda, multiscale=args.multiscale)
166
+ query_features = extract_features(model, data_loader_query, args.use_cuda, multiscale=args.multiscale)
167
+
168
+ if utils.get_rank() == 0: # only rank 0 will work from now on
169
+ # normalize features
170
+ train_features = nn.functional.normalize(train_features, dim=1, p=2)
171
+ query_features = nn.functional.normalize(query_features, dim=1, p=2)
172
+
173
+ ############################################################################
174
+ # Step 2: similarity
175
+ sim = torch.mm(train_features, query_features.T)
176
+ ranks = torch.argsort(-sim, dim=0).cpu().numpy()
177
+
178
+ ############################################################################
179
+ # Step 3: evaluate
180
+ gnd = dataset_train.cfg['gnd']
181
+ # evaluate ranks
182
+ ks = [1, 5, 10]
183
+ # search for easy & hard
184
+ gnd_t = []
185
+ for i in range(len(gnd)):
186
+ g = {}
187
+ g['ok'] = np.concatenate([gnd[i]['easy'], gnd[i]['hard']])
188
+ g['junk'] = np.concatenate([gnd[i]['junk']])
189
+ gnd_t.append(g)
190
+ mapM, apsM, mprM, prsM = utils.compute_map(ranks, gnd_t, ks)
191
+ # search for hard
192
+ gnd_t = []
193
+ for i in range(len(gnd)):
194
+ g = {}
195
+ g['ok'] = np.concatenate([gnd[i]['hard']])
196
+ g['junk'] = np.concatenate([gnd[i]['junk'], gnd[i]['easy']])
197
+ gnd_t.append(g)
198
+ mapH, apsH, mprH, prsH = utils.compute_map(ranks, gnd_t, ks)
199
+ print('>> {}: mAP M: {}, H: {}'.format(args.dataset, np.around(mapM*100, decimals=2), np.around(mapH*100, decimals=2)))
200
+ print('>> {}: mP@k{} M: {}, H: {}'.format(args.dataset, np.array(ks), np.around(mprM*100, decimals=2), np.around(mprH*100, decimals=2)))
201
+ dist.barrier()
PuzzleTuning/Counterpart PreTrain Methods/dino-main/eval_knn.py ADDED
@@ -0,0 +1,242 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import os
15
+ import sys
16
+ import argparse
17
+
18
+ import torch
19
+ from torch import nn
20
+ import torch.distributed as dist
21
+ import torch.backends.cudnn as cudnn
22
+ from torchvision import datasets
23
+ from torchvision import transforms as pth_transforms
24
+ from torchvision import models as torchvision_models
25
+
26
+ import utils
27
+ import vision_transformer as vits
28
+
29
+
30
+ def extract_feature_pipeline(args):
31
+ # ============ preparing data ... ============
32
+ transform = pth_transforms.Compose([
33
+ pth_transforms.Resize(256, interpolation=3),
34
+ pth_transforms.CenterCrop(224),
35
+ pth_transforms.ToTensor(),
36
+ pth_transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
37
+ ])
38
+ dataset_train = ReturnIndexDataset(os.path.join(args.data_path, "train"), transform=transform)
39
+ dataset_val = ReturnIndexDataset(os.path.join(args.data_path, "val"), transform=transform)
40
+ sampler = torch.utils.data.DistributedSampler(dataset_train, shuffle=False)
41
+ data_loader_train = torch.utils.data.DataLoader(
42
+ dataset_train,
43
+ sampler=sampler,
44
+ batch_size=args.batch_size_per_gpu,
45
+ num_workers=args.num_workers,
46
+ pin_memory=True,
47
+ drop_last=False,
48
+ )
49
+ data_loader_val = torch.utils.data.DataLoader(
50
+ dataset_val,
51
+ batch_size=args.batch_size_per_gpu,
52
+ num_workers=args.num_workers,
53
+ pin_memory=True,
54
+ drop_last=False,
55
+ )
56
+ print(f"Data loaded with {len(dataset_train)} train and {len(dataset_val)} val imgs.")
57
+
58
+ # ============ building network ... ============
59
+ if "vit" in args.arch:
60
+ model = vits.__dict__[args.arch](patch_size=args.patch_size, num_classes=0)
61
+ print(f"Model {args.arch} {args.patch_size}x{args.patch_size} built.")
62
+ elif "xcit" in args.arch:
63
+ model = torch.hub.load('facebookresearch/xcit:main', args.arch, num_classes=0)
64
+ elif args.arch in torchvision_models.__dict__.keys():
65
+ model = torchvision_models.__dict__[args.arch](num_classes=0)
66
+ model.fc = nn.Identity()
67
+ else:
68
+ print(f"Architecture {args.arch} non supported")
69
+ sys.exit(1)
70
+ model.cuda()
71
+ utils.load_pretrained_weights(model, args.pretrained_weights, args.checkpoint_key, args.arch, args.patch_size)
72
+ model.eval()
73
+
74
+ # ============ extract features ... ============
75
+ print("Extracting features for train set...")
76
+ train_features = extract_features(model, data_loader_train, args.use_cuda)
77
+ print("Extracting features for val set...")
78
+ test_features = extract_features(model, data_loader_val, args.use_cuda)
79
+
80
+ if utils.get_rank() == 0:
81
+ train_features = nn.functional.normalize(train_features, dim=1, p=2)
82
+ test_features = nn.functional.normalize(test_features, dim=1, p=2)
83
+
84
+ train_labels = torch.tensor([s[-1] for s in dataset_train.samples]).long()
85
+ test_labels = torch.tensor([s[-1] for s in dataset_val.samples]).long()
86
+ # save features and labels
87
+ if args.dump_features and dist.get_rank() == 0:
88
+ torch.save(train_features.cpu(), os.path.join(args.dump_features, "trainfeat.pth"))
89
+ torch.save(test_features.cpu(), os.path.join(args.dump_features, "testfeat.pth"))
90
+ torch.save(train_labels.cpu(), os.path.join(args.dump_features, "trainlabels.pth"))
91
+ torch.save(test_labels.cpu(), os.path.join(args.dump_features, "testlabels.pth"))
92
+ return train_features, test_features, train_labels, test_labels
93
+
94
+
95
+ @torch.no_grad()
96
+ def extract_features(model, data_loader, use_cuda=True, multiscale=False):
97
+ metric_logger = utils.MetricLogger(delimiter=" ")
98
+ features = None
99
+ for samples, index in metric_logger.log_every(data_loader, 10):
100
+ samples = samples.cuda(non_blocking=True)
101
+ index = index.cuda(non_blocking=True)
102
+ if multiscale:
103
+ feats = utils.multi_scale(samples, model)
104
+ else:
105
+ feats = model(samples).clone()
106
+
107
+ # init storage feature matrix
108
+ if dist.get_rank() == 0 and features is None:
109
+ features = torch.zeros(len(data_loader.dataset), feats.shape[-1])
110
+ if use_cuda:
111
+ features = features.cuda(non_blocking=True)
112
+ print(f"Storing features into tensor of shape {features.shape}")
113
+
114
+ # get indexes from all processes
115
+ y_all = torch.empty(dist.get_world_size(), index.size(0), dtype=index.dtype, device=index.device)
116
+ y_l = list(y_all.unbind(0))
117
+ y_all_reduce = torch.distributed.all_gather(y_l, index, async_op=True)
118
+ y_all_reduce.wait()
119
+ index_all = torch.cat(y_l)
120
+
121
+ # share features between processes
122
+ feats_all = torch.empty(
123
+ dist.get_world_size(),
124
+ feats.size(0),
125
+ feats.size(1),
126
+ dtype=feats.dtype,
127
+ device=feats.device,
128
+ )
129
+ output_l = list(feats_all.unbind(0))
130
+ output_all_reduce = torch.distributed.all_gather(output_l, feats, async_op=True)
131
+ output_all_reduce.wait()
132
+
133
+ # update storage feature matrix
134
+ if dist.get_rank() == 0:
135
+ if use_cuda:
136
+ features.index_copy_(0, index_all, torch.cat(output_l))
137
+ else:
138
+ features.index_copy_(0, index_all.cpu(), torch.cat(output_l).cpu())
139
+ return features
140
+
141
+
142
+ @torch.no_grad()
143
+ def knn_classifier(train_features, train_labels, test_features, test_labels, k, T, num_classes=1000):
144
+ top1, top5, total = 0.0, 0.0, 0
145
+ train_features = train_features.t()
146
+ num_test_images, num_chunks = test_labels.shape[0], 100
147
+ imgs_per_chunk = num_test_images // num_chunks
148
+ retrieval_one_hot = torch.zeros(k, num_classes).to(train_features.device)
149
+ for idx in range(0, num_test_images, imgs_per_chunk):
150
+ # get the features for test images
151
+ features = test_features[
152
+ idx : min((idx + imgs_per_chunk), num_test_images), :
153
+ ]
154
+ targets = test_labels[idx : min((idx + imgs_per_chunk), num_test_images)]
155
+ batch_size = targets.shape[0]
156
+
157
+ # calculate the dot product and compute top-k neighbors
158
+ similarity = torch.mm(features, train_features)
159
+ distances, indices = similarity.topk(k, largest=True, sorted=True)
160
+ candidates = train_labels.view(1, -1).expand(batch_size, -1)
161
+ retrieved_neighbors = torch.gather(candidates, 1, indices)
162
+
163
+ retrieval_one_hot.resize_(batch_size * k, num_classes).zero_()
164
+ retrieval_one_hot.scatter_(1, retrieved_neighbors.view(-1, 1), 1)
165
+ distances_transform = distances.clone().div_(T).exp_()
166
+ probs = torch.sum(
167
+ torch.mul(
168
+ retrieval_one_hot.view(batch_size, -1, num_classes),
169
+ distances_transform.view(batch_size, -1, 1),
170
+ ),
171
+ 1,
172
+ )
173
+ _, predictions = probs.sort(1, True)
174
+
175
+ # find the predictions that match the target
176
+ correct = predictions.eq(targets.data.view(-1, 1))
177
+ top1 = top1 + correct.narrow(1, 0, 1).sum().item()
178
+ top5 = top5 + correct.narrow(1, 0, min(5, k)).sum().item() # top5 does not make sense if k < 5
179
+ total += targets.size(0)
180
+ top1 = top1 * 100.0 / total
181
+ top5 = top5 * 100.0 / total
182
+ return top1, top5
183
+
184
+
185
+ class ReturnIndexDataset(datasets.ImageFolder):
186
+ def __getitem__(self, idx):
187
+ img, lab = super(ReturnIndexDataset, self).__getitem__(idx)
188
+ return img, idx
189
+
190
+
191
+ if __name__ == '__main__':
192
+ parser = argparse.ArgumentParser('Evaluation with weighted k-NN on ImageNet')
193
+ parser.add_argument('--batch_size_per_gpu', default=128, type=int, help='Per-GPU batch-size')
194
+ parser.add_argument('--nb_knn', default=[10, 20, 100, 200], nargs='+', type=int,
195
+ help='Number of NN to use. 20 is usually working the best.')
196
+ parser.add_argument('--temperature', default=0.07, type=float,
197
+ help='Temperature used in the voting coefficient')
198
+ parser.add_argument('--pretrained_weights', default='', type=str, help="Path to pretrained weights to evaluate.")
199
+ parser.add_argument('--use_cuda', default=True, type=utils.bool_flag,
200
+ help="Should we store the features on GPU? We recommend setting this to False if you encounter OOM")
201
+ parser.add_argument('--arch', default='vit_small', type=str, help='Architecture')
202
+ parser.add_argument('--patch_size', default=16, type=int, help='Patch resolution of the model.')
203
+ parser.add_argument("--checkpoint_key", default="teacher", type=str,
204
+ help='Key to use in the checkpoint (example: "teacher")')
205
+ parser.add_argument('--dump_features', default=None,
206
+ help='Path where to save computed features, empty for no saving')
207
+ parser.add_argument('--load_features', default=None, help="""If the features have
208
+ already been computed, where to find them.""")
209
+ parser.add_argument('--num_workers', default=10, type=int, help='Number of data loading workers per GPU.')
210
+ parser.add_argument("--dist_url", default="env://", type=str, help="""url used to set up
211
+ distributed training; see https://pytorch.org/docs/stable/distributed.html""")
212
+ parser.add_argument("--local_rank", default=0, type=int, help="Please ignore and do not set this argument.")
213
+ parser.add_argument('--data_path', default='/path/to/imagenet/', type=str)
214
+ args = parser.parse_args()
215
+
216
+ utils.init_distributed_mode(args)
217
+ print("git:\n {}\n".format(utils.get_sha()))
218
+ print("\n".join("%s: %s" % (k, str(v)) for k, v in sorted(dict(vars(args)).items())))
219
+ cudnn.benchmark = True
220
+
221
+ if args.load_features:
222
+ train_features = torch.load(os.path.join(args.load_features, "trainfeat.pth"))
223
+ test_features = torch.load(os.path.join(args.load_features, "testfeat.pth"))
224
+ train_labels = torch.load(os.path.join(args.load_features, "trainlabels.pth"))
225
+ test_labels = torch.load(os.path.join(args.load_features, "testlabels.pth"))
226
+ else:
227
+ # need to extract features !
228
+ train_features, test_features, train_labels, test_labels = extract_feature_pipeline(args)
229
+
230
+ if utils.get_rank() == 0:
231
+ if args.use_cuda:
232
+ train_features = train_features.cuda()
233
+ test_features = test_features.cuda()
234
+ train_labels = train_labels.cuda()
235
+ test_labels = test_labels.cuda()
236
+
237
+ print("Features are ready!\nStart the k-NN classification.")
238
+ for k in args.nb_knn:
239
+ top1, top5 = knn_classifier(train_features, train_labels,
240
+ test_features, test_labels, k, args.temperature)
241
+ print(f"{k}-NN classifier result: Top1: {top1}, Top5: {top5}")
242
+ dist.barrier()
PuzzleTuning/Counterpart PreTrain Methods/dino-main/eval_linear.py ADDED
@@ -0,0 +1,281 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import os
15
+ import argparse
16
+ import json
17
+ from pathlib import Path
18
+
19
+ import torch
20
+ from torch import nn
21
+ import torch.distributed as dist
22
+ import torch.backends.cudnn as cudnn
23
+ from torchvision import datasets
24
+ from torchvision import transforms as pth_transforms
25
+ from torchvision import models as torchvision_models
26
+
27
+ import utils
28
+ import vision_transformer as vits
29
+
30
+
31
+ def eval_linear(args):
32
+ utils.init_distributed_mode(args)
33
+ print("git:\n {}\n".format(utils.get_sha()))
34
+ print("\n".join("%s: %s" % (k, str(v)) for k, v in sorted(dict(vars(args)).items())))
35
+ cudnn.benchmark = True
36
+
37
+ # ============ building network ... ============
38
+ # if the network is a Vision Transformer (i.e. vit_tiny, vit_small, vit_base)
39
+ if args.arch in vits.__dict__.keys():
40
+ model = vits.__dict__[args.arch](patch_size=args.patch_size, num_classes=0)
41
+ embed_dim = model.embed_dim * (args.n_last_blocks + int(args.avgpool_patchtokens))
42
+ # if the network is a XCiT
43
+ elif "xcit" in args.arch:
44
+ model = torch.hub.load('facebookresearch/xcit:main', args.arch, num_classes=0)
45
+ embed_dim = model.embed_dim
46
+ # otherwise, we check if the architecture is in torchvision models
47
+ elif args.arch in torchvision_models.__dict__.keys():
48
+ model = torchvision_models.__dict__[args.arch]()
49
+ embed_dim = model.fc.weight.shape[1]
50
+ model.fc = nn.Identity()
51
+ else:
52
+ print(f"Unknow architecture: {args.arch}")
53
+ sys.exit(1)
54
+ model.cuda()
55
+ model.eval()
56
+ # load weights to evaluate
57
+ utils.load_pretrained_weights(model, args.pretrained_weights, args.checkpoint_key, args.arch, args.patch_size)
58
+ print(f"Model {args.arch} built.")
59
+
60
+ linear_classifier = LinearClassifier(embed_dim, num_labels=args.num_labels)
61
+ linear_classifier = linear_classifier.cuda()
62
+ linear_classifier = nn.parallel.DistributedDataParallel(linear_classifier, device_ids=[args.gpu])
63
+
64
+ # ============ preparing data ... ============
65
+ val_transform = pth_transforms.Compose([
66
+ pth_transforms.Resize(256, interpolation=3),
67
+ pth_transforms.CenterCrop(224),
68
+ pth_transforms.ToTensor(),
69
+ pth_transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
70
+ ])
71
+ dataset_val = datasets.ImageFolder(os.path.join(args.data_path, "val"), transform=val_transform)
72
+ val_loader = torch.utils.data.DataLoader(
73
+ dataset_val,
74
+ batch_size=args.batch_size_per_gpu,
75
+ num_workers=args.num_workers,
76
+ pin_memory=True,
77
+ )
78
+
79
+ if args.evaluate:
80
+ utils.load_pretrained_linear_weights(linear_classifier, args.arch, args.patch_size)
81
+ test_stats = validate_network(val_loader, model, linear_classifier, args.n_last_blocks, args.avgpool_patchtokens)
82
+ print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%")
83
+ return
84
+
85
+ train_transform = pth_transforms.Compose([
86
+ pth_transforms.RandomResizedCrop(224),
87
+ pth_transforms.RandomHorizontalFlip(),
88
+ pth_transforms.ToTensor(),
89
+ pth_transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
90
+ ])
91
+ dataset_train = datasets.ImageFolder(os.path.join(args.data_path, "train"), transform=train_transform)
92
+ sampler = torch.utils.data.distributed.DistributedSampler(dataset_train)
93
+ train_loader = torch.utils.data.DataLoader(
94
+ dataset_train,
95
+ sampler=sampler,
96
+ batch_size=args.batch_size_per_gpu,
97
+ num_workers=args.num_workers,
98
+ pin_memory=True,
99
+ )
100
+ print(f"Data loaded with {len(dataset_train)} train and {len(dataset_val)} val imgs.")
101
+
102
+ # set optimizer
103
+ optimizer = torch.optim.SGD(
104
+ linear_classifier.parameters(),
105
+ args.lr * (args.batch_size_per_gpu * utils.get_world_size()) / 256., # linear scaling rule
106
+ momentum=0.9,
107
+ weight_decay=0, # we do not apply weight decay
108
+ )
109
+ scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, args.epochs, eta_min=0)
110
+
111
+ # Optionally resume from a checkpoint
112
+ to_restore = {"epoch": 0, "best_acc": 0.}
113
+ utils.restart_from_checkpoint(
114
+ os.path.join(args.output_dir, "checkpoint.pth.tar"),
115
+ run_variables=to_restore,
116
+ state_dict=linear_classifier,
117
+ optimizer=optimizer,
118
+ scheduler=scheduler,
119
+ )
120
+ start_epoch = to_restore["epoch"]
121
+ best_acc = to_restore["best_acc"]
122
+
123
+ for epoch in range(start_epoch, args.epochs):
124
+ train_loader.sampler.set_epoch(epoch)
125
+
126
+ train_stats = train(model, linear_classifier, optimizer, train_loader, epoch, args.n_last_blocks, args.avgpool_patchtokens)
127
+ scheduler.step()
128
+
129
+ log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
130
+ 'epoch': epoch}
131
+ if epoch % args.val_freq == 0 or epoch == args.epochs - 1:
132
+ test_stats = validate_network(val_loader, model, linear_classifier, args.n_last_blocks, args.avgpool_patchtokens)
133
+ print(f"Accuracy at epoch {epoch} of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%")
134
+ best_acc = max(best_acc, test_stats["acc1"])
135
+ print(f'Max accuracy so far: {best_acc:.2f}%')
136
+ log_stats = {**{k: v for k, v in log_stats.items()},
137
+ **{f'test_{k}': v for k, v in test_stats.items()}}
138
+ if utils.is_main_process():
139
+ with (Path(args.output_dir) / "log.txt").open("a") as f:
140
+ f.write(json.dumps(log_stats) + "\n")
141
+ save_dict = {
142
+ "epoch": epoch + 1,
143
+ "state_dict": linear_classifier.state_dict(),
144
+ "optimizer": optimizer.state_dict(),
145
+ "scheduler": scheduler.state_dict(),
146
+ "best_acc": best_acc,
147
+ }
148
+ torch.save(save_dict, os.path.join(args.output_dir, "checkpoint.pth.tar"))
149
+ print("Training of the supervised linear classifier on frozen features completed.\n"
150
+ "Top-1 test accuracy: {acc:.1f}".format(acc=best_acc))
151
+
152
+
153
+ def train(model, linear_classifier, optimizer, loader, epoch, n, avgpool):
154
+ linear_classifier.train()
155
+ metric_logger = utils.MetricLogger(delimiter=" ")
156
+ metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
157
+ header = 'Epoch: [{}]'.format(epoch)
158
+ for (inp, target) in metric_logger.log_every(loader, 20, header):
159
+ # move to gpu
160
+ inp = inp.cuda(non_blocking=True)
161
+ target = target.cuda(non_blocking=True)
162
+
163
+ # forward
164
+ with torch.no_grad():
165
+ if "vit" in args.arch:
166
+ intermediate_output = model.get_intermediate_layers(inp, n)
167
+ output = torch.cat([x[:, 0] for x in intermediate_output], dim=-1)
168
+ if avgpool:
169
+ output = torch.cat((output.unsqueeze(-1), torch.mean(intermediate_output[-1][:, 1:], dim=1).unsqueeze(-1)), dim=-1)
170
+ output = output.reshape(output.shape[0], -1)
171
+ else:
172
+ output = model(inp)
173
+ output = linear_classifier(output)
174
+
175
+ # compute cross entropy loss
176
+ loss = nn.CrossEntropyLoss()(output, target)
177
+
178
+ # compute the gradients
179
+ optimizer.zero_grad()
180
+ loss.backward()
181
+
182
+ # step
183
+ optimizer.step()
184
+
185
+ # log
186
+ torch.cuda.synchronize()
187
+ metric_logger.update(loss=loss.item())
188
+ metric_logger.update(lr=optimizer.param_groups[0]["lr"])
189
+ # gather the stats from all processes
190
+ metric_logger.synchronize_between_processes()
191
+ print("Averaged stats:", metric_logger)
192
+ return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
193
+
194
+
195
+ @torch.no_grad()
196
+ def validate_network(val_loader, model, linear_classifier, n, avgpool):
197
+ linear_classifier.eval()
198
+ metric_logger = utils.MetricLogger(delimiter=" ")
199
+ header = 'Test:'
200
+ for inp, target in metric_logger.log_every(val_loader, 20, header):
201
+ # move to gpu
202
+ inp = inp.cuda(non_blocking=True)
203
+ target = target.cuda(non_blocking=True)
204
+
205
+ # forward
206
+ with torch.no_grad():
207
+ if "vit" in args.arch:
208
+ intermediate_output = model.get_intermediate_layers(inp, n)
209
+ output = torch.cat([x[:, 0] for x in intermediate_output], dim=-1)
210
+ if avgpool:
211
+ output = torch.cat((output.unsqueeze(-1), torch.mean(intermediate_output[-1][:, 1:], dim=1).unsqueeze(-1)), dim=-1)
212
+ output = output.reshape(output.shape[0], -1)
213
+ else:
214
+ output = model(inp)
215
+ output = linear_classifier(output)
216
+ loss = nn.CrossEntropyLoss()(output, target)
217
+
218
+ if linear_classifier.module.num_labels >= 5:
219
+ acc1, acc5 = utils.accuracy(output, target, topk=(1, 5))
220
+ else:
221
+ acc1, = utils.accuracy(output, target, topk=(1,))
222
+
223
+ batch_size = inp.shape[0]
224
+ metric_logger.update(loss=loss.item())
225
+ metric_logger.meters['acc1'].update(acc1.item(), n=batch_size)
226
+ if linear_classifier.module.num_labels >= 5:
227
+ metric_logger.meters['acc5'].update(acc5.item(), n=batch_size)
228
+ if linear_classifier.module.num_labels >= 5:
229
+ print('* Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f} loss {losses.global_avg:.3f}'
230
+ .format(top1=metric_logger.acc1, top5=metric_logger.acc5, losses=metric_logger.loss))
231
+ else:
232
+ print('* Acc@1 {top1.global_avg:.3f} loss {losses.global_avg:.3f}'
233
+ .format(top1=metric_logger.acc1, losses=metric_logger.loss))
234
+ return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
235
+
236
+
237
+ class LinearClassifier(nn.Module):
238
+ """Linear layer to train on top of frozen features"""
239
+ def __init__(self, dim, num_labels=1000):
240
+ super(LinearClassifier, self).__init__()
241
+ self.num_labels = num_labels
242
+ self.linear = nn.Linear(dim, num_labels)
243
+ self.linear.weight.data.normal_(mean=0.0, std=0.01)
244
+ self.linear.bias.data.zero_()
245
+
246
+ def forward(self, x):
247
+ # flatten
248
+ x = x.view(x.size(0), -1)
249
+
250
+ # linear layer
251
+ return self.linear(x)
252
+
253
+
254
+ if __name__ == '__main__':
255
+ parser = argparse.ArgumentParser('Evaluation with linear classification on ImageNet')
256
+ parser.add_argument('--n_last_blocks', default=4, type=int, help="""Concatenate [CLS] tokens
257
+ for the `n` last blocks. We use `n=4` when evaluating ViT-Small and `n=1` with ViT-Base.""")
258
+ parser.add_argument('--avgpool_patchtokens', default=False, type=utils.bool_flag,
259
+ help="""Whether ot not to concatenate the global average pooled features to the [CLS] token.
260
+ We typically set this to False for ViT-Small and to True with ViT-Base.""")
261
+ parser.add_argument('--arch', default='vit_small', type=str, help='Architecture')
262
+ parser.add_argument('--patch_size', default=16, type=int, help='Patch resolution of the model.')
263
+ parser.add_argument('--pretrained_weights', default='', type=str, help="Path to pretrained weights to evaluate.")
264
+ parser.add_argument("--checkpoint_key", default="teacher", type=str, help='Key to use in the checkpoint (example: "teacher")')
265
+ parser.add_argument('--epochs', default=100, type=int, help='Number of epochs of training.')
266
+ parser.add_argument("--lr", default=0.001, type=float, help="""Learning rate at the beginning of
267
+ training (highest LR used during training). The learning rate is linearly scaled
268
+ with the batch size, and specified here for a reference batch size of 256.
269
+ We recommend tweaking the LR depending on the checkpoint evaluated.""")
270
+ parser.add_argument('--batch_size_per_gpu', default=128, type=int, help='Per-GPU batch-size')
271
+ parser.add_argument("--dist_url", default="env://", type=str, help="""url used to set up
272
+ distributed training; see https://pytorch.org/docs/stable/distributed.html""")
273
+ parser.add_argument("--local_rank", default=0, type=int, help="Please ignore and do not set this argument.")
274
+ parser.add_argument('--data_path', default='/path/to/imagenet/', type=str)
275
+ parser.add_argument('--num_workers', default=10, type=int, help='Number of data loading workers per GPU.')
276
+ parser.add_argument('--val_freq', default=1, type=int, help="Epoch frequency for validation.")
277
+ parser.add_argument('--output_dir', default=".", help='Path to save logs and checkpoints')
278
+ parser.add_argument('--num_labels', default=1000, type=int, help='Number of labels for linear classifier')
279
+ parser.add_argument('--evaluate', dest='evaluate', action='store_true', help='evaluate model on validation set')
280
+ args = parser.parse_args()
281
+ eval_linear(args)
PuzzleTuning/Counterpart PreTrain Methods/dino-main/eval_video_segmentation.py ADDED
@@ -0,0 +1,292 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """
15
+ Some parts are taken from https://github.com/Liusifei/UVC
16
+ """
17
+ import os
18
+ import copy
19
+ import glob
20
+ import queue
21
+ from urllib.request import urlopen
22
+ import argparse
23
+ import numpy as np
24
+ from tqdm import tqdm
25
+
26
+ import cv2
27
+ import torch
28
+ import torch.nn as nn
29
+ from torch.nn import functional as F
30
+ from PIL import Image
31
+ from torchvision import transforms
32
+
33
+ import utils
34
+ import vision_transformer as vits
35
+
36
+
37
+ @torch.no_grad()
38
+ def eval_video_tracking_davis(args, model, frame_list, video_dir, first_seg, seg_ori, color_palette):
39
+ """
40
+ Evaluate tracking on a video given first frame & segmentation
41
+ """
42
+ video_folder = os.path.join(args.output_dir, video_dir.split('/')[-1])
43
+ os.makedirs(video_folder, exist_ok=True)
44
+
45
+ # The queue stores the n preceeding frames
46
+ que = queue.Queue(args.n_last_frames)
47
+
48
+ # first frame
49
+ frame1, ori_h, ori_w = read_frame(frame_list[0])
50
+ # extract first frame feature
51
+ frame1_feat = extract_feature(model, frame1).T # dim x h*w
52
+
53
+ # saving first segmentation
54
+ out_path = os.path.join(video_folder, "00000.png")
55
+ imwrite_indexed(out_path, seg_ori, color_palette)
56
+ mask_neighborhood = None
57
+ for cnt in tqdm(range(1, len(frame_list))):
58
+ frame_tar = read_frame(frame_list[cnt])[0]
59
+
60
+ # we use the first segmentation and the n previous ones
61
+ used_frame_feats = [frame1_feat] + [pair[0] for pair in list(que.queue)]
62
+ used_segs = [first_seg] + [pair[1] for pair in list(que.queue)]
63
+
64
+ frame_tar_avg, feat_tar, mask_neighborhood = label_propagation(args, model, frame_tar, used_frame_feats, used_segs, mask_neighborhood)
65
+
66
+ # pop out oldest frame if neccessary
67
+ if que.qsize() == args.n_last_frames:
68
+ que.get()
69
+ # push current results into queue
70
+ seg = copy.deepcopy(frame_tar_avg)
71
+ que.put([feat_tar, seg])
72
+
73
+ # upsampling & argmax
74
+ frame_tar_avg = F.interpolate(frame_tar_avg, scale_factor=args.patch_size, mode='bilinear', align_corners=False, recompute_scale_factor=False)[0]
75
+ frame_tar_avg = norm_mask(frame_tar_avg)
76
+ _, frame_tar_seg = torch.max(frame_tar_avg, dim=0)
77
+
78
+ # saving to disk
79
+ frame_tar_seg = np.array(frame_tar_seg.squeeze().cpu(), dtype=np.uint8)
80
+ frame_tar_seg = np.array(Image.fromarray(frame_tar_seg).resize((ori_w, ori_h), 0))
81
+ frame_nm = frame_list[cnt].split('/')[-1].replace(".jpg", ".png")
82
+ imwrite_indexed(os.path.join(video_folder, frame_nm), frame_tar_seg, color_palette)
83
+
84
+
85
+ def restrict_neighborhood(h, w):
86
+ # We restrict the set of source nodes considered to a spatial neighborhood of the query node (i.e. ``local attention'')
87
+ mask = torch.zeros(h, w, h, w)
88
+ for i in range(h):
89
+ for j in range(w):
90
+ for p in range(2 * args.size_mask_neighborhood + 1):
91
+ for q in range(2 * args.size_mask_neighborhood + 1):
92
+ if i - args.size_mask_neighborhood + p < 0 or i - args.size_mask_neighborhood + p >= h:
93
+ continue
94
+ if j - args.size_mask_neighborhood + q < 0 or j - args.size_mask_neighborhood + q >= w:
95
+ continue
96
+ mask[i, j, i - args.size_mask_neighborhood + p, j - args.size_mask_neighborhood + q] = 1
97
+
98
+ mask = mask.reshape(h * w, h * w)
99
+ return mask.cuda(non_blocking=True)
100
+
101
+
102
+ def norm_mask(mask):
103
+ c, h, w = mask.size()
104
+ for cnt in range(c):
105
+ mask_cnt = mask[cnt,:,:]
106
+ if(mask_cnt.max() > 0):
107
+ mask_cnt = (mask_cnt - mask_cnt.min())
108
+ mask_cnt = mask_cnt/mask_cnt.max()
109
+ mask[cnt,:,:] = mask_cnt
110
+ return mask
111
+
112
+
113
+ def label_propagation(args, model, frame_tar, list_frame_feats, list_segs, mask_neighborhood=None):
114
+ """
115
+ propagate segs of frames in list_frames to frame_tar
116
+ """
117
+ ## we only need to extract feature of the target frame
118
+ feat_tar, h, w = extract_feature(model, frame_tar, return_h_w=True)
119
+
120
+ return_feat_tar = feat_tar.T # dim x h*w
121
+
122
+ ncontext = len(list_frame_feats)
123
+ feat_sources = torch.stack(list_frame_feats) # nmb_context x dim x h*w
124
+
125
+ feat_tar = F.normalize(feat_tar, dim=1, p=2)
126
+ feat_sources = F.normalize(feat_sources, dim=1, p=2)
127
+
128
+ feat_tar = feat_tar.unsqueeze(0).repeat(ncontext, 1, 1)
129
+ aff = torch.exp(torch.bmm(feat_tar, feat_sources) / 0.1) # nmb_context x h*w (tar: query) x h*w (source: keys)
130
+
131
+ if args.size_mask_neighborhood > 0:
132
+ if mask_neighborhood is None:
133
+ mask_neighborhood = restrict_neighborhood(h, w)
134
+ mask_neighborhood = mask_neighborhood.unsqueeze(0).repeat(ncontext, 1, 1)
135
+ aff *= mask_neighborhood
136
+
137
+ aff = aff.transpose(2, 1).reshape(-1, h * w) # nmb_context*h*w (source: keys) x h*w (tar: queries)
138
+ tk_val, _ = torch.topk(aff, dim=0, k=args.topk)
139
+ tk_val_min, _ = torch.min(tk_val, dim=0)
140
+ aff[aff < tk_val_min] = 0
141
+
142
+ aff = aff / torch.sum(aff, keepdim=True, axis=0)
143
+
144
+ list_segs = [s.cuda() for s in list_segs]
145
+ segs = torch.cat(list_segs)
146
+ nmb_context, C, h, w = segs.shape
147
+ segs = segs.reshape(nmb_context, C, -1).transpose(2, 1).reshape(-1, C).T # C x nmb_context*h*w
148
+ seg_tar = torch.mm(segs, aff)
149
+ seg_tar = seg_tar.reshape(1, C, h, w)
150
+ return seg_tar, return_feat_tar, mask_neighborhood
151
+
152
+
153
+ def extract_feature(model, frame, return_h_w=False):
154
+ """Extract one frame feature everytime."""
155
+ out = model.get_intermediate_layers(frame.unsqueeze(0).cuda(), n=1)[0]
156
+ out = out[:, 1:, :] # we discard the [CLS] token
157
+ h, w = int(frame.shape[1] / model.patch_embed.patch_size), int(frame.shape[2] / model.patch_embed.patch_size)
158
+ dim = out.shape[-1]
159
+ out = out[0].reshape(h, w, dim)
160
+ out = out.reshape(-1, dim)
161
+ if return_h_w:
162
+ return out, h, w
163
+ return out
164
+
165
+
166
+ def imwrite_indexed(filename, array, color_palette):
167
+ """ Save indexed png for DAVIS."""
168
+ if np.atleast_3d(array).shape[2] != 1:
169
+ raise Exception("Saving indexed PNGs requires 2D array.")
170
+
171
+ im = Image.fromarray(array)
172
+ im.putpalette(color_palette.ravel())
173
+ im.save(filename, format='PNG')
174
+
175
+
176
+ def to_one_hot(y_tensor, n_dims=None):
177
+ """
178
+ Take integer y (tensor or variable) with n dims &
179
+ convert it to 1-hot representation with n+1 dims.
180
+ """
181
+ if(n_dims is None):
182
+ n_dims = int(y_tensor.max()+ 1)
183
+ _,h,w = y_tensor.size()
184
+ y_tensor = y_tensor.type(torch.LongTensor).view(-1, 1)
185
+ n_dims = n_dims if n_dims is not None else int(torch.max(y_tensor)) + 1
186
+ y_one_hot = torch.zeros(y_tensor.size()[0], n_dims).scatter_(1, y_tensor, 1)
187
+ y_one_hot = y_one_hot.view(h,w,n_dims)
188
+ return y_one_hot.permute(2, 0, 1).unsqueeze(0)
189
+
190
+
191
+ def read_frame_list(video_dir):
192
+ frame_list = [img for img in glob.glob(os.path.join(video_dir,"*.jpg"))]
193
+ frame_list = sorted(frame_list)
194
+ return frame_list
195
+
196
+
197
+ def read_frame(frame_dir, scale_size=[480]):
198
+ """
199
+ read a single frame & preprocess
200
+ """
201
+ img = cv2.imread(frame_dir)
202
+ ori_h, ori_w, _ = img.shape
203
+ if len(scale_size) == 1:
204
+ if(ori_h > ori_w):
205
+ tw = scale_size[0]
206
+ th = (tw * ori_h) / ori_w
207
+ th = int((th // 64) * 64)
208
+ else:
209
+ th = scale_size[0]
210
+ tw = (th * ori_w) / ori_h
211
+ tw = int((tw // 64) * 64)
212
+ else:
213
+ th, tw = scale_size
214
+ img = cv2.resize(img, (tw, th))
215
+ img = img.astype(np.float32)
216
+ img = img / 255.0
217
+ img = img[:, :, ::-1]
218
+ img = np.transpose(img.copy(), (2, 0, 1))
219
+ img = torch.from_numpy(img).float()
220
+ img = color_normalize(img)
221
+ return img, ori_h, ori_w
222
+
223
+
224
+ def read_seg(seg_dir, factor, scale_size=[480]):
225
+ seg = Image.open(seg_dir)
226
+ _w, _h = seg.size # note PIL.Image.Image's size is (w, h)
227
+ if len(scale_size) == 1:
228
+ if(_w > _h):
229
+ _th = scale_size[0]
230
+ _tw = (_th * _w) / _h
231
+ _tw = int((_tw // 64) * 64)
232
+ else:
233
+ _tw = scale_size[0]
234
+ _th = (_tw * _h) / _w
235
+ _th = int((_th // 64) * 64)
236
+ else:
237
+ _th = scale_size[1]
238
+ _tw = scale_size[0]
239
+ small_seg = np.array(seg.resize((_tw // factor, _th // factor), 0))
240
+ small_seg = torch.from_numpy(small_seg.copy()).contiguous().float().unsqueeze(0)
241
+ return to_one_hot(small_seg), np.asarray(seg)
242
+
243
+
244
+ def color_normalize(x, mean=[0.485, 0.456, 0.406], std=[0.228, 0.224, 0.225]):
245
+ for t, m, s in zip(x, mean, std):
246
+ t.sub_(m)
247
+ t.div_(s)
248
+ return x
249
+
250
+
251
+ if __name__ == '__main__':
252
+ parser = argparse.ArgumentParser('Evaluation with video object segmentation on DAVIS 2017')
253
+ parser.add_argument('--pretrained_weights', default='', type=str, help="Path to pretrained weights to evaluate.")
254
+ parser.add_argument('--arch', default='vit_small', type=str,
255
+ choices=['vit_tiny', 'vit_small', 'vit_base'], help='Architecture (support only ViT atm).')
256
+ parser.add_argument('--patch_size', default=16, type=int, help='Patch resolution of the model.')
257
+ parser.add_argument("--checkpoint_key", default="teacher", type=str, help='Key to use in the checkpoint (example: "teacher")')
258
+ parser.add_argument('--output_dir', default=".", help='Path where to save segmentations')
259
+ parser.add_argument('--data_path', default='/path/to/davis/', type=str)
260
+ parser.add_argument("--n_last_frames", type=int, default=7, help="number of preceeding frames")
261
+ parser.add_argument("--size_mask_neighborhood", default=12, type=int,
262
+ help="We restrict the set of source nodes considered to a spatial neighborhood of the query node")
263
+ parser.add_argument("--topk", type=int, default=5, help="accumulate label from top k neighbors")
264
+ parser.add_argument("--bs", type=int, default=6, help="Batch size, try to reduce if OOM")
265
+ args = parser.parse_args()
266
+
267
+ print("git:\n {}\n".format(utils.get_sha()))
268
+ print("\n".join("%s: %s" % (k, str(v)) for k, v in sorted(dict(vars(args)).items())))
269
+
270
+ # building network
271
+ model = vits.__dict__[args.arch](patch_size=args.patch_size, num_classes=0)
272
+ print(f"Model {args.arch} {args.patch_size}x{args.patch_size} built.")
273
+ model.cuda()
274
+ utils.load_pretrained_weights(model, args.pretrained_weights, args.checkpoint_key, args.arch, args.patch_size)
275
+ for param in model.parameters():
276
+ param.requires_grad = False
277
+ model.eval()
278
+
279
+ color_palette = []
280
+ for line in urlopen("https://raw.githubusercontent.com/Liusifei/UVC/master/libs/data/palette.txt"):
281
+ color_palette.append([int(i) for i in line.decode("utf-8").split('\n')[0].split(" ")])
282
+ color_palette = np.asarray(color_palette, dtype=np.uint8).reshape(-1,3)
283
+
284
+ video_list = open(os.path.join(args.data_path, "ImageSets/2017/val.txt")).readlines()
285
+ for i, video_name in enumerate(video_list):
286
+ video_name = video_name.strip()
287
+ print(f'[{i}/{len(video_list)}] Begin to segmentate video {video_name}.')
288
+ video_dir = os.path.join(args.data_path, "JPEGImages/480p/", video_name)
289
+ frame_list = read_frame_list(video_dir)
290
+ seg_path = frame_list[0].replace("JPEGImages", "Annotations").replace("jpg", "png")
291
+ first_seg, seg_ori = read_seg(seg_path, args.patch_size)
292
+ eval_video_tracking_davis(args, model, frame_list, video_dir, first_seg, seg_ori, color_palette)
PuzzleTuning/Counterpart PreTrain Methods/dino-main/hubconf.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import torch
15
+ from torchvision.models.resnet import resnet50
16
+
17
+ import vision_transformer as vits
18
+
19
+ dependencies = ["torch", "torchvision"]
20
+
21
+
22
+ def dino_vits16(pretrained=True, **kwargs):
23
+ """
24
+ ViT-Small/16x16 pre-trained with DINO.
25
+ Achieves 74.5% top-1 accuracy on ImageNet with k-NN classification.
26
+ """
27
+ model = vits.__dict__["vit_small"](patch_size=16, num_classes=0, **kwargs)
28
+ if pretrained:
29
+ state_dict = torch.hub.load_state_dict_from_url(
30
+ url="https://dl.fbaipublicfiles.com/dino/dino_deitsmall16_pretrain/dino_deitsmall16_pretrain.pth",
31
+ map_location="cpu",
32
+ )
33
+ model.load_state_dict(state_dict, strict=True)
34
+ return model
35
+
36
+
37
+ def dino_vits8(pretrained=True, **kwargs):
38
+ """
39
+ ViT-Small/8x8 pre-trained with DINO.
40
+ Achieves 78.3% top-1 accuracy on ImageNet with k-NN classification.
41
+ """
42
+ model = vits.__dict__["vit_small"](patch_size=8, num_classes=0, **kwargs)
43
+ if pretrained:
44
+ state_dict = torch.hub.load_state_dict_from_url(
45
+ url="https://dl.fbaipublicfiles.com/dino/dino_deitsmall8_pretrain/dino_deitsmall8_pretrain.pth",
46
+ map_location="cpu",
47
+ )
48
+ model.load_state_dict(state_dict, strict=True)
49
+ return model
50
+
51
+
52
+ def dino_vitb16(pretrained=True, **kwargs):
53
+ """
54
+ ViT-Base/16x16 pre-trained with DINO.
55
+ Achieves 76.1% top-1 accuracy on ImageNet with k-NN classification.
56
+ """
57
+ model = vits.__dict__["vit_base"](patch_size=16, num_classes=0, **kwargs)
58
+ if pretrained:
59
+ state_dict = torch.hub.load_state_dict_from_url(
60
+ url="https://dl.fbaipublicfiles.com/dino/dino_vitbase16_pretrain/dino_vitbase16_pretrain.pth",
61
+ map_location="cpu",
62
+ )
63
+ model.load_state_dict(state_dict, strict=True)
64
+ return model
65
+
66
+
67
+ def dino_vitb8(pretrained=True, **kwargs):
68
+ """
69
+ ViT-Base/8x8 pre-trained with DINO.
70
+ Achieves 77.4% top-1 accuracy on ImageNet with k-NN classification.
71
+ """
72
+ model = vits.__dict__["vit_base"](patch_size=8, num_classes=0, **kwargs)
73
+ if pretrained:
74
+ state_dict = torch.hub.load_state_dict_from_url(
75
+ url="https://dl.fbaipublicfiles.com/dino/dino_vitbase8_pretrain/dino_vitbase8_pretrain.pth",
76
+ map_location="cpu",
77
+ )
78
+ model.load_state_dict(state_dict, strict=True)
79
+ return model
80
+
81
+
82
+ def dino_resnet50(pretrained=True, **kwargs):
83
+ """
84
+ ResNet-50 pre-trained with DINO.
85
+ Achieves 75.3% top-1 accuracy on ImageNet linear evaluation benchmark (requires to train `fc`).
86
+ """
87
+ model = resnet50(pretrained=False, **kwargs)
88
+ model.fc = torch.nn.Identity()
89
+ if pretrained:
90
+ state_dict = torch.hub.load_state_dict_from_url(
91
+ url="https://dl.fbaipublicfiles.com/dino/dino_resnet50_pretrain/dino_resnet50_pretrain.pth",
92
+ map_location="cpu",
93
+ )
94
+ model.load_state_dict(state_dict, strict=False)
95
+ return model
96
+
97
+
98
+ def dino_xcit_small_12_p16(pretrained=True, **kwargs):
99
+ """
100
+ XCiT-Small-12/16 pre-trained with DINO.
101
+ """
102
+ model = torch.hub.load('facebookresearch/xcit:main', "xcit_small_12_p16", num_classes=0, **kwargs)
103
+ if pretrained:
104
+ state_dict = torch.hub.load_state_dict_from_url(
105
+ url="https://dl.fbaipublicfiles.com/dino/dino_xcit_small_12_p16_pretrain/dino_xcit_small_12_p16_pretrain.pth",
106
+ map_location="cpu",
107
+ )
108
+ model.load_state_dict(state_dict, strict=True)
109
+ return model
110
+
111
+
112
+ def dino_xcit_small_12_p8(pretrained=True, **kwargs):
113
+ """
114
+ XCiT-Small-12/8 pre-trained with DINO.
115
+ """
116
+ model = torch.hub.load('facebookresearch/xcit:main', "xcit_small_12_p8", num_classes=0, **kwargs)
117
+ if pretrained:
118
+ state_dict = torch.hub.load_state_dict_from_url(
119
+ url="https://dl.fbaipublicfiles.com/dino/dino_xcit_small_12_p8_pretrain/dino_xcit_small_12_p8_pretrain.pth",
120
+ map_location="cpu",
121
+ )
122
+ model.load_state_dict(state_dict, strict=True)
123
+ return model
124
+
125
+
126
+ def dino_xcit_medium_24_p16(pretrained=True, **kwargs):
127
+ """
128
+ XCiT-Medium-24/16 pre-trained with DINO.
129
+ """
130
+ model = torch.hub.load('facebookresearch/xcit:main', "xcit_medium_24_p16", num_classes=0, **kwargs)
131
+ if pretrained:
132
+ state_dict = torch.hub.load_state_dict_from_url(
133
+ url="https://dl.fbaipublicfiles.com/dino/dino_xcit_medium_24_p16_pretrain/dino_xcit_medium_24_p16_pretrain.pth",
134
+ map_location="cpu",
135
+ )
136
+ model.load_state_dict(state_dict, strict=True)
137
+ return model
138
+
139
+
140
+ def dino_xcit_medium_24_p8(pretrained=True, **kwargs):
141
+ """
142
+ XCiT-Medium-24/8 pre-trained with DINO.
143
+ """
144
+ model = torch.hub.load('facebookresearch/xcit:main', "xcit_medium_24_p8", num_classes=0, **kwargs)
145
+ if pretrained:
146
+ state_dict = torch.hub.load_state_dict_from_url(
147
+ url="https://dl.fbaipublicfiles.com/dino/dino_xcit_medium_24_p8_pretrain/dino_xcit_medium_24_p8_pretrain.pth",
148
+ map_location="cpu",
149
+ )
150
+ model.load_state_dict(state_dict, strict=True)
151
+ return model
PuzzleTuning/Counterpart PreTrain Methods/dino-main/main_dino.py ADDED
@@ -0,0 +1,515 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import argparse
15
+ import os
16
+ import sys
17
+ import datetime
18
+ import time
19
+ import math
20
+ import json
21
+ from pathlib import Path
22
+
23
+ import numpy as np
24
+ from PIL import Image
25
+ import torch
26
+ import torch.nn as nn
27
+ import torch.distributed as dist
28
+ import torch.backends.cudnn as cudnn
29
+ import torch.nn.functional as F
30
+ from torchvision import datasets, transforms
31
+ from torchvision import models as torchvision_models
32
+
33
+ import utils
34
+ import vision_transformer as vits
35
+ from vision_transformer import DINOHead
36
+
37
+ torchvision_archs = sorted(name for name in torchvision_models.__dict__
38
+ if name.islower() and not name.startswith("__")
39
+ and callable(torchvision_models.__dict__[name]))
40
+
41
+ def get_args_parser():
42
+ parser = argparse.ArgumentParser('DINO', add_help=False)
43
+
44
+ # Model parameters
45
+ parser.add_argument('--arch', default='vit_base', type=str,
46
+ help="""Name of architecture to train. For quick experiments with ViTs,
47
+ we recommend using vit_tiny or vit_small.""")
48
+ parser.add_argument('--patch_size', default=16, type=int, help="""Size in pixels
49
+ of input square patches - default 16 (for 16x16 patches). Using smaller
50
+ values leads to better performance but requires more memory. Applies only
51
+ for ViTs (vit_tiny, vit_small and vit_base). If <16, we recommend disabling
52
+ mixed precision training (--use_fp16 false) to avoid unstabilities.""")
53
+ parser.add_argument('--input_size', default=224, type=int)
54
+ parser.add_argument('--out_dim', default=65536, type=int, help="""Dimensionality of
55
+ the DINO head output. For complex and large datasets large values (like 65k) work well.""")
56
+ parser.add_argument('--norm_last_layer', default=True, type=utils.bool_flag,
57
+ help="""Whether or not to weight normalize the last layer of the DINO head.
58
+ Not normalizing leads to better performance but can make the training unstable.
59
+ In our experiments, we typically set this paramater to False with vit_small and True with vit_base.""")
60
+ parser.add_argument('--momentum_teacher', default=0.996, type=float, help="""Base EMA
61
+ parameter for teacher update. The value is increased to 1 during training with cosine schedule.
62
+ We recommend setting a higher value with small batches: for example use 0.9995 with batch size of 256.""")
63
+ parser.add_argument('--use_bn_in_head', default=False, type=utils.bool_flag,
64
+ help="Whether to use batch normalizations in projection head (Default: False)")
65
+
66
+ # Temperature teacher parameters
67
+ parser.add_argument('--warmup_teacher_temp', default=0.04, type=float,
68
+ help="""Initial value for the teacher temperature: 0.04 works well in most cases.
69
+ Try decreasing it if the training loss does not decrease.""")
70
+ parser.add_argument('--teacher_temp', default=0.04, type=float, help="""Final value (after linear warmup)
71
+ of the teacher temperature. For most experiments, anything above 0.07 is unstable. We recommend
72
+ starting with the default value of 0.04 and increase this slightly if needed.""")
73
+ parser.add_argument('--warmup_teacher_temp_epochs', default=0, type=int,
74
+ help='Number of warmup epochs for the teacher temperature (Default: 30).')
75
+
76
+ # Training/Optimization parameters
77
+ parser.add_argument('--use_fp16', type=utils.bool_flag, default=True, help="""Whether or not
78
+ to use half precision for training. Improves training time and memory requirements,
79
+ but can provoke instability and slight decay of performance. We recommend disabling
80
+ mixed precision if the loss is unstable, if reducing the patch size or if training with bigger ViTs.""")
81
+ parser.add_argument('--weight_decay', type=float, default=0.04, help="""Initial value of the
82
+ weight decay. With ViT, a smaller value at the beginning of training works well.""")
83
+ parser.add_argument('--weight_decay_end', type=float, default=0.4, help="""Final value of the
84
+ weight decay. We use a cosine schedule for WD and using a larger decay by
85
+ the end of training improves performance for ViTs.""")
86
+ parser.add_argument('--clip_grad', type=float, default=3.0, help="""Maximal parameter
87
+ gradient norm if using gradient clipping. Clipping with norm .3 ~ 1.0 can
88
+ help optimization for larger ViT architectures. 0 for disabling.""")
89
+ parser.add_argument('--batch_size_per_gpu', default=512, type=int,
90
+ help='Per-GPU batch-size : number of distinct images loaded on one GPU.')
91
+ parser.add_argument('--epochs', default=150, type=int, help='Number of epochs of training.')
92
+ parser.add_argument('--freeze_last_layer', default=1, type=int, help="""Number of epochs
93
+ during which we keep the output layer fixed. Typically doing so during
94
+ the first epoch helps training. Try increasing this value if the loss does not decrease.""")
95
+ parser.add_argument("--lr", default=1.5e-4, type=float, help="""Learning rate at the end of
96
+ linear warmup (highest LR used during training). The learning rate is linearly scaled
97
+ with the batch size, and specified here for a reference batch size of 512.""")
98
+ parser.add_argument("--warmup_epochs", default=20, type=int,
99
+ help="Number of epochs for the linear learning-rate warm up.")
100
+ parser.add_argument('--min_lr', type=float, default=1e-6, help="""Target LR at the
101
+ end of optimization. We use a cosine LR schedule with linear warmup.""")
102
+ parser.add_argument('--optimizer', default='adamw', type=str,
103
+ choices=['adamw', 'sgd', 'lars'], help="""Type of optimizer. We recommend using adamw with ViTs.""")
104
+ parser.add_argument('--drop_path_rate', type=float, default=0.1, help="stochastic depth rate")
105
+
106
+ # Multi-crop parameters
107
+ parser.add_argument('--global_crops_scale', type=float, nargs='+', default=(0.4, 1.),
108
+ help="""Scale range of the cropped image before resizing, relatively to the origin image.
109
+ Used for large global view cropping. When disabling multi-crop (--local_crops_number 0), we
110
+ recommand using a wider range of scale ("--global_crops_scale 0.14 1." for example)""")
111
+ parser.add_argument('--local_crops_number', type=int, default=8, help="""Number of small
112
+ local views to generate. Set this parameter to 0 to disable multi-crop training.
113
+ When disabling multi-crop we recommend to use "--global_crops_scale 0.14 1." """)
114
+ parser.add_argument('--local_crops_scale', type=float, nargs='+', default=(0.05, 0.4),
115
+ help="""Scale range of the cropped image before resizing, relatively to the origin image.
116
+ Used for small local view cropping of multi-crop.""")
117
+
118
+ # Misc
119
+ parser.add_argument('--data_path', default='/root/autodl-tmp/All', type=str,
120
+ help='Please specify path to the ImageNet training data.')
121
+ parser.add_argument('--basic_state_dict', default='/root/autodl-tmp/ViT_b16_224_Imagenet.pth', type=str,
122
+ help='Load in pretrained or un-pretrained model pth')
123
+ parser.add_argument('--output_dir', default="/home/CPIA/saved_models/DINO", type=str, help='Path to save logs and checkpoints.')
124
+ parser.add_argument('--saveckp_freq', default=50, type=int, help='Save checkpoint every x epochs.')
125
+ parser.add_argument('--seed', default=0, type=int, help='Random seed.')
126
+ parser.add_argument('--num_workers', default=32, type=int, help='Number of data loading workers per GPU.')
127
+ parser.add_argument("--dist_url", default="env://", type=str, help="""url used to set up
128
+ distributed training; see https://pytorch.org/docs/stable/distributed.html""")
129
+ parser.add_argument("--local_rank", default=0, type=int, help="Please ignore and do not set this argument.")
130
+ return parser
131
+
132
+
133
+ def train_dino(args):
134
+ utils.init_distributed_mode(args)
135
+ utils.fix_random_seeds(args.seed)
136
+ print('ok')
137
+ print("git:\n {}\n".format(utils.get_sha()))
138
+ print("\n".join("%s: %s" % (k, str(v)) for k, v in sorted(dict(vars(args)).items())))
139
+ cudnn.benchmark = True
140
+
141
+ # ============ preparing data ... ============
142
+ transform = DataAugmentationDINO(
143
+ args.global_crops_scale,
144
+ args.local_crops_scale,
145
+ args.local_crops_number,
146
+ )
147
+ dataset = datasets.ImageFolder(args.data_path, transform=transform)
148
+ sampler = torch.utils.data.DistributedSampler(dataset, shuffle=True)
149
+ data_loader = torch.utils.data.DataLoader(
150
+ dataset,
151
+ sampler=sampler,
152
+ batch_size=args.batch_size_per_gpu,
153
+ num_workers=args.num_workers,
154
+ pin_memory=True,
155
+ drop_last=True,
156
+ )
157
+ print(f"Data loaded: there are {len(dataset)} images.")
158
+
159
+ # ============ building student and teacher networks ... ============
160
+ # we changed the name DeiT-S for ViT-S to avoid confusions
161
+ args.arch = args.arch.replace("deit", "vit")
162
+ # if the network is a Vision Transformer (i.e. vit_tiny, vit_small, vit_base)
163
+ if args.arch in vits.__dict__.keys():
164
+ student = vits.__dict__[args.arch](
165
+ patch_size=args.patch_size,
166
+ drop_path_rate=args.drop_path_rate, # stochastic depth
167
+ )
168
+ if args.basic_state_dict is not None: # Transfer-learning
169
+ try:
170
+ basic_state_dict = torch.load(args.basic_state_dict)
171
+ if 'model' in basic_state_dict:
172
+ basic_state_dict = basic_state_dict['model']
173
+ else:
174
+ pass
175
+ student.load_state_dict(basic_state_dict, False)
176
+
177
+ except:
178
+ print('erro in args.basic_state_dict:', args.basic_state_dict)
179
+ print('Student PreTuning Restart') # 没倒进去
180
+
181
+ else:
182
+ print('Student PreTuning with Transfer-learning with:', args.basic_state_dict)
183
+
184
+ else:
185
+ print('Student PreTuning Restart')
186
+ teacher = vits.__dict__[args.arch](patch_size=args.patch_size)
187
+ if args.basic_state_dict is not None: # Transfer-learning
188
+ try:
189
+ basic_state_dict = torch.load(args.basic_state_dict)
190
+ if 'model' in basic_state_dict:
191
+ basic_state_dict = basic_state_dict['model']
192
+ else:
193
+ pass
194
+ teacher.load_state_dict(basic_state_dict, False)
195
+
196
+ except:
197
+ print('erro in args.basic_state_dict:', args.basic_state_dict)
198
+ print('Teacher PreTuning Restart') # 没倒进去
199
+
200
+ else:
201
+ print('Teacher PreTuning with Transfer-learning with:', args.basic_state_dict)
202
+
203
+ else:
204
+ print('Teacher PreTuning Restart')
205
+ embed_dim = student.embed_dim
206
+ # if the network is a XCiT
207
+ """elif args.arch in torch.hub.list("facebookresearch/xcit:main"):
208
+ student = torch.hub.load('facebookresearch/xcit:main', args.arch,
209
+ pretrained=False, drop_path_rate=args.drop_path_rate)
210
+ teacher = torch.hub.load('facebookresearch/xcit:main', args.arch, pretrained=False)
211
+ embed_dim = student.embed_dim
212
+ # otherwise, we check if the architecture is in torchvision models
213
+ elif args.arch in torchvision_models.__dict__.keys():
214
+ student = torchvision_models.__dict__[args.arch]()
215
+ teacher = torchvision_models.__dict__[args.arch]()
216
+ embed_dim = student.fc.weight.shape[1]
217
+ else:
218
+ print(f"Unknow architecture: {args.arch}")"""
219
+
220
+ # multi-crop wrapper handles forward with inputs of different resolutions
221
+ student = utils.MultiCropWrapper(student, DINOHead(
222
+ embed_dim,
223
+ args.out_dim,
224
+ use_bn=args.use_bn_in_head,
225
+ norm_last_layer=args.norm_last_layer,
226
+ ))
227
+ teacher = utils.MultiCropWrapper(
228
+ teacher,
229
+ DINOHead(embed_dim, args.out_dim, args.use_bn_in_head),
230
+ )
231
+ # move networks to gpu
232
+ student, teacher = student.cuda(), teacher.cuda()
233
+ # synchronize batch norms (if any)
234
+ if utils.has_batchnorms(student):
235
+ student = nn.SyncBatchNorm.convert_sync_batchnorm(student)
236
+ teacher = nn.SyncBatchNorm.convert_sync_batchnorm(teacher)
237
+
238
+ # we need DDP wrapper to have synchro batch norms working...
239
+ teacher = nn.parallel.DistributedDataParallel(teacher, device_ids=[args.gpu])
240
+ teacher_without_ddp = teacher.module
241
+ else:
242
+ # teacher_without_ddp and teacher are the same thing
243
+ teacher_without_ddp = teacher
244
+ student = nn.parallel.DistributedDataParallel(student, device_ids=[args.gpu])
245
+ # teacher and student start with the same weights
246
+ teacher_without_ddp.load_state_dict(student.module.state_dict())
247
+ # there is no backpropagation through the teacher, so no need for gradients
248
+ for p in teacher.parameters():
249
+ p.requires_grad = False
250
+ print(f"Student and Teacher are built: they are both {args.arch} network.")
251
+
252
+ # ============ preparing loss ... ============
253
+ dino_loss = DINOLoss(
254
+ args.out_dim,
255
+ args.local_crops_number + 2, # total number of crops = 2 global crops + local_crops_number
256
+ args.warmup_teacher_temp,
257
+ args.teacher_temp,
258
+ args.warmup_teacher_temp_epochs,
259
+ args.epochs,
260
+ ).cuda()
261
+
262
+ # ============ preparing optimizer ... ============
263
+ params_groups = utils.get_params_groups(student)
264
+ if args.optimizer == "adamw":
265
+ optimizer = torch.optim.AdamW(params_groups) # to use with ViTs
266
+ elif args.optimizer == "sgd":
267
+ optimizer = torch.optim.SGD(params_groups, lr=0, momentum=0.9) # lr is set by scheduler
268
+ elif args.optimizer == "lars":
269
+ optimizer = utils.LARS(params_groups) # to use with convnet and large batches
270
+ # for mixed precision training
271
+ fp16_scaler = None
272
+ if args.use_fp16:
273
+ fp16_scaler = torch.cuda.amp.GradScaler()
274
+
275
+ # ============ init schedulers ... ============
276
+ lr_schedule = utils.cosine_scheduler(
277
+ args.lr * (args.batch_size_per_gpu * utils.get_world_size()) / 256., # linear scaling rule
278
+ args.min_lr,
279
+ args.epochs, len(data_loader),
280
+ warmup_epochs=args.warmup_epochs,
281
+ )
282
+ wd_schedule = utils.cosine_scheduler(
283
+ args.weight_decay,
284
+ args.weight_decay_end,
285
+ args.epochs, len(data_loader),
286
+ )
287
+ # momentum parameter is increased to 1. during training with a cosine schedule
288
+ momentum_schedule = utils.cosine_scheduler(args.momentum_teacher, 1,
289
+ args.epochs, len(data_loader))
290
+ print(f"Loss, optimizer and schedulers ready.")
291
+
292
+ # ============ optionally resume training ... ============
293
+ to_restore = {"epoch": 0}
294
+ utils.restart_from_checkpoint(
295
+ os.path.join(args.output_dir, "checkpoint.pth"),
296
+ run_variables=to_restore,
297
+ student=student,
298
+ teacher=teacher,
299
+ optimizer=optimizer,
300
+ fp16_scaler=fp16_scaler,
301
+ dino_loss=dino_loss,
302
+ )
303
+ start_epoch = to_restore["epoch"]
304
+
305
+ start_time = time.time()
306
+ print("Starting DINO training !")
307
+ for epoch in range(start_epoch, args.epochs):
308
+ data_loader.sampler.set_epoch(epoch)
309
+
310
+ # ============ training one epoch of DINO ... ============
311
+ train_stats = train_one_epoch(student, teacher, teacher_without_ddp, dino_loss,
312
+ data_loader, optimizer, lr_schedule, wd_schedule, momentum_schedule,
313
+ epoch, fp16_scaler, args)
314
+
315
+ # ============ writing logs ... ============
316
+ save_dict = {
317
+ 'student': student.state_dict(),
318
+ 'teacher': teacher.state_dict(),
319
+ 'optimizer': optimizer.state_dict(),
320
+ 'epoch': epoch + 1,
321
+ 'args': args,
322
+ 'dino_loss': dino_loss.state_dict(),
323
+ }
324
+ if fp16_scaler is not None:
325
+ save_dict['fp16_scaler'] = fp16_scaler.state_dict()
326
+ utils.save_on_master(save_dict, os.path.join(args.output_dir, 'checkpoint.pth'))
327
+ if args.saveckp_freq and (epoch % args.saveckp_freq == 0 or epoch + 1 == args.epochs):
328
+ initial_setting = os.path.split(args.basic_state_dict)[1].split('.')[0] # 'ViT_b16_224_Imagenet'
329
+ dataset_using = os.path.split(args.data_path)[1]
330
+ utils.save_on_master(save_dict, os.path.join(args.output_dir, f'dino_'+initial_setting + '_' + dataset_using+f'_checkpoint{epoch:04}.pth'))
331
+
332
+
333
+
334
+
335
+ log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
336
+ 'epoch': epoch}
337
+ if utils.is_main_process():
338
+ with (Path(args.output_dir) / "log.txt").open("a") as f:
339
+ f.write(json.dumps(log_stats) + "\n")
340
+ total_time = time.time() - start_time
341
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
342
+ print('Training time {}'.format(total_time_str))
343
+
344
+
345
+ def train_one_epoch(student, teacher, teacher_without_ddp, dino_loss, data_loader,
346
+ optimizer, lr_schedule, wd_schedule, momentum_schedule,epoch,
347
+ fp16_scaler, args):
348
+ metric_logger = utils.MetricLogger(delimiter=" ")
349
+ header = 'Epoch: [{}/{}]'.format(epoch, args.epochs)
350
+ for it, (images, _) in enumerate(metric_logger.log_every(data_loader, 10, header)):
351
+ # update weight decay and learning rate according to their schedule
352
+ it = len(data_loader) * epoch + it # global training iteration
353
+ for i, param_group in enumerate(optimizer.param_groups):
354
+ param_group["lr"] = lr_schedule[it]
355
+ if i == 0: # only the first group is regularized
356
+ param_group["weight_decay"] = wd_schedule[it]
357
+
358
+ # move images to gpu
359
+ images = [im.cuda(non_blocking=True) for im in images]
360
+ # teacher and student forward passes + compute dino loss
361
+ with torch.cuda.amp.autocast(fp16_scaler is not None):
362
+ teacher_output = teacher(images[:2]) # only the 2 global views pass through the teacher
363
+ student_output = student(images)
364
+ loss = dino_loss(student_output, teacher_output, epoch)
365
+
366
+ if not math.isfinite(loss.item()):
367
+ print("Loss is {}, stopping training".format(loss.item()), force=True)
368
+ sys.exit(1)
369
+
370
+ # student update
371
+ optimizer.zero_grad()
372
+ param_norms = None
373
+ if fp16_scaler is None:
374
+ loss.backward()
375
+ if args.clip_grad:
376
+ param_norms = utils.clip_gradients(student, args.clip_grad)
377
+ utils.cancel_gradients_last_layer(epoch, student,
378
+ args.freeze_last_layer)
379
+ optimizer.step()
380
+ else:
381
+ fp16_scaler.scale(loss).backward()
382
+ if args.clip_grad:
383
+ fp16_scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place
384
+ param_norms = utils.clip_gradients(student, args.clip_grad)
385
+ utils.cancel_gradients_last_layer(epoch, student,
386
+ args.freeze_last_layer)
387
+ fp16_scaler.step(optimizer)
388
+ fp16_scaler.update()
389
+
390
+ # EMA update for the teacher
391
+ with torch.no_grad():
392
+ m = momentum_schedule[it] # momentum parameter
393
+ for param_q, param_k in zip(student.module.parameters(), teacher_without_ddp.parameters()):
394
+ param_k.data.mul_(m).add_((1 - m) * param_q.detach().data)
395
+
396
+ # logging
397
+ torch.cuda.synchronize()
398
+ metric_logger.update(loss=loss.item())
399
+ metric_logger.update(lr=optimizer.param_groups[0]["lr"])
400
+ metric_logger.update(wd=optimizer.param_groups[0]["weight_decay"])
401
+ # gather the stats from all processes
402
+ metric_logger.synchronize_between_processes()
403
+ print("Averaged stats:", metric_logger)
404
+ return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
405
+
406
+
407
+ class DINOLoss(nn.Module):
408
+ def __init__(self, out_dim, ncrops, warmup_teacher_temp, teacher_temp,
409
+ warmup_teacher_temp_epochs, nepochs, student_temp=0.1,
410
+ center_momentum=0.9):
411
+ super().__init__()
412
+ self.student_temp = student_temp
413
+ self.center_momentum = center_momentum
414
+ self.ncrops = ncrops
415
+ self.register_buffer("center", torch.zeros(1, out_dim))
416
+ # we apply a warm up for the teacher temperature because
417
+ # a too high temperature makes the training instable at the beginning
418
+ self.teacher_temp_schedule = np.concatenate((
419
+ np.linspace(warmup_teacher_temp,
420
+ teacher_temp, warmup_teacher_temp_epochs),
421
+ np.ones(nepochs - warmup_teacher_temp_epochs) * teacher_temp
422
+ ))
423
+
424
+ def forward(self, student_output, teacher_output, epoch):
425
+ """
426
+ Cross-entropy between softmax outputs of the teacher and student networks.
427
+ """
428
+ student_out = student_output / self.student_temp
429
+ student_out = student_out.chunk(self.ncrops)
430
+
431
+ # teacher centering and sharpening
432
+ temp = self.teacher_temp_schedule[epoch]
433
+ teacher_out = F.softmax((teacher_output - self.center) / temp, dim=-1)
434
+ teacher_out = teacher_out.detach().chunk(2)
435
+
436
+ total_loss = 0
437
+ n_loss_terms = 0
438
+ for iq, q in enumerate(teacher_out):
439
+ for v in range(len(student_out)):
440
+ if v == iq:
441
+ # we skip cases where student and teacher operate on the same view
442
+ continue
443
+ loss = torch.sum(-q * F.log_softmax(student_out[v], dim=-1), dim=-1)
444
+ total_loss += loss.mean()
445
+ n_loss_terms += 1
446
+ total_loss /= n_loss_terms
447
+ self.update_center(teacher_output)
448
+ return total_loss
449
+
450
+ @torch.no_grad()
451
+ def update_center(self, teacher_output):
452
+ """
453
+ Update center used for teacher output.
454
+ """
455
+ batch_center = torch.sum(teacher_output, dim=0, keepdim=True)
456
+ dist.all_reduce(batch_center)
457
+ batch_center = batch_center / (len(teacher_output) * dist.get_world_size())
458
+
459
+ # ema update
460
+ self.center = self.center * self.center_momentum + batch_center * (1 - self.center_momentum)
461
+
462
+
463
+ class DataAugmentationDINO(object):
464
+ def __init__(self, global_crops_scale, local_crops_scale, local_crops_number):
465
+ flip_and_color_jitter = transforms.Compose([
466
+ transforms.RandomHorizontalFlip(p=0.5),
467
+ transforms.RandomApply(
468
+ [transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.2, hue=0.1)],
469
+ p=0.8
470
+ ),
471
+ transforms.RandomGrayscale(p=0.2),
472
+ ])
473
+ normalize = transforms.Compose([
474
+ transforms.ToTensor(),
475
+ transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
476
+ ])
477
+
478
+ # first global crop
479
+ self.global_transfo1 = transforms.Compose([
480
+ transforms.RandomResizedCrop(224, scale=global_crops_scale, interpolation=Image.BICUBIC),
481
+ flip_and_color_jitter,
482
+ utils.GaussianBlur(1.0),
483
+ normalize,
484
+ ])
485
+ # second global crop
486
+ self.global_transfo2 = transforms.Compose([
487
+ transforms.RandomResizedCrop(224, scale=global_crops_scale, interpolation=Image.BICUBIC),
488
+ flip_and_color_jitter,
489
+ utils.GaussianBlur(0.1),
490
+ utils.Solarization(0.2),
491
+ normalize,
492
+ ])
493
+ # transformation for the local small crops
494
+ self.local_crops_number = local_crops_number
495
+ self.local_transfo = transforms.Compose([
496
+ transforms.RandomResizedCrop(96, scale=local_crops_scale, interpolation=Image.BICUBIC),
497
+ flip_and_color_jitter,
498
+ utils.GaussianBlur(p=0.5),
499
+ normalize,
500
+ ])
501
+
502
+ def __call__(self, image):
503
+ crops = []
504
+ crops.append(self.global_transfo1(image))
505
+ crops.append(self.global_transfo2(image))
506
+ for _ in range(self.local_crops_number):
507
+ crops.append(self.local_transfo(image))
508
+ return crops
509
+
510
+
511
+ if __name__ == '__main__':
512
+ parser = argparse.ArgumentParser('DINO', parents=[get_args_parser()])
513
+ args = parser.parse_args()
514
+ Path(args.output_dir).mkdir(parents=True, exist_ok=True)
515
+ train_dino(args)
PuzzleTuning/Counterpart PreTrain Methods/dino-main/run_with_submitit.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """
15
+ A script to run multinode training with submitit.
16
+ Almost copy-paste from https://github.com/facebookresearch/deit/blob/main/run_with_submitit.py
17
+ """
18
+ import argparse
19
+ import os
20
+ import uuid
21
+ from pathlib import Path
22
+
23
+ import main_dino
24
+ import submitit
25
+
26
+
27
+ def parse_args():
28
+ parser = argparse.ArgumentParser("Submitit for DINO", parents=[main_dino.get_args_parser()])
29
+ parser.add_argument("--ngpus", default=8, type=int, help="Number of gpus to request on each node")
30
+ parser.add_argument("--nodes", default=2, type=int, help="Number of nodes to request")
31
+ parser.add_argument("--timeout", default=2800, type=int, help="Duration of the job")
32
+
33
+ parser.add_argument("--partition", default="learnfair", type=str, help="Partition where to submit")
34
+ parser.add_argument("--use_volta32", action='store_true', help="Big models? Use this")
35
+ parser.add_argument('--comment', default="", type=str,
36
+ help='Comment to pass to scheduler, e.g. priority message')
37
+ return parser.parse_args()
38
+
39
+
40
+ def get_shared_folder() -> Path:
41
+ user = os.getenv("USER")
42
+ if Path("/checkpoint/").is_dir():
43
+ p = Path(f"/checkpoint/{user}/experiments")
44
+ p.mkdir(exist_ok=True)
45
+ return p
46
+ raise RuntimeError("No shared folder available")
47
+
48
+
49
+ def get_init_file():
50
+ # Init file must not exist, but it's parent dir must exist.
51
+ os.makedirs(str(get_shared_folder()), exist_ok=True)
52
+ init_file = get_shared_folder() / f"{uuid.uuid4().hex}_init"
53
+ if init_file.exists():
54
+ os.remove(str(init_file))
55
+ return init_file
56
+
57
+
58
+ class Trainer(object):
59
+ def __init__(self, args):
60
+ self.args = args
61
+
62
+ def __call__(self):
63
+ import main_dino
64
+
65
+ self._setup_gpu_args()
66
+ main_dino.train_dino(self.args)
67
+
68
+ def checkpoint(self):
69
+ import os
70
+ import submitit
71
+
72
+ self.args.dist_url = get_init_file().as_uri()
73
+ print("Requeuing ", self.args)
74
+ empty_trainer = type(self)(self.args)
75
+ return submitit.helpers.DelayedSubmission(empty_trainer)
76
+
77
+ def _setup_gpu_args(self):
78
+ import submitit
79
+ from pathlib import Path
80
+
81
+ job_env = submitit.JobEnvironment()
82
+ self.args.output_dir = Path(str(self.args.output_dir).replace("%j", str(job_env.job_id)))
83
+ self.args.gpu = job_env.local_rank
84
+ self.args.rank = job_env.global_rank
85
+ self.args.world_size = job_env.num_tasks
86
+ print(f"Process group: {job_env.num_tasks} tasks, rank: {job_env.global_rank}")
87
+
88
+
89
+ def main():
90
+ args = parse_args()
91
+ if args.output_dir == "":
92
+ args.output_dir = get_shared_folder() / "%j"
93
+ Path(args.output_dir).mkdir(parents=True, exist_ok=True)
94
+ executor = submitit.AutoExecutor(folder=args.output_dir, slurm_max_num_timeout=30)
95
+
96
+ num_gpus_per_node = args.ngpus
97
+ nodes = args.nodes
98
+ timeout_min = args.timeout
99
+
100
+ partition = args.partition
101
+ kwargs = {}
102
+ if args.use_volta32:
103
+ kwargs['slurm_constraint'] = 'volta32gb'
104
+ if args.comment:
105
+ kwargs['slurm_comment'] = args.comment
106
+
107
+ executor.update_parameters(
108
+ mem_gb=40 * num_gpus_per_node,
109
+ gpus_per_node=num_gpus_per_node,
110
+ tasks_per_node=num_gpus_per_node, # one task per GPU
111
+ cpus_per_task=10,
112
+ nodes=nodes,
113
+ timeout_min=timeout_min, # max is 60 * 72
114
+ # Below are cluster dependent parameters
115
+ slurm_partition=partition,
116
+ slurm_signal_delay_s=120,
117
+ **kwargs
118
+ )
119
+
120
+ executor.update_parameters(name="dino")
121
+
122
+ args.dist_url = get_init_file().as_uri()
123
+
124
+ trainer = Trainer(args)
125
+ job = executor.submit(trainer)
126
+
127
+ print(f"Submitted job_id: {job.job_id}")
128
+ print(f"Logs and checkpoints will be saved at: {args.output_dir}")
129
+
130
+
131
+ if __name__ == "__main__":
132
+ main()
PuzzleTuning/Counterpart PreTrain Methods/dino-main/utils.py ADDED
@@ -0,0 +1,829 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """
15
+ Misc functions.
16
+
17
+ Mostly copy-paste from torchvision references or other public repos like DETR:
18
+ https://github.com/facebookresearch/detr/blob/master/util/misc.py
19
+ """
20
+ import os
21
+ import sys
22
+ import time
23
+ import math
24
+ import random
25
+ import datetime
26
+ import subprocess
27
+ from collections import defaultdict, deque
28
+
29
+ import numpy as np
30
+ import torch
31
+ from torch import nn
32
+ import torch.distributed as dist
33
+ from PIL import ImageFilter, ImageOps
34
+
35
+
36
+ class GaussianBlur(object):
37
+ """
38
+ Apply Gaussian Blur to the PIL image.
39
+ """
40
+ def __init__(self, p=0.5, radius_min=0.1, radius_max=2.):
41
+ self.prob = p
42
+ self.radius_min = radius_min
43
+ self.radius_max = radius_max
44
+
45
+ def __call__(self, img):
46
+ do_it = random.random() <= self.prob
47
+ if not do_it:
48
+ return img
49
+
50
+ return img.filter(
51
+ ImageFilter.GaussianBlur(
52
+ radius=random.uniform(self.radius_min, self.radius_max)
53
+ )
54
+ )
55
+
56
+
57
+ class Solarization(object):
58
+ """
59
+ Apply Solarization to the PIL image.
60
+ """
61
+ def __init__(self, p):
62
+ self.p = p
63
+
64
+ def __call__(self, img):
65
+ if random.random() < self.p:
66
+ return ImageOps.solarize(img)
67
+ else:
68
+ return img
69
+
70
+
71
+ def load_pretrained_weights(model, pretrained_weights, checkpoint_key, model_name, patch_size):
72
+ if os.path.isfile(pretrained_weights):
73
+ state_dict = torch.load(pretrained_weights, map_location="cpu")
74
+ if checkpoint_key is not None and checkpoint_key in state_dict:
75
+ print(f"Take key {checkpoint_key} in provided checkpoint dict")
76
+ state_dict = state_dict[checkpoint_key]
77
+ # remove `module.` prefix
78
+ state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
79
+ # remove `backbone.` prefix induced by multicrop wrapper
80
+ state_dict = {k.replace("backbone.", ""): v for k, v in state_dict.items()}
81
+ msg = model.load_state_dict(state_dict, strict=False)
82
+ print('Pretrained weights found at {} and loaded with msg: {}'.format(pretrained_weights, msg))
83
+ else:
84
+ print("Please use the `--pretrained_weights` argument to indicate the path of the checkpoint to evaluate.")
85
+ url = None
86
+ if model_name == "vit_small" and patch_size == 16:
87
+ url = "dino_deitsmall16_pretrain/dino_deitsmall16_pretrain.pth"
88
+ elif model_name == "vit_small" and patch_size == 8:
89
+ url = "dino_deitsmall8_pretrain/dino_deitsmall8_pretrain.pth"
90
+ elif model_name == "vit_base" and patch_size == 16:
91
+ url = "dino_vitbase16_pretrain/dino_vitbase16_pretrain.pth"
92
+ elif model_name == "vit_base" and patch_size == 8:
93
+ url = "dino_vitbase8_pretrain/dino_vitbase8_pretrain.pth"
94
+ elif model_name == "xcit_small_12_p16":
95
+ url = "dino_xcit_small_12_p16_pretrain/dino_xcit_small_12_p16_pretrain.pth"
96
+ elif model_name == "xcit_small_12_p8":
97
+ url = "dino_xcit_small_12_p8_pretrain/dino_xcit_small_12_p8_pretrain.pth"
98
+ elif model_name == "xcit_medium_24_p16":
99
+ url = "dino_xcit_medium_24_p16_pretrain/dino_xcit_medium_24_p16_pretrain.pth"
100
+ elif model_name == "xcit_medium_24_p8":
101
+ url = "dino_xcit_medium_24_p8_pretrain/dino_xcit_medium_24_p8_pretrain.pth"
102
+ elif model_name == "resnet50":
103
+ url = "dino_resnet50_pretrain/dino_resnet50_pretrain.pth"
104
+ if url is not None:
105
+ print("Since no pretrained weights have been provided, we load the reference pretrained DINO weights.")
106
+ state_dict = torch.hub.load_state_dict_from_url(url="https://dl.fbaipublicfiles.com/dino/" + url)
107
+ model.load_state_dict(state_dict, strict=True)
108
+ else:
109
+ print("There is no reference weights available for this model => We use random weights.")
110
+
111
+
112
+ def load_pretrained_linear_weights(linear_classifier, model_name, patch_size):
113
+ url = None
114
+ if model_name == "vit_small" and patch_size == 16:
115
+ url = "dino_deitsmall16_pretrain/dino_deitsmall16_linearweights.pth"
116
+ elif model_name == "vit_small" and patch_size == 8:
117
+ url = "dino_deitsmall8_pretrain/dino_deitsmall8_linearweights.pth"
118
+ elif model_name == "vit_base" and patch_size == 16:
119
+ url = "dino_vitbase16_pretrain/dino_vitbase16_linearweights.pth"
120
+ elif model_name == "vit_base" and patch_size == 8:
121
+ url = "dino_vitbase8_pretrain/dino_vitbase8_linearweights.pth"
122
+ elif model_name == "resnet50":
123
+ url = "dino_resnet50_pretrain/dino_resnet50_linearweights.pth"
124
+ if url is not None:
125
+ print("We load the reference pretrained linear weights.")
126
+ state_dict = torch.hub.load_state_dict_from_url(url="https://dl.fbaipublicfiles.com/dino/" + url)["state_dict"]
127
+ linear_classifier.load_state_dict(state_dict, strict=True)
128
+ else:
129
+ print("We use random linear weights.")
130
+
131
+
132
+ def clip_gradients(model, clip):
133
+ norms = []
134
+ for name, p in model.named_parameters():
135
+ if p.grad is not None:
136
+ param_norm = p.grad.data.norm(2)
137
+ norms.append(param_norm.item())
138
+ clip_coef = clip / (param_norm + 1e-6)
139
+ if clip_coef < 1:
140
+ p.grad.data.mul_(clip_coef)
141
+ return norms
142
+
143
+
144
+ def cancel_gradients_last_layer(epoch, model, freeze_last_layer):
145
+ if epoch >= freeze_last_layer:
146
+ return
147
+ for n, p in model.named_parameters():
148
+ if "last_layer" in n:
149
+ p.grad = None
150
+
151
+
152
+ def restart_from_checkpoint(ckp_path, run_variables=None, **kwargs):
153
+ """
154
+ Re-start from checkpoint
155
+ """
156
+ if not os.path.isfile(ckp_path):
157
+ return
158
+ print("Found checkpoint at {}".format(ckp_path))
159
+
160
+ # open checkpoint file
161
+ checkpoint = torch.load(ckp_path, map_location="cpu")
162
+
163
+ # key is what to look for in the checkpoint file
164
+ # value is the object to load
165
+ # example: {'state_dict': model}
166
+ for key, value in kwargs.items():
167
+ if key in checkpoint and value is not None:
168
+ try:
169
+ msg = value.load_state_dict(checkpoint[key], strict=False)
170
+ print("=> loaded '{}' from checkpoint '{}' with msg {}".format(key, ckp_path, msg))
171
+ except TypeError:
172
+ try:
173
+ msg = value.load_state_dict(checkpoint[key])
174
+ print("=> loaded '{}' from checkpoint: '{}'".format(key, ckp_path))
175
+ except ValueError:
176
+ print("=> failed to load '{}' from checkpoint: '{}'".format(key, ckp_path))
177
+ else:
178
+ print("=> key '{}' not found in checkpoint: '{}'".format(key, ckp_path))
179
+
180
+ # re load variable important for the run
181
+ if run_variables is not None:
182
+ for var_name in run_variables:
183
+ if var_name in checkpoint:
184
+ run_variables[var_name] = checkpoint[var_name]
185
+
186
+
187
+ def cosine_scheduler(base_value, final_value, epochs, niter_per_ep, warmup_epochs=0, start_warmup_value=0):
188
+ warmup_schedule = np.array([])
189
+ warmup_iters = warmup_epochs * niter_per_ep
190
+ if warmup_epochs > 0:
191
+ warmup_schedule = np.linspace(start_warmup_value, base_value, warmup_iters)
192
+
193
+ iters = np.arange(epochs * niter_per_ep - warmup_iters)
194
+ schedule = final_value + 0.5 * (base_value - final_value) * (1 + np.cos(np.pi * iters / len(iters)))
195
+
196
+ schedule = np.concatenate((warmup_schedule, schedule))
197
+ assert len(schedule) == epochs * niter_per_ep
198
+ return schedule
199
+
200
+
201
+ def bool_flag(s):
202
+ """
203
+ Parse boolean arguments from the command line.
204
+ """
205
+ FALSY_STRINGS = {"off", "false", "0"}
206
+ TRUTHY_STRINGS = {"on", "true", "1"}
207
+ if s.lower() in FALSY_STRINGS:
208
+ return False
209
+ elif s.lower() in TRUTHY_STRINGS:
210
+ return True
211
+ else:
212
+ raise argparse.ArgumentTypeError("invalid value for a boolean flag")
213
+
214
+
215
+ def fix_random_seeds(seed=31):
216
+ """
217
+ Fix random seeds.
218
+ """
219
+ torch.manual_seed(seed)
220
+ torch.cuda.manual_seed_all(seed)
221
+ np.random.seed(seed)
222
+
223
+
224
+ class SmoothedValue(object):
225
+ """Track a series of values and provide access to smoothed values over a
226
+ window or the global series average.
227
+ """
228
+
229
+ def __init__(self, window_size=20, fmt=None):
230
+ if fmt is None:
231
+ fmt = "{median:.6f} ({global_avg:.6f})"
232
+ self.deque = deque(maxlen=window_size)
233
+ self.total = 0.0
234
+ self.count = 0
235
+ self.fmt = fmt
236
+
237
+ def update(self, value, n=1):
238
+ self.deque.append(value)
239
+ self.count += n
240
+ self.total += value * n
241
+
242
+ def synchronize_between_processes(self):
243
+ """
244
+ Warning: does not synchronize the deque!
245
+ """
246
+ if not is_dist_avail_and_initialized():
247
+ return
248
+ t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
249
+ dist.barrier()
250
+ dist.all_reduce(t)
251
+ t = t.tolist()
252
+ self.count = int(t[0])
253
+ self.total = t[1]
254
+
255
+ @property
256
+ def median(self):
257
+ d = torch.tensor(list(self.deque))
258
+ return d.median().item()
259
+
260
+ @property
261
+ def avg(self):
262
+ d = torch.tensor(list(self.deque), dtype=torch.float32)
263
+ return d.mean().item()
264
+
265
+ @property
266
+ def global_avg(self):
267
+ return self.total / self.count
268
+
269
+ @property
270
+ def max(self):
271
+ return max(self.deque)
272
+
273
+ @property
274
+ def value(self):
275
+ return self.deque[-1]
276
+
277
+ def __str__(self):
278
+ return self.fmt.format(
279
+ median=self.median,
280
+ avg=self.avg,
281
+ global_avg=self.global_avg,
282
+ max=self.max,
283
+ value=self.value)
284
+
285
+
286
+ def reduce_dict(input_dict, average=True):
287
+ """
288
+ Args:
289
+ input_dict (dict): all the values will be reduced
290
+ average (bool): whether to do average or sum
291
+ Reduce the values in the dictionary from all processes so that all processes
292
+ have the averaged results. Returns a dict with the same fields as
293
+ input_dict, after reduction.
294
+ """
295
+ world_size = get_world_size()
296
+ if world_size < 2:
297
+ return input_dict
298
+ with torch.no_grad():
299
+ names = []
300
+ values = []
301
+ # sort the keys so that they are consistent across processes
302
+ for k in sorted(input_dict.keys()):
303
+ names.append(k)
304
+ values.append(input_dict[k])
305
+ values = torch.stack(values, dim=0)
306
+ dist.all_reduce(values)
307
+ if average:
308
+ values /= world_size
309
+ reduced_dict = {k: v for k, v in zip(names, values)}
310
+ return reduced_dict
311
+
312
+
313
+ class MetricLogger(object):
314
+ def __init__(self, delimiter="\t"):
315
+ self.meters = defaultdict(SmoothedValue)
316
+ self.delimiter = delimiter
317
+
318
+ def update(self, **kwargs):
319
+ for k, v in kwargs.items():
320
+ if isinstance(v, torch.Tensor):
321
+ v = v.item()
322
+ assert isinstance(v, (float, int))
323
+ self.meters[k].update(v)
324
+
325
+ def __getattr__(self, attr):
326
+ if attr in self.meters:
327
+ return self.meters[attr]
328
+ if attr in self.__dict__:
329
+ return self.__dict__[attr]
330
+ raise AttributeError("'{}' object has no attribute '{}'".format(
331
+ type(self).__name__, attr))
332
+
333
+ def __str__(self):
334
+ loss_str = []
335
+ for name, meter in self.meters.items():
336
+ loss_str.append(
337
+ "{}: {}".format(name, str(meter))
338
+ )
339
+ return self.delimiter.join(loss_str)
340
+
341
+ def synchronize_between_processes(self):
342
+ for meter in self.meters.values():
343
+ meter.synchronize_between_processes()
344
+
345
+ def add_meter(self, name, meter):
346
+ self.meters[name] = meter
347
+
348
+ def log_every(self, iterable, print_freq, header=None):
349
+ i = 0
350
+ if not header:
351
+ header = ''
352
+ start_time = time.time()
353
+ end = time.time()
354
+ iter_time = SmoothedValue(fmt='{avg:.6f}')
355
+ data_time = SmoothedValue(fmt='{avg:.6f}')
356
+ space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
357
+ if torch.cuda.is_available():
358
+ log_msg = self.delimiter.join([
359
+ header,
360
+ '[{0' + space_fmt + '}/{1}]',
361
+ 'eta: {eta}',
362
+ '{meters}',
363
+ 'time: {time}',
364
+ 'data: {data}',
365
+ 'max mem: {memory:.0f}'
366
+ ])
367
+ else:
368
+ log_msg = self.delimiter.join([
369
+ header,
370
+ '[{0' + space_fmt + '}/{1}]',
371
+ 'eta: {eta}',
372
+ '{meters}',
373
+ 'time: {time}',
374
+ 'data: {data}'
375
+ ])
376
+ MB = 1024.0 * 1024.0
377
+ for obj in iterable:
378
+ data_time.update(time.time() - end)
379
+ yield obj
380
+ iter_time.update(time.time() - end)
381
+ if i % print_freq == 0 or i == len(iterable) - 1:
382
+ eta_seconds = iter_time.global_avg * (len(iterable) - i)
383
+ eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
384
+ if torch.cuda.is_available():
385
+ print(log_msg.format(
386
+ i, len(iterable), eta=eta_string,
387
+ meters=str(self),
388
+ time=str(iter_time), data=str(data_time),
389
+ memory=torch.cuda.max_memory_allocated() / MB))
390
+ else:
391
+ print(log_msg.format(
392
+ i, len(iterable), eta=eta_string,
393
+ meters=str(self),
394
+ time=str(iter_time), data=str(data_time)))
395
+ i += 1
396
+ end = time.time()
397
+ total_time = time.time() - start_time
398
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
399
+ print('{} Total time: {} ({:.6f} s / it)'.format(
400
+ header, total_time_str, total_time / len(iterable)))
401
+
402
+
403
+ def get_sha():
404
+ cwd = os.path.dirname(os.path.abspath(__file__))
405
+
406
+ def _run(command):
407
+ return subprocess.check_output(command, cwd=cwd).decode('ascii').strip()
408
+ sha = 'N/A'
409
+ diff = "clean"
410
+ branch = 'N/A'
411
+ try:
412
+ sha = _run(['git', 'rev-parse', 'HEAD'])
413
+ subprocess.check_output(['git', 'diff'], cwd=cwd)
414
+ diff = _run(['git', 'diff-index', 'HEAD'])
415
+ diff = "has uncommited changes" if diff else "clean"
416
+ branch = _run(['git', 'rev-parse', '--abbrev-ref', 'HEAD'])
417
+ except Exception:
418
+ pass
419
+ message = f"sha: {sha}, status: {diff}, branch: {branch}"
420
+ return message
421
+
422
+
423
+ def is_dist_avail_and_initialized():
424
+ if not dist.is_available():
425
+ return False
426
+ if not dist.is_initialized():
427
+ return False
428
+ return True
429
+
430
+
431
+ def get_world_size():
432
+ if not is_dist_avail_and_initialized():
433
+ return 1
434
+ return dist.get_world_size()
435
+
436
+
437
+ def get_rank():
438
+ if not is_dist_avail_and_initialized():
439
+ return 0
440
+ return dist.get_rank()
441
+
442
+
443
+ def is_main_process():
444
+ return get_rank() == 0
445
+
446
+
447
+ def save_on_master(*args, **kwargs):
448
+ if is_main_process():
449
+ torch.save(*args, **kwargs)
450
+
451
+
452
+ def setup_for_distributed(is_master):
453
+ """
454
+ This function disables printing when not in master process
455
+ """
456
+ import builtins as __builtin__
457
+ builtin_print = __builtin__.print
458
+
459
+ def print(*args, **kwargs):
460
+ force = kwargs.pop('force', False)
461
+ if is_master or force:
462
+ builtin_print(*args, **kwargs)
463
+
464
+ __builtin__.print = print
465
+
466
+
467
+ def init_distributed_mode(args):
468
+ # launched with torch.distributed.launch
469
+ if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
470
+ args.rank = int(os.environ["RANK"])
471
+ args.world_size = int(os.environ['WORLD_SIZE'])
472
+ args.gpu = int(os.environ['LOCAL_RANK'])
473
+ # launched with submitit on a slurm cluster
474
+ elif 'SLURM_PROCID' in os.environ:
475
+ args.rank = int(os.environ['SLURM_PROCID'])
476
+ args.gpu = args.rank % torch.cuda.device_count()
477
+ # launched naively with `python main_dino.py`
478
+ # we manually add MASTER_ADDR and MASTER_PORT to env variables
479
+ elif torch.cuda.is_available():
480
+ print('Will run the code on one GPU.')
481
+ args.rank, args.gpu, args.world_size = 0, 0, 1
482
+ os.environ['MASTER_ADDR'] = '127.0.0.1'
483
+ os.environ['MASTER_PORT'] = '29500'
484
+ else:
485
+ print('Does not support training without GPU.')
486
+ sys.exit(1)
487
+
488
+ dist.init_process_group(
489
+ backend="nccl",
490
+ init_method=args.dist_url,
491
+ world_size=args.world_size,
492
+ rank=args.rank,
493
+ )
494
+
495
+ torch.cuda.set_device(args.gpu)
496
+ print('| distributed init (rank {}): {}'.format(
497
+ args.rank, args.dist_url), flush=True)
498
+ dist.barrier()
499
+ setup_for_distributed(args.rank == 0)
500
+
501
+
502
+ def accuracy(output, target, topk=(1,)):
503
+ """Computes the accuracy over the k top predictions for the specified values of k"""
504
+ maxk = max(topk)
505
+ batch_size = target.size(0)
506
+ _, pred = output.topk(maxk, 1, True, True)
507
+ pred = pred.t()
508
+ correct = pred.eq(target.reshape(1, -1).expand_as(pred))
509
+ return [correct[:k].reshape(-1).float().sum(0) * 100. / batch_size for k in topk]
510
+
511
+
512
+ def _no_grad_trunc_normal_(tensor, mean, std, a, b):
513
+ # Cut & paste from PyTorch official master until it's in a few official releases - RW
514
+ # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
515
+ def norm_cdf(x):
516
+ # Computes standard normal cumulative distribution function
517
+ return (1. + math.erf(x / math.sqrt(2.))) / 2.
518
+
519
+ if (mean < a - 2 * std) or (mean > b + 2 * std):
520
+ warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
521
+ "The distribution of values may be incorrect.",
522
+ stacklevel=2)
523
+
524
+ with torch.no_grad():
525
+ # Values are generated by using a truncated uniform distribution and
526
+ # then using the inverse CDF for the normal distribution.
527
+ # Get upper and lower cdf values
528
+ l = norm_cdf((a - mean) / std)
529
+ u = norm_cdf((b - mean) / std)
530
+
531
+ # Uniformly fill tensor with values from [l, u], then translate to
532
+ # [2l-1, 2u-1].
533
+ tensor.uniform_(2 * l - 1, 2 * u - 1)
534
+
535
+ # Use inverse cdf transform for normal distribution to get truncated
536
+ # standard normal
537
+ tensor.erfinv_()
538
+
539
+ # Transform to proper mean, std
540
+ tensor.mul_(std * math.sqrt(2.))
541
+ tensor.add_(mean)
542
+
543
+ # Clamp to ensure it's in the proper range
544
+ tensor.clamp_(min=a, max=b)
545
+ return tensor
546
+
547
+
548
+ def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
549
+ # type: (Tensor, float, float, float, float) -> Tensor
550
+ return _no_grad_trunc_normal_(tensor, mean, std, a, b)
551
+
552
+
553
+ class LARS(torch.optim.Optimizer):
554
+ """
555
+ Almost copy-paste from https://github.com/facebookresearch/barlowtwins/blob/main/main.py
556
+ """
557
+ def __init__(self, params, lr=0, weight_decay=0, momentum=0.9, eta=0.001,
558
+ weight_decay_filter=None, lars_adaptation_filter=None):
559
+ defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum,
560
+ eta=eta, weight_decay_filter=weight_decay_filter,
561
+ lars_adaptation_filter=lars_adaptation_filter)
562
+ super().__init__(params, defaults)
563
+
564
+ @torch.no_grad()
565
+ def step(self):
566
+ for g in self.param_groups:
567
+ for p in g['params']:
568
+ dp = p.grad
569
+
570
+ if dp is None:
571
+ continue
572
+
573
+ if p.ndim != 1:
574
+ dp = dp.add(p, alpha=g['weight_decay'])
575
+
576
+ if p.ndim != 1:
577
+ param_norm = torch.norm(p)
578
+ update_norm = torch.norm(dp)
579
+ one = torch.ones_like(param_norm)
580
+ q = torch.where(param_norm > 0.,
581
+ torch.where(update_norm > 0,
582
+ (g['eta'] * param_norm / update_norm), one), one)
583
+ dp = dp.mul(q)
584
+
585
+ param_state = self.state[p]
586
+ if 'mu' not in param_state:
587
+ param_state['mu'] = torch.zeros_like(p)
588
+ mu = param_state['mu']
589
+ mu.mul_(g['momentum']).add_(dp)
590
+
591
+ p.add_(mu, alpha=-g['lr'])
592
+
593
+
594
+ class MultiCropWrapper(nn.Module):
595
+ """
596
+ Perform forward pass separately on each resolution input.
597
+ The inputs corresponding to a single resolution are clubbed and single
598
+ forward is run on the same resolution inputs. Hence we do several
599
+ forward passes = number of different resolutions used. We then
600
+ concatenate all the output features and run the head forward on these
601
+ concatenated features.
602
+ """
603
+ def __init__(self, backbone, head):
604
+ super(MultiCropWrapper, self).__init__()
605
+ # disable layers dedicated to ImageNet labels classification
606
+ backbone.fc, backbone.head = nn.Identity(), nn.Identity()
607
+ self.backbone = backbone
608
+ self.head = head
609
+
610
+ def forward(self, x):
611
+ # convert to list
612
+ if not isinstance(x, list):
613
+ x = [x]
614
+ idx_crops = torch.cumsum(torch.unique_consecutive(
615
+ torch.tensor([inp.shape[-1] for inp in x]),
616
+ return_counts=True,
617
+ )[1], 0)
618
+ start_idx, output = 0, torch.empty(0).to(x[0].device)
619
+ for end_idx in idx_crops:
620
+ _out = self.backbone(torch.cat(x[start_idx: end_idx]))
621
+ # The output is a tuple with XCiT model. See:
622
+ # https://github.com/facebookresearch/xcit/blob/master/xcit.py#L404-L405
623
+ if isinstance(_out, tuple):
624
+ _out = _out[0]
625
+ # accumulate outputs
626
+ output = torch.cat((output, _out))
627
+ start_idx = end_idx
628
+ # Run the head forward on the concatenated features.
629
+ return self.head(output)
630
+
631
+
632
+ def get_params_groups(model):
633
+ regularized = []
634
+ not_regularized = []
635
+ for name, param in model.named_parameters():
636
+ if not param.requires_grad:
637
+ continue
638
+ # we do not regularize biases nor Norm parameters
639
+ if name.endswith(".bias") or len(param.shape) == 1:
640
+ not_regularized.append(param)
641
+ else:
642
+ regularized.append(param)
643
+ return [{'params': regularized}, {'params': not_regularized, 'weight_decay': 0.}]
644
+
645
+
646
+ def has_batchnorms(model):
647
+ bn_types = (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, nn.SyncBatchNorm)
648
+ for name, module in model.named_modules():
649
+ if isinstance(module, bn_types):
650
+ return True
651
+ return False
652
+
653
+
654
+ class PCA():
655
+ """
656
+ Class to compute and apply PCA.
657
+ """
658
+ def __init__(self, dim=256, whit=0.5):
659
+ self.dim = dim
660
+ self.whit = whit
661
+ self.mean = None
662
+
663
+ def train_pca(self, cov):
664
+ """
665
+ Takes a covariance matrix (np.ndarray) as input.
666
+ """
667
+ d, v = np.linalg.eigh(cov)
668
+ eps = d.max() * 1e-5
669
+ n_0 = (d < eps).sum()
670
+ if n_0 > 0:
671
+ d[d < eps] = eps
672
+
673
+ # total energy
674
+ totenergy = d.sum()
675
+
676
+ # sort eigenvectors with eigenvalues order
677
+ idx = np.argsort(d)[::-1][:self.dim]
678
+ d = d[idx]
679
+ v = v[:, idx]
680
+
681
+ print("keeping %.2f %% of the energy" % (d.sum() / totenergy * 100.0))
682
+
683
+ # for the whitening
684
+ d = np.diag(1. / d**self.whit)
685
+
686
+ # principal components
687
+ self.dvt = np.dot(d, v.T)
688
+
689
+ def apply(self, x):
690
+ # input is from numpy
691
+ if isinstance(x, np.ndarray):
692
+ if self.mean is not None:
693
+ x -= self.mean
694
+ return np.dot(self.dvt, x.T).T
695
+
696
+ # input is from torch and is on GPU
697
+ if x.is_cuda:
698
+ if self.mean is not None:
699
+ x -= torch.cuda.FloatTensor(self.mean)
700
+ return torch.mm(torch.cuda.FloatTensor(self.dvt), x.transpose(0, 1)).transpose(0, 1)
701
+
702
+ # input if from torch, on CPU
703
+ if self.mean is not None:
704
+ x -= torch.FloatTensor(self.mean)
705
+ return torch.mm(torch.FloatTensor(self.dvt), x.transpose(0, 1)).transpose(0, 1)
706
+
707
+
708
+ def compute_ap(ranks, nres):
709
+ """
710
+ Computes average precision for given ranked indexes.
711
+ Arguments
712
+ ---------
713
+ ranks : zerro-based ranks of positive images
714
+ nres : number of positive images
715
+ Returns
716
+ -------
717
+ ap : average precision
718
+ """
719
+
720
+ # number of images ranked by the system
721
+ nimgranks = len(ranks)
722
+
723
+ # accumulate trapezoids in PR-plot
724
+ ap = 0
725
+
726
+ recall_step = 1. / nres
727
+
728
+ for j in np.arange(nimgranks):
729
+ rank = ranks[j]
730
+
731
+ if rank == 0:
732
+ precision_0 = 1.
733
+ else:
734
+ precision_0 = float(j) / rank
735
+
736
+ precision_1 = float(j + 1) / (rank + 1)
737
+
738
+ ap += (precision_0 + precision_1) * recall_step / 2.
739
+
740
+ return ap
741
+
742
+
743
+ def compute_map(ranks, gnd, kappas=[]):
744
+ """
745
+ Computes the mAP for a given set of returned results.
746
+ Usage:
747
+ map = compute_map (ranks, gnd)
748
+ computes mean average precsion (map) only
749
+ map, aps, pr, prs = compute_map (ranks, gnd, kappas)
750
+ computes mean average precision (map), average precision (aps) for each query
751
+ computes mean precision at kappas (pr), precision at kappas (prs) for each query
752
+ Notes:
753
+ 1) ranks starts from 0, ranks.shape = db_size X #queries
754
+ 2) The junk results (e.g., the query itself) should be declared in the gnd stuct array
755
+ 3) If there are no positive images for some query, that query is excluded from the evaluation
756
+ """
757
+
758
+ map = 0.
759
+ nq = len(gnd) # number of queries
760
+ aps = np.zeros(nq)
761
+ pr = np.zeros(len(kappas))
762
+ prs = np.zeros((nq, len(kappas)))
763
+ nempty = 0
764
+
765
+ for i in np.arange(nq):
766
+ qgnd = np.array(gnd[i]['ok'])
767
+
768
+ # no positive images, skip from the average
769
+ if qgnd.shape[0] == 0:
770
+ aps[i] = float('nan')
771
+ prs[i, :] = float('nan')
772
+ nempty += 1
773
+ continue
774
+
775
+ try:
776
+ qgndj = np.array(gnd[i]['junk'])
777
+ except:
778
+ qgndj = np.empty(0)
779
+
780
+ # sorted positions of positive and junk images (0 based)
781
+ pos = np.arange(ranks.shape[0])[np.in1d(ranks[:,i], qgnd)]
782
+ junk = np.arange(ranks.shape[0])[np.in1d(ranks[:,i], qgndj)]
783
+
784
+ k = 0;
785
+ ij = 0;
786
+ if len(junk):
787
+ # decrease positions of positives based on the number of
788
+ # junk images appearing before them
789
+ ip = 0
790
+ while (ip < len(pos)):
791
+ while (ij < len(junk) and pos[ip] > junk[ij]):
792
+ k += 1
793
+ ij += 1
794
+ pos[ip] = pos[ip] - k
795
+ ip += 1
796
+
797
+ # compute ap
798
+ ap = compute_ap(pos, len(qgnd))
799
+ map = map + ap
800
+ aps[i] = ap
801
+
802
+ # compute precision @ k
803
+ pos += 1 # get it to 1-based
804
+ for j in np.arange(len(kappas)):
805
+ kq = min(max(pos), kappas[j]);
806
+ prs[i, j] = (pos <= kq).sum() / kq
807
+ pr = pr + prs[i, :]
808
+
809
+ map = map / (nq - nempty)
810
+ pr = pr / (nq - nempty)
811
+
812
+ return map, aps, pr, prs
813
+
814
+
815
+ def multi_scale(samples, model):
816
+ v = None
817
+ for s in [1, 1/2**(1/2), 1/2]: # we use 3 different scales
818
+ if s == 1:
819
+ inp = samples.clone()
820
+ else:
821
+ inp = nn.functional.interpolate(samples, scale_factor=s, mode='bilinear', align_corners=False)
822
+ feats = model(inp).clone()
823
+ if v is None:
824
+ v = feats
825
+ else:
826
+ v += feats
827
+ v /= 3
828
+ v /= v.norm()
829
+ return v
PuzzleTuning/Counterpart PreTrain Methods/dino-main/video_generation.py ADDED
@@ -0,0 +1,378 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import os
15
+ import glob
16
+ import sys
17
+ import argparse
18
+ import cv2
19
+
20
+ from tqdm import tqdm
21
+ import matplotlib.pyplot as plt
22
+ import torch
23
+ import torch.nn as nn
24
+ import torchvision
25
+ from torchvision import transforms as pth_transforms
26
+ import numpy as np
27
+ from PIL import Image
28
+
29
+ import utils
30
+ import vision_transformer as vits
31
+
32
+
33
+ FOURCC = {
34
+ "mp4": cv2.VideoWriter_fourcc(*"MP4V"),
35
+ "avi": cv2.VideoWriter_fourcc(*"XVID"),
36
+ }
37
+ DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
38
+
39
+
40
+ class VideoGenerator:
41
+ def __init__(self, args):
42
+ self.args = args
43
+ # self.model = None
44
+ # Don't need to load model if you only want a video
45
+ if not self.args.video_only:
46
+ self.model = self.__load_model()
47
+
48
+ def run(self):
49
+ if self.args.input_path is None:
50
+ print(f"Provided input path {self.args.input_path} is non valid.")
51
+ sys.exit(1)
52
+ else:
53
+ if self.args.video_only:
54
+ self._generate_video_from_images(
55
+ self.args.input_path, self.args.output_path
56
+ )
57
+ else:
58
+ # If input path exists
59
+ if os.path.exists(self.args.input_path):
60
+ # If input is a video file
61
+ if os.path.isfile(self.args.input_path):
62
+ frames_folder = os.path.join(self.args.output_path, "frames")
63
+ attention_folder = os.path.join(
64
+ self.args.output_path, "attention"
65
+ )
66
+
67
+ os.makedirs(frames_folder, exist_ok=True)
68
+ os.makedirs(attention_folder, exist_ok=True)
69
+
70
+ self._extract_frames_from_video(
71
+ self.args.input_path, frames_folder
72
+ )
73
+
74
+ self._inference(
75
+ frames_folder,
76
+ attention_folder,
77
+ )
78
+
79
+ self._generate_video_from_images(
80
+ attention_folder, self.args.output_path
81
+ )
82
+
83
+ # If input is a folder of already extracted frames
84
+ if os.path.isdir(self.args.input_path):
85
+ attention_folder = os.path.join(
86
+ self.args.output_path, "attention"
87
+ )
88
+
89
+ os.makedirs(attention_folder, exist_ok=True)
90
+
91
+ self._inference(self.args.input_path, attention_folder)
92
+
93
+ self._generate_video_from_images(
94
+ attention_folder, self.args.output_path
95
+ )
96
+
97
+ # If input path doesn't exists
98
+ else:
99
+ print(f"Provided input path {self.args.input_path} doesn't exists.")
100
+ sys.exit(1)
101
+
102
+ def _extract_frames_from_video(self, inp: str, out: str):
103
+ vidcap = cv2.VideoCapture(inp)
104
+ self.args.fps = vidcap.get(cv2.CAP_PROP_FPS)
105
+
106
+ print(f"Video: {inp} ({self.args.fps} fps)")
107
+ print(f"Extracting frames to {out}")
108
+
109
+ success, image = vidcap.read()
110
+ count = 0
111
+ while success:
112
+ cv2.imwrite(
113
+ os.path.join(out, f"frame-{count:04}.jpg"),
114
+ image,
115
+ )
116
+ success, image = vidcap.read()
117
+ count += 1
118
+
119
+ def _generate_video_from_images(self, inp: str, out: str):
120
+ img_array = []
121
+ attention_images_list = sorted(glob.glob(os.path.join(inp, "attn-*.jpg")))
122
+
123
+ # Get size of the first image
124
+ with open(attention_images_list[0], "rb") as f:
125
+ img = Image.open(f)
126
+ img = img.convert("RGB")
127
+ size = (img.width, img.height)
128
+ img_array.append(cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR))
129
+
130
+ print(f"Generating video {size} to {out}")
131
+
132
+ for filename in tqdm(attention_images_list[1:]):
133
+ with open(filename, "rb") as f:
134
+ img = Image.open(f)
135
+ img = img.convert("RGB")
136
+ img_array.append(cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR))
137
+
138
+ out = cv2.VideoWriter(
139
+ os.path.join(out, "video." + self.args.video_format),
140
+ FOURCC[self.args.video_format],
141
+ self.args.fps,
142
+ size,
143
+ )
144
+
145
+ for i in range(len(img_array)):
146
+ out.write(img_array[i])
147
+ out.release()
148
+ print("Done")
149
+
150
+ def _inference(self, inp: str, out: str):
151
+ print(f"Generating attention images to {out}")
152
+
153
+ for img_path in tqdm(sorted(glob.glob(os.path.join(inp, "*.jpg")))):
154
+ with open(img_path, "rb") as f:
155
+ img = Image.open(f)
156
+ img = img.convert("RGB")
157
+
158
+ if self.args.resize is not None:
159
+ transform = pth_transforms.Compose(
160
+ [
161
+ pth_transforms.ToTensor(),
162
+ pth_transforms.Resize(self.args.resize),
163
+ pth_transforms.Normalize(
164
+ (0.485, 0.456, 0.406), (0.229, 0.224, 0.225)
165
+ ),
166
+ ]
167
+ )
168
+ else:
169
+ transform = pth_transforms.Compose(
170
+ [
171
+ pth_transforms.ToTensor(),
172
+ pth_transforms.Normalize(
173
+ (0.485, 0.456, 0.406), (0.229, 0.224, 0.225)
174
+ ),
175
+ ]
176
+ )
177
+
178
+ img = transform(img)
179
+
180
+ # make the image divisible by the patch size
181
+ w, h = (
182
+ img.shape[1] - img.shape[1] % self.args.patch_size,
183
+ img.shape[2] - img.shape[2] % self.args.patch_size,
184
+ )
185
+ img = img[:, :w, :h].unsqueeze(0)
186
+
187
+ w_featmap = img.shape[-2] // self.args.patch_size
188
+ h_featmap = img.shape[-1] // self.args.patch_size
189
+
190
+ attentions = self.model.get_last_selfattention(img.to(DEVICE))
191
+
192
+ nh = attentions.shape[1] # number of head
193
+
194
+ # we keep only the output patch attention
195
+ attentions = attentions[0, :, 0, 1:].reshape(nh, -1)
196
+
197
+ # we keep only a certain percentage of the mass
198
+ val, idx = torch.sort(attentions)
199
+ val /= torch.sum(val, dim=1, keepdim=True)
200
+ cumval = torch.cumsum(val, dim=1)
201
+ th_attn = cumval > (1 - self.args.threshold)
202
+ idx2 = torch.argsort(idx)
203
+ for head in range(nh):
204
+ th_attn[head] = th_attn[head][idx2[head]]
205
+ th_attn = th_attn.reshape(nh, w_featmap, h_featmap).float()
206
+ # interpolate
207
+ th_attn = (
208
+ nn.functional.interpolate(
209
+ th_attn.unsqueeze(0),
210
+ scale_factor=self.args.patch_size,
211
+ mode="nearest",
212
+ )[0]
213
+ .cpu()
214
+ .numpy()
215
+ )
216
+
217
+ attentions = attentions.reshape(nh, w_featmap, h_featmap)
218
+ attentions = (
219
+ nn.functional.interpolate(
220
+ attentions.unsqueeze(0),
221
+ scale_factor=self.args.patch_size,
222
+ mode="nearest",
223
+ )[0]
224
+ .cpu()
225
+ .numpy()
226
+ )
227
+
228
+ # save attentions heatmaps
229
+ fname = os.path.join(out, "attn-" + os.path.basename(img_path))
230
+ plt.imsave(
231
+ fname=fname,
232
+ arr=sum(
233
+ attentions[i] * 1 / attentions.shape[0]
234
+ for i in range(attentions.shape[0])
235
+ ),
236
+ cmap="inferno",
237
+ format="jpg",
238
+ )
239
+
240
+ def __load_model(self):
241
+ # build model
242
+ model = vits.__dict__[self.args.arch](
243
+ patch_size=self.args.patch_size, num_classes=0
244
+ )
245
+ for p in model.parameters():
246
+ p.requires_grad = False
247
+ model.eval()
248
+ model.to(DEVICE)
249
+
250
+ if os.path.isfile(self.args.pretrained_weights):
251
+ state_dict = torch.load(self.args.pretrained_weights, map_location="cpu")
252
+ if (
253
+ self.args.checkpoint_key is not None
254
+ and self.args.checkpoint_key in state_dict
255
+ ):
256
+ print(
257
+ f"Take key {self.args.checkpoint_key} in provided checkpoint dict"
258
+ )
259
+ state_dict = state_dict[self.args.checkpoint_key]
260
+ state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
261
+ # remove `backbone.` prefix induced by multicrop wrapper
262
+ state_dict = {k.replace("backbone.", ""): v for k, v in state_dict.items()}
263
+ msg = model.load_state_dict(state_dict, strict=False)
264
+ print(
265
+ "Pretrained weights found at {} and loaded with msg: {}".format(
266
+ self.args.pretrained_weights, msg
267
+ )
268
+ )
269
+ else:
270
+ print(
271
+ "Please use the `--pretrained_weights` argument to indicate the path of the checkpoint to evaluate."
272
+ )
273
+ url = None
274
+ if self.args.arch == "vit_small" and self.args.patch_size == 16:
275
+ url = "dino_deitsmall16_pretrain/dino_deitsmall16_pretrain.pth"
276
+ elif self.args.arch == "vit_small" and self.args.patch_size == 8:
277
+ url = "dino_deitsmall8_300ep_pretrain/dino_deitsmall8_300ep_pretrain.pth" # model used for visualizations in our paper
278
+ elif self.args.arch == "vit_base" and self.args.patch_size == 16:
279
+ url = "dino_vitbase16_pretrain/dino_vitbase16_pretrain.pth"
280
+ elif self.args.arch == "vit_base" and self.args.patch_size == 8:
281
+ url = "dino_vitbase8_pretrain/dino_vitbase8_pretrain.pth"
282
+ if url is not None:
283
+ print(
284
+ "Since no pretrained weights have been provided, we load the reference pretrained DINO weights."
285
+ )
286
+ state_dict = torch.hub.load_state_dict_from_url(
287
+ url="https://dl.fbaipublicfiles.com/dino/" + url
288
+ )
289
+ model.load_state_dict(state_dict, strict=True)
290
+ else:
291
+ print(
292
+ "There is no reference weights available for this model => We use random weights."
293
+ )
294
+ return model
295
+
296
+
297
+ def parse_args():
298
+ parser = argparse.ArgumentParser("Generation self-attention video")
299
+ parser.add_argument(
300
+ "--arch",
301
+ default="vit_small",
302
+ type=str,
303
+ choices=["vit_tiny", "vit_small", "vit_base"],
304
+ help="Architecture (support only ViT atm).",
305
+ )
306
+ parser.add_argument(
307
+ "--patch_size", default=8, type=int, help="Patch resolution of the self.model."
308
+ )
309
+ parser.add_argument(
310
+ "--pretrained_weights",
311
+ default="",
312
+ type=str,
313
+ help="Path to pretrained weights to load.",
314
+ )
315
+ parser.add_argument(
316
+ "--checkpoint_key",
317
+ default="teacher",
318
+ type=str,
319
+ help='Key to use in the checkpoint (example: "teacher")',
320
+ )
321
+ parser.add_argument(
322
+ "--input_path",
323
+ required=True,
324
+ type=str,
325
+ help="""Path to a video file if you want to extract frames
326
+ or to a folder of images already extracted by yourself.
327
+ or to a folder of attention images.""",
328
+ )
329
+ parser.add_argument(
330
+ "--output_path",
331
+ default="./",
332
+ type=str,
333
+ help="""Path to store a folder of frames and / or a folder of attention images.
334
+ and / or a final video. Default to current directory.""",
335
+ )
336
+ parser.add_argument(
337
+ "--threshold",
338
+ type=float,
339
+ default=0.6,
340
+ help="""We visualize masks
341
+ obtained by thresholding the self-attention maps to keep xx percent of the mass.""",
342
+ )
343
+ parser.add_argument(
344
+ "--resize",
345
+ default=None,
346
+ type=int,
347
+ nargs="+",
348
+ help="""Apply a resize transformation to input image(s). Use if OOM error.
349
+ Usage (single or W H): --resize 512, --resize 720 1280""",
350
+ )
351
+ parser.add_argument(
352
+ "--video_only",
353
+ action="store_true",
354
+ help="""Use this flag if you only want to generate a video and not all attention images.
355
+ If used, --input_path must be set to the folder of attention images. Ex: ./attention/""",
356
+ )
357
+ parser.add_argument(
358
+ "--fps",
359
+ default=30.0,
360
+ type=float,
361
+ help="FPS of input / output video. Automatically set if you extract frames from a video.",
362
+ )
363
+ parser.add_argument(
364
+ "--video_format",
365
+ default="mp4",
366
+ type=str,
367
+ choices=["mp4", "avi"],
368
+ help="Format of generated video (mp4 or avi).",
369
+ )
370
+
371
+ return parser.parse_args()
372
+
373
+
374
+ if __name__ == "__main__":
375
+ args = parse_args()
376
+
377
+ vg = VideoGenerator(args)
378
+ vg.run()
PuzzleTuning/Counterpart PreTrain Methods/dino-main/vision_transformer.py ADDED
@@ -0,0 +1,291 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """
15
+ Mostly copy-paste from timm library.
16
+ https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
17
+ """
18
+ import math
19
+ from functools import partial
20
+
21
+ import torch
22
+ import torch.nn as nn
23
+
24
+ from utils import trunc_normal_
25
+
26
+
27
+ def drop_path(x, drop_prob: float = 0., training: bool = False):
28
+ if drop_prob == 0. or not training:
29
+ return x
30
+ keep_prob = 1 - drop_prob
31
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
32
+ random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
33
+ random_tensor.floor_() # binarize
34
+ output = x.div(keep_prob) * random_tensor
35
+ return output
36
+
37
+
38
+ class DropPath(nn.Module):
39
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
40
+ """
41
+ def __init__(self, drop_prob=None):
42
+ super(DropPath, self).__init__()
43
+ self.drop_prob = drop_prob
44
+
45
+ def forward(self, x):
46
+ return drop_path(x, self.drop_prob, self.training)
47
+
48
+
49
+ class Mlp(nn.Module):
50
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
51
+ super().__init__()
52
+ out_features = out_features or in_features
53
+ hidden_features = hidden_features or in_features
54
+ self.fc1 = nn.Linear(in_features, hidden_features)
55
+ self.act = act_layer()
56
+ self.fc2 = nn.Linear(hidden_features, out_features)
57
+ self.drop = nn.Dropout(drop)
58
+
59
+ def forward(self, x):
60
+ x = self.fc1(x)
61
+ x = self.act(x)
62
+ x = self.drop(x)
63
+ x = self.fc2(x)
64
+ x = self.drop(x)
65
+ return x
66
+
67
+
68
+ class Attention(nn.Module):
69
+ def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
70
+ super().__init__()
71
+ self.num_heads = num_heads
72
+ head_dim = dim // num_heads
73
+ self.scale = qk_scale or head_dim ** -0.5
74
+
75
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
76
+ self.attn_drop = nn.Dropout(attn_drop)
77
+ self.proj = nn.Linear(dim, dim)
78
+ self.proj_drop = nn.Dropout(proj_drop)
79
+
80
+ def forward(self, x):
81
+ B, N, C = x.shape
82
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
83
+ q, k, v = qkv[0], qkv[1], qkv[2]
84
+
85
+ attn = (q @ k.transpose(-2, -1)) * self.scale
86
+ attn = attn.softmax(dim=-1)
87
+ attn = self.attn_drop(attn)
88
+
89
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
90
+ x = self.proj(x)
91
+ x = self.proj_drop(x)
92
+ return x, attn
93
+
94
+
95
+ class Block(nn.Module):
96
+ def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
97
+ drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
98
+ super().__init__()
99
+ self.norm1 = norm_layer(dim)
100
+ self.attn = Attention(
101
+ dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
102
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
103
+ self.norm2 = norm_layer(dim)
104
+ mlp_hidden_dim = int(dim * mlp_ratio)
105
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
106
+
107
+ def forward(self, x, return_attention=False):
108
+ y, attn = self.attn(self.norm1(x))
109
+ if return_attention:
110
+ return attn
111
+ x = x + self.drop_path(y)
112
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
113
+ return x
114
+
115
+
116
+ class PatchEmbed(nn.Module):
117
+ """ Image to Patch Embedding
118
+ """
119
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
120
+ super().__init__()
121
+ num_patches = (img_size // patch_size) * (img_size // patch_size)
122
+ self.img_size = img_size
123
+ self.patch_size = patch_size
124
+ self.num_patches = num_patches
125
+
126
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
127
+
128
+ def forward(self, x):
129
+ B, C, H, W = x.shape
130
+ x = self.proj(x).flatten(2).transpose(1, 2)
131
+ return x
132
+
133
+
134
+ class VisionTransformer(nn.Module):
135
+ """ Vision Transformer """
136
+ def __init__(self, img_size=[224], patch_size=16, in_chans=3, num_classes=0, embed_dim=768, depth=12,
137
+ num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,
138
+ drop_path_rate=0., norm_layer=nn.LayerNorm, **kwargs):
139
+ super().__init__()
140
+ self.num_features = self.embed_dim = embed_dim
141
+
142
+ self.patch_embed = PatchEmbed(
143
+ img_size=img_size[0], patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
144
+ num_patches = self.patch_embed.num_patches
145
+
146
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
147
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
148
+ self.pos_drop = nn.Dropout(p=drop_rate)
149
+
150
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
151
+ self.blocks = nn.ModuleList([
152
+ Block(
153
+ dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
154
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer)
155
+ for i in range(depth)])
156
+ self.norm = norm_layer(embed_dim)
157
+
158
+ # Classifier head
159
+ self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
160
+
161
+ trunc_normal_(self.pos_embed, std=.02)
162
+ trunc_normal_(self.cls_token, std=.02)
163
+ self.apply(self._init_weights)
164
+
165
+ def _init_weights(self, m):
166
+ if isinstance(m, nn.Linear):
167
+ trunc_normal_(m.weight, std=.02)
168
+ if isinstance(m, nn.Linear) and m.bias is not None:
169
+ nn.init.constant_(m.bias, 0)
170
+ elif isinstance(m, nn.LayerNorm):
171
+ nn.init.constant_(m.bias, 0)
172
+ nn.init.constant_(m.weight, 1.0)
173
+
174
+ def interpolate_pos_encoding(self, x, w, h):
175
+ npatch = x.shape[1] - 1
176
+ N = self.pos_embed.shape[1] - 1
177
+ if npatch == N and w == h:
178
+ return self.pos_embed
179
+ class_pos_embed = self.pos_embed[:, 0]
180
+ patch_pos_embed = self.pos_embed[:, 1:]
181
+ dim = x.shape[-1]
182
+ w0 = w // self.patch_embed.patch_size
183
+ h0 = h // self.patch_embed.patch_size
184
+ # we add a small number to avoid floating point error in the interpolation
185
+ # see discussion at https://github.com/facebookresearch/dino/issues/8
186
+ w0, h0 = w0 + 0.1, h0 + 0.1
187
+ patch_pos_embed = nn.functional.interpolate(
188
+ patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2),
189
+ scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)),
190
+ mode='bicubic',
191
+ )
192
+ assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1]
193
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
194
+ return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
195
+
196
+ def prepare_tokens(self, x):
197
+ B, nc, w, h = x.shape
198
+ x = self.patch_embed(x) # patch linear embedding
199
+
200
+ # add the [CLS] token to the embed patch tokens
201
+ cls_tokens = self.cls_token.expand(B, -1, -1)
202
+ x = torch.cat((cls_tokens, x), dim=1)
203
+
204
+ # add positional encoding to each token
205
+ x = x + self.interpolate_pos_encoding(x, w, h)
206
+
207
+ return self.pos_drop(x)
208
+
209
+ def forward(self, x):
210
+ x = self.prepare_tokens(x)
211
+ for blk in self.blocks:
212
+ x = blk(x)
213
+ x = self.norm(x)
214
+ return x[:, 0]
215
+
216
+ def get_last_selfattention(self, x):
217
+ x = self.prepare_tokens(x)
218
+ for i, blk in enumerate(self.blocks):
219
+ if i < len(self.blocks) - 1:
220
+ x = blk(x)
221
+ else:
222
+ # return attention of the last block
223
+ return blk(x, return_attention=True)
224
+
225
+ def get_intermediate_layers(self, x, n=1):
226
+ x = self.prepare_tokens(x)
227
+ # we return the output tokens from the `n` last blocks
228
+ output = []
229
+ for i, blk in enumerate(self.blocks):
230
+ x = blk(x)
231
+ if len(self.blocks) - i <= n:
232
+ output.append(self.norm(x))
233
+ return output
234
+
235
+
236
+ def vit_tiny(patch_size=16, **kwargs):
237
+ model = VisionTransformer(
238
+ patch_size=patch_size, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4,
239
+ qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
240
+ return model
241
+
242
+
243
+ def vit_small(patch_size=16, **kwargs):
244
+ model = VisionTransformer(
245
+ patch_size=patch_size, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4,
246
+ qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
247
+ return model
248
+
249
+
250
+ def vit_base(patch_size=16, **kwargs):
251
+ model = VisionTransformer(
252
+ patch_size=patch_size, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4,
253
+ qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
254
+ return model
255
+
256
+
257
+ class DINOHead(nn.Module):
258
+ def __init__(self, in_dim, out_dim, use_bn=False, norm_last_layer=True, nlayers=3, hidden_dim=2048, bottleneck_dim=256):
259
+ super().__init__()
260
+ nlayers = max(nlayers, 1)
261
+ if nlayers == 1:
262
+ self.mlp = nn.Linear(in_dim, bottleneck_dim)
263
+ else:
264
+ layers = [nn.Linear(in_dim, hidden_dim)]
265
+ if use_bn:
266
+ layers.append(nn.BatchNorm1d(hidden_dim))
267
+ layers.append(nn.GELU())
268
+ for _ in range(nlayers - 2):
269
+ layers.append(nn.Linear(hidden_dim, hidden_dim))
270
+ if use_bn:
271
+ layers.append(nn.BatchNorm1d(hidden_dim))
272
+ layers.append(nn.GELU())
273
+ layers.append(nn.Linear(hidden_dim, bottleneck_dim))
274
+ self.mlp = nn.Sequential(*layers)
275
+ self.apply(self._init_weights)
276
+ self.last_layer = nn.utils.weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False))
277
+ self.last_layer.weight_g.data.fill_(1)
278
+ if norm_last_layer:
279
+ self.last_layer.weight_g.requires_grad = False
280
+
281
+ def _init_weights(self, m):
282
+ if isinstance(m, nn.Linear):
283
+ trunc_normal_(m.weight, std=.02)
284
+ if isinstance(m, nn.Linear) and m.bias is not None:
285
+ nn.init.constant_(m.bias, 0)
286
+
287
+ def forward(self, x):
288
+ x = self.mlp(x)
289
+ x = nn.functional.normalize(x, dim=-1, p=2)
290
+ x = self.last_layer(x)
291
+ return x
PuzzleTuning/Counterpart PreTrain Methods/dino-main/visualize_attention.py ADDED
@@ -0,0 +1,213 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import os
15
+ import sys
16
+ import argparse
17
+ import cv2
18
+ import random
19
+ import colorsys
20
+ import requests
21
+ from io import BytesIO
22
+
23
+ import skimage.io
24
+ from skimage.measure import find_contours
25
+ import matplotlib.pyplot as plt
26
+ from matplotlib.patches import Polygon
27
+ import torch
28
+ import torch.nn as nn
29
+ import torchvision
30
+ from torchvision import transforms as pth_transforms
31
+ import numpy as np
32
+ from PIL import Image
33
+
34
+ import utils
35
+ import vision_transformer as vits
36
+
37
+
38
+ def apply_mask(image, mask, color, alpha=0.5):
39
+ for c in range(3):
40
+ image[:, :, c] = image[:, :, c] * (1 - alpha * mask) + alpha * mask * color[c] * 255
41
+ return image
42
+
43
+
44
+ def random_colors(N, bright=True):
45
+ """
46
+ Generate random colors.
47
+ """
48
+ brightness = 1.0 if bright else 0.7
49
+ hsv = [(i / N, 1, brightness) for i in range(N)]
50
+ colors = list(map(lambda c: colorsys.hsv_to_rgb(*c), hsv))
51
+ random.shuffle(colors)
52
+ return colors
53
+
54
+
55
+ def display_instances(image, mask, fname="test", figsize=(5, 5), blur=False, contour=True, alpha=0.5):
56
+ fig = plt.figure(figsize=figsize, frameon=False)
57
+ ax = plt.Axes(fig, [0., 0., 1., 1.])
58
+ ax.set_axis_off()
59
+ fig.add_axes(ax)
60
+ ax = plt.gca()
61
+
62
+ N = 1
63
+ mask = mask[None, :, :]
64
+ # Generate random colors
65
+ colors = random_colors(N)
66
+
67
+ # Show area outside image boundaries.
68
+ height, width = image.shape[:2]
69
+ margin = 0
70
+ ax.set_ylim(height + margin, -margin)
71
+ ax.set_xlim(-margin, width + margin)
72
+ ax.axis('off')
73
+ masked_image = image.astype(np.uint32).copy()
74
+ for i in range(N):
75
+ color = colors[i]
76
+ _mask = mask[i]
77
+ if blur:
78
+ _mask = cv2.blur(_mask,(10,10))
79
+ # Mask
80
+ masked_image = apply_mask(masked_image, _mask, color, alpha)
81
+ # Mask Polygon
82
+ # Pad to ensure proper polygons for masks that touch image edges.
83
+ if contour:
84
+ padded_mask = np.zeros((_mask.shape[0] + 2, _mask.shape[1] + 2))
85
+ padded_mask[1:-1, 1:-1] = _mask
86
+ contours = find_contours(padded_mask, 0.5)
87
+ for verts in contours:
88
+ # Subtract the padding and flip (y, x) to (x, y)
89
+ verts = np.fliplr(verts) - 1
90
+ p = Polygon(verts, facecolor="none", edgecolor=color)
91
+ ax.add_patch(p)
92
+ ax.imshow(masked_image.astype(np.uint8), aspect='auto')
93
+ fig.savefig(fname)
94
+ print(f"{fname} saved.")
95
+ return
96
+
97
+
98
+ if __name__ == '__main__':
99
+ parser = argparse.ArgumentParser('Visualize Self-Attention maps')
100
+ parser.add_argument('--arch', default='vit_small', type=str,
101
+ choices=['vit_tiny', 'vit_small', 'vit_base'], help='Architecture (support only ViT atm).')
102
+ parser.add_argument('--patch_size', default=8, type=int, help='Patch resolution of the model.')
103
+ parser.add_argument('--pretrained_weights', default='', type=str,
104
+ help="Path to pretrained weights to load.")
105
+ parser.add_argument("--checkpoint_key", default="teacher", type=str,
106
+ help='Key to use in the checkpoint (example: "teacher")')
107
+ parser.add_argument("--image_path", default=None, type=str, help="Path of the image to load.")
108
+ parser.add_argument("--image_size", default=(480, 480), type=int, nargs="+", help="Resize image.")
109
+ parser.add_argument('--output_dir', default='.', help='Path where to save visualizations.')
110
+ parser.add_argument("--threshold", type=float, default=None, help="""We visualize masks
111
+ obtained by thresholding the self-attention maps to keep xx% of the mass.""")
112
+ args = parser.parse_args()
113
+
114
+ device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
115
+ # build model
116
+ model = vits.__dict__[args.arch](patch_size=args.patch_size, num_classes=0)
117
+ for p in model.parameters():
118
+ p.requires_grad = False
119
+ model.eval()
120
+ model.to(device)
121
+ if os.path.isfile(args.pretrained_weights):
122
+ state_dict = torch.load(args.pretrained_weights, map_location="cpu")
123
+ if args.checkpoint_key is not None and args.checkpoint_key in state_dict:
124
+ print(f"Take key {args.checkpoint_key} in provided checkpoint dict")
125
+ state_dict = state_dict[args.checkpoint_key]
126
+ # remove `module.` prefix
127
+ state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
128
+ # remove `backbone.` prefix induced by multicrop wrapper
129
+ state_dict = {k.replace("backbone.", ""): v for k, v in state_dict.items()}
130
+ msg = model.load_state_dict(state_dict, strict=False)
131
+ print('Pretrained weights found at {} and loaded with msg: {}'.format(args.pretrained_weights, msg))
132
+ else:
133
+ print("Please use the `--pretrained_weights` argument to indicate the path of the checkpoint to evaluate.")
134
+ url = None
135
+ if args.arch == "vit_small" and args.patch_size == 16:
136
+ url = "dino_deitsmall16_pretrain/dino_deitsmall16_pretrain.pth"
137
+ elif args.arch == "vit_small" and args.patch_size == 8:
138
+ url = "dino_deitsmall8_300ep_pretrain/dino_deitsmall8_300ep_pretrain.pth" # model used for visualizations in our paper
139
+ elif args.arch == "vit_base" and args.patch_size == 16:
140
+ url = "dino_vitbase16_pretrain/dino_vitbase16_pretrain.pth"
141
+ elif args.arch == "vit_base" and args.patch_size == 8:
142
+ url = "dino_vitbase8_pretrain/dino_vitbase8_pretrain.pth"
143
+ if url is not None:
144
+ print("Since no pretrained weights have been provided, we load the reference pretrained DINO weights.")
145
+ state_dict = torch.hub.load_state_dict_from_url(url="https://dl.fbaipublicfiles.com/dino/" + url)
146
+ model.load_state_dict(state_dict, strict=True)
147
+ else:
148
+ print("There is no reference weights available for this model => We use random weights.")
149
+
150
+ # open image
151
+ if args.image_path is None:
152
+ # user has not specified any image - we use our own image
153
+ print("Please use the `--image_path` argument to indicate the path of the image you wish to visualize.")
154
+ print("Since no image path have been provided, we take the first image in our paper.")
155
+ response = requests.get("https://dl.fbaipublicfiles.com/dino/img.png")
156
+ img = Image.open(BytesIO(response.content))
157
+ img = img.convert('RGB')
158
+ elif os.path.isfile(args.image_path):
159
+ with open(args.image_path, 'rb') as f:
160
+ img = Image.open(f)
161
+ img = img.convert('RGB')
162
+ else:
163
+ print(f"Provided image path {args.image_path} is non valid.")
164
+ sys.exit(1)
165
+ transform = pth_transforms.Compose([
166
+ pth_transforms.Resize(args.image_size),
167
+ pth_transforms.ToTensor(),
168
+ pth_transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
169
+ ])
170
+ img = transform(img)
171
+
172
+ # make the image divisible by the patch size
173
+ w, h = img.shape[1] - img.shape[1] % args.patch_size, img.shape[2] - img.shape[2] % args.patch_size
174
+ img = img[:, :w, :h].unsqueeze(0)
175
+
176
+ w_featmap = img.shape[-2] // args.patch_size
177
+ h_featmap = img.shape[-1] // args.patch_size
178
+
179
+ attentions = model.get_last_selfattention(img.to(device))
180
+
181
+ nh = attentions.shape[1] # number of head
182
+
183
+ # we keep only the output patch attention
184
+ attentions = attentions[0, :, 0, 1:].reshape(nh, -1)
185
+
186
+ if args.threshold is not None:
187
+ # we keep only a certain percentage of the mass
188
+ val, idx = torch.sort(attentions)
189
+ val /= torch.sum(val, dim=1, keepdim=True)
190
+ cumval = torch.cumsum(val, dim=1)
191
+ th_attn = cumval > (1 - args.threshold)
192
+ idx2 = torch.argsort(idx)
193
+ for head in range(nh):
194
+ th_attn[head] = th_attn[head][idx2[head]]
195
+ th_attn = th_attn.reshape(nh, w_featmap, h_featmap).float()
196
+ # interpolate
197
+ th_attn = nn.functional.interpolate(th_attn.unsqueeze(0), scale_factor=args.patch_size, mode="nearest")[0].cpu().numpy()
198
+
199
+ attentions = attentions.reshape(nh, w_featmap, h_featmap)
200
+ attentions = nn.functional.interpolate(attentions.unsqueeze(0), scale_factor=args.patch_size, mode="nearest")[0].cpu().numpy()
201
+
202
+ # save attentions heatmaps
203
+ os.makedirs(args.output_dir, exist_ok=True)
204
+ torchvision.utils.save_image(torchvision.utils.make_grid(img, normalize=True, scale_each=True), os.path.join(args.output_dir, "img.png"))
205
+ for j in range(nh):
206
+ fname = os.path.join(args.output_dir, "attn-head" + str(j) + ".png")
207
+ plt.imsave(fname=fname, arr=attentions[j], format='png')
208
+ print(f"{fname} saved.")
209
+
210
+ if args.threshold is not None:
211
+ image = skimage.io.imread(os.path.join(args.output_dir, "img.png"))
212
+ for j in range(nh):
213
+ display_instances(image, th_attn[j], fname=os.path.join(args.output_dir, "mask_th" + str(args.threshold) + "_head" + str(j) +".png"), blur=False)
PuzzleTuning/Counterpart PreTrain Methods/gcmae/FINETUNE.md ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Fine tune GCMAE
2
+
3
+ ```
4
+ python main_finetune.py \
5
+ --data_path path/to/data \
6
+ --nb_classes 9 \
7
+ --output_dir path/to/ouput/dir \
8
+ --log_dir path/to/log/dir \
9
+ --batch_size 128 \
10
+ --model vit_base_patch16 \
11
+ --epochs 50 \
12
+ --finetune path/to/pth/path \
13
+ ```
PuzzleTuning/Counterpart PreTrain Methods/gcmae/GCMAE.png ADDED
PuzzleTuning/Counterpart PreTrain Methods/gcmae/LICENSE ADDED
@@ -0,0 +1,399 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Attribution-NonCommercial 4.0 International
2
+
3
+ =======================================================================
4
+
5
+ Creative Commons Corporation ("Creative Commons") is not a law firm and
6
+ does not provide legal services or legal advice. Distribution of
7
+ Creative Commons public licenses does not create a lawyer-client or
8
+ other relationship. Creative Commons makes its licenses and related
9
+ information available on an "as-is" basis. Creative Commons gives no
10
+ warranties regarding its licenses, any material licensed under their
11
+ terms and conditions, or any related information. Creative Commons
12
+ disclaims all liability for damages resulting from their use to the
13
+ fullest extent possible.
14
+
15
+ Using Creative Commons Public Licenses
16
+
17
+ Creative Commons public licenses provide a standard set of terms and
18
+ conditions that creators and other rights holders may use to share
19
+ original works of authorship and other material subject to copyright
20
+ and certain other rights specified in the public license below. The
21
+ following considerations are for informational purposes only, are not
22
+ exhaustive, and do not form part of our licenses.
23
+
24
+ Considerations for licensors: Our public licenses are
25
+ intended for use by those authorized to give the public
26
+ permission to use material in ways otherwise restricted by
27
+ copyright and certain other rights. Our licenses are
28
+ irrevocable. Licensors should read and understand the terms
29
+ and conditions of the license they choose before applying it.
30
+ Licensors should also secure all rights necessary before
31
+ applying our licenses so that the public can reuse the
32
+ material as expected. Licensors should clearly mark any
33
+ material not subject to the license. This includes other CC-
34
+ licensed material, or material used under an exception or
35
+ limitation to copyright. More considerations for licensors:
36
+ wiki.creativecommons.org/Considerations_for_licensors
37
+
38
+ Considerations for the public: By using one of our public
39
+ licenses, a licensor grants the public permission to use the
40
+ licensed material under specified terms and conditions. If
41
+ the licensor's permission is not necessary for any reason--for
42
+ example, because of any applicable exception or limitation to
43
+ copyright--then that use is not regulated by the license. Our
44
+ licenses grant only permissions under copyright and certain
45
+ other rights that a licensor has authority to grant. Use of
46
+ the licensed material may still be restricted for other
47
+ reasons, including because others have copyright or other
48
+ rights in the material. A licensor may make special requests,
49
+ such as asking that all changes be marked or described.
50
+ Although not required by our licenses, you are encouraged to
51
+ respect those requests where reasonable. More_considerations
52
+ for the public:
53
+ wiki.creativecommons.org/Considerations_for_licensees
54
+
55
+ =======================================================================
56
+
57
+ Creative Commons Attribution-NonCommercial 4.0 International Public
58
+ License
59
+
60
+ By exercising the Licensed Rights (defined below), You accept and agree
61
+ to be bound by the terms and conditions of this Creative Commons
62
+ Attribution-NonCommercial 4.0 International Public License ("Public
63
+ License"). To the extent this Public License may be interpreted as a
64
+ contract, You are granted the Licensed Rights in consideration of Your
65
+ acceptance of these terms and conditions, and the Licensor grants You
66
+ such rights in consideration of benefits the Licensor receives from
67
+ making the Licensed Material available under these terms and
68
+ conditions.
69
+
70
+ Section 1 -- Definitions.
71
+
72
+ a. Adapted Material means material subject to Copyright and Similar
73
+ Rights that is derived from or based upon the Licensed Material
74
+ and in which the Licensed Material is translated, altered,
75
+ arranged, transformed, or otherwise modified in a manner requiring
76
+ permission under the Copyright and Similar Rights held by the
77
+ Licensor. For purposes of this Public License, where the Licensed
78
+ Material is a musical work, performance, or sound recording,
79
+ Adapted Material is always produced where the Licensed Material is
80
+ synched in timed relation with a moving image.
81
+
82
+ b. Adapter's License means the license You apply to Your Copyright
83
+ and Similar Rights in Your contributions to Adapted Material in
84
+ accordance with the terms and conditions of this Public License.
85
+
86
+ c. Copyright and Similar Rights means copyright and/or similar rights
87
+ closely related to copyright including, without limitation,
88
+ performance, broadcast, sound recording, and Sui Generis Database
89
+ Rights, without regard to how the rights are labeled or
90
+ categorized. For purposes of this Public License, the rights
91
+ specified in Section 2(b)(1)-(2) are not Copyright and Similar
92
+ Rights.
93
+ d. Effective Technological Measures means those measures that, in the
94
+ absence of proper authority, may not be circumvented under laws
95
+ fulfilling obligations under Article 11 of the WIPO Copyright
96
+ Treaty adopted on December 20, 1996, and/or similar international
97
+ agreements.
98
+
99
+ e. Exceptions and Limitations means fair use, fair dealing, and/or
100
+ any other exception or limitation to Copyright and Similar Rights
101
+ that applies to Your use of the Licensed Material.
102
+
103
+ f. Licensed Material means the artistic or literary work, database,
104
+ or other material to which the Licensor applied this Public
105
+ License.
106
+
107
+ g. Licensed Rights means the rights granted to You subject to the
108
+ terms and conditions of this Public License, which are limited to
109
+ all Copyright and Similar Rights that apply to Your use of the
110
+ Licensed Material and that the Licensor has authority to license.
111
+
112
+ h. Licensor means the individual(s) or entity(ies) granting rights
113
+ under this Public License.
114
+
115
+ i. NonCommercial means not primarily intended for or directed towards
116
+ commercial advantage or monetary compensation. For purposes of
117
+ this Public License, the exchange of the Licensed Material for
118
+ other material subject to Copyright and Similar Rights by digital
119
+ file-sharing or similar means is NonCommercial provided there is
120
+ no payment of monetary compensation in connection with the
121
+ exchange.
122
+
123
+ j. Share means to provide material to the public by any means or
124
+ process that requires permission under the Licensed Rights, such
125
+ as reproduction, public display, public performance, distribution,
126
+ dissemination, communication, or importation, and to make material
127
+ available to the public including in ways that members of the
128
+ public may access the material from a place and at a time
129
+ individually chosen by them.
130
+
131
+ k. Sui Generis Database Rights means rights other than copyright
132
+ resulting from Directive 96/9/EC of the European Parliament and of
133
+ the Council of 11 March 1996 on the legal protection of databases,
134
+ as amended and/or succeeded, as well as other essentially
135
+ equivalent rights anywhere in the world.
136
+
137
+ l. You means the individual or entity exercising the Licensed Rights
138
+ under this Public License. Your has a corresponding meaning.
139
+
140
+ Section 2 -- Scope.
141
+
142
+ a. License grant.
143
+
144
+ 1. Subject to the terms and conditions of this Public License,
145
+ the Licensor hereby grants You a worldwide, royalty-free,
146
+ non-sublicensable, non-exclusive, irrevocable license to
147
+ exercise the Licensed Rights in the Licensed Material to:
148
+
149
+ a. reproduce and Share the Licensed Material, in whole or
150
+ in part, for NonCommercial purposes only; and
151
+
152
+ b. produce, reproduce, and Share Adapted Material for
153
+ NonCommercial purposes only.
154
+
155
+ 2. Exceptions and Limitations. For the avoidance of doubt, where
156
+ Exceptions and Limitations apply to Your use, this Public
157
+ License does not apply, and You do not need to comply with
158
+ its terms and conditions.
159
+
160
+ 3. Term. The term of this Public License is specified in Section
161
+ 6(a).
162
+
163
+ 4. Media and formats; technical modifications allowed. The
164
+ Licensor authorizes You to exercise the Licensed Rights in
165
+ all media and formats whether now known or hereafter created,
166
+ and to make technical modifications necessary to do so. The
167
+ Licensor waives and/or agrees not to assert any right or
168
+ authority to forbid You from making technical modifications
169
+ necessary to exercise the Licensed Rights, including
170
+ technical modifications necessary to circumvent Effective
171
+ Technological Measures. For purposes of this Public License,
172
+ simply making modifications authorized by this Section 2(a)
173
+ (4) never produces Adapted Material.
174
+
175
+ 5. Downstream recipients.
176
+
177
+ a. Offer from the Licensor -- Licensed Material. Every
178
+ recipient of the Licensed Material automatically
179
+ receives an offer from the Licensor to exercise the
180
+ Licensed Rights under the terms and conditions of this
181
+ Public License.
182
+
183
+ b. No downstream restrictions. You may not offer or impose
184
+ any additional or different terms or conditions on, or
185
+ apply any Effective Technological Measures to, the
186
+ Licensed Material if doing so restricts exercise of the
187
+ Licensed Rights by any recipient of the Licensed
188
+ Material.
189
+
190
+ 6. No endorsement. Nothing in this Public License constitutes or
191
+ may be construed as permission to assert or imply that You
192
+ are, or that Your use of the Licensed Material is, connected
193
+ with, or sponsored, endorsed, or granted official status by,
194
+ the Licensor or others designated to receive attribution as
195
+ provided in Section 3(a)(1)(A)(i).
196
+
197
+ b. Other rights.
198
+
199
+ 1. Moral rights, such as the right of integrity, are not
200
+ licensed under this Public License, nor are publicity,
201
+ privacy, and/or other similar personality rights; however, to
202
+ the extent possible, the Licensor waives and/or agrees not to
203
+ assert any such rights held by the Licensor to the limited
204
+ extent necessary to allow You to exercise the Licensed
205
+ Rights, but not otherwise.
206
+
207
+ 2. Patent and trademark rights are not licensed under this
208
+ Public License.
209
+
210
+ 3. To the extent possible, the Licensor waives any right to
211
+ collect royalties from You for the exercise of the Licensed
212
+ Rights, whether directly or through a collecting society
213
+ under any voluntary or waivable statutory or compulsory
214
+ licensing scheme. In all other cases the Licensor expressly
215
+ reserves any right to collect such royalties, including when
216
+ the Licensed Material is used other than for NonCommercial
217
+ purposes.
218
+
219
+ Section 3 -- License Conditions.
220
+
221
+ Your exercise of the Licensed Rights is expressly made subject to the
222
+ following conditions.
223
+
224
+ a. Attribution.
225
+
226
+ 1. If You Share the Licensed Material (including in modified
227
+ form), You must:
228
+
229
+ a. retain the following if it is supplied by the Licensor
230
+ with the Licensed Material:
231
+
232
+ i. identification of the creator(s) of the Licensed
233
+ Material and any others designated to receive
234
+ attribution, in any reasonable manner requested by
235
+ the Licensor (including by pseudonym if
236
+ designated);
237
+
238
+ ii. a copyright notice;
239
+
240
+ iii. a notice that refers to this Public License;
241
+
242
+ iv. a notice that refers to the disclaimer of
243
+ warranties;
244
+
245
+ v. a URI or hyperlink to the Licensed Material to the
246
+ extent reasonably practicable;
247
+
248
+ b. indicate if You modified the Licensed Material and
249
+ retain an indication of any previous modifications; and
250
+
251
+ c. indicate the Licensed Material is licensed under this
252
+ Public License, and include the text of, or the URI or
253
+ hyperlink to, this Public License.
254
+
255
+ 2. You may satisfy the conditions in Section 3(a)(1) in any
256
+ reasonable manner based on the medium, means, and context in
257
+ which You Share the Licensed Material. For example, it may be
258
+ reasonable to satisfy the conditions by providing a URI or
259
+ hyperlink to a resource that includes the required
260
+ information.
261
+
262
+ 3. If requested by the Licensor, You must remove any of the
263
+ information required by Section 3(a)(1)(A) to the extent
264
+ reasonably practicable.
265
+
266
+ 4. If You Share Adapted Material You produce, the Adapter's
267
+ License You apply must not prevent recipients of the Adapted
268
+ Material from complying with this Public License.
269
+
270
+ Section 4 -- Sui Generis Database Rights.
271
+
272
+ Where the Licensed Rights include Sui Generis Database Rights that
273
+ apply to Your use of the Licensed Material:
274
+
275
+ a. for the avoidance of doubt, Section 2(a)(1) grants You the right
276
+ to extract, reuse, reproduce, and Share all or a substantial
277
+ portion of the contents of the database for NonCommercial purposes
278
+ only;
279
+
280
+ b. if You include all or a substantial portion of the database
281
+ contents in a database in which You have Sui Generis Database
282
+ Rights, then the database in which You have Sui Generis Database
283
+ Rights (but not its individual contents) is Adapted Material; and
284
+
285
+ c. You must comply with the conditions in Section 3(a) if You Share
286
+ all or a substantial portion of the contents of the database.
287
+
288
+ For the avoidance of doubt, this Section 4 supplements and does not
289
+ replace Your obligations under this Public License where the Licensed
290
+ Rights include other Copyright and Similar Rights.
291
+
292
+ Section 5 -- Disclaimer of Warranties and Limitation of Liability.
293
+
294
+ a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE
295
+ EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS
296
+ AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF
297
+ ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS,
298
+ IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION,
299
+ WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR
300
+ PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS,
301
+ ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT
302
+ KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT
303
+ ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU.
304
+
305
+ b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE
306
+ TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION,
307
+ NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT,
308
+ INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES,
309
+ COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR
310
+ USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN
311
+ ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR
312
+ DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR
313
+ IN PART, THIS LIMITATION MAY NOT APPLY TO YOU.
314
+
315
+ c. The disclaimer of warranties and limitation of liability provided
316
+ above shall be interpreted in a manner that, to the extent
317
+ possible, most closely approximates an absolute disclaimer and
318
+ waiver of all liability.
319
+
320
+ Section 6 -- Term and Termination.
321
+
322
+ a. This Public License applies for the term of the Copyright and
323
+ Similar Rights licensed here. However, if You fail to comply with
324
+ this Public License, then Your rights under this Public License
325
+ terminate automatically.
326
+
327
+ b. Where Your right to use the Licensed Material has terminated under
328
+ Section 6(a), it reinstates:
329
+
330
+ 1. automatically as of the date the violation is cured, provided
331
+ it is cured within 30 days of Your discovery of the
332
+ violation; or
333
+
334
+ 2. upon express reinstatement by the Licensor.
335
+
336
+ For the avoidance of doubt, this Section 6(b) does not affect any
337
+ right the Licensor may have to seek remedies for Your violations
338
+ of this Public License.
339
+
340
+ c. For the avoidance of doubt, the Licensor may also offer the
341
+ Licensed Material under separate terms or conditions or stop
342
+ distributing the Licensed Material at any time; however, doing so
343
+ will not terminate this Public License.
344
+
345
+ d. Sections 1, 5, 6, 7, and 8 survive termination of this Public
346
+ License.
347
+
348
+ Section 7 -- Other Terms and Conditions.
349
+
350
+ a. The Licensor shall not be bound by any additional or different
351
+ terms or conditions communicated by You unless expressly agreed.
352
+
353
+ b. Any arrangements, understandings, or agreements regarding the
354
+ Licensed Material not stated herein are separate from and
355
+ independent of the terms and conditions of this Public License.
356
+
357
+ Section 8 -- Interpretation.
358
+
359
+ a. For the avoidance of doubt, this Public License does not, and
360
+ shall not be interpreted to, reduce, limit, restrict, or impose
361
+ conditions on any use of the Licensed Material that could lawfully
362
+ be made without permission under this Public License.
363
+
364
+ b. To the extent possible, if any provision of this Public License is
365
+ deemed unenforceable, it shall be automatically reformed to the
366
+ minimum extent necessary to make it enforceable. If the provision
367
+ cannot be reformed, it shall be severed from this Public License
368
+ without affecting the enforceability of the remaining terms and
369
+ conditions.
370
+
371
+ c. No term or condition of this Public License will be waived and no
372
+ failure to comply consented to unless expressly agreed to by the
373
+ Licensor.
374
+
375
+ d. Nothing in this Public License constitutes or may be interpreted
376
+ as a limitation upon, or waiver of, any privileges and immunities
377
+ that apply to the Licensor or You, including from the legal
378
+ processes of any jurisdiction or authority.
379
+
380
+ =======================================================================
381
+
382
+ Creative Commons is not a party to its public
383
+ licenses. Notwithstanding, Creative Commons may elect to apply one of
384
+ its public licenses to material it publishes and in those instances
385
+ will be considered the “Licensor.” The text of the Creative Commons
386
+ public licenses is dedicated to the public domain under the CC0 Public
387
+ Domain Dedication. Except for the limited purpose of indicating that
388
+ material is shared under a Creative Commons public license or as
389
+ otherwise permitted by the Creative Commons policies published at
390
+ creativecommons.org/policies, Creative Commons does not authorize the
391
+ use of the trademark "Creative Commons" or any other trademark or logo
392
+ of Creative Commons without its prior written consent including,
393
+ without limitation, in connection with any unauthorized modifications
394
+ to any of its public licenses or any other arrangements,
395
+ understandings, or agreements concerning use of licensed material. For
396
+ the avoidance of doubt, this paragraph does not form part of the
397
+ public licenses.
398
+
399
+ Creative Commons may be contacted at creativecommons.org.
PuzzleTuning/Counterpart PreTrain Methods/gcmae/LINPROBE.md ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Linear probe GCMAE
2
+
3
+ ```
4
+ python main_linprobe.py \
5
+ --data_path_train path/to/train/data \
6
+ --data_path_val path/to/val/data \
7
+ --nb_classes 2 \
8
+ --output_dir path/to/ouput/dir \
9
+ --log_dir path/to/log/dir \
10
+ --batch_size 512 \
11
+ --model vit_base_patch16 \
12
+ --epochs 90 \
13
+ --finetune path/to/pth/path
14
+ ```
15
+
PuzzleTuning/Counterpart PreTrain Methods/gcmae/PRETRAIN.md ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Pre-training GCMAE
2
+
3
+ To pre-train ViT-Base (recommended default)
4
+ ```
5
+ python main_pretrain.py \
6
+ --data_path path/to/data \
7
+ --data_val_path path/to/data \
8
+ --output_dir path/to/ouput/dir \
9
+ --log_dir path/to/log/dir \
10
+ --batch_size 128 \
11
+ --model gcmae_vit_base_patch16 \
12
+ --norm_pix_loss \
13
+ --mask_ratio 0.5 \
14
+ --epochs 80 \
15
+ --warmup_epochs 40 \
16
+ --blr 1e-3 --weight_decay 0.05 \
17
+ --low_dim 768 \
18
+ --nce_k 8192 \
19
+ --nce_t 0.07 \
20
+ --nce_m 0.5 \
21
+ ```
PuzzleTuning/Counterpart PreTrain Methods/gcmae/README.md ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # GCMAE
2
+
3
+ The original repo of GCMAE could be found [here](https://github.com/StarUniversus/gcmae)
4
+
5
+ To install environments:
6
+ ```bash
7
+ pip install -r requirements.txt
8
+ ```
9
+
10
+
11
+ To start pretraining:
12
+ ```bash
13
+ # You need to alter the script according to your directories
14
+ bash pretrain.sh
15
+ ```
PuzzleTuning/Counterpart PreTrain Methods/gcmae/README_origin.md ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # GCMAE
2
+
3
+ <p align="center">
4
+ <img src="./GCMAE.png" width="800">
5
+ </p>
6
+
7
+ The official implementation of the paper [Global Contrast Masked Autoencoders Are Powerful Pathological Representation Learners](https://arxiv.org/abs/2205.09048)
8
+
9
+ ```
10
+ @article{li2022gcmae,
11
+ author = {Quan, Hao and Li, Xingyu and Chen, Weixing and Bai, Qun and Zou, Mingchen and Yang, Ruijie and Zheng, Tingting and Qi, Ruiqun and Gao, Xinghua and Cui, Xiaoyu},
12
+ title = {Global Contrast Masked Autoencoders Are Powerful Pathological Representation Learners},
13
+ journal={arXiv:2205.09048},
14
+ year = {2022},
15
+ }
16
+ ```
17
+ ## Abstract
18
+ Based on digital whole slide scanning technique, artificial intelligence algorithms represented by deep learning have achieved remarkable results in the field of computational pathology. Compared with other medical images such as Computed Tomography (CT) or Magnetic Resonance Imaging (MRI), pathological images are more difficult to annotate, thus there is an extreme lack of data sets that can be used for supervised learning. In this study, a self-supervised learning (SSL) model, Global Contrast Masked Autoencoders (GCMAE), is proposed, which has the ability to represent both global and local domain-specific features of whole slide image (WSI), as well as excellent cross-data transfer ability. The Camelyon16 and NCTCRC datasets are used to evaluate the performance of our model. When dealing with transfer learning tasks with different data sets, the experimental results show that GCMAE has better linear classification accuracy than MAE, which can reach **81.10%** and **89.22%** respectively. Our method outperforms the previous state of-the-art algorithm and even surpass supervised learning (improved by **3.86%** on NCTCRC data sets).
19
+
20
+ ## Installation
21
+ This repo is a modification on the [mae repo](https://github.com/facebookresearch/mae). Installation and preparation follow that repo.
22
+
23
+ ## Usage
24
+
25
+ * [PRETRAIN](PRETRAIN.md)
26
+
27
+ * [LINPROBE](LINPROBE.md)
28
+
29
+ * [FINETUNE](FINETUNE.md)
30
+
31
+ * [ Visual GCMAE feature representation](VISUAL.md)
32
+
33
+ ## Dataset
34
+ * [Camelyon16](https://pan.baidu.com/s/1N0fqJR9u8yq-y6ZY0mSoUw?pwd=noms)
35
+ * [NCT-CRC-HE-100K](https://zenodo.org/record/1214456)
36
+ * [BreakHis](https://web.inf.ufpr.br/vri/databases/breast-cancer-histopathological-database-breakhis/)
37
+
38
+ ## License
39
+ Distributed under the CC-BY-NC 4.0 License. See [LICENSE](LICENSE) for more information.
PuzzleTuning/Counterpart PreTrain Methods/gcmae/VISUAL.md ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Visual GCMAE feature representation
2
+
3
+ ```
4
+ python tsne.py \
5
+ --batch_size 128 \
6
+ --model vit_base_patch16 \
7
+ --finetune path/to/pth \
8
+ --save_path path/to/png \
9
+ --data_path_val path/to/data \
10
+ ```
PuzzleTuning/Counterpart PreTrain Methods/gcmae/engine_finetune.py ADDED
@@ -0,0 +1,183 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ # --------------------------------------------------------
7
+ # References:
8
+ # DeiT: https://github.com/facebookresearch/deit
9
+ # BEiT: https://github.com/microsoft/unilm/tree/master/beit
10
+ # MAE: https://github.com/facebookresearch/mae
11
+ # --------------------------------------------------------
12
+
13
+ import math
14
+ import sys
15
+ from typing import Iterable, Optional
16
+
17
+ import torch
18
+
19
+ from timm.data import Mixup
20
+ from timm.utils import accuracy
21
+
22
+ import util.misc as misc
23
+ import util.lr_sched as lr_sched
24
+ from sklearn.metrics import roc_auc_score
25
+ from sklearn.metrics import precision_score, recall_score, f1_score, confusion_matrix
26
+ import torch.nn.functional as F
27
+ import numpy
28
+ import numpy as np
29
+ from torchmetrics import Specificity, AUROC
30
+
31
+
32
+
33
+ def train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module,
34
+ data_loader: Iterable, optimizer: torch.optim.Optimizer,
35
+ device: torch.device, epoch: int, loss_scaler, max_norm: float = 0,
36
+ mixup_fn: Optional[Mixup] = None, log_writer=None,
37
+ args=None):
38
+ model.train(True)
39
+ metric_logger = misc.MetricLogger(delimiter=" ")
40
+ metric_logger.add_meter('lr', misc.SmoothedValue(window_size=1, fmt='{value:.6f}'))
41
+ header = 'Epoch: [{}]'.format(epoch)
42
+ print_freq = 20
43
+
44
+ accum_iter = args.accum_iter
45
+
46
+ optimizer.zero_grad()
47
+
48
+ if log_writer is not None:
49
+ print('log_dir: {}'.format(log_writer.log_dir))
50
+
51
+ for data_iter_step, (samples, targets) in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
52
+
53
+ # we use a per iteration (instead of per epoch) lr scheduler
54
+ if data_iter_step % accum_iter == 0:
55
+ lr_sched.adjust_learning_rate(optimizer, data_iter_step / len(data_loader) + epoch, args)
56
+
57
+ samples = samples.to(device, non_blocking=True)
58
+ targets = targets.to(device, non_blocking=True)
59
+
60
+ if mixup_fn is not None:
61
+ samples, targets = mixup_fn(samples, targets)
62
+
63
+ with torch.cuda.amp.autocast():
64
+ outputs = model(samples)
65
+ loss = criterion(outputs, targets)
66
+
67
+ loss_value = loss.item()
68
+
69
+ if not math.isfinite(loss_value):
70
+ print("Loss is {}, stopping training".format(loss_value))
71
+ sys.exit(1)
72
+
73
+ loss /= accum_iter
74
+ loss_scaler(loss, optimizer, clip_grad=max_norm,
75
+ parameters=model.parameters(), create_graph=False,
76
+ update_grad=(data_iter_step + 1) % accum_iter == 0)
77
+ if (data_iter_step + 1) % accum_iter == 0:
78
+ optimizer.zero_grad()
79
+
80
+ torch.cuda.synchronize()
81
+
82
+ metric_logger.update(loss=loss_value)
83
+ min_lr = 10.
84
+ max_lr = 0.
85
+ for group in optimizer.param_groups:
86
+ min_lr = min(min_lr, group["lr"])
87
+ max_lr = max(max_lr, group["lr"])
88
+
89
+ metric_logger.update(lr=max_lr)
90
+
91
+ loss_value_reduce = misc.all_reduce_mean(loss_value)
92
+ if log_writer is not None and (data_iter_step + 1) % accum_iter == 0:
93
+ """ We use epoch_1000x as the x-axis in tensorboard.
94
+ This calibrates different curves when batch size changes.
95
+ """
96
+ epoch_1000x = int((data_iter_step / len(data_loader) + epoch) * 1000)
97
+ log_writer.add_scalar('loss', loss_value_reduce, epoch_1000x)
98
+ log_writer.add_scalar('lr', max_lr, epoch_1000x)
99
+
100
+ # gather the stats from all processes
101
+ metric_logger.synchronize_between_processes()
102
+ print("Averaged stats:", metric_logger)
103
+ return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
104
+
105
+ def confusion_m(y_true, y_pred):
106
+ tn, fp, fn, tp = confusion_matrix(y_true, y_pred).ravel()
107
+ return tn, fp, fn, tp
108
+
109
+ def compute_metrics_binary(probs, preds, targets):
110
+ auc = roc_auc_score(targets, probs) * 100
111
+ precision = precision_score(targets, preds) * 100
112
+ recall = recall_score(targets, preds) * 100
113
+ f1 = f1_score(targets, preds) * 100
114
+ tn, fp, fn, tp = confusion_m(targets, preds)
115
+ specificity = (tn / float(tn+fp)) * 100
116
+ return auc, precision, recall, f1, specificity
117
+
118
+ def compute_metrics_multiclass(probs, preds, targets, nb_classes):
119
+ preds_tensor, probs_tensor, targets_tensor = torch.tensor(preds), torch.tensor(probs), torch.tensor(targets)
120
+ auroc = AUROC(average='macro', num_classes=nb_classes)
121
+ auc = auroc(probs_tensor, targets_tensor) * 100
122
+ precision = precision_score(targets, preds, average='macro') * 100
123
+ recall = recall_score(targets, preds, average='macro') * 100
124
+ f1 = f1_score(targets, preds, average='macro') * 100
125
+ speci = Specificity(average='macro', num_classes=nb_classes)
126
+ specificity = speci(preds_tensor, targets_tensor) * 100
127
+ return auc, precision, recall, f1, specificity
128
+
129
+ @torch.no_grad()
130
+ def evaluate(data_loader, model, device, nb_classes):
131
+ criterion = torch.nn.CrossEntropyLoss()
132
+ m = torch.nn.Softmax(dim=1)
133
+ metric_logger = misc.MetricLogger(delimiter=" ")
134
+ header = 'Test:'
135
+
136
+ # switch to evaluation mode
137
+ model.eval()
138
+ probs = []
139
+ targets = []
140
+ preds = []
141
+
142
+ for batch in metric_logger.log_every(data_loader, 10, header):
143
+ images = batch[0]
144
+ target = batch[-1]
145
+ images = images.to(device, non_blocking=True)
146
+ target = target.to(device, non_blocking=True)
147
+
148
+ # compute output
149
+ with torch.cuda.amp.autocast():
150
+ output = model(images)
151
+ loss = criterion(output, target)
152
+ output = m(output)
153
+ score, pred = output.topk(1, 1, True, True)
154
+ if nb_classes == 2:
155
+ prob = output[:, 1]
156
+ elif nb_classes > 2:
157
+ prob = output
158
+
159
+ probs.extend(prob.detach().cpu().numpy())
160
+ targets.extend(target.detach().cpu().numpy())
161
+ preds.extend(pred.tolist())
162
+ if nb_classes < 5:
163
+ acc1, acc5 = accuracy(output, target, topk=(1, 1))
164
+ else:
165
+ acc1, acc5 = accuracy(output, target, topk=(1, 5))
166
+
167
+ batch_size = images.shape[0]
168
+ metric_logger.update(loss=loss.item())
169
+ metric_logger.meters['acc1'].update(acc1.item(), n=batch_size)
170
+ metric_logger.meters['acc5'].update(acc5.item(), n=batch_size)
171
+ # gather the stats from all processes
172
+ metric_logger.synchronize_between_processes()
173
+ print('* Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f} loss {losses.global_avg:.3f}'
174
+ .format(top1=metric_logger.acc1, top5=metric_logger.acc5, losses=metric_logger.loss))
175
+ if nb_classes == 2:
176
+ print("binary class metrics!")
177
+ auc, precision, recall, f1, specificity = compute_metrics_binary(probs, preds, targets)
178
+
179
+ elif nb_classes > 2:
180
+ print("multi_class metrics!")
181
+ auc, precision, recall, f1, specificity = compute_metrics_multiclass(probs, preds, targets, nb_classes)
182
+
183
+ return {k: meter.global_avg for k, meter in metric_logger.meters.items()}, auc, precision, recall, f1, specificity
PuzzleTuning/Counterpart PreTrain Methods/gcmae/engine_pretrain.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ # --------------------------------------------------------
7
+ # References:
8
+ # DeiT: https://github.com/facebookresearch/deit
9
+ # BEiT: https://github.com/microsoft/unilm/tree/master/beit
10
+ # --------------------------------------------------------
11
+ import math
12
+ import sys
13
+ from typing import Iterable
14
+
15
+ import torch
16
+
17
+ import util.misc as misc
18
+ import util.lr_sched as lr_sched
19
+ from test_npid import NN, kNN
20
+
21
+ def train_one_epoch(model: torch.nn.Module,
22
+ data_loader: Iterable,
23
+ optimizer: torch.optim.Optimizer,
24
+ device: torch.device, epoch: int, loss_scaler,
25
+ log_writer=None,
26
+ args=None,
27
+ lemniscate=None,
28
+ ):
29
+ model.train(True)
30
+ metric_logger = misc.MetricLogger(delimiter=" ")
31
+ metric_logger.add_meter('lr', misc.SmoothedValue(window_size=1, fmt='{value:.6f}'))
32
+ header = 'Epoch: [{}]'.format(epoch)
33
+ print_freq = 20
34
+
35
+ accum_iter = args.accum_iter
36
+
37
+ optimizer.zero_grad()
38
+
39
+ if log_writer is not None:
40
+ print('log_dir: {}'.format(log_writer.log_dir))
41
+
42
+ for data_iter_step, (samples, _, index) in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
43
+
44
+ # we use a per iteration (instead of per epoch) lr scheduler
45
+ if data_iter_step % accum_iter == 0:
46
+ lr_sched.adjust_learning_rate(optimizer, data_iter_step / len(data_loader) + epoch, args)
47
+
48
+ samples = samples.to(device, non_blocking=True)
49
+ index = index.to(device, non_blocking=True)
50
+ with torch.cuda.amp.autocast():
51
+ loss_mae, _, _, loss_npid, _= model(samples, mask_ratio=args.mask_ratio, index = index, is_train=True)
52
+ loss = loss_mae + 0.1 * loss_npid
53
+
54
+ loss_value = loss.item()
55
+
56
+ if not math.isfinite(loss_value):
57
+ print("Loss is {}, stopping training".format(loss_value))
58
+ sys.exit(1)
59
+
60
+ loss /= accum_iter
61
+ loss_scaler(loss, optimizer, parameters=model.parameters(),
62
+ update_grad=(data_iter_step + 1) % accum_iter == 0)
63
+ if (data_iter_step + 1) % accum_iter == 0:
64
+ optimizer.zero_grad()
65
+
66
+ torch.cuda.synchronize()
67
+
68
+
69
+ metric_logger.update(loss_all=loss_value)
70
+ metric_logger.update(loss_mae=loss_mae.item())
71
+ metric_logger.update(loss_npid=loss_npid.item())
72
+
73
+
74
+ lr = optimizer.param_groups[0]["lr"]
75
+ metric_logger.update(lr=lr)
76
+
77
+ loss_value_reduce = misc.all_reduce_mean(loss_value)
78
+ if log_writer is not None and (data_iter_step + 1) % accum_iter == 0:
79
+ """ We use epoch_1000x as the x-axis in tensorboard.
80
+ This calibrates different curves when batch size changes.
81
+ """
82
+
83
+ epoch_1000x = int((data_iter_step / len(data_loader) + epoch) * 1000)
84
+ log_writer.add_scalar('train_loss', loss_value_reduce, epoch_1000x)
85
+ log_writer.add_scalar('origin_loss/train_loss_mae', loss_mae.item(), epoch_1000x)
86
+ log_writer.add_scalar('origin_loss/train_loss_npid', loss_npid.item(), epoch_1000x)
87
+ log_writer.add_scalar('lr', lr, epoch_1000x)
88
+
89
+ # pred1 = NN(epoch, model, lemniscate, data_loader, data_loader_val)
90
+ # log_writer.add_scalar('NN_ac', pred1, epoch)
91
+ # if args.output_dir and (epoch % 20 == 0 or epoch + 1 == args.epochs):
92
+
93
+ # top1 = kNN(0, model, lemniscate, data_loader, data_loader_val, 200, args.nce_t)
94
+ # log_writer.add_scalar('KNN_top1', top1, epoch)
95
+
96
+ metric_logger.synchronize_between_processes()
97
+ print("Averaged stats:", metric_logger)
98
+ # return {k: meter.global_avg for k, meter in metric_logger.meters.items()}, pred1
99
+ return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
PuzzleTuning/Counterpart PreTrain Methods/gcmae/lib/NCEAverage.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.autograd import Function
3
+ from torch import nn
4
+ from .alias_multinomial import AliasMethod
5
+ import math
6
+
7
+ class NCEFunction(Function):
8
+ @staticmethod
9
+ def forward(self, x, y, memory, idx, params):
10
+ K = int(params[0].item())
11
+ T = params[1].item()
12
+ Z = params[2].item()
13
+
14
+ momentum = params[3].item()
15
+ batchSize = x.size(0)
16
+ outputSize = memory.size(0)
17
+ inputSize = memory.size(1)
18
+
19
+ # sample positives & negatives
20
+ idx.select(1,0).copy_(y.detach())
21
+
22
+ # sample correspoinding weights
23
+ weight = torch.index_select(memory, 0, idx.view(-1))
24
+ weight.resize_(batchSize, K+1, inputSize)
25
+
26
+ # inner product
27
+ with torch.cuda.amp.autocast(enabled = False):
28
+ out = torch.bmm(weight, x.detach().reshape(batchSize, inputSize, 1))
29
+ #print(out)
30
+ out.div_(T).exp_() # batchSize * self.K+1
31
+ #x.detach().resize_(batchSize, inputSize)
32
+ #print(out)
33
+ if Z < 0:
34
+ params[2] = out.mean() * outputSize
35
+ Z = params[2].item()
36
+ print("normalization constant Z is set to {:.1f}".format(Z))
37
+
38
+ out.div_(Z).resize_(batchSize, K+1)
39
+
40
+ self.save_for_backward(x, memory, y, weight, out, params)
41
+
42
+ return out
43
+
44
+ @staticmethod
45
+ def backward(self, gradOutput):
46
+ x, memory, y, weight, out, params = self.saved_tensors
47
+ K = int(params[0].item())
48
+ T = params[1].item()
49
+ Z = params[2].item()
50
+ momentum = params[3].item()
51
+ batchSize = gradOutput.size(0)
52
+
53
+ # gradients d Pm / d linear = exp(linear) / Z
54
+ gradOutput.detach().mul_(out.detach())
55
+ # add temperature
56
+ gradOutput.detach().div_(T)
57
+
58
+
59
+
60
+ # gradient of linear
61
+ with torch.cuda.amp.autocast(enabled = False):
62
+ gradInput = torch.bmm(gradOutput.detach().reshape(batchSize, 1, K+1), weight)
63
+ gradInput.resize_as_(x)
64
+
65
+ # update the non-parametric detach()
66
+ weight_pos = weight.select(1, 0).resize_as_(x)
67
+ weight_pos.mul_(momentum)
68
+ weight_pos.add_(torch.mul(x.detach(), 1-momentum))
69
+ w_norm = weight_pos.pow(2).sum(1, keepdim=True).pow(0.5)
70
+ updated_weight = weight_pos.div(w_norm)
71
+ memory.index_copy_(0, y, updated_weight)
72
+
73
+ return gradInput, None, None, None, None
74
+
75
+ class NCEAverage(nn.Module):
76
+
77
+ def __init__(self, inputSize, outputSize, K, T=0.07, momentum=0.5, Z=None):
78
+ super(NCEAverage, self).__init__()
79
+ self.nLem = outputSize
80
+ self.unigrams = torch.ones(self.nLem)
81
+ self.multinomial = AliasMethod(self.unigrams)
82
+ self.multinomial.cuda()
83
+ self.K = K
84
+
85
+ self.register_buffer('params',torch.tensor([K, T, -1, momentum]))
86
+ stdv = 1. / math.sqrt(inputSize/3)
87
+ self.register_buffer('memory', torch.rand(outputSize, inputSize).mul_(2*stdv).add_(-stdv))
88
+
89
+ def forward(self, x, y):
90
+ batchSize = x.size(0)
91
+ idx = self.multinomial.draw(batchSize * (self.K+1)).view(batchSize, -1)
92
+ out = NCEFunction.apply(x, y, self.memory, idx, self.params)
93
+ return out
94
+
PuzzleTuning/Counterpart PreTrain Methods/gcmae/lib/NCECriterion.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+
4
+ eps = 1e-7
5
+
6
+ class NCECriterion(nn.Module):
7
+
8
+ def __init__(self, nLem):
9
+ super(NCECriterion, self).__init__()
10
+ self.nLem = nLem
11
+
12
+ def forward(self, x, targets):
13
+ batchSize = x.size(0)
14
+ K = x.size(1)-1
15
+ Pnt = 1 / float(self.nLem)
16
+ Pns = 1 / float(self.nLem)
17
+
18
+ # eq 5.1 : P(origin=model) = Pmt / (Pmt + k*Pnt)
19
+ Pmt = x.select(1,0)
20
+ Pmt_div = Pmt.add(K * Pnt + eps)
21
+ lnPmt = torch.div(Pmt, Pmt_div)
22
+
23
+ # eq 5.2 : P(origin=noise) = k*Pns / (Pms + k*Pns)
24
+ Pon_div = x.narrow(1,1,K).add(K * Pns + eps)
25
+ Pon = Pon_div.clone().fill_(K * Pns)
26
+ lnPon = torch.div(Pon, Pon_div)
27
+
28
+ # equation 6 in ref. A
29
+ lnPmt.log_()
30
+ lnPon.log_()
31
+
32
+ lnPmtsum = lnPmt.sum(0)
33
+ lnPonsum = lnPon.view(-1, 1).sum(0)
34
+
35
+ loss = - (lnPmtsum + lnPonsum) / batchSize
36
+
37
+ return loss
38
+
PuzzleTuning/Counterpart PreTrain Methods/gcmae/lib/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # nothing
PuzzleTuning/Counterpart PreTrain Methods/gcmae/lib/__pycache__/NCEAverage.cpython-38.pyc ADDED
Binary file (2.96 kB). View file
 
PuzzleTuning/Counterpart PreTrain Methods/gcmae/lib/__pycache__/NCECriterion.cpython-38.pyc ADDED
Binary file (1.13 kB). View file
 
PuzzleTuning/Counterpart PreTrain Methods/gcmae/lib/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (127 Bytes). View file
 
PuzzleTuning/Counterpart PreTrain Methods/gcmae/lib/__pycache__/alias_multinomial.cpython-38.pyc ADDED
Binary file (1.68 kB). View file
 
PuzzleTuning/Counterpart PreTrain Methods/gcmae/lib/__pycache__/normalize.cpython-38.pyc ADDED
Binary file (819 Bytes). View file
 
PuzzleTuning/Counterpart PreTrain Methods/gcmae/lib/__pycache__/utils.cpython-38.pyc ADDED
Binary file (873 Bytes). View file
 
PuzzleTuning/Counterpart PreTrain Methods/gcmae/lib/alias_multinomial.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+
4
+ class AliasMethod(object):
5
+ '''
6
+ From: https://hips.seas.harvard.edu/blog/2013/03/03/the-alias-method-efficient-sampling-with-many-discrete-outcomes/
7
+ '''
8
+ def __init__(self, probs):
9
+
10
+ if probs.sum() > 1:
11
+ probs.div_(probs.sum())
12
+ K = len(probs)
13
+ self.prob = torch.zeros(K)
14
+ self.alias = torch.LongTensor([0]*K)
15
+
16
+ # Sort the data into the outcomes with probabilities
17
+ # that are larger and smaller than 1/K.
18
+ smaller = []
19
+ larger = []
20
+ for kk, prob in enumerate(probs):
21
+ self.prob[kk] = K*prob
22
+ if self.prob[kk] < 1.0:
23
+ smaller.append(kk)
24
+ else:
25
+ larger.append(kk)
26
+
27
+ # Loop though and create little binary mixtures that
28
+ # appropriately allocate the larger outcomes over the
29
+ # overall uniform mixture.
30
+ while len(smaller) > 0 and len(larger) > 0:
31
+ small = smaller.pop()
32
+ large = larger.pop()
33
+
34
+ self.alias[small] = large
35
+ self.prob[large] = (self.prob[large] - 1.0) + self.prob[small]
36
+
37
+ if self.prob[large] < 1.0:
38
+ smaller.append(large)
39
+ else:
40
+ larger.append(large)
41
+
42
+ for last_one in smaller+larger:
43
+ self.prob[last_one] = 1
44
+
45
+ def cuda(self):
46
+ self.prob = self.prob.cuda()
47
+ self.alias = self.alias.cuda()
48
+
49
+ def draw(self, N):
50
+ '''
51
+ Draw N samples from multinomial
52
+ '''
53
+ K = self.alias.size(0)
54
+
55
+ kk = torch.zeros(N, dtype=torch.long, device=self.prob.device).random_(0, K)
56
+ prob = self.prob.index_select(0, kk)
57
+ alias = self.alias.index_select(0, kk)
58
+ # b is whether a random number is greater than q
59
+ b = torch.bernoulli(prob)
60
+ oq = kk.mul(b.long())
61
+ oj = alias.mul((1-b).long())
62
+
63
+ return oq + oj
64
+
PuzzleTuning/Counterpart PreTrain Methods/gcmae/lib/normalize.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.autograd import Variable
3
+ from torch import nn
4
+
5
+ class Normalize(nn.Module):
6
+
7
+ def __init__(self, power=2):
8
+ super(Normalize, self).__init__()
9
+ self.power = power
10
+
11
+ def forward(self, x):
12
+ norm = x.pow(self.power).sum(1, keepdim=True).pow(1./self.power)
13
+ out = x.div(norm)
14
+ return out