init submit
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +1 -0
- PuzzleTuning/Archive/PuzzleTuning_demoset.zip +3 -0
- PuzzleTuning/Archive/build_CPIA_dataset.sh +35 -0
- PuzzleTuning/Archive/warwick_CLS.zip +3 -0
- PuzzleTuning/Backbone/GetPromptModel.py +87 -0
- PuzzleTuning/Backbone/ResHybrid.py +481 -0
- PuzzleTuning/Backbone/Transformer_blocks.py +1631 -0
- PuzzleTuning/Backbone/VPT_structure.py +133 -0
- PuzzleTuning/Backbone/attention_modules.py +303 -0
- PuzzleTuning/Backbone/counterpart_models/README.md +25 -0
- PuzzleTuning/Backbone/counterpart_models/conformer.py +483 -0
- PuzzleTuning/Backbone/counterpart_models/crossformer.py +624 -0
- PuzzleTuning/Backbone/counterpart_models/crossformer_backbone.py +659 -0
- PuzzleTuning/Backbone/getmodel.py +392 -0
- PuzzleTuning/Counterpart PreTrain Methods/ReadMe.md +2 -0
- PuzzleTuning/Counterpart PreTrain Methods/dino-main/LICENSE +201 -0
- PuzzleTuning/Counterpart PreTrain Methods/dino-main/README.md +14 -0
- PuzzleTuning/Counterpart PreTrain Methods/dino-main/eval_copy_detection.py +301 -0
- PuzzleTuning/Counterpart PreTrain Methods/dino-main/eval_image_retrieval.py +201 -0
- PuzzleTuning/Counterpart PreTrain Methods/dino-main/eval_knn.py +242 -0
- PuzzleTuning/Counterpart PreTrain Methods/dino-main/eval_linear.py +281 -0
- PuzzleTuning/Counterpart PreTrain Methods/dino-main/eval_video_segmentation.py +292 -0
- PuzzleTuning/Counterpart PreTrain Methods/dino-main/hubconf.py +151 -0
- PuzzleTuning/Counterpart PreTrain Methods/dino-main/main_dino.py +515 -0
- PuzzleTuning/Counterpart PreTrain Methods/dino-main/run_with_submitit.py +132 -0
- PuzzleTuning/Counterpart PreTrain Methods/dino-main/utils.py +829 -0
- PuzzleTuning/Counterpart PreTrain Methods/dino-main/video_generation.py +378 -0
- PuzzleTuning/Counterpart PreTrain Methods/dino-main/vision_transformer.py +291 -0
- PuzzleTuning/Counterpart PreTrain Methods/dino-main/visualize_attention.py +213 -0
- PuzzleTuning/Counterpart PreTrain Methods/gcmae/FINETUNE.md +13 -0
- PuzzleTuning/Counterpart PreTrain Methods/gcmae/GCMAE.png +0 -0
- PuzzleTuning/Counterpart PreTrain Methods/gcmae/LICENSE +399 -0
- PuzzleTuning/Counterpart PreTrain Methods/gcmae/LINPROBE.md +15 -0
- PuzzleTuning/Counterpart PreTrain Methods/gcmae/PRETRAIN.md +21 -0
- PuzzleTuning/Counterpart PreTrain Methods/gcmae/README.md +15 -0
- PuzzleTuning/Counterpart PreTrain Methods/gcmae/README_origin.md +39 -0
- PuzzleTuning/Counterpart PreTrain Methods/gcmae/VISUAL.md +10 -0
- PuzzleTuning/Counterpart PreTrain Methods/gcmae/engine_finetune.py +183 -0
- PuzzleTuning/Counterpart PreTrain Methods/gcmae/engine_pretrain.py +99 -0
- PuzzleTuning/Counterpart PreTrain Methods/gcmae/lib/NCEAverage.py +94 -0
- PuzzleTuning/Counterpart PreTrain Methods/gcmae/lib/NCECriterion.py +38 -0
- PuzzleTuning/Counterpart PreTrain Methods/gcmae/lib/__init__.py +1 -0
- PuzzleTuning/Counterpart PreTrain Methods/gcmae/lib/__pycache__/NCEAverage.cpython-38.pyc +0 -0
- PuzzleTuning/Counterpart PreTrain Methods/gcmae/lib/__pycache__/NCECriterion.cpython-38.pyc +0 -0
- PuzzleTuning/Counterpart PreTrain Methods/gcmae/lib/__pycache__/__init__.cpython-38.pyc +0 -0
- PuzzleTuning/Counterpart PreTrain Methods/gcmae/lib/__pycache__/alias_multinomial.cpython-38.pyc +0 -0
- PuzzleTuning/Counterpart PreTrain Methods/gcmae/lib/__pycache__/normalize.cpython-38.pyc +0 -0
- PuzzleTuning/Counterpart PreTrain Methods/gcmae/lib/__pycache__/utils.cpython-38.pyc +0 -0
- PuzzleTuning/Counterpart PreTrain Methods/gcmae/lib/alias_multinomial.py +64 -0
- PuzzleTuning/Counterpart PreTrain Methods/gcmae/lib/normalize.py +14 -0
.gitattributes
CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
PuzzleTuning/Counterpart[[:space:]]PreTrain[[:space:]]Methods/gcmae/nohup.out filter=lfs diff=lfs merge=lfs -text
|
PuzzleTuning/Archive/PuzzleTuning_demoset.zip
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:363cbdc5f8e944f99542a5727d4b2457dd79bbfca7eb70d76857de4b8be92858
|
3 |
+
size 4987321
|
PuzzleTuning/Archive/build_CPIA_dataset.sh
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/sh
|
2 |
+
# go to the dataset location
|
3 |
+
|
4 |
+
# altering the zip files
|
5 |
+
|
6 |
+
zip -F L.zip --out L_Scale.zip
|
7 |
+
zip -FF L_Scale.zip --out L.zip -fz
|
8 |
+
zip -F M.zip --out M_Scale.zip
|
9 |
+
zip -FF M_Scale.zip --out M.zip -fz
|
10 |
+
|
11 |
+
rm -f L_Scale.zip
|
12 |
+
rm -f L.z01
|
13 |
+
rm -f M_Scale.zip
|
14 |
+
rm -f M.z01
|
15 |
+
rm -f M.z02
|
16 |
+
|
17 |
+
# build a directory of datasets
|
18 |
+
mkdir datasets
|
19 |
+
mv L.zip datasets
|
20 |
+
mv M.zip datasets
|
21 |
+
mv S.zip datasets
|
22 |
+
|
23 |
+
cd datasets
|
24 |
+
unzip L.zip
|
25 |
+
unzip M.zip
|
26 |
+
unzip S.zip
|
27 |
+
|
28 |
+
rm -f L.zip
|
29 |
+
rm -f M.zip
|
30 |
+
rm -f S.zip
|
31 |
+
|
32 |
+
mkdir All
|
33 |
+
cp -r L/* All/ &
|
34 |
+
cp -r M/* All/ &
|
35 |
+
cp -r S/* All/
|
PuzzleTuning/Archive/warwick_CLS.zip
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:9813cdc86e2420476e06638965df6040cec0197b148be786f0bf88b020e445f6
|
3 |
+
size 15098031
|
PuzzleTuning/Backbone/GetPromptModel.py
ADDED
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
build_promptmodel Script ver: Oct 17th 14:20
|
3 |
+
|
4 |
+
"""
|
5 |
+
|
6 |
+
try:
|
7 |
+
from Backbone.VPT_structure import *
|
8 |
+
except:
|
9 |
+
from Backbone.VPT_structure import *
|
10 |
+
|
11 |
+
|
12 |
+
def build_promptmodel(num_classes=1000, edge_size=224, model_idx='ViT', patch_size=16,
|
13 |
+
Prompt_Token_num=20, VPT_type="Deep", prompt_state_dict=None, base_state_dict='timm'):
|
14 |
+
"""
|
15 |
+
following the https://github.com/sagizty/VPT
|
16 |
+
this build the VPT (prompt version of ViT), with additional prompt tokens,
|
17 |
+
each layer the information become [B, N_patch + N_prompt, Dim]
|
18 |
+
|
19 |
+
During training only the prompt tokens and the head layer are
|
20 |
+
set to be learnable while the rest of Transformer layers are frozen
|
21 |
+
|
22 |
+
# VPT_type = "Shallow" / "Deep"
|
23 |
+
- Shallow: concatenate N_prompt of prompt tokens before the first Transformer Encoder block,
|
24 |
+
each layer the information become [B, N_patch + N_prompt, Dim]
|
25 |
+
- Deep: concatenate N_prompt of prompt tokens to each Transformer Encoder block,
|
26 |
+
this will replace the output prompt tokens learnt form previous encoder.
|
27 |
+
"""
|
28 |
+
|
29 |
+
if model_idx[0:3] == 'ViT':
|
30 |
+
|
31 |
+
if base_state_dict is None:
|
32 |
+
basic_state_dict = None
|
33 |
+
|
34 |
+
elif type(base_state_dict) == str:
|
35 |
+
if base_state_dict == 'timm':
|
36 |
+
# ViT_Prompt
|
37 |
+
import timm
|
38 |
+
# from pprint import pprint
|
39 |
+
# model_names = timm.list_models('*vit*')
|
40 |
+
# pprint(model_names)
|
41 |
+
|
42 |
+
basic_model = timm.create_model('vit_base_patch' + str(patch_size) + '_' + str(edge_size),
|
43 |
+
pretrained=True)
|
44 |
+
basic_state_dict = basic_model.state_dict()
|
45 |
+
print('in prompt model building, timm ViT loaded for base_state_dict')
|
46 |
+
|
47 |
+
else:
|
48 |
+
basic_state_dict = None
|
49 |
+
print('in prompt model building, no vaild str for base_state_dict')
|
50 |
+
|
51 |
+
else: # state dict: collections.OrderedDict
|
52 |
+
basic_state_dict = base_state_dict
|
53 |
+
print('in prompt model building, a .pth base_state_dict loaded')
|
54 |
+
|
55 |
+
model = VPT_ViT(img_size=edge_size, patch_size=patch_size, Prompt_Token_num=Prompt_Token_num,
|
56 |
+
VPT_type=VPT_type, basic_state_dict=basic_state_dict)
|
57 |
+
|
58 |
+
model.New_CLS_head(num_classes)
|
59 |
+
|
60 |
+
if prompt_state_dict is not None:
|
61 |
+
try:
|
62 |
+
model.load_prompt(prompt_state_dict)
|
63 |
+
except:
|
64 |
+
print('erro in .pth prompt_state_dict')
|
65 |
+
else:
|
66 |
+
print('in prompt model building, a .pth prompt_state_dict loaded')
|
67 |
+
|
68 |
+
model.Freeze()
|
69 |
+
else:
|
70 |
+
print("The model is not difined in the Prompt script!!")
|
71 |
+
return -1
|
72 |
+
|
73 |
+
try:
|
74 |
+
img = torch.randn(1, 3, edge_size, edge_size)
|
75 |
+
preds = model(img) # (1, class_number)
|
76 |
+
print('Build VPT model with in/out shape: ', img.shape, ' -> ', preds.shape)
|
77 |
+
|
78 |
+
except:
|
79 |
+
print("Problem exist in the model defining process!!")
|
80 |
+
return -1
|
81 |
+
else:
|
82 |
+
print('model is ready now!')
|
83 |
+
return model
|
84 |
+
|
85 |
+
|
86 |
+
if __name__ == '__main__':
|
87 |
+
model = build_promptmodel(prompt_state_dict=None, base_state_dict='timm', num_classes=0)
|
PuzzleTuning/Backbone/ResHybrid.py
ADDED
@@ -0,0 +1,481 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
MSHT
|
3 |
+
|
4 |
+
Models ver: OCT 27th 20:00 official release
|
5 |
+
|
6 |
+
by the authors, check our github page:
|
7 |
+
https://github.com/sagizty/Multi-Stage-Hybrid-Transformer
|
8 |
+
|
9 |
+
|
10 |
+
ResNet stages' feature map
|
11 |
+
|
12 |
+
# input = 3, 384, 384
|
13 |
+
torch.Size([1, 256, 96, 96])
|
14 |
+
torch.Size([1, 512, 48, 48])
|
15 |
+
torch.Size([1, 1024, 24, 24])
|
16 |
+
torch.Size([1, 2048, 12, 12])
|
17 |
+
torch.Size([1, 1000])
|
18 |
+
|
19 |
+
# input = 3, 224, 224
|
20 |
+
torch.Size([1, 256, 56, 56])
|
21 |
+
torch.Size([1, 512, 28, 28])
|
22 |
+
torch.Size([1, 1024, 14, 14])
|
23 |
+
torch.Size([1, 2048, 7, 7])
|
24 |
+
torch.Size([1, 1000])
|
25 |
+
|
26 |
+
ref
|
27 |
+
https://note.youdao.com/ynoteshare1/index.html?id=5a7dbe1a71713c317062ddeedd97d98e&type=note
|
28 |
+
"""
|
29 |
+
import torch
|
30 |
+
from torch import nn
|
31 |
+
from functools import partial
|
32 |
+
from torchsummary import summary
|
33 |
+
import os
|
34 |
+
from Backbone import Transformer_blocks
|
35 |
+
|
36 |
+
|
37 |
+
# ResNet Bottleneck_block_constructor
|
38 |
+
class Bottleneck_block_constructor(nn.Module):
|
39 |
+
|
40 |
+
extention = 4
|
41 |
+
|
42 |
+
# 定义初始化的网络和参数
|
43 |
+
def __init__(self, inplane, midplane, stride, downsample=None):
|
44 |
+
super(Bottleneck_block_constructor, self).__init__()
|
45 |
+
|
46 |
+
outplane = midplane * self.extention
|
47 |
+
|
48 |
+
self.conv1 = nn.Conv2d(inplane, midplane, kernel_size=1, stride=stride, bias=False)
|
49 |
+
self.bn1 = nn.BatchNorm2d(midplane)
|
50 |
+
|
51 |
+
self.conv2 = nn.Conv2d(midplane, midplane, kernel_size=3, stride=1, padding=1, bias=False)
|
52 |
+
self.bn2 = nn.BatchNorm2d(midplane)
|
53 |
+
|
54 |
+
self.conv3 = nn.Conv2d(midplane, outplane, kernel_size=1, stride=1, bias=False)
|
55 |
+
self.bn3 = nn.BatchNorm2d(midplane * self.extention)
|
56 |
+
|
57 |
+
self.relu = nn.ReLU(inplace=False)
|
58 |
+
|
59 |
+
self.downsample = downsample
|
60 |
+
self.stride = stride
|
61 |
+
|
62 |
+
def forward(self, x):
|
63 |
+
|
64 |
+
out = self.relu(self.bn1(self.conv1(x)))
|
65 |
+
out = self.relu(self.bn2(self.conv2(out)))
|
66 |
+
out = self.relu(self.bn3(self.conv3(out)))
|
67 |
+
|
68 |
+
if self.downsample is not None:
|
69 |
+
residual = self.downsample(x)
|
70 |
+
else:
|
71 |
+
residual = x
|
72 |
+
|
73 |
+
out += residual
|
74 |
+
|
75 |
+
out = self.relu(out)
|
76 |
+
|
77 |
+
return out
|
78 |
+
|
79 |
+
|
80 |
+
# Hybrid_backbone of ResNets
|
81 |
+
class Hybrid_backbone_4(nn.Module):
|
82 |
+
|
83 |
+
def __init__(self, block_constructor, bottleneck_channels_setting=None, identity_layers_setting=None,
|
84 |
+
stage_stride_setting=None, fc_num_classes=None, feature_idx=None):
|
85 |
+
|
86 |
+
if bottleneck_channels_setting is None:
|
87 |
+
bottleneck_channels_setting = [64, 128, 256, 512]
|
88 |
+
if identity_layers_setting is None:
|
89 |
+
identity_layers_setting = [3, 4, 6, 3]
|
90 |
+
if stage_stride_setting is None:
|
91 |
+
stage_stride_setting = [1, 2, 2, 2]
|
92 |
+
|
93 |
+
self.inplane = 64
|
94 |
+
self.fc_num_classes = fc_num_classes
|
95 |
+
self.feature_idx = feature_idx
|
96 |
+
|
97 |
+
super(Hybrid_backbone_4, self).__init__()
|
98 |
+
|
99 |
+
self.block_constructor = block_constructor # Bottleneck_block_constructor
|
100 |
+
self.bcs = bottleneck_channels_setting # [64, 128, 256, 512]
|
101 |
+
self.ils = identity_layers_setting # [3, 4, 6, 3]
|
102 |
+
self.sss = stage_stride_setting # [1, 2, 2, 2]
|
103 |
+
|
104 |
+
# stem
|
105 |
+
# alter the RGB pic chanel to match inplane
|
106 |
+
self.conv1 = nn.Conv2d(3, self.inplane, kernel_size=7, stride=2, padding=3, bias=False)
|
107 |
+
self.bn1 = nn.BatchNorm2d(self.inplane)
|
108 |
+
self.relu = nn.ReLU()
|
109 |
+
self.maxpool = nn.MaxPool2d(kernel_size=3, padding=1, stride=2)
|
110 |
+
|
111 |
+
# ResNet stages
|
112 |
+
self.layer1 = self.make_stage_layer(self.block_constructor, self.bcs[0], self.ils[0], self.sss[0])
|
113 |
+
self.layer2 = self.make_stage_layer(self.block_constructor, self.bcs[1], self.ils[1], self.sss[1])
|
114 |
+
self.layer3 = self.make_stage_layer(self.block_constructor, self.bcs[2], self.ils[2], self.sss[2])
|
115 |
+
self.layer4 = self.make_stage_layer(self.block_constructor, self.bcs[3], self.ils[3], self.sss[3])
|
116 |
+
|
117 |
+
# cls head
|
118 |
+
if self.fc_num_classes is not None:
|
119 |
+
self.avgpool = nn.AvgPool2d(7)
|
120 |
+
self.fc = nn.Linear(512 * self.block_constructor.extention, fc_num_classes)
|
121 |
+
|
122 |
+
def forward(self, x):
|
123 |
+
|
124 |
+
# stem
|
125 |
+
out = self.conv1(x)
|
126 |
+
out = self.bn1(out)
|
127 |
+
out = self.relu(out)
|
128 |
+
stem_out = self.maxpool(out)
|
129 |
+
|
130 |
+
# Resnet block of 4 stages
|
131 |
+
stage1_out = self.layer1(stem_out)
|
132 |
+
stage2_out = self.layer2(stage1_out)
|
133 |
+
stage3_out = self.layer3(stage2_out)
|
134 |
+
stage4_out = self.layer4(stage3_out)
|
135 |
+
|
136 |
+
if self.fc_num_classes is not None:
|
137 |
+
# connect to cls head mlp if asked
|
138 |
+
fc_out = self.avgpool(stage4_out)
|
139 |
+
fc_out = torch.flatten(fc_out, 1)
|
140 |
+
fc_out = self.fc(fc_out)
|
141 |
+
|
142 |
+
# get what we need for different usage
|
143 |
+
if self.feature_idx == 'stages':
|
144 |
+
if self.fc_num_classes is not None:
|
145 |
+
return stage1_out, stage2_out, stage3_out, stage4_out, fc_out
|
146 |
+
else:
|
147 |
+
return stage1_out, stage2_out, stage3_out, stage4_out
|
148 |
+
elif self.feature_idx == 'features':
|
149 |
+
if self.fc_num_classes is not None:
|
150 |
+
return stem_out, stage1_out, stage2_out, stage3_out, stage4_out, fc_out
|
151 |
+
else:
|
152 |
+
return stem_out, stage1_out, stage2_out, stage3_out, stage4_out
|
153 |
+
else: # self.feature_idx is None
|
154 |
+
if self.fc_num_classes is not None:
|
155 |
+
return fc_out
|
156 |
+
else:
|
157 |
+
return stage4_out
|
158 |
+
|
159 |
+
def make_stage_layer(self, block_constractor, midplane, block_num, stride=1):
|
160 |
+
"""
|
161 |
+
block:
|
162 |
+
midplane:usually = output chanel/4
|
163 |
+
block_num:
|
164 |
+
stride:stride of the ResNet Conv Block
|
165 |
+
"""
|
166 |
+
|
167 |
+
block_list = []
|
168 |
+
|
169 |
+
outplane = midplane * block_constractor.extention # extention
|
170 |
+
|
171 |
+
if stride != 1 or self.inplane != outplane:
|
172 |
+
downsample = nn.Sequential(
|
173 |
+
nn.Conv2d(self.inplane, outplane, stride=stride, kernel_size=1, bias=False),
|
174 |
+
nn.BatchNorm2d(midplane * block_constractor.extention)
|
175 |
+
)
|
176 |
+
else:
|
177 |
+
downsample = None
|
178 |
+
|
179 |
+
# Conv Block
|
180 |
+
conv_block = block_constractor(self.inplane, midplane, stride=stride, downsample=downsample)
|
181 |
+
block_list.append(conv_block)
|
182 |
+
|
183 |
+
self.inplane = outplane # update inplane for the next stage
|
184 |
+
|
185 |
+
# Identity Block
|
186 |
+
for i in range(1, block_num):
|
187 |
+
block_list.append(block_constractor(self.inplane, midplane, stride=1, downsample=None))
|
188 |
+
|
189 |
+
return nn.Sequential(*block_list) # stack blocks
|
190 |
+
|
191 |
+
|
192 |
+
class Hybrid_backbone_3(nn.Module): # 3 stages version
|
193 |
+
|
194 |
+
def __init__(self, block_constructor, bottleneck_channels_setting=None, identity_layers_setting=None,
|
195 |
+
stage_stride_setting=None, fc_num_classes=None, feature_idx=None):
|
196 |
+
|
197 |
+
if bottleneck_channels_setting is None:
|
198 |
+
bottleneck_channels_setting = [64, 128, 256]
|
199 |
+
if identity_layers_setting is None:
|
200 |
+
identity_layers_setting = [3, 4, 6]
|
201 |
+
if stage_stride_setting is None:
|
202 |
+
stage_stride_setting = [1, 2, 2]
|
203 |
+
|
204 |
+
self.inplane = 64
|
205 |
+
self.fc_num_classes = fc_num_classes
|
206 |
+
self.feature_idx = feature_idx
|
207 |
+
|
208 |
+
super(Hybrid_backbone_3, self).__init__()
|
209 |
+
|
210 |
+
self.block_constructor = block_constructor # Bottleneck_block_constructor
|
211 |
+
self.bcs = bottleneck_channels_setting # [64, 128, 256]
|
212 |
+
self.ils = identity_layers_setting # [3, 4, 6]
|
213 |
+
self.sss = stage_stride_setting # [1, 2, 2]
|
214 |
+
|
215 |
+
# stem
|
216 |
+
self.conv1 = nn.Conv2d(3, self.inplane, kernel_size=7, stride=2, padding=3, bias=False)
|
217 |
+
self.bn1 = nn.BatchNorm2d(self.inplane)
|
218 |
+
self.relu = nn.ReLU()
|
219 |
+
self.maxpool = nn.MaxPool2d(kernel_size=3, padding=1, stride=2)
|
220 |
+
|
221 |
+
# ResNet 3 stages
|
222 |
+
self.layer1 = self.make_stage_layer(self.block_constructor, self.bcs[0], self.ils[0], self.sss[0])
|
223 |
+
self.layer2 = self.make_stage_layer(self.block_constructor, self.bcs[1], self.ils[1], self.sss[1])
|
224 |
+
self.layer3 = self.make_stage_layer(self.block_constructor, self.bcs[2], self.ils[2], self.sss[2])
|
225 |
+
|
226 |
+
if self.fc_num_classes is not None:
|
227 |
+
self.avgpool = nn.AvgPool2d(24) # 224-14 384-24
|
228 |
+
self.fc = nn.Linear(self.bcs[-1] * self.block_constructor.extention, fc_num_classes)
|
229 |
+
|
230 |
+
def forward(self, x):
|
231 |
+
# stem:conv+bn+relu+maxpool
|
232 |
+
out = self.conv1(x)
|
233 |
+
out = self.bn1(out)
|
234 |
+
out = self.relu(out)
|
235 |
+
stem_out = self.maxpool(out)
|
236 |
+
|
237 |
+
# Resnet 3 stages
|
238 |
+
stage1_out = self.layer1(stem_out)
|
239 |
+
stage2_out = self.layer2(stage1_out)
|
240 |
+
stage3_out = self.layer3(stage2_out)
|
241 |
+
|
242 |
+
if self.fc_num_classes is not None:
|
243 |
+
fc_out = self.avgpool(stage3_out)
|
244 |
+
fc_out = torch.flatten(fc_out, 1)
|
245 |
+
fc_out = self.fc(fc_out)
|
246 |
+
|
247 |
+
if self.feature_idx == 'stages':
|
248 |
+
if self.fc_num_classes is not None:
|
249 |
+
return stage1_out, stage2_out, stage3_out, fc_out
|
250 |
+
else:
|
251 |
+
return stage1_out, stage2_out, stage3_out
|
252 |
+
elif self.feature_idx == 'features':
|
253 |
+
if self.fc_num_classes is not None:
|
254 |
+
return stem_out, stage1_out, stage2_out, stage3_out, fc_out
|
255 |
+
else:
|
256 |
+
return stem_out, stage1_out, stage2_out, stage3_out
|
257 |
+
else: # self.feature_idx is None
|
258 |
+
if self.fc_num_classes is not None:
|
259 |
+
return fc_out
|
260 |
+
else:
|
261 |
+
return stage3_out
|
262 |
+
|
263 |
+
def make_stage_layer(self, block_constractor, midplane, block_num, stride=1):
|
264 |
+
"""
|
265 |
+
block:
|
266 |
+
midplane:
|
267 |
+
block_num:
|
268 |
+
stride:
|
269 |
+
"""
|
270 |
+
|
271 |
+
block_list = []
|
272 |
+
|
273 |
+
outplane = midplane * block_constractor.extention # extention
|
274 |
+
|
275 |
+
if stride != 1 or self.inplane != outplane:
|
276 |
+
downsample = nn.Sequential(
|
277 |
+
nn.Conv2d(self.inplane, outplane, stride=stride, kernel_size=1, bias=False),
|
278 |
+
nn.BatchNorm2d(midplane * block_constractor.extention)
|
279 |
+
)
|
280 |
+
else:
|
281 |
+
downsample = None
|
282 |
+
|
283 |
+
# Conv Block
|
284 |
+
conv_block = block_constractor(self.inplane, midplane, stride=stride, downsample=downsample)
|
285 |
+
block_list.append(conv_block)
|
286 |
+
|
287 |
+
self.inplane = outplane
|
288 |
+
|
289 |
+
# Identity Block
|
290 |
+
for i in range(1, block_num):
|
291 |
+
block_list.append(block_constractor(self.inplane, midplane, stride=1, downsample=None))
|
292 |
+
|
293 |
+
return nn.Sequential(*block_list)
|
294 |
+
|
295 |
+
|
296 |
+
def Hybrid_a(backbone, img_size=224, patch_size=1, in_chans=3, num_classes=1000, embed_dim=768, depth=8,
|
297 |
+
num_heads=12, mlp_ratio=4., qkv_bias=True, representation_size=None, drop_rate=0., attn_drop_rate=0.,
|
298 |
+
drop_path_rate=0., norm_layer=None, act_layer=None):
|
299 |
+
# directly stack CNNs and Transformer blocks
|
300 |
+
embed_layer = partial(Transformer_blocks.Hybrid_feature_map_Embed, backbone=backbone)
|
301 |
+
|
302 |
+
Hybrid_model = Transformer_blocks.VisionTransformer(img_size, patch_size, in_chans, num_classes, embed_dim, depth,
|
303 |
+
num_heads, mlp_ratio, qkv_bias, representation_size,
|
304 |
+
drop_rate, attn_drop_rate, drop_path_rate, embed_layer,
|
305 |
+
norm_layer, act_layer)
|
306 |
+
|
307 |
+
return Hybrid_model
|
308 |
+
|
309 |
+
|
310 |
+
def create_model(model_idx, edge_size, pretrained=True, num_classes=2, drop_rate=0., attn_drop_rate=0.,
|
311 |
+
drop_path_rate=0., use_cls_token=True, use_pos_embedding=True, use_att_module='SimAM'):
|
312 |
+
"""
|
313 |
+
get one of MSHT models
|
314 |
+
|
315 |
+
:param model_idx: the model we are going to use. by the format of Model_size_other_info
|
316 |
+
:param edge_size: the input edge size of the dataloder
|
317 |
+
:param pretrained: The backbone CNN is initiate randomly or by its official Pretrained models
|
318 |
+
:param num_classes: classification required number of your dataset
|
319 |
+
|
320 |
+
:param drop_rate: The dropout layer's probility of proposed models
|
321 |
+
:param attn_drop_rate: The dropout layer(right after the MHSA block or MHGA block)'s probility of proposed models
|
322 |
+
:param drop_path_rate: The probility of stochastic depth
|
323 |
+
|
324 |
+
:param use_cls_token: To use the class token
|
325 |
+
:param use_pos_embedding: To use the positional enbedding
|
326 |
+
:param use_att_module: To use which attention module in the FGD Focus block
|
327 |
+
# use_att_module in ['SimAM', 'CBAM', 'SE'] different attention module we applied in the ablation study
|
328 |
+
|
329 |
+
:return: prepared model
|
330 |
+
"""
|
331 |
+
|
332 |
+
if pretrained:
|
333 |
+
from torchvision import models
|
334 |
+
backbone_weights = models.resnet50(pretrained=True).state_dict()
|
335 |
+
# True for pretrained Resnet50 model, False will randomly initiate
|
336 |
+
else:
|
337 |
+
backbone_weights = None
|
338 |
+
|
339 |
+
if model_idx[0:11] == 'Hybrid1_224' and edge_size == 224: # ablation study: no focus depth=8 edge_size == 224
|
340 |
+
backbone = Hybrid_backbone_4(block_constructor=Bottleneck_block_constructor,
|
341 |
+
bottleneck_channels_setting=[64, 128, 256, 512],
|
342 |
+
identity_layers_setting=[3, 4, 6, 3],
|
343 |
+
stage_stride_setting=[1, 2, 2, 2],
|
344 |
+
fc_num_classes=None,
|
345 |
+
feature_idx=None)
|
346 |
+
|
347 |
+
if pretrained:
|
348 |
+
try:
|
349 |
+
backbone.load_state_dict(backbone_weights, False)
|
350 |
+
except:
|
351 |
+
print("backbone not loaded")
|
352 |
+
else:
|
353 |
+
print("backbone loaded")
|
354 |
+
|
355 |
+
model = Hybrid_a(backbone, img_size=edge_size, patch_size=1, in_chans=3, num_classes=num_classes, embed_dim=768,
|
356 |
+
depth=8, num_heads=12, mlp_ratio=4., qkv_bias=True, representation_size=None,
|
357 |
+
drop_rate=drop_rate, attn_drop_rate=attn_drop_rate, drop_path_rate=drop_path_rate,
|
358 |
+
norm_layer=None, act_layer=None)
|
359 |
+
|
360 |
+
elif model_idx[0:11] == 'Hybrid1_384' and edge_size == 384: # ablation study: no focus depth=8 edge_size == 384
|
361 |
+
backbone = Hybrid_backbone_4(block_constructor=Bottleneck_block_constructor,
|
362 |
+
bottleneck_channels_setting=[64, 128, 256, 512],
|
363 |
+
identity_layers_setting=[3, 4, 6, 3],
|
364 |
+
stage_stride_setting=[1, 2, 2, 2],
|
365 |
+
fc_num_classes=None,
|
366 |
+
feature_idx=None)
|
367 |
+
|
368 |
+
if pretrained:
|
369 |
+
try:
|
370 |
+
backbone.load_state_dict(backbone_weights, False)
|
371 |
+
except:
|
372 |
+
print("backbone not loaded")
|
373 |
+
else:
|
374 |
+
print("backbone loaded")
|
375 |
+
|
376 |
+
model = Hybrid_a(backbone, img_size=edge_size, patch_size=1, in_chans=3, num_classes=num_classes, embed_dim=768,
|
377 |
+
depth=8, num_heads=12, mlp_ratio=4., qkv_bias=True, representation_size=None,
|
378 |
+
drop_rate=drop_rate, attn_drop_rate=attn_drop_rate, drop_path_rate=drop_path_rate,
|
379 |
+
norm_layer=None, act_layer=None)
|
380 |
+
|
381 |
+
elif model_idx[0:11] == 'Hybrid2_224' and edge_size == 224: # Proposed model ablation study: edge_size==224
|
382 |
+
backbone = Hybrid_backbone_4(block_constructor=Bottleneck_block_constructor,
|
383 |
+
bottleneck_channels_setting=[64, 128, 256, 512],
|
384 |
+
identity_layers_setting=[3, 4, 6, 3],
|
385 |
+
stage_stride_setting=[1, 2, 2, 2],
|
386 |
+
fc_num_classes=None,
|
387 |
+
feature_idx='stages')
|
388 |
+
if pretrained:
|
389 |
+
try:
|
390 |
+
backbone.load_state_dict(backbone_weights, False)
|
391 |
+
except:
|
392 |
+
print("backbone not loaded")
|
393 |
+
else:
|
394 |
+
print("backbone loaded")
|
395 |
+
|
396 |
+
model = Transformer_blocks.Stage_wise_hybrid_Transformer(backbone, num_classes=num_classes,
|
397 |
+
drop_rate=drop_rate, attn_drop_rate=attn_drop_rate,
|
398 |
+
drop_path_rate=drop_path_rate,
|
399 |
+
use_cls_token=use_cls_token,
|
400 |
+
use_pos_embedding=use_pos_embedding,
|
401 |
+
use_att_module=use_att_module,
|
402 |
+
stage_size=(56, 28, 14, 7),
|
403 |
+
stage_dim=[256, 512, 1024, 2048])
|
404 |
+
|
405 |
+
elif model_idx[0:11] == 'Hybrid2_384' and edge_size == 384: # Proposed model 384 !!!
|
406 |
+
backbone = Hybrid_backbone_4(block_constructor=Bottleneck_block_constructor,
|
407 |
+
bottleneck_channels_setting=[64, 128, 256, 512],
|
408 |
+
identity_layers_setting=[3, 4, 6, 3],
|
409 |
+
stage_stride_setting=[1, 2, 2, 2],
|
410 |
+
fc_num_classes=None,
|
411 |
+
feature_idx='stages')
|
412 |
+
if pretrained:
|
413 |
+
try:
|
414 |
+
backbone.load_state_dict(backbone_weights, False)
|
415 |
+
except:
|
416 |
+
print("backbone not loaded")
|
417 |
+
else:
|
418 |
+
print("backbone loaded")
|
419 |
+
|
420 |
+
model = Transformer_blocks.Stage_wise_hybrid_Transformer(backbone, num_classes=num_classes,
|
421 |
+
drop_rate=drop_rate, attn_drop_rate=attn_drop_rate,
|
422 |
+
drop_path_rate=drop_path_rate,
|
423 |
+
use_cls_token=use_cls_token,
|
424 |
+
use_pos_embedding=use_pos_embedding,
|
425 |
+
use_att_module=use_att_module,
|
426 |
+
stage_size=(96, 48, 24, 12),
|
427 |
+
stage_dim=[256, 512, 1024, 2048])
|
428 |
+
|
429 |
+
elif model_idx[0:11] == 'Hybrid3_224' and edge_size == 224: # Proposed model ablation study: edge_size==224
|
430 |
+
backbone = Hybrid_backbone_3(block_constructor=Bottleneck_block_constructor,
|
431 |
+
bottleneck_channels_setting=[64, 128, 256],
|
432 |
+
identity_layers_setting=[3, 4, 6],
|
433 |
+
stage_stride_setting=[1, 2, 2],
|
434 |
+
fc_num_classes=None,
|
435 |
+
feature_idx='stages')
|
436 |
+
if pretrained:
|
437 |
+
try:
|
438 |
+
backbone.load_state_dict(backbone_weights, False)
|
439 |
+
except:
|
440 |
+
print("backbone not loaded")
|
441 |
+
else:
|
442 |
+
print("backbone loaded")
|
443 |
+
|
444 |
+
model = Transformer_blocks.Stage_wise_hybrid_Transformer(backbone, num_classes=num_classes,
|
445 |
+
drop_rate=drop_rate, attn_drop_rate=attn_drop_rate,
|
446 |
+
drop_path_rate=drop_path_rate,
|
447 |
+
use_cls_token=use_cls_token,
|
448 |
+
use_pos_embedding=use_pos_embedding,
|
449 |
+
use_att_module=use_att_module,
|
450 |
+
stage_size=(56, 28, 14),
|
451 |
+
stage_dim=[256, 512, 1024])
|
452 |
+
|
453 |
+
elif model_idx[0:11] == 'Hybrid3_384' and edge_size == 384: # Proposed model 384 !!!
|
454 |
+
backbone = Hybrid_backbone_3(block_constructor=Bottleneck_block_constructor,
|
455 |
+
bottleneck_channels_setting=[64, 128, 256],
|
456 |
+
identity_layers_setting=[3, 4, 6],
|
457 |
+
stage_stride_setting=[1, 2, 2],
|
458 |
+
fc_num_classes=None,
|
459 |
+
feature_idx='stages')
|
460 |
+
if pretrained:
|
461 |
+
try:
|
462 |
+
backbone.load_state_dict(backbone_weights, False)
|
463 |
+
except:
|
464 |
+
print("backbone not loaded")
|
465 |
+
else:
|
466 |
+
print("backbone loaded")
|
467 |
+
|
468 |
+
model = Transformer_blocks.Stage_wise_hybrid_Transformer(backbone, num_classes=num_classes,
|
469 |
+
drop_rate=drop_rate, attn_drop_rate=attn_drop_rate,
|
470 |
+
drop_path_rate=drop_path_rate,
|
471 |
+
use_cls_token=use_cls_token,
|
472 |
+
use_pos_embedding=use_pos_embedding,
|
473 |
+
use_att_module=use_att_module,
|
474 |
+
stage_size=(96, 48, 24),
|
475 |
+
stage_dim=[256, 512, 1024])
|
476 |
+
|
477 |
+
else:
|
478 |
+
print('not a valid hybrid model')
|
479 |
+
return -1
|
480 |
+
|
481 |
+
return model
|
PuzzleTuning/Backbone/Transformer_blocks.py
ADDED
@@ -0,0 +1,1631 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Transformer blocks script ver: OCT 28th 15:00
|
3 |
+
|
4 |
+
bug fix: 'Cross-attn' name is used in MHGA for compareability
|
5 |
+
|
6 |
+
by the authors, check our github page:
|
7 |
+
https://github.com/sagizty/Multi-Stage-Hybrid-Transformer
|
8 |
+
|
9 |
+
based on:timm
|
10 |
+
https://www.freeaihub.com/post/94067.html
|
11 |
+
|
12 |
+
"""
|
13 |
+
|
14 |
+
import math
|
15 |
+
import logging
|
16 |
+
from functools import partial
|
17 |
+
from collections import OrderedDict
|
18 |
+
|
19 |
+
import torch
|
20 |
+
import torch.nn as nn
|
21 |
+
import torch.nn.functional as F
|
22 |
+
|
23 |
+
import numpy as np
|
24 |
+
|
25 |
+
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
26 |
+
|
27 |
+
from timm.models.layers import StdConv2dSame, DropPath, to_2tuple, trunc_normal_
|
28 |
+
|
29 |
+
from .attention_modules import simam_module, cbam_module, se_module
|
30 |
+
|
31 |
+
|
32 |
+
class FFN(nn.Module): # Mlp from timm
|
33 |
+
"""
|
34 |
+
FFN (from timm)
|
35 |
+
|
36 |
+
:param in_features:
|
37 |
+
:param hidden_features:
|
38 |
+
:param out_features:
|
39 |
+
:param act_layer:
|
40 |
+
:param drop:
|
41 |
+
"""
|
42 |
+
|
43 |
+
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
|
44 |
+
super().__init__()
|
45 |
+
|
46 |
+
out_features = out_features or in_features
|
47 |
+
hidden_features = hidden_features or in_features
|
48 |
+
|
49 |
+
self.fc1 = nn.Linear(in_features, hidden_features)
|
50 |
+
self.act = act_layer()
|
51 |
+
|
52 |
+
self.fc2 = nn.Linear(hidden_features, out_features)
|
53 |
+
self.drop = nn.Dropout(drop)
|
54 |
+
|
55 |
+
def forward(self, x):
|
56 |
+
x = self.fc1(x)
|
57 |
+
x = self.act(x)
|
58 |
+
x = self.drop(x)
|
59 |
+
|
60 |
+
x = self.fc2(x)
|
61 |
+
x = self.drop(x)
|
62 |
+
|
63 |
+
return x
|
64 |
+
|
65 |
+
|
66 |
+
class Attention(nn.Module): # qkv Transform + MSA(MHSA) (Attention from timm)
|
67 |
+
"""
|
68 |
+
qkv Transform + MSA(MHSA) (from timm)
|
69 |
+
|
70 |
+
# input x.shape = batch, patch_number, patch_dim
|
71 |
+
# output x.shape = batch, patch_number, patch_dim
|
72 |
+
|
73 |
+
:param dim: dim=CNN feature dim, because the patch size is 1x1
|
74 |
+
:param num_heads:
|
75 |
+
:param qkv_bias:
|
76 |
+
:param qk_scale: by default head_dim ** -0.5 (squre root)
|
77 |
+
:param attn_drop: dropout rate after MHSA
|
78 |
+
:param proj_drop:
|
79 |
+
|
80 |
+
"""
|
81 |
+
|
82 |
+
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
|
83 |
+
super().__init__()
|
84 |
+
self.num_heads = num_heads
|
85 |
+
head_dim = dim // num_heads
|
86 |
+
# NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
|
87 |
+
self.scale = qk_scale or head_dim ** -0.5
|
88 |
+
|
89 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
90 |
+
|
91 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
92 |
+
self.proj = nn.Linear(dim, dim)
|
93 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
94 |
+
|
95 |
+
def forward(self, x):
|
96 |
+
# input x.shape = batch, patch_number, patch_dim
|
97 |
+
batch, patch_number, patch_dim = x.shape
|
98 |
+
|
99 |
+
# mlp transform + head split [N, P, D] -> [N, P, 3D] -> [N, P, 3, H, D/H] -> [3, N, H, P, D/H]
|
100 |
+
qkv = self.qkv(x).reshape(batch, patch_number, 3, self.num_heads, patch_dim //
|
101 |
+
self.num_heads).permute(2, 0, 3, 1, 4)
|
102 |
+
# 3 [N, H, P, D/H]
|
103 |
+
q, k, v = qkv[0], qkv[1], qkv[2]
|
104 |
+
|
105 |
+
# [N, H, P, D/H] -> [N, H, P, D/H]
|
106 |
+
attn = (q @ k.transpose(-2, -1)) * self.scale
|
107 |
+
attn = attn.softmax(dim=-1)
|
108 |
+
|
109 |
+
attn = self.attn_drop(attn) # Dropout
|
110 |
+
|
111 |
+
# head fusion [N, H, P, D/H] -> [N, P, H, D/H] -> [N, P, D]
|
112 |
+
x = (attn @ v).transpose(1, 2).reshape(batch, patch_number, patch_dim)
|
113 |
+
|
114 |
+
x = self.proj(x)
|
115 |
+
x = self.proj_drop(x) # mlp
|
116 |
+
|
117 |
+
# output x.shape = batch, patch_number, patch_dim
|
118 |
+
return x
|
119 |
+
|
120 |
+
|
121 |
+
class Encoder_Block(nn.Module): # teansformer Block from timm
|
122 |
+
|
123 |
+
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
|
124 |
+
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
|
125 |
+
"""
|
126 |
+
# input x.shape = batch, patch_number, patch_dim
|
127 |
+
# output x.shape = batch, patch_number, patch_dim
|
128 |
+
|
129 |
+
:param dim: dim
|
130 |
+
:param num_heads:
|
131 |
+
:param mlp_ratio: FFN
|
132 |
+
:param qkv_bias:
|
133 |
+
:param qk_scale: by default head_dim ** -0.5 (squre root)
|
134 |
+
:param drop:
|
135 |
+
:param attn_drop: dropout rate after Attention
|
136 |
+
:param drop_path: dropout rate after sd
|
137 |
+
:param act_layer: FFN act
|
138 |
+
:param norm_layer: Pre Norm
|
139 |
+
"""
|
140 |
+
super().__init__()
|
141 |
+
# Pre Norm
|
142 |
+
self.norm1 = norm_layer(dim) # Transformer used the nn.LayerNorm
|
143 |
+
self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop,
|
144 |
+
proj_drop=drop)
|
145 |
+
# NOTE from timm: drop path for stochastic depth, we shall see if this is better than dropout here
|
146 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() # stochastic depth
|
147 |
+
|
148 |
+
# Add & Norm
|
149 |
+
self.norm2 = norm_layer(dim)
|
150 |
+
|
151 |
+
# FFN
|
152 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
153 |
+
self.mlp = FFN(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
|
154 |
+
|
155 |
+
def forward(self, x):
|
156 |
+
x = x + self.drop_path(self.attn(self.norm1(x)))
|
157 |
+
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
158 |
+
return x
|
159 |
+
|
160 |
+
|
161 |
+
class Guided_Attention(nn.Module): # q1 k1 v0 Transform + MSA(MHSA) (based on timm Attention)
|
162 |
+
"""
|
163 |
+
notice the q abd k is guided information from Focus module
|
164 |
+
qkv Transform + MSA(MHSA) (from timm)
|
165 |
+
|
166 |
+
# 3 input of x.shape = batch, patch_number, patch_dim
|
167 |
+
# 1 output of x.shape = batch, patch_number, patch_dim
|
168 |
+
|
169 |
+
:param dim: dim = CNN feature dim, because the patch size is 1x1
|
170 |
+
:param num_heads:
|
171 |
+
:param qkv_bias:
|
172 |
+
:param qk_scale: by default head_dim ** -0.5 (squre root)
|
173 |
+
:param attn_drop:
|
174 |
+
:param proj_drop:
|
175 |
+
|
176 |
+
"""
|
177 |
+
|
178 |
+
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
|
179 |
+
super().__init__()
|
180 |
+
self.num_heads = num_heads
|
181 |
+
head_dim = dim // num_heads
|
182 |
+
self.scale = qk_scale or head_dim ** -0.5
|
183 |
+
|
184 |
+
self.qT = nn.Linear(dim, dim, bias=qkv_bias)
|
185 |
+
self.kT = nn.Linear(dim, dim, bias=qkv_bias)
|
186 |
+
self.vT = nn.Linear(dim, dim, bias=qkv_bias)
|
187 |
+
|
188 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
189 |
+
|
190 |
+
self.proj = nn.Linear(dim, dim)
|
191 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
192 |
+
|
193 |
+
def forward(self, q_encoder, k_encoder, v_input):
|
194 |
+
# 3 input of x.shape = batch, patch_number, patch_dim
|
195 |
+
batch, patch_number, patch_dim = v_input.shape
|
196 |
+
|
197 |
+
q = self.qT(q_encoder).reshape(batch, patch_number, 1, self.num_heads,
|
198 |
+
patch_dim // self.num_heads).permute(2, 0, 3, 1, 4)
|
199 |
+
k = self.kT(k_encoder).reshape(batch, patch_number, 1, self.num_heads,
|
200 |
+
patch_dim // self.num_heads).permute(2, 0, 3, 1, 4)
|
201 |
+
v = self.vT(v_input).reshape(batch, patch_number, 1, self.num_heads,
|
202 |
+
patch_dim // self.num_heads).permute(2, 0, 3, 1, 4)
|
203 |
+
q = q[0]
|
204 |
+
k = k[0]
|
205 |
+
v = v[0]
|
206 |
+
|
207 |
+
attn = (q @ k.transpose(-2, -1)) * self.scale
|
208 |
+
attn = attn.softmax(dim=-1)
|
209 |
+
|
210 |
+
attn = self.attn_drop(attn) # Dropout
|
211 |
+
|
212 |
+
x = (attn @ v).transpose(1, 2).reshape(batch, patch_number, patch_dim)
|
213 |
+
|
214 |
+
x = self.proj(x)
|
215 |
+
x = self.proj_drop(x) # mlp Dropout
|
216 |
+
|
217 |
+
# output of x.shape = batch, patch_number, patch_dim
|
218 |
+
return x
|
219 |
+
|
220 |
+
|
221 |
+
class Decoder_Block(nn.Module):
|
222 |
+
# FGD Decoder (Transformer encoder + Guided Attention block block)
|
223 |
+
def __init__(self, dim, num_heads=8, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
|
224 |
+
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
|
225 |
+
"""
|
226 |
+
# input x.shape = batch, patch_number, patch_dim
|
227 |
+
# output x.shape = batch, patch_number, patch_dim
|
228 |
+
|
229 |
+
:param dim: dim=CNN feature dim, because the patch size is 1x1
|
230 |
+
:param num_heads: multi-head
|
231 |
+
:param mlp_ratio: FFN expand ratio
|
232 |
+
:param qkv_bias: qkv MLP bias
|
233 |
+
:param qk_scale: by default head_dim ** -0.5 (squre root)
|
234 |
+
:param drop: the MLP after MHSA equipt a dropout rate
|
235 |
+
:param attn_drop: dropout rate after attention block
|
236 |
+
:param drop_path: dropout rate for stochastic depth
|
237 |
+
:param act_layer: FFN act
|
238 |
+
:param norm_layer: Pre Norm strategy with norm layer
|
239 |
+
"""
|
240 |
+
super().__init__()
|
241 |
+
# Pre Norm
|
242 |
+
self.norm0 = norm_layer(dim) # nn.LayerNorm
|
243 |
+
self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop,
|
244 |
+
proj_drop=drop)
|
245 |
+
# stochastic depth
|
246 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
247 |
+
|
248 |
+
# Pre Norm
|
249 |
+
self.norm1 = norm_layer(dim)
|
250 |
+
|
251 |
+
# FFN1
|
252 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
253 |
+
self.FFN1 = FFN(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
|
254 |
+
|
255 |
+
# Guided_Attention
|
256 |
+
self.Cross_attn = Guided_Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
|
257 |
+
attn_drop=attn_drop, proj_drop=drop)
|
258 |
+
|
259 |
+
# Add & Norm
|
260 |
+
self.norm2 = norm_layer(dim)
|
261 |
+
# FFN2
|
262 |
+
self.FFN2 = FFN(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
|
263 |
+
|
264 |
+
# Add & Norm
|
265 |
+
self.norm3 = norm_layer(dim)
|
266 |
+
|
267 |
+
def forward(self, q_encoder, k_encoder, v_input):
|
268 |
+
v_self = v_input + self.drop_path(self.attn(self.norm0(v_input)))
|
269 |
+
|
270 |
+
v_self = v_self + self.drop_path(self.FFN1(self.norm1(v_self)))
|
271 |
+
|
272 |
+
# norm layer for v only, the normalization of q and k is inside FGD Focus block
|
273 |
+
v_self = v_self + self.drop_path(self.Cross_attn(q_encoder, k_encoder, self.norm2(v_self)))
|
274 |
+
|
275 |
+
v_self = v_self + self.drop_path(self.FFN2(self.norm3(v_self)))
|
276 |
+
|
277 |
+
return v_self
|
278 |
+
|
279 |
+
|
280 |
+
'''
|
281 |
+
# testing example
|
282 |
+
|
283 |
+
model=Decoder_Block(dim=768)
|
284 |
+
k = torch.randn(7, 49, 768)
|
285 |
+
q = torch.randn(7, 49, 768)
|
286 |
+
v = torch.randn(7, 49, 768)
|
287 |
+
x = model(k,q,v)
|
288 |
+
print(x.shape)
|
289 |
+
'''
|
290 |
+
|
291 |
+
|
292 |
+
# MViT modules
|
293 |
+
# from https://github.com/facebookresearch/SlowFast/slowfast/models/attention.py
|
294 |
+
def attention_pool(tensor, pool, thw_shape, has_cls_embed=True, norm=None):
|
295 |
+
"""
|
296 |
+
attention pooling constructor
|
297 |
+
|
298 |
+
input:
|
299 |
+
tensor of (B, Head, N, C) or (B, N, C)
|
300 |
+
thw_shape: T, H, W 对应CNN的特征图形状(2D形状)T is video frams
|
301 |
+
|
302 |
+
numpy.prob(T, H, W) == N(Num_patches) - 1 (cls token if it is there)
|
303 |
+
|
304 |
+
output:
|
305 |
+
tensor of (B, Head, N_O, C) or (B, N_O, C)
|
306 |
+
thw_shape: T_O, H_O, W_O
|
307 |
+
|
308 |
+
:param tensor: input feature patches
|
309 |
+
:param pool: pooling/conv layer
|
310 |
+
:param thw_shape: reconstruction feature map shape
|
311 |
+
:param has_cls_embed: if cls token is used
|
312 |
+
:param norm: norm layer
|
313 |
+
|
314 |
+
"""
|
315 |
+
if pool is None: # no pool
|
316 |
+
return tensor, thw_shape
|
317 |
+
|
318 |
+
tensor_dim = tensor.ndim
|
319 |
+
|
320 |
+
# fix dim: [B, Head, N, C]
|
321 |
+
# N is Num_patches in Transformer modeling
|
322 |
+
|
323 |
+
if tensor_dim == 4:
|
324 |
+
pass
|
325 |
+
elif tensor_dim == 3: # [B, N, C] -> [B, Head(1), N, C]
|
326 |
+
tensor = tensor.unsqueeze(1)
|
327 |
+
else:
|
328 |
+
raise NotImplementedError(f"Unsupported input dimension {tensor.shape}")
|
329 |
+
|
330 |
+
if has_cls_embed:
|
331 |
+
cls_tok, tensor = tensor[:, :, :1, :], tensor[:, :, 1:, :]
|
332 |
+
|
333 |
+
B, Head, N, C = tensor.shape
|
334 |
+
T, H, W = thw_shape # numpy.prob(T, H, W) == N(Num_patches) - 1 (cls token if it is there)
|
335 |
+
|
336 |
+
# [B, Head, N, C] -> [B * Head, T, H, W, C] -> [B * Head, C, T, H, W]
|
337 |
+
tensor = (tensor.reshape(B * Head, T, H, W, C).permute(0, 4, 1, 2, 3).contiguous())
|
338 |
+
# use tensor.contiguous() to matain its memory location
|
339 |
+
|
340 |
+
# [B * Head, C, T, H, W] -> [B * Head, C, T_O, H_O, W_O]
|
341 |
+
tensor = pool(tensor) # 3D Pooling/ 3D Conv
|
342 |
+
|
343 |
+
# output T, H, W
|
344 |
+
thw_shape = [tensor.shape[2], tensor.shape[3], tensor.shape[4]]
|
345 |
+
# output Num_patches: numpy.prob(T, H, W)
|
346 |
+
N_pooled = tensor.shape[2] * tensor.shape[3] * tensor.shape[4]
|
347 |
+
|
348 |
+
# [B * Head, C, T_O, H_O, W_O] -> [B, Head, C, N_O(T_O*H_O*W_O)] -> [B, Head, N_O, C]
|
349 |
+
tensor = tensor.reshape(B, Head, C, N_pooled).transpose(2, 3)
|
350 |
+
|
351 |
+
if has_cls_embed:
|
352 |
+
# [B, Head, N_O, C] -> [B, Head, N_O+1(cls token), C]
|
353 |
+
tensor = torch.cat((cls_tok, tensor), dim=2)
|
354 |
+
|
355 |
+
# norm
|
356 |
+
if norm is not None:
|
357 |
+
tensor = norm(tensor)
|
358 |
+
|
359 |
+
# Assert tensor_dim in [3, 4]
|
360 |
+
if tensor_dim == 4: # [B, Head, N_O, C] multi-head
|
361 |
+
pass
|
362 |
+
else: # tensor_dim == 3: this is a single Head
|
363 |
+
tensor = tensor.squeeze(1) # [B, N_O, C]
|
364 |
+
|
365 |
+
return tensor, thw_shape
|
366 |
+
|
367 |
+
|
368 |
+
'''
|
369 |
+
# case 1 single-head no pooling scale
|
370 |
+
x = torch.randn(1, 197, 768)
|
371 |
+
thw_shape = [1, 14, 14]
|
372 |
+
pool = nn.MaxPool3d((1, 1, 1), (1, 1, 1), (0, 0, 0), ceil_mode=False)
|
373 |
+
y, thw = attention_pool(x, pool, thw_shape)
|
374 |
+
|
375 |
+
print(y.shape) # torch.Size([1, 197, 768])
|
376 |
+
print(thw) # [1, 14, 14]
|
377 |
+
|
378 |
+
|
379 |
+
# case 2 multi-head no pooling scale
|
380 |
+
x = torch.randn(1, 8, 197, 96) # [B, Head, N_O, C] multi-head
|
381 |
+
thw_shape = [1, 14, 14]
|
382 |
+
pool = nn.MaxPool3d((1, 1, 1), (1, 1, 1), (0, 0, 0), ceil_mode=False)
|
383 |
+
y, thw = attention_pool(x, pool, thw_shape)
|
384 |
+
|
385 |
+
print(y.shape) # torch.Size([1, 8, 197, 96])
|
386 |
+
print(thw) # [1, 14, 14]
|
387 |
+
|
388 |
+
|
389 |
+
# case 3 pooling scale
|
390 |
+
x = torch.randn(1, 197, 768)
|
391 |
+
thw_shape = [1, 14, 14]
|
392 |
+
pool = nn.MaxPool3d((1, 2, 2), (1, 2, 2), (0, 0, 0), ceil_mode=False)
|
393 |
+
y, thw = attention_pool(x, pool, thw_shape)
|
394 |
+
|
395 |
+
print(y.shape) # torch.Size([1, 50, 768])
|
396 |
+
print(thw) # [1, 7, 7]
|
397 |
+
|
398 |
+
|
399 |
+
# case 4 multi-head pooling scale
|
400 |
+
x = torch.randn(1, 8, 197, 96) # [B, Head, N_O, C] multi-head
|
401 |
+
thw_shape = [1, 14, 14]
|
402 |
+
pool = nn.MaxPool3d((1, 2, 2), (1, 2, 2), (0, 0, 0), ceil_mode=False)
|
403 |
+
y, thw = attention_pool(x, pool, thw_shape)
|
404 |
+
|
405 |
+
print(y.shape) # torch.Size([1, 8, 50, 96])
|
406 |
+
print(thw) # [1, 7, 7]
|
407 |
+
'''
|
408 |
+
|
409 |
+
|
410 |
+
class MultiScaleAttention(nn.Module): # Attention module
|
411 |
+
"""
|
412 |
+
Attention module constructor
|
413 |
+
|
414 |
+
input:
|
415 |
+
tensor of (B, N, C)
|
416 |
+
thw_shape: T, H, W 对应CNN的特征图形状(2D形状)T is video frams
|
417 |
+
|
418 |
+
numpy.prob(T, H, W) == N(Num_patches) - 1 (cls token if it is there)
|
419 |
+
|
420 |
+
output:
|
421 |
+
tensor of (B, N_O, C)
|
422 |
+
thw_shape: T_O, H_O, W_O
|
423 |
+
|
424 |
+
:param dim: Transformer feature dim
|
425 |
+
:param num_heads: Transformer heads
|
426 |
+
:param qkv_bias: projecting bias
|
427 |
+
:param drop_rate: dropout rate after attention calculation and mlp
|
428 |
+
|
429 |
+
:param kernel_q: pooling kernal size for q
|
430 |
+
:param kernel_kv: pooling kernal size for k and v
|
431 |
+
:param stride_q: pooling kernal stride for q
|
432 |
+
:param stride_kv: pooling kernal stride for k and v
|
433 |
+
|
434 |
+
:param norm_layer: norm layer
|
435 |
+
:param has_cls_embed: if cls token is used
|
436 |
+
:param mode: mode for attention pooling(downsampling) Options include `conv`, `avg`, and `max`.
|
437 |
+
:param pool_first: process pooling(downsampling) before liner projecting
|
438 |
+
|
439 |
+
"""
|
440 |
+
|
441 |
+
def __init__(
|
442 |
+
self,
|
443 |
+
dim,
|
444 |
+
num_heads=8,
|
445 |
+
qkv_bias=False,
|
446 |
+
drop_rate=0.0,
|
447 |
+
kernel_q=(1, 1, 1),
|
448 |
+
kernel_kv=(1, 1, 1),
|
449 |
+
stride_q=(1, 1, 1),
|
450 |
+
stride_kv=(1, 1, 1),
|
451 |
+
norm_layer=nn.LayerNorm,
|
452 |
+
has_cls_embed=True,
|
453 |
+
# Options include `conv`, `avg`, and `max`.
|
454 |
+
mode="conv",
|
455 |
+
# If True, perform pool before projection.
|
456 |
+
pool_first=False,
|
457 |
+
):
|
458 |
+
super().__init__()
|
459 |
+
|
460 |
+
self.pool_first = pool_first
|
461 |
+
self.drop_rate = drop_rate
|
462 |
+
self.num_heads = num_heads
|
463 |
+
head_dim = dim // num_heads
|
464 |
+
self.scale = head_dim ** -0.5 # squre root
|
465 |
+
self.has_cls_embed = has_cls_embed
|
466 |
+
|
467 |
+
padding_q = [int(q // 2) for q in kernel_q] # 以半个kernal size进行padding,向下取整
|
468 |
+
padding_kv = [int(kv // 2) for kv in kernel_kv]
|
469 |
+
|
470 |
+
# projecting mlp
|
471 |
+
self.q = nn.Linear(dim, dim, bias=qkv_bias)
|
472 |
+
self.k = nn.Linear(dim, dim, bias=qkv_bias)
|
473 |
+
self.v = nn.Linear(dim, dim, bias=qkv_bias)
|
474 |
+
self.proj = nn.Linear(dim, dim)
|
475 |
+
|
476 |
+
if drop_rate > 0.0:
|
477 |
+
self.proj_drop = nn.Dropout(drop_rate)
|
478 |
+
|
479 |
+
# Skip pooling with kernel and stride size of (1, 1, 1).
|
480 |
+
if np.prod(kernel_q) == 1 and np.prod(stride_q) == 1:
|
481 |
+
kernel_q = () # clear
|
482 |
+
if np.prod(kernel_kv) == 1 and np.prod(stride_kv) == 1:
|
483 |
+
kernel_kv = ()
|
484 |
+
|
485 |
+
if mode in ("avg", "max"): # use nn.MaxPool3d or nn.AvgPool3d
|
486 |
+
pool_op = nn.MaxPool3d if mode == "max" else nn.AvgPool3d
|
487 |
+
self.pool_q = (
|
488 |
+
pool_op(kernel_q, stride_q, padding_q, ceil_mode=False)
|
489 |
+
if len(kernel_q) > 0
|
490 |
+
else None # Skip pooling if kernel is cleared
|
491 |
+
)
|
492 |
+
self.pool_k = (
|
493 |
+
pool_op(kernel_kv, stride_kv, padding_kv, ceil_mode=False)
|
494 |
+
if len(kernel_kv) > 0
|
495 |
+
else None
|
496 |
+
)
|
497 |
+
self.pool_v = (
|
498 |
+
pool_op(kernel_kv, stride_kv, padding_kv, ceil_mode=False)
|
499 |
+
if len(kernel_kv) > 0
|
500 |
+
else None
|
501 |
+
)
|
502 |
+
|
503 |
+
elif mode == "conv": # use nn.Conv3d with depth wise conv and fixed channel setting
|
504 |
+
self.pool_q = (
|
505 |
+
nn.Conv3d(
|
506 |
+
head_dim,
|
507 |
+
head_dim,
|
508 |
+
kernel_q,
|
509 |
+
stride=stride_q,
|
510 |
+
padding=padding_q,
|
511 |
+
groups=head_dim,
|
512 |
+
bias=False,
|
513 |
+
)
|
514 |
+
if len(kernel_q) > 0
|
515 |
+
else None
|
516 |
+
)
|
517 |
+
self.norm_q = norm_layer(head_dim) if len(kernel_q) > 0 else None
|
518 |
+
|
519 |
+
self.pool_k = (
|
520 |
+
nn.Conv3d(
|
521 |
+
head_dim,
|
522 |
+
head_dim,
|
523 |
+
kernel_kv,
|
524 |
+
stride=stride_kv,
|
525 |
+
padding=padding_kv,
|
526 |
+
groups=head_dim,
|
527 |
+
bias=False,
|
528 |
+
)
|
529 |
+
if len(kernel_kv) > 0
|
530 |
+
else None
|
531 |
+
)
|
532 |
+
self.norm_k = norm_layer(head_dim) if len(kernel_kv) > 0 else None
|
533 |
+
|
534 |
+
self.pool_v = (
|
535 |
+
nn.Conv3d(
|
536 |
+
head_dim,
|
537 |
+
head_dim,
|
538 |
+
kernel_kv,
|
539 |
+
stride=stride_kv,
|
540 |
+
padding=padding_kv,
|
541 |
+
groups=head_dim,
|
542 |
+
bias=False,
|
543 |
+
)
|
544 |
+
if len(kernel_kv) > 0
|
545 |
+
else None
|
546 |
+
)
|
547 |
+
self.norm_v = norm_layer(head_dim) if len(kernel_kv) > 0 else None
|
548 |
+
else:
|
549 |
+
raise NotImplementedError(f"Unsupported model {mode}")
|
550 |
+
|
551 |
+
def forward(self, x, thw_shape):
|
552 |
+
"""
|
553 |
+
x: Transformer feature patches
|
554 |
+
thw_shape: reconstruction feature map shape
|
555 |
+
"""
|
556 |
+
|
557 |
+
B, N, C = x.shape
|
558 |
+
|
559 |
+
# step 1: duplicate projecting + head split: [B, N, C] -> [B, H, N, C/H]
|
560 |
+
|
561 |
+
if self.pool_first: # step a.1 embedding
|
562 |
+
# head split [B, N, C] -> [B, N, H, C/H] -> [B, H, N, C/H]
|
563 |
+
x = x.reshape(B, N, self.num_heads, C // self.num_heads).permute(
|
564 |
+
0, 2, 1, 3
|
565 |
+
)
|
566 |
+
q = k = v = x
|
567 |
+
|
568 |
+
else: # step b.1 projecting first
|
569 |
+
# mlp transform + head split: [B, N, C] -> [B, N, H, C/H] -> [B, H, N, C/H]
|
570 |
+
# todo 这里我觉得可能共享mlp映射更好,能有更好的交互,但是分离mlp更节约计算量
|
571 |
+
q = k = v = x
|
572 |
+
q = (
|
573 |
+
self.q(q)
|
574 |
+
.reshape(B, N, self.num_heads, C // self.num_heads)
|
575 |
+
.permute(0, 2, 1, 3)
|
576 |
+
)
|
577 |
+
k = (
|
578 |
+
self.k(k)
|
579 |
+
.reshape(B, N, self.num_heads, C // self.num_heads)
|
580 |
+
.permute(0, 2, 1, 3)
|
581 |
+
)
|
582 |
+
v = (
|
583 |
+
self.v(v)
|
584 |
+
.reshape(B, N, self.num_heads, C // self.num_heads)
|
585 |
+
.permute(0, 2, 1, 3)
|
586 |
+
)
|
587 |
+
|
588 |
+
# step 2: calculate attention_pool feature sequence and its shape
|
589 |
+
# [B, H, N0, C/H] -> [B, H, N1, C/H]
|
590 |
+
q, q_shape = attention_pool(
|
591 |
+
q,
|
592 |
+
self.pool_q,
|
593 |
+
thw_shape,
|
594 |
+
has_cls_embed=self.has_cls_embed,
|
595 |
+
norm=self.norm_q if hasattr(self, "norm_q") else None,
|
596 |
+
)
|
597 |
+
k, k_shape = attention_pool(
|
598 |
+
k,
|
599 |
+
self.pool_k,
|
600 |
+
thw_shape,
|
601 |
+
has_cls_embed=self.has_cls_embed,
|
602 |
+
norm=self.norm_k if hasattr(self, "norm_k") else None,
|
603 |
+
)
|
604 |
+
v, v_shape = attention_pool(
|
605 |
+
v,
|
606 |
+
self.pool_v,
|
607 |
+
thw_shape,
|
608 |
+
has_cls_embed=self.has_cls_embed,
|
609 |
+
norm=self.norm_v if hasattr(self, "norm_v") else None,
|
610 |
+
)
|
611 |
+
|
612 |
+
if self.pool_first: # step a.3 MLP projecting
|
613 |
+
# calculate patch number, q_N, k_N, v_N
|
614 |
+
q_N = (
|
615 |
+
np.prod(q_shape) + 1
|
616 |
+
if self.has_cls_embed
|
617 |
+
else np.prod(q_shape)
|
618 |
+
)
|
619 |
+
k_N = (
|
620 |
+
np.prod(k_shape) + 1
|
621 |
+
if self.has_cls_embed
|
622 |
+
else np.prod(k_shape)
|
623 |
+
)
|
624 |
+
v_N = (
|
625 |
+
np.prod(v_shape) + 1
|
626 |
+
if self.has_cls_embed
|
627 |
+
else np.prod(v_shape)
|
628 |
+
)
|
629 |
+
|
630 |
+
# [B, H, N1, C/H] -> [B, N1, H, C/H] -> [B, N1, C] -> MLP
|
631 |
+
# -> [B, N1, C] -> [B, N1, H, C/H] -> [B, H, N1, C/H]
|
632 |
+
q = q.permute(0, 2, 1, 3).reshape(B, q_N, C)
|
633 |
+
q = (
|
634 |
+
self.q(q)
|
635 |
+
.reshape(B, q_N, self.num_heads, C // self.num_heads)
|
636 |
+
.permute(0, 2, 1, 3)
|
637 |
+
)
|
638 |
+
|
639 |
+
v = v.permute(0, 2, 1, 3).reshape(B, v_N, C)
|
640 |
+
v = (
|
641 |
+
self.v(v)
|
642 |
+
.reshape(B, v_N, self.num_heads, C // self.num_heads)
|
643 |
+
.permute(0, 2, 1, 3)
|
644 |
+
)
|
645 |
+
|
646 |
+
k = k.permute(0, 2, 1, 3).reshape(B, k_N, C)
|
647 |
+
k = (
|
648 |
+
self.k(k)
|
649 |
+
.reshape(B, k_N, self.num_heads, C // self.num_heads)
|
650 |
+
.permute(0, 2, 1, 3)
|
651 |
+
)
|
652 |
+
|
653 |
+
# step 3: attention calculation
|
654 |
+
# multi-head self attention [B, H, N1, C/H] -> [B, H, N1, C/H]
|
655 |
+
attn = (q @ k.transpose(-2, -1)) * self.scale
|
656 |
+
attn = attn.softmax(dim=-1)
|
657 |
+
|
658 |
+
# head squeeze [B, H, N1, C/H] -> [B, N1, H, C/H] -> [B, N1, C]
|
659 |
+
N = q.shape[2]
|
660 |
+
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
661 |
+
|
662 |
+
# step 4: mlp stablization and dropout [B, N1, C] -> [B, N1, C]
|
663 |
+
x = self.proj(x)
|
664 |
+
if self.drop_rate > 0.0:
|
665 |
+
x = self.proj_drop(x)
|
666 |
+
|
667 |
+
return x, q_shape
|
668 |
+
|
669 |
+
|
670 |
+
'''
|
671 |
+
# case 1
|
672 |
+
model = MultiScaleAttention(768)
|
673 |
+
x = torch.randn(1, 197, 768)
|
674 |
+
y, thw = model(x, [1, 14, 14])
|
675 |
+
print(y.shape)
|
676 |
+
|
677 |
+
|
678 |
+
# case 2
|
679 |
+
kernel_q = (1, 2, 2)
|
680 |
+
kernel_kv = (1, 2, 2)
|
681 |
+
stride_q = (1, 2, 2)
|
682 |
+
stride_kv = (1, 2, 2)
|
683 |
+
# MultiScaleAttention 中设计以半个kernal size进行padding,向下取整
|
684 |
+
|
685 |
+
model = MultiScaleAttention(768, kernel_q=kernel_q, kernel_kv=kernel_kv, stride_q=stride_q, stride_kv=stride_kv)
|
686 |
+
x = torch.randn(1, 197, 768)
|
687 |
+
y, thw = model(x, [1, 14, 14])
|
688 |
+
|
689 |
+
print(y.shape) # 输出torch.Size([1, 65, 768]):不padding是7*7 由于padding变成8*8, 之后加上cls token
|
690 |
+
'''
|
691 |
+
|
692 |
+
|
693 |
+
class MultiScaleBlock(nn.Module): # MViT Encoder
|
694 |
+
"""
|
695 |
+
Attention module constructor
|
696 |
+
|
697 |
+
input:
|
698 |
+
tensor of (B, N, C)
|
699 |
+
thw_shape: T, H, W 对应CNN的特征图形状(2D形状)T is video frams
|
700 |
+
|
701 |
+
numpy.prob(T, H, W) == N(Num_patches) - 1 (cls token if it is there)
|
702 |
+
|
703 |
+
output:
|
704 |
+
tensor of (B, N_O, C)
|
705 |
+
thw_shape: T_O, H_O, W_O
|
706 |
+
|
707 |
+
:param dim: Transformer feature dim
|
708 |
+
:param dim_out:
|
709 |
+
|
710 |
+
:param num_heads: Transformer heads
|
711 |
+
:param mlp_ratio: FFN hidden expansion
|
712 |
+
:param qkv_bias: projecting bias
|
713 |
+
:param drop_rate: dropout rate after attention calculation and mlp
|
714 |
+
:param drop_path: dropout rate for SD
|
715 |
+
:param act_layer: FFN act
|
716 |
+
:param norm_layer: Pre Norm
|
717 |
+
|
718 |
+
:param up_rate:
|
719 |
+
:param kernel_q: pooling kernal size for q
|
720 |
+
:param kernel_kv: pooling kernal size for k and v
|
721 |
+
:param stride_q: pooling kernal stride for q
|
722 |
+
:param stride_kv: pooling kernal stride for k and v
|
723 |
+
|
724 |
+
:param has_cls_embed: if cls token is used
|
725 |
+
:param mode: mode for attention pooling(downsampling) Options include `conv`, `avg`, and `max`.
|
726 |
+
:param pool_first: process pooling(downsampling) before liner projecting
|
727 |
+
|
728 |
+
"""
|
729 |
+
|
730 |
+
def __init__(
|
731 |
+
self,
|
732 |
+
dim,
|
733 |
+
dim_out,
|
734 |
+
num_heads=8,
|
735 |
+
mlp_ratio=4.0,
|
736 |
+
qkv_bias=False,
|
737 |
+
drop_rate=0.0,
|
738 |
+
drop_path=0.0,
|
739 |
+
act_layer=nn.GELU,
|
740 |
+
norm_layer=nn.LayerNorm,
|
741 |
+
up_rate=None,
|
742 |
+
kernel_q=(1, 1, 1),
|
743 |
+
kernel_kv=(1, 1, 1),
|
744 |
+
stride_q=(1, 1, 1),
|
745 |
+
stride_kv=(1, 1, 1),
|
746 |
+
has_cls_embed=True,
|
747 |
+
mode="conv",
|
748 |
+
pool_first=False,
|
749 |
+
):
|
750 |
+
super().__init__()
|
751 |
+
|
752 |
+
self.has_cls_embed = has_cls_embed
|
753 |
+
|
754 |
+
# step 1: Attention projecting
|
755 |
+
self.dim = dim
|
756 |
+
self.dim_out = dim_out
|
757 |
+
self.norm1 = norm_layer(dim) # pre-norm
|
758 |
+
|
759 |
+
self.attn = MultiScaleAttention(
|
760 |
+
dim,
|
761 |
+
num_heads=num_heads,
|
762 |
+
qkv_bias=qkv_bias,
|
763 |
+
drop_rate=drop_rate,
|
764 |
+
kernel_q=kernel_q,
|
765 |
+
kernel_kv=kernel_kv,
|
766 |
+
stride_q=stride_q,
|
767 |
+
stride_kv=stride_kv,
|
768 |
+
norm_layer=nn.LayerNorm,
|
769 |
+
has_cls_embed=self.has_cls_embed,
|
770 |
+
mode=mode,
|
771 |
+
pool_first=pool_first,
|
772 |
+
)
|
773 |
+
|
774 |
+
self.drop_path = (DropPath(drop_path) if drop_path > 0.0 else nn.Identity())
|
775 |
+
|
776 |
+
# residual connection for Attention projecting
|
777 |
+
kernel_skip = kernel_q # fixme ori: [s + 1 if s > 1 else s for s in stride_q]
|
778 |
+
stride_skip = stride_q
|
779 |
+
padding_skip = [int(skip // 2) for skip in kernel_skip] # 以半个kernal size进行padding,向下取整
|
780 |
+
|
781 |
+
self.pool_skip = (
|
782 |
+
nn.MaxPool3d(kernel_skip, stride_skip, padding_skip, ceil_mode=False)
|
783 |
+
if len(kernel_skip) > 0
|
784 |
+
else None)
|
785 |
+
|
786 |
+
self.norm2 = norm_layer(dim) # pre-norm
|
787 |
+
|
788 |
+
# step 2: FFN projecting
|
789 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
790 |
+
|
791 |
+
# here use FFN to encode feature into abstractive information in the dimension
|
792 |
+
# TODO: check the use case for up_rate, and merge the following lines
|
793 |
+
if up_rate is not None and up_rate > 1:
|
794 |
+
mlp_dim_out = dim * up_rate
|
795 |
+
else:
|
796 |
+
mlp_dim_out = dim_out
|
797 |
+
|
798 |
+
self.mlp = FFN(
|
799 |
+
in_features=dim,
|
800 |
+
hidden_features=mlp_hidden_dim,
|
801 |
+
out_features=mlp_dim_out,
|
802 |
+
act_layer=act_layer,
|
803 |
+
drop=drop_rate,
|
804 |
+
)
|
805 |
+
|
806 |
+
# residual connection for FFN projecting
|
807 |
+
if dim != dim_out:
|
808 |
+
self.proj = nn.Linear(dim, dim_out)
|
809 |
+
|
810 |
+
def forward(self, x, thw_shape):
|
811 |
+
# step 1: Attention projecting
|
812 |
+
x_block, thw_shape_new = self.attn(self.norm1(x), thw_shape)
|
813 |
+
# residual connection for Attention projecting
|
814 |
+
x_res, _ = attention_pool(x, self.pool_skip, thw_shape, has_cls_embed=self.has_cls_embed)
|
815 |
+
x = x_res + self.drop_path(x_block)
|
816 |
+
|
817 |
+
# step 2: FFN projecting
|
818 |
+
x_norm = self.norm2(x)
|
819 |
+
x_mlp = self.mlp(x_norm)
|
820 |
+
# residual connection for FFN projecting
|
821 |
+
if self.dim != self.dim_out:
|
822 |
+
x = self.proj(x_norm)
|
823 |
+
x = x + self.drop_path(x_mlp)
|
824 |
+
|
825 |
+
return x, thw_shape_new
|
826 |
+
|
827 |
+
|
828 |
+
'''
|
829 |
+
# case 1
|
830 |
+
model = MultiScaleBlock(768,1024)
|
831 |
+
x = torch.randn(1, 197, 768)
|
832 |
+
y, thw = model(x, [1, 14, 14])
|
833 |
+
print(y.shape) # torch.Size([1, 197, 1024])
|
834 |
+
|
835 |
+
|
836 |
+
# case 2
|
837 |
+
kernel_q = (1, 2, 2)
|
838 |
+
kernel_kv = (1, 2, 2)
|
839 |
+
stride_q = (1, 2, 2)
|
840 |
+
stride_kv = (1, 2, 2)
|
841 |
+
# MultiScaleAttention 中设计以半个kernal size进行padding,向下取整
|
842 |
+
|
843 |
+
model = MultiScaleBlock(768, 1024, kernel_q=kernel_q, kernel_kv=kernel_kv, stride_q=stride_q, stride_kv=stride_kv)
|
844 |
+
x = torch.randn(1, 197, 768)
|
845 |
+
y, thw = model(x, [1, 14, 14])
|
846 |
+
|
847 |
+
print(y.shape) # 输出torch.Size([1, 65, 1024]):不padding是7*7 由于padding变成8*8, 之后加上cls token
|
848 |
+
'''
|
849 |
+
|
850 |
+
|
851 |
+
class PatchEmbed(nn.Module): # PatchEmbed from timm
|
852 |
+
"""
|
853 |
+
Image to Patch Embedding
|
854 |
+
"""
|
855 |
+
|
856 |
+
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
|
857 |
+
super().__init__()
|
858 |
+
img_size = to_2tuple(img_size)
|
859 |
+
patch_size = to_2tuple(patch_size)
|
860 |
+
num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
|
861 |
+
self.img_size = img_size
|
862 |
+
self.patch_size = patch_size
|
863 |
+
self.num_patches = num_patches
|
864 |
+
|
865 |
+
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
|
866 |
+
|
867 |
+
def forward(self, x):
|
868 |
+
B, C, H, W = x.shape
|
869 |
+
# FIXME look at relaxing size constraints
|
870 |
+
assert H == self.img_size[0] and W == self.img_size[1], \
|
871 |
+
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
|
872 |
+
x = self.proj(x).flatten(2).transpose(1, 2)
|
873 |
+
|
874 |
+
# x: (B, 14*14, 768)
|
875 |
+
return x
|
876 |
+
|
877 |
+
|
878 |
+
class Hybrid_feature_map_Embed(nn.Module): # HybridEmbed from timm
|
879 |
+
"""
|
880 |
+
CNN Feature Map Embedding, required backbone which is just for referance here
|
881 |
+
Extract feature map from CNN, flatten, project to embedding dim.
|
882 |
+
|
883 |
+
# input x.shape = batch, feature_dim, feature_size[0], feature_size[1]
|
884 |
+
# output x.shape = batch, patch_number, patch_dim
|
885 |
+
"""
|
886 |
+
|
887 |
+
def __init__(self, backbone, img_size=224, patch_size=1, feature_size=None, feature_dim=None,
|
888 |
+
in_chans=3, embed_dim=768):
|
889 |
+
super().__init__()
|
890 |
+
|
891 |
+
assert isinstance(backbone, nn.Module)
|
892 |
+
|
893 |
+
img_size = to_2tuple(img_size)
|
894 |
+
patch_size = to_2tuple(patch_size)
|
895 |
+
self.img_size = img_size
|
896 |
+
self.patch_size = patch_size
|
897 |
+
self.backbone = backbone
|
898 |
+
|
899 |
+
if feature_size is None or feature_dim is None: # backbone output feature_size
|
900 |
+
with torch.no_grad():
|
901 |
+
# NOTE Most reliable way of determining output dims is to run forward pass
|
902 |
+
training = backbone.training
|
903 |
+
if training:
|
904 |
+
backbone.eval()
|
905 |
+
o = self.backbone(torch.zeros(1, in_chans, img_size[0], img_size[1]))
|
906 |
+
if isinstance(o, (list, tuple)):
|
907 |
+
o = o[-1] # last feature if backbone outputs list/tuple of features
|
908 |
+
feature_size = o.shape[-2:]
|
909 |
+
feature_dim = o.shape[1]
|
910 |
+
backbone.train(training)
|
911 |
+
else:
|
912 |
+
feature_size = to_2tuple(feature_size)
|
913 |
+
'''
|
914 |
+
if hasattr(self.backbone, 'feature_info'):
|
915 |
+
feature_dim = self.backbone.feature_info.channels()[-1]
|
916 |
+
else:
|
917 |
+
feature_dim = self.backbone.num_features
|
918 |
+
'''
|
919 |
+
|
920 |
+
assert feature_size[0] % patch_size[0] == 0 and feature_size[1] % patch_size[1] == 0
|
921 |
+
|
922 |
+
self.grid_size = (feature_size[0] // patch_size[0], feature_size[1] // patch_size[1]) # patchlize
|
923 |
+
|
924 |
+
self.num_patches = self.grid_size[0] * self.grid_size[1]
|
925 |
+
|
926 |
+
self.proj = nn.Conv2d(in_channels=feature_dim, out_channels=embed_dim,
|
927 |
+
kernel_size=patch_size, stride=patch_size)
|
928 |
+
|
929 |
+
def forward(self, x):
|
930 |
+
x = self.backbone(x)
|
931 |
+
if isinstance(x, (list, tuple)):
|
932 |
+
x = x[-1] # last feature if backbone outputs list/tuple of features
|
933 |
+
|
934 |
+
x = self.proj(x).flatten(2).transpose(1, 2) # shape = ( )
|
935 |
+
"""
|
936 |
+
x.shape: batch, feature_dim, feature_size[0], feature_size[1]
|
937 |
+
proj(x).shape: batch, embed_dim, patch_height_num, patch_width_num
|
938 |
+
flatten(2).shape: batch, embed_dim, patch_num
|
939 |
+
.transpose(1, 2).shape: batch feature_patch_number feature_patch_dim
|
940 |
+
"""
|
941 |
+
# output: x.shape = batch, patch_number, patch_dim
|
942 |
+
return x
|
943 |
+
|
944 |
+
|
945 |
+
class Last_feature_map_Embed(nn.Module):
|
946 |
+
"""
|
947 |
+
use this block to connect last CNN stage to the first Transformer block
|
948 |
+
Extract feature map from CNN, flatten, project to embedding dim.
|
949 |
+
|
950 |
+
# input x.shape = batch, feature_dim, feature_size[0], feature_size[1]
|
951 |
+
# output x.shape = batch, patch_number, patch_dim
|
952 |
+
"""
|
953 |
+
|
954 |
+
def __init__(self, patch_size=1, feature_size=(7, 7), feature_dim=2048, embed_dim=768,
|
955 |
+
Attention_module=None):
|
956 |
+
super().__init__()
|
957 |
+
|
958 |
+
# Attention module
|
959 |
+
if Attention_module is not None:
|
960 |
+
if Attention_module == 'SimAM':
|
961 |
+
self.Attention_module = simam_module(e_lambda=1e-4)
|
962 |
+
elif Attention_module == 'CBAM':
|
963 |
+
self.Attention_module = cbam_module(gate_channels=feature_dim)
|
964 |
+
elif Attention_module == 'SE':
|
965 |
+
self.Attention_module = se_module(channel=feature_dim)
|
966 |
+
else:
|
967 |
+
self.Attention_module = None
|
968 |
+
|
969 |
+
patch_size = to_2tuple(patch_size)
|
970 |
+
self.patch_size = patch_size
|
971 |
+
|
972 |
+
feature_size = to_2tuple(feature_size)
|
973 |
+
|
974 |
+
# feature map should be matching the size
|
975 |
+
assert feature_size[0] % self.patch_size[0] == 0 and feature_size[1] % self.patch_size[1] == 0
|
976 |
+
|
977 |
+
self.grid_size = (feature_size[0] // self.patch_size[0], feature_size[1] // self.patch_size[1]) # patch
|
978 |
+
|
979 |
+
self.num_patches = self.grid_size[0] * self.grid_size[1]
|
980 |
+
|
981 |
+
# use the conv to split the patch by the following design:
|
982 |
+
self.proj = nn.Conv2d(in_channels=feature_dim, out_channels=embed_dim,
|
983 |
+
kernel_size=self.patch_size, stride=self.patch_size)
|
984 |
+
|
985 |
+
def forward(self, x):
|
986 |
+
if self.Attention_module is not None:
|
987 |
+
x = self.Attention_module(x)
|
988 |
+
|
989 |
+
if isinstance(x, (list, tuple)):
|
990 |
+
x = x[-1] # last feature if backbone outputs list/tuple of features
|
991 |
+
|
992 |
+
x = self.proj(x).flatten(2).transpose(1, 2)
|
993 |
+
"""
|
994 |
+
x.shape: batch, feature_dim, feature_size[0], feature_size[1]
|
995 |
+
proj(x).shape: batch, embed_dim, patch_height_num, patch_width_num
|
996 |
+
flatten(2).shape: batch, embed_dim, patch_num
|
997 |
+
.transpose(1, 2).shape: batch feature_patch_number feature_patch_dim
|
998 |
+
"""
|
999 |
+
# output 格式 x.shape = batch, patch_number, patch_dim
|
1000 |
+
return x
|
1001 |
+
|
1002 |
+
|
1003 |
+
class Focus_Embed(nn.Module): # Attention guided module for hybridzing the early stages CNN feature
|
1004 |
+
"""
|
1005 |
+
FGD Focus module
|
1006 |
+
Extract feature map from CNN, flatten, project to embedding dim. and use them as attention guidance
|
1007 |
+
|
1008 |
+
input: x.shape = batch, feature_dim, feature_size[0], feature_size[1]
|
1009 |
+
|
1010 |
+
Firstly, an attention block will be used to stable the feature projecting process
|
1011 |
+
|
1012 |
+
Secondly, for each feature map,the focus will be 2 path: gaze and glance
|
1013 |
+
in gaze path Max pool will be applied to get prominent information
|
1014 |
+
in glance path Avg pool will be applied to get general information
|
1015 |
+
|
1016 |
+
after the dual pooling path 2 seperate CNNs will be used to project the dimension
|
1017 |
+
Finally, flattern and transpose will be applied
|
1018 |
+
|
1019 |
+
output 2 attention guidance: gaze, glance
|
1020 |
+
x.shape = batch, patch_number, patch_dim
|
1021 |
+
|
1022 |
+
|
1023 |
+
ref:
|
1024 |
+
ResNet50's feature map from different stages (edge size of 224)
|
1025 |
+
stage 1 output feature map: torch.Size([b, 256, 56, 56])
|
1026 |
+
stage 2 output feature map: torch.Size([b, 512, 28, 28])
|
1027 |
+
stage 3 output feature map: torch.Size([b, 1024, 14, 14])
|
1028 |
+
stage 4 output feature map: torch.Size([b, 2048, 7, 7])
|
1029 |
+
"""
|
1030 |
+
|
1031 |
+
def __init__(self, patch_size=1, target_feature_size=(7, 7), feature_size=(56, 56), feature_dim=256, embed_dim=768,
|
1032 |
+
Attention_module=None, norm_layer=nn.LayerNorm):
|
1033 |
+
super().__init__()
|
1034 |
+
patch_size = to_2tuple(patch_size)
|
1035 |
+
feature_size = to_2tuple(feature_size) # patch size of the current feature map
|
1036 |
+
|
1037 |
+
target_feature_size = to_2tuple(target_feature_size) # patch size of the last feature map
|
1038 |
+
|
1039 |
+
# cheak feature map can be patchlize to target_feature_size
|
1040 |
+
assert feature_size[0] % target_feature_size[0] == 0 and feature_size[1] % target_feature_size[1] == 0
|
1041 |
+
|
1042 |
+
# cheak target_feature map can be patchlize to patch
|
1043 |
+
assert target_feature_size[0] % patch_size[0] == 0 and target_feature_size[1] % patch_size[1] == 0
|
1044 |
+
|
1045 |
+
# Attention block
|
1046 |
+
if Attention_module is not None:
|
1047 |
+
if Attention_module == 'SimAM':
|
1048 |
+
self.Attention_module = simam_module(e_lambda=1e-4)
|
1049 |
+
elif Attention_module == 'CBAM':
|
1050 |
+
self.Attention_module = cbam_module(gate_channels=feature_dim)
|
1051 |
+
elif Attention_module == 'SE':
|
1052 |
+
self.Attention_module = se_module(channel=feature_dim)
|
1053 |
+
else:
|
1054 |
+
self.Attention_module = None
|
1055 |
+
|
1056 |
+
# split focus ROI
|
1057 |
+
self.focus_size = (feature_size[0] // target_feature_size[0], feature_size[1] // target_feature_size[1])
|
1058 |
+
self.num_focus = self.focus_size[0] * self.focus_size[1]
|
1059 |
+
# by kernel_size=focus_size, stride=focus_size design
|
1060 |
+
# output_size=target_feature_size=7x7 so as to match the minist feature map
|
1061 |
+
|
1062 |
+
self.gaze = nn.MaxPool2d(self.focus_size, stride=self.focus_size)
|
1063 |
+
self.glance = nn.AvgPool2d(self.focus_size, stride=self.focus_size)
|
1064 |
+
# x.shape: batch, feature_dim, target_feature_size[0], target_feature_size[1]
|
1065 |
+
|
1066 |
+
# split patch
|
1067 |
+
self.grid_size = (target_feature_size[0] // patch_size[0], target_feature_size[1] // patch_size[1])
|
1068 |
+
self.num_patches = self.grid_size[0] * self.grid_size[1]
|
1069 |
+
|
1070 |
+
# use CNN to project dim to patch_dim
|
1071 |
+
self.gaze_proj = nn.Conv2d(in_channels=feature_dim, out_channels=embed_dim,
|
1072 |
+
kernel_size=patch_size, stride=patch_size)
|
1073 |
+
self.glance_proj = nn.Conv2d(in_channels=feature_dim, out_channels=embed_dim,
|
1074 |
+
kernel_size=patch_size, stride=patch_size)
|
1075 |
+
|
1076 |
+
self.norm_q = norm_layer(embed_dim) # Transformer nn.LayerNorm
|
1077 |
+
self.norm_k = norm_layer(embed_dim) # Transformer nn.LayerNorm
|
1078 |
+
|
1079 |
+
def forward(self, x):
|
1080 |
+
if self.Attention_module is not None:
|
1081 |
+
x = self.Attention_module(x)
|
1082 |
+
|
1083 |
+
if isinstance(x, (list, tuple)):
|
1084 |
+
x = x[-1] # last feature if backbone outputs list/tuple of features
|
1085 |
+
|
1086 |
+
q = self.norm_q(self.gaze_proj(self.gaze(x)).flatten(2).transpose(1, 2))
|
1087 |
+
k = self.norm_k(self.glance_proj(self.glance(x)).flatten(2).transpose(1, 2))
|
1088 |
+
"""
|
1089 |
+
x.shape: batch, feature_dim, feature_size[0], feature_size[1]
|
1090 |
+
gaze/glance(x).shape: batch, feature_dim, target_feature_size[0], target_feature_size[1]
|
1091 |
+
proj(x).shape: batch, embed_dim, patch_height_num, patch_width_num
|
1092 |
+
flatten(2).shape: batch, embed_dim, patch_num
|
1093 |
+
.transpose(1, 2).shape: batch feature_patch_number feature_patch_dim
|
1094 |
+
"""
|
1095 |
+
# output x.shape = batch, patch_number, patch_dim
|
1096 |
+
return q, k
|
1097 |
+
|
1098 |
+
|
1099 |
+
'''
|
1100 |
+
# test sample
|
1101 |
+
model = Focus_Embed()
|
1102 |
+
x = torch.randn(4, 256, 56, 56)
|
1103 |
+
y1,y2 = model(x)
|
1104 |
+
print(y1.shape)
|
1105 |
+
print(y2.shape)
|
1106 |
+
'''
|
1107 |
+
|
1108 |
+
|
1109 |
+
class Focus_SEmbed(nn.Module): # Attention guided module for hybridzing the early stages CNN feature
|
1110 |
+
"""
|
1111 |
+
|
1112 |
+
self focus (q=k) based on FGD Focus block
|
1113 |
+
|
1114 |
+
Extract feature map from CNN, flatten, project to embedding dim. and use them as attention guidance
|
1115 |
+
|
1116 |
+
input: x.shape = batch, feature_dim, feature_size[0], feature_size[1]
|
1117 |
+
|
1118 |
+
Firstly, an attention block will be used to stable the feature projecting process
|
1119 |
+
|
1120 |
+
Secondly, for each feature map,the focus will be 1 path: glance
|
1121 |
+
in glance path Avg pool will be applied to get general information
|
1122 |
+
|
1123 |
+
after the pooling process 1 CNN will be used to project the dimension
|
1124 |
+
Finally, flattern and transpose will be applied
|
1125 |
+
|
1126 |
+
output 2 attention guidance: glance, glance
|
1127 |
+
x.shape = batch, patch_number, patch_dim
|
1128 |
+
"""
|
1129 |
+
|
1130 |
+
def __init__(self, patch_size=1, target_feature_size=(7, 7), feature_size=(56, 56), feature_dim=256, embed_dim=768,
|
1131 |
+
Attention_module=None, norm_layer=nn.LayerNorm):
|
1132 |
+
super().__init__()
|
1133 |
+
patch_size = to_2tuple(patch_size)
|
1134 |
+
feature_size = to_2tuple(feature_size)
|
1135 |
+
|
1136 |
+
target_feature_size = to_2tuple(target_feature_size)
|
1137 |
+
|
1138 |
+
assert feature_size[0] % target_feature_size[0] == 0 and feature_size[1] % target_feature_size[1] == 0
|
1139 |
+
|
1140 |
+
assert target_feature_size[0] % patch_size[0] == 0 and target_feature_size[1] % patch_size[1] == 0
|
1141 |
+
|
1142 |
+
if Attention_module is not None:
|
1143 |
+
if Attention_module == 'SimAM':
|
1144 |
+
self.Attention_module = simam_module(e_lambda=1e-4)
|
1145 |
+
elif Attention_module == 'CBAM':
|
1146 |
+
self.Attention_module = cbam_module(gate_channels=feature_dim)
|
1147 |
+
elif Attention_module == 'SE':
|
1148 |
+
self.Attention_module = se_module(channel=feature_dim)
|
1149 |
+
else:
|
1150 |
+
self.Attention_module = None
|
1151 |
+
|
1152 |
+
self.focus_size = (feature_size[0] // target_feature_size[0], feature_size[1] // target_feature_size[1])
|
1153 |
+
self.num_focus = self.focus_size[0] * self.focus_size[1]
|
1154 |
+
|
1155 |
+
self.gaze = nn.MaxPool2d(self.focus_size, stride=self.focus_size)
|
1156 |
+
|
1157 |
+
self.grid_size = (target_feature_size[0] // patch_size[0], target_feature_size[1] // patch_size[1])
|
1158 |
+
self.num_patches = self.grid_size[0] * self.grid_size[1]
|
1159 |
+
|
1160 |
+
self.proj = nn.Conv2d(in_channels=feature_dim, out_channels=embed_dim, kernel_size=patch_size,
|
1161 |
+
stride=patch_size)
|
1162 |
+
|
1163 |
+
self.norm_f = norm_layer(embed_dim)
|
1164 |
+
|
1165 |
+
def forward(self, x):
|
1166 |
+
if self.Attention_module is not None:
|
1167 |
+
x = self.Attention_module(x)
|
1168 |
+
|
1169 |
+
if isinstance(x, (list, tuple)):
|
1170 |
+
x = x[-1] # last feature if backbone outputs list/tuple of features
|
1171 |
+
|
1172 |
+
q = self.norm_f(self.proj(self.gaze(x)).flatten(2).transpose(1, 2))
|
1173 |
+
k = q
|
1174 |
+
"""
|
1175 |
+
x.shape: batch, feature_dim, feature_size[0], feature_size[1]
|
1176 |
+
gaze/glance(x).shape: batch, feature_dim, target_feature_size[0], target_feature_size[1]
|
1177 |
+
proj(x).shape: batch, embed_dim, patch_height_num, patch_width_num
|
1178 |
+
flatten(2).shape: batch, embed_dim, patch_num
|
1179 |
+
.transpose(1, 2).shape: batch feature_patch_number feature_patch_dim
|
1180 |
+
"""
|
1181 |
+
# output x.shape = batch, patch_number, patch_dim
|
1182 |
+
return q, k
|
1183 |
+
|
1184 |
+
|
1185 |
+
class Focus_Aggressive(nn.Module): # Attention guided module for hybridzing the early stages CNN feature
|
1186 |
+
"""
|
1187 |
+
Aggressive CNN Focus based on FGD Focus block
|
1188 |
+
|
1189 |
+
Extract feature map from CNN, flatten, project to embedding dim. and use them as attention guidance
|
1190 |
+
|
1191 |
+
input: x.shape = batch, feature_dim, feature_size[0], feature_size[1]
|
1192 |
+
|
1193 |
+
Firstly, an attention block will be used to stable the feature projecting process
|
1194 |
+
|
1195 |
+
Secondly, 2 CNNs will be used to project the dimension
|
1196 |
+
|
1197 |
+
Finally, flattern and transpose will be applied
|
1198 |
+
|
1199 |
+
output 2 attention guidance: gaze, glance
|
1200 |
+
x.shape = batch, patch_number, patch_dim
|
1201 |
+
|
1202 |
+
"""
|
1203 |
+
|
1204 |
+
def __init__(self, patch_size=1, target_feature_size=(7, 7), feature_size=(56, 56), feature_dim=256, embed_dim=768,
|
1205 |
+
Attention_module=None, norm_layer=nn.LayerNorm):
|
1206 |
+
super().__init__()
|
1207 |
+
patch_size = to_2tuple(patch_size) # patch size of the last feature map
|
1208 |
+
feature_size = to_2tuple(feature_size)
|
1209 |
+
|
1210 |
+
target_feature_size = to_2tuple(target_feature_size)
|
1211 |
+
|
1212 |
+
assert feature_size[0] % target_feature_size[0] == 0 and feature_size[1] % target_feature_size[1] == 0
|
1213 |
+
|
1214 |
+
assert target_feature_size[0] % patch_size[0] == 0 and target_feature_size[1] % patch_size[1] == 0
|
1215 |
+
|
1216 |
+
if Attention_module is not None:
|
1217 |
+
if Attention_module == 'SimAM':
|
1218 |
+
self.Attention_module = simam_module(e_lambda=1e-4)
|
1219 |
+
elif Attention_module == 'CBAM':
|
1220 |
+
self.Attention_module = cbam_module(gate_channels=feature_dim)
|
1221 |
+
elif Attention_module == 'SE':
|
1222 |
+
self.Attention_module = se_module(channel=feature_dim)
|
1223 |
+
else:
|
1224 |
+
self.Attention_module = None
|
1225 |
+
|
1226 |
+
self.focus_size = (feature_size[0] // target_feature_size[0], feature_size[1] // target_feature_size[1])
|
1227 |
+
|
1228 |
+
self.grid_size = (self.focus_size[0] * patch_size[0], self.focus_size[1] * patch_size[1])
|
1229 |
+
self.num_patches = (feature_size[0] // self.grid_size[0]) * (feature_size[1] // self.grid_size[1])
|
1230 |
+
|
1231 |
+
self.gaze_proj = nn.Conv2d(in_channels=feature_dim, out_channels=embed_dim,
|
1232 |
+
kernel_size=self.grid_size, stride=self.grid_size)
|
1233 |
+
self.glance_proj = nn.Conv2d(in_channels=feature_dim, out_channels=embed_dim,
|
1234 |
+
kernel_size=self.grid_size, stride=self.grid_size)
|
1235 |
+
|
1236 |
+
self.norm_q = norm_layer(embed_dim)
|
1237 |
+
self.norm_k = norm_layer(embed_dim)
|
1238 |
+
|
1239 |
+
def forward(self, x):
|
1240 |
+
if self.Attention_module is not None:
|
1241 |
+
x = self.Attention_module(x)
|
1242 |
+
|
1243 |
+
if isinstance(x, (list, tuple)):
|
1244 |
+
x = x[-1] # last feature if backbone outputs list/tuple of features
|
1245 |
+
|
1246 |
+
q = self.norm_q(self.gaze_proj(x).flatten(2).transpose(1, 2))
|
1247 |
+
k = self.norm_k(self.glance_proj(x).flatten(2).transpose(1, 2))
|
1248 |
+
"""
|
1249 |
+
x.shape: batch, feature_dim, feature_size[0], feature_size[1]
|
1250 |
+
proj(x).shape: batch, embed_dim, patch_height_num, patch_width_num
|
1251 |
+
flatten(2).shape: batch, embed_dim, patch_num
|
1252 |
+
.transpose(1, 2).shape: batch feature_patch_number feature_patch_dim
|
1253 |
+
"""
|
1254 |
+
# output x.shape = batch, patch_number, patch_dim
|
1255 |
+
return q, k
|
1256 |
+
|
1257 |
+
|
1258 |
+
class Focus_SAggressive(nn.Module): # Attention guided module for hybridzing the early stages CNN feature
|
1259 |
+
"""
|
1260 |
+
Aggressive CNN self Focus
|
1261 |
+
Extract feature map from CNN, flatten, project to embedding dim. and use them as attention guidance
|
1262 |
+
|
1263 |
+
input: x.shape = batch, feature_dim, feature_size[0], feature_size[1]
|
1264 |
+
|
1265 |
+
Firstly, an attention block will be used to stable the feature projecting process
|
1266 |
+
|
1267 |
+
Secondly, 1 CNN will be used to project the dimension
|
1268 |
+
|
1269 |
+
Finally, flattern and transpose will be applied
|
1270 |
+
|
1271 |
+
output 2 attention guidance: glance, glance
|
1272 |
+
x.shape = batch, patch_number, patch_dim
|
1273 |
+
"""
|
1274 |
+
|
1275 |
+
def __init__(self, patch_size=1, target_feature_size=(7, 7), feature_size=(56, 56), feature_dim=256, embed_dim=768,
|
1276 |
+
Attention_module=None, norm_layer=nn.LayerNorm):
|
1277 |
+
super().__init__()
|
1278 |
+
patch_size = to_2tuple(patch_size)
|
1279 |
+
feature_size = to_2tuple(feature_size)
|
1280 |
+
|
1281 |
+
target_feature_size = to_2tuple(target_feature_size)
|
1282 |
+
|
1283 |
+
assert feature_size[0] % target_feature_size[0] == 0 and feature_size[1] % target_feature_size[1] == 0
|
1284 |
+
|
1285 |
+
assert target_feature_size[0] % patch_size[0] == 0 and target_feature_size[1] % patch_size[1] == 0
|
1286 |
+
|
1287 |
+
if Attention_module is not None:
|
1288 |
+
if Attention_module == 'SimAM':
|
1289 |
+
self.Attention_module = simam_module(e_lambda=1e-4)
|
1290 |
+
elif Attention_module == 'CBAM':
|
1291 |
+
self.Attention_module = cbam_module(gate_channels=feature_dim)
|
1292 |
+
elif Attention_module == 'SE':
|
1293 |
+
self.Attention_module = se_module(channel=feature_dim)
|
1294 |
+
else:
|
1295 |
+
self.Attention_module = None
|
1296 |
+
|
1297 |
+
self.focus_size = (feature_size[0] // target_feature_size[0], feature_size[1] // target_feature_size[1])
|
1298 |
+
|
1299 |
+
self.grid_size = (self.focus_size[0] * patch_size[0], self.focus_size[1] * patch_size[1])
|
1300 |
+
self.num_patches = (feature_size[0] // self.grid_size[0]) * (feature_size[1] // self.grid_size[1])
|
1301 |
+
|
1302 |
+
self.proj = nn.Conv2d(in_channels=feature_dim, out_channels=embed_dim,
|
1303 |
+
kernel_size=self.grid_size, stride=self.grid_size)
|
1304 |
+
|
1305 |
+
self.norm_f = norm_layer(embed_dim)
|
1306 |
+
|
1307 |
+
def forward(self, x):
|
1308 |
+
if self.Attention_module is not None:
|
1309 |
+
x = self.Attention_module(x)
|
1310 |
+
|
1311 |
+
if isinstance(x, (list, tuple)):
|
1312 |
+
x = x[-1] # last feature if backbone outputs list/tuple of features
|
1313 |
+
|
1314 |
+
q = self.norm_f(self.proj(x).flatten(2).transpose(1, 2))
|
1315 |
+
k = q
|
1316 |
+
"""
|
1317 |
+
x.shape: batch, feature_dim, feature_size[0], feature_size[1]
|
1318 |
+
proj(x).shape: batch, embed_dim, patch_height_num, patch_width_num
|
1319 |
+
flatten(2).shape: batch, embed_dim, patch_num
|
1320 |
+
.transpose(1, 2).shape: batch feature_patch_number feature_patch_dim
|
1321 |
+
"""
|
1322 |
+
# output x.shape = batch, patch_number, patch_dim
|
1323 |
+
return q, k
|
1324 |
+
|
1325 |
+
|
1326 |
+
class VisionTransformer(nn.Module): # From timm to review the ViT and ViT_resn5
|
1327 |
+
"""
|
1328 |
+
Vision Transformer
|
1329 |
+
A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale`
|
1330 |
+
- https://arxiv.org/abs/2010.11929
|
1331 |
+
Includes distillation token & head support for `DeiT: Data-efficient Image Transformers`
|
1332 |
+
- https://arxiv.org/abs/2012.12877
|
1333 |
+
"""
|
1334 |
+
|
1335 |
+
def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
|
1336 |
+
num_heads=12, mlp_ratio=4., qkv_bias=True, representation_size=None, drop_rate=0., attn_drop_rate=0.,
|
1337 |
+
drop_path_rate=0., embed_layer=PatchEmbed, norm_layer=None, act_layer=None):
|
1338 |
+
"""
|
1339 |
+
Args:
|
1340 |
+
img_size (int, tuple): input image size
|
1341 |
+
patch_size (int, tuple): patch size
|
1342 |
+
in_chans (int): number of input channels
|
1343 |
+
num_classes (int): number of classes for classification head
|
1344 |
+
embed_dim (int): embedding dimension
|
1345 |
+
depth (int): depth of transformer
|
1346 |
+
num_heads (int): number of attention heads
|
1347 |
+
mlp_ratio (int): ratio of mlp hidden dim to embedding dim
|
1348 |
+
qkv_bias (bool): enable bias for qkv if True
|
1349 |
+
representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set
|
1350 |
+
drop_rate (float): dropout rate
|
1351 |
+
attn_drop_rate (float): attention dropout rate
|
1352 |
+
drop_path_rate (float): stochastic depth rate
|
1353 |
+
embed_layer (nn.Module): patch embedding layer
|
1354 |
+
norm_layer: (nn.Module): normalization layer
|
1355 |
+
"""
|
1356 |
+
super().__init__()
|
1357 |
+
self.num_classes = num_classes
|
1358 |
+
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
|
1359 |
+
self.num_tokens = 1
|
1360 |
+
norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
|
1361 |
+
act_layer = act_layer or nn.GELU
|
1362 |
+
|
1363 |
+
self.patch_embed = embed_layer(
|
1364 |
+
img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
|
1365 |
+
num_patches = self.patch_embed.num_patches
|
1366 |
+
|
1367 |
+
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
1368 |
+
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))
|
1369 |
+
self.pos_drop = nn.Dropout(p=drop_rate)
|
1370 |
+
|
1371 |
+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
|
1372 |
+
|
1373 |
+
self.blocks = nn.Sequential(*[
|
1374 |
+
Encoder_Block(dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop=drop_rate,
|
1375 |
+
attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, act_layer=act_layer)
|
1376 |
+
for i in range(depth)])
|
1377 |
+
|
1378 |
+
self.norm = norm_layer(embed_dim)
|
1379 |
+
|
1380 |
+
# Representation layer
|
1381 |
+
if representation_size:
|
1382 |
+
self.num_features = representation_size
|
1383 |
+
self.pre_logits = nn.Sequential(OrderedDict([
|
1384 |
+
('fc', nn.Linear(embed_dim, representation_size)),
|
1385 |
+
('act', nn.Tanh())
|
1386 |
+
]))
|
1387 |
+
else:
|
1388 |
+
self.pre_logits = nn.Identity()
|
1389 |
+
|
1390 |
+
# Classifier head(s)
|
1391 |
+
self.head = nn.Linear(self.num_features, self.num_classes) if self.num_classes > 0 else nn.Identity()
|
1392 |
+
self.head_dist = None
|
1393 |
+
|
1394 |
+
def forward_features(self, x):
|
1395 |
+
x = self.patch_embed(x)
|
1396 |
+
# print(x.shape,self.pos_embed.shape)
|
1397 |
+
cls_token = self.cls_token.expand(x.shape[0], -1, -1) # stole cls_tokens impl from Phil Wang, thanks
|
1398 |
+
x = torch.cat((cls_token, x), dim=1)
|
1399 |
+
x = self.pos_drop(x + self.pos_embed)
|
1400 |
+
|
1401 |
+
x = self.blocks(x)
|
1402 |
+
x = self.norm(x)
|
1403 |
+
return self.pre_logits(x[:, 0]) # use cls token for cls head
|
1404 |
+
|
1405 |
+
def forward(self, x):
|
1406 |
+
x = self.forward_features(x)
|
1407 |
+
x = self.head(x)
|
1408 |
+
return x
|
1409 |
+
|
1410 |
+
|
1411 |
+
class Stage_wise_hybrid_Transformer(nn.Module):
|
1412 |
+
"""
|
1413 |
+
MSHT: Multi Stage Backbone Transformer
|
1414 |
+
Stem + 4 ResNet stages(Backbone)is used as backbone
|
1415 |
+
then, last feature map patch embedding is used to connect the CNN output to the decoder1 input
|
1416 |
+
|
1417 |
+
horizonally, 4 ResNet Stage has its feature map connecting to the Focus module
|
1418 |
+
which we be use as attention guidance into the FGD decoder
|
1419 |
+
"""
|
1420 |
+
|
1421 |
+
def __init__(self, backbone, num_classes=1000, patch_size=1, embed_dim=768, depth=4, num_heads=8, mlp_ratio=4.,
|
1422 |
+
qkv_bias=True, representation_size=None, drop_rate=0., attn_drop_rate=0., drop_path_rate=0.,
|
1423 |
+
use_cls_token=True, use_pos_embedding=True, use_att_module='SimAM', stage_size=(56, 28, 14, 7),
|
1424 |
+
stage_dim=(256, 512, 1024, 2048), norm_layer=None, act_layer=None):
|
1425 |
+
"""
|
1426 |
+
Args:
|
1427 |
+
backbone (nn.Module): input backbone = stem + 4 ResNet stages
|
1428 |
+
num_classes (int): number of classes for classification head
|
1429 |
+
patch_size (int, tuple): patch size
|
1430 |
+
embed_dim (int): embedding dimension
|
1431 |
+
depth (int): depth of transformer
|
1432 |
+
num_heads (int): number of attention heads
|
1433 |
+
mlp_ratio (int): ratio of mlp hidden dim to embedding dim
|
1434 |
+
qkv_bias (bool): enable bias for qkv if True
|
1435 |
+
representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set
|
1436 |
+
drop_rate (float): dropout rate
|
1437 |
+
attn_drop_rate (float): attention dropout rate
|
1438 |
+
drop_path_rate (float): stochastic depth rate
|
1439 |
+
|
1440 |
+
use_cls_token(bool): classification token
|
1441 |
+
use_pos_embedding(bool): use positional embedding
|
1442 |
+
use_att_module(str or None): use which attention module in embedding
|
1443 |
+
|
1444 |
+
stage_size (int, tuple): the stage feature map size of ResNet stages
|
1445 |
+
stage_dim (int, tuple): the stage feature map dimension of ResNet stages
|
1446 |
+
norm_layer: (nn.Module): normalization layer
|
1447 |
+
"""
|
1448 |
+
super().__init__()
|
1449 |
+
self.num_classes = num_classes
|
1450 |
+
if len(stage_dim) != len(stage_size):
|
1451 |
+
raise TypeError('stage_dim and stage_size mismatch!')
|
1452 |
+
else:
|
1453 |
+
self.stage_num = len(stage_dim)
|
1454 |
+
|
1455 |
+
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
|
1456 |
+
|
1457 |
+
self.cls_token_num = 1 if use_cls_token else 0
|
1458 |
+
self.use_pos_embedding = use_pos_embedding
|
1459 |
+
|
1460 |
+
norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
|
1461 |
+
act_layer = act_layer or nn.GELU
|
1462 |
+
|
1463 |
+
# backbone CNN
|
1464 |
+
self.backbone = backbone
|
1465 |
+
|
1466 |
+
# Attention module
|
1467 |
+
if use_att_module is not None:
|
1468 |
+
if use_att_module in ['SimAM', 'CBAM', 'SE']:
|
1469 |
+
Attention_module = use_att_module
|
1470 |
+
else:
|
1471 |
+
Attention_module = None
|
1472 |
+
else:
|
1473 |
+
Attention_module = None
|
1474 |
+
|
1475 |
+
self.patch_embed = Last_feature_map_Embed(patch_size=patch_size, feature_size=stage_size[-1],
|
1476 |
+
feature_dim=stage_dim[-1], embed_dim=self.embed_dim,
|
1477 |
+
Attention_module=Attention_module)
|
1478 |
+
num_patches = self.patch_embed.num_patches
|
1479 |
+
|
1480 |
+
# global sharing cls token and positional embedding
|
1481 |
+
self.cls_token_0 = nn.Parameter(torch.zeros(1, 1, embed_dim)) # like message token
|
1482 |
+
if self.use_pos_embedding:
|
1483 |
+
self.pos_embed_0 = nn.Parameter(torch.zeros(1, num_patches + self.cls_token_num, embed_dim))
|
1484 |
+
|
1485 |
+
'''
|
1486 |
+
self.cls_token_1 = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
1487 |
+
self.pos_embed_1 = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))
|
1488 |
+
|
1489 |
+
self.cls_token_2 = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
1490 |
+
self.pos_embed_2 = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))
|
1491 |
+
|
1492 |
+
self.cls_token_3 = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
1493 |
+
self.pos_embed_3 = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))
|
1494 |
+
|
1495 |
+
self.cls_token_4 = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
1496 |
+
self.pos_embed_4 = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))
|
1497 |
+
'''
|
1498 |
+
|
1499 |
+
self.pos_drop = nn.Dropout(p=drop_rate)
|
1500 |
+
# stochastic depth
|
1501 |
+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
|
1502 |
+
|
1503 |
+
self.dec1 = Decoder_Block(dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias,
|
1504 |
+
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[0], norm_layer=norm_layer,
|
1505 |
+
act_layer=act_layer)
|
1506 |
+
self.Fo1 = Focus_Embed(patch_size=patch_size, target_feature_size=stage_size[-1], feature_size=stage_size[0],
|
1507 |
+
feature_dim=stage_dim[0], embed_dim=embed_dim, Attention_module=Attention_module,
|
1508 |
+
norm_layer=norm_layer)
|
1509 |
+
|
1510 |
+
self.dec2 = Decoder_Block(dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias,
|
1511 |
+
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[1], norm_layer=norm_layer,
|
1512 |
+
act_layer=act_layer)
|
1513 |
+
self.Fo2 = Focus_Embed(patch_size=patch_size, target_feature_size=stage_size[-1], feature_size=stage_size[1],
|
1514 |
+
feature_dim=stage_dim[1], embed_dim=embed_dim, Attention_module=Attention_module,
|
1515 |
+
norm_layer=norm_layer)
|
1516 |
+
|
1517 |
+
self.dec3 = Decoder_Block(dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias,
|
1518 |
+
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[2], norm_layer=norm_layer,
|
1519 |
+
act_layer=act_layer)
|
1520 |
+
self.Fo3 = Focus_Embed(patch_size=patch_size, target_feature_size=stage_size[-1], feature_size=stage_size[2],
|
1521 |
+
feature_dim=stage_dim[2], embed_dim=embed_dim, Attention_module=Attention_module,
|
1522 |
+
norm_layer=norm_layer)
|
1523 |
+
|
1524 |
+
if self.stage_num == 4:
|
1525 |
+
self.dec4 = Decoder_Block(dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias,
|
1526 |
+
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[3], norm_layer=norm_layer,
|
1527 |
+
act_layer=act_layer)
|
1528 |
+
self.Fo4 = Focus_Embed(patch_size=patch_size, target_feature_size=stage_size[-1],
|
1529 |
+
feature_size=stage_size[-1],
|
1530 |
+
feature_dim=stage_dim[-1], embed_dim=embed_dim, Attention_module=Attention_module,
|
1531 |
+
norm_layer=norm_layer)
|
1532 |
+
|
1533 |
+
self.norm = norm_layer(embed_dim)
|
1534 |
+
|
1535 |
+
# Representation layer
|
1536 |
+
if representation_size:
|
1537 |
+
self.num_features = representation_size
|
1538 |
+
self.pre_logits = nn.Sequential(OrderedDict([
|
1539 |
+
('fc', nn.Linear(embed_dim, representation_size)),
|
1540 |
+
('act', nn.Tanh())
|
1541 |
+
]))
|
1542 |
+
else:
|
1543 |
+
self.pre_logits = nn.Identity()
|
1544 |
+
|
1545 |
+
# Classifier head(s)
|
1546 |
+
self.head = nn.Linear(self.num_features, self.num_classes) if self.num_classes > 0 else nn.Identity()
|
1547 |
+
self.head_dist = None
|
1548 |
+
|
1549 |
+
def forward_features(self, x):
|
1550 |
+
if self.stage_num == 3:
|
1551 |
+
stage1_out, stage2_out, stage3_out = self.backbone(x)
|
1552 |
+
# embedding the last feature map
|
1553 |
+
x = self.patch_embed(stage3_out)
|
1554 |
+
|
1555 |
+
elif self.stage_num == 4:
|
1556 |
+
stage1_out, stage2_out, stage3_out, stage4_out = self.backbone(x)
|
1557 |
+
# embedding the last feature map
|
1558 |
+
x = self.patch_embed(stage4_out)
|
1559 |
+
else:
|
1560 |
+
raise TypeError('stage_dim is not legal !')
|
1561 |
+
|
1562 |
+
# get guidance info
|
1563 |
+
s1_q, s1_k = self.Fo1(stage1_out)
|
1564 |
+
s2_q, s2_k = self.Fo2(stage2_out)
|
1565 |
+
s3_q, s3_k = self.Fo3(stage3_out)
|
1566 |
+
if self.stage_num == 4:
|
1567 |
+
s4_q, s4_k = self.Fo4(stage4_out)
|
1568 |
+
|
1569 |
+
if self.cls_token_num != 0: # concat cls token
|
1570 |
+
# process the(cls token / message token)
|
1571 |
+
cls_token_0 = self.cls_token_0.expand(x.shape[0], -1, -1) # stole cls_tokens impl from Phil Wang, thanks
|
1572 |
+
x = torch.cat((cls_token_0, x), dim=1) # 增加classification head patch
|
1573 |
+
|
1574 |
+
s1_q = torch.cat((cls_token_0, s1_q), dim=1)
|
1575 |
+
s1_k = torch.cat((cls_token_0, s1_k), dim=1)
|
1576 |
+
s2_q = torch.cat((cls_token_0, s2_q), dim=1)
|
1577 |
+
s2_k = torch.cat((cls_token_0, s2_k), dim=1)
|
1578 |
+
s3_q = torch.cat((cls_token_0, s3_q), dim=1)
|
1579 |
+
s3_k = torch.cat((cls_token_0, s3_k), dim=1)
|
1580 |
+
if self.stage_num == 4:
|
1581 |
+
s4_q = torch.cat((cls_token_0, s4_q), dim=1)
|
1582 |
+
s4_k = torch.cat((cls_token_0, s4_k), dim=1)
|
1583 |
+
|
1584 |
+
if self.use_pos_embedding:
|
1585 |
+
|
1586 |
+
s1_q = self.pos_drop(s1_q + self.pos_embed_0)
|
1587 |
+
s1_k = self.pos_drop(s1_k + self.pos_embed_0)
|
1588 |
+
s2_q = self.pos_drop(s2_q + self.pos_embed_0)
|
1589 |
+
s2_k = self.pos_drop(s2_k + self.pos_embed_0)
|
1590 |
+
s3_q = self.pos_drop(s3_q + self.pos_embed_0)
|
1591 |
+
s3_k = self.pos_drop(s3_k + self.pos_embed_0)
|
1592 |
+
if self.stage_num == 4:
|
1593 |
+
s4_q = self.pos_drop(s4_q + self.pos_embed_0)
|
1594 |
+
s4_k = self.pos_drop(s4_k + self.pos_embed_0)
|
1595 |
+
|
1596 |
+
# plus to encoding positional infor
|
1597 |
+
x = self.pos_drop(x + self.pos_embed_0)
|
1598 |
+
|
1599 |
+
else:
|
1600 |
+
|
1601 |
+
s1_q = self.pos_drop(s1_q)
|
1602 |
+
s1_k = self.pos_drop(s1_k)
|
1603 |
+
s2_q = self.pos_drop(s2_q)
|
1604 |
+
s2_k = self.pos_drop(s2_k)
|
1605 |
+
s3_q = self.pos_drop(s3_q)
|
1606 |
+
s3_k = self.pos_drop(s3_k)
|
1607 |
+
if self.stage_num == 4:
|
1608 |
+
s4_q = self.pos_drop(s4_q)
|
1609 |
+
s4_k = self.pos_drop(s4_k)
|
1610 |
+
|
1611 |
+
# stem's feature map
|
1612 |
+
x = self.pos_drop(x)
|
1613 |
+
|
1614 |
+
# Decoder module use the guidance to help global modeling process
|
1615 |
+
|
1616 |
+
x = self.dec1(s1_q, s1_k, x)
|
1617 |
+
|
1618 |
+
x = self.dec2(s2_q, s2_k, x)
|
1619 |
+
|
1620 |
+
x = self.dec3(s3_q, s3_k, x)
|
1621 |
+
|
1622 |
+
if self.stage_num == 4:
|
1623 |
+
x = self.dec4(s4_q, s4_k, x)
|
1624 |
+
|
1625 |
+
x = self.norm(x)
|
1626 |
+
return self.pre_logits(x[:, 0]) # take the first cls token
|
1627 |
+
|
1628 |
+
def forward(self, x):
|
1629 |
+
x = self.forward_features(x) # connect the cls token to the cls head
|
1630 |
+
x = self.head(x)
|
1631 |
+
return x
|
PuzzleTuning/Backbone/VPT_structure.py
ADDED
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
VPT Script ver: Oct 17th 14:30
|
3 |
+
|
4 |
+
based on
|
5 |
+
timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm
|
6 |
+
"""
|
7 |
+
|
8 |
+
import torch
|
9 |
+
import torch.nn as nn
|
10 |
+
|
11 |
+
from timm.models.vision_transformer import VisionTransformer, PatchEmbed
|
12 |
+
|
13 |
+
|
14 |
+
class VPT_ViT(VisionTransformer):
|
15 |
+
def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
|
16 |
+
num_heads=12, mlp_ratio=4., qkv_bias=True, drop_rate=0., attn_drop_rate=0., drop_path_rate=0.,
|
17 |
+
embed_layer=PatchEmbed, norm_layer=None, act_layer=None, Prompt_Token_num=1,
|
18 |
+
VPT_type="Shallow", basic_state_dict=None):
|
19 |
+
|
20 |
+
# Recreate ViT
|
21 |
+
super().__init__(img_size=img_size, patch_size=patch_size, in_chans=in_chans, num_classes=num_classes,
|
22 |
+
embed_dim=embed_dim, depth=depth, num_heads=num_heads, mlp_ratio=mlp_ratio,
|
23 |
+
qkv_bias=qkv_bias, drop_rate=drop_rate, attn_drop_rate=attn_drop_rate,
|
24 |
+
drop_path_rate=drop_path_rate, embed_layer=embed_layer,
|
25 |
+
norm_layer=norm_layer, act_layer=act_layer)
|
26 |
+
|
27 |
+
# load basic state_dict
|
28 |
+
if basic_state_dict is not None:
|
29 |
+
self.load_state_dict(basic_state_dict, False)
|
30 |
+
|
31 |
+
self.VPT_type = VPT_type
|
32 |
+
if VPT_type == "Deep":
|
33 |
+
self.Prompt_Tokens = nn.Parameter(torch.zeros(depth, Prompt_Token_num, embed_dim))
|
34 |
+
else: # "Shallow"
|
35 |
+
self.Prompt_Tokens = nn.Parameter(torch.zeros(1, Prompt_Token_num, embed_dim))
|
36 |
+
|
37 |
+
def New_CLS_head(self, new_classes=15):
|
38 |
+
if new_classes != 0:
|
39 |
+
self.head = nn.Linear(self.embed_dim, new_classes)
|
40 |
+
else:
|
41 |
+
self.head = nn.Identity()
|
42 |
+
|
43 |
+
def Freeze(self):
|
44 |
+
for param in self.parameters():
|
45 |
+
param.requires_grad = False
|
46 |
+
|
47 |
+
self.Prompt_Tokens.requires_grad = True
|
48 |
+
try:
|
49 |
+
for param in self.head.parameters():
|
50 |
+
param.requires_grad = True
|
51 |
+
except:
|
52 |
+
pass
|
53 |
+
|
54 |
+
def UnFreeze(self):
|
55 |
+
for param in self.parameters():
|
56 |
+
param.requires_grad = True
|
57 |
+
|
58 |
+
def obtain_prompt(self):
|
59 |
+
prompt_state_dict = {'head': self.head.state_dict(),
|
60 |
+
'Prompt_Tokens': self.Prompt_Tokens}
|
61 |
+
# print(prompt_state_dict)
|
62 |
+
return prompt_state_dict
|
63 |
+
|
64 |
+
def load_prompt(self, prompt_state_dict):
|
65 |
+
try:
|
66 |
+
self.head.load_state_dict(prompt_state_dict['head'], False)
|
67 |
+
except:
|
68 |
+
print('head not match, so skip head')
|
69 |
+
else:
|
70 |
+
print('prompt head match')
|
71 |
+
|
72 |
+
if self.Prompt_Tokens.shape == prompt_state_dict['Prompt_Tokens'].shape:
|
73 |
+
|
74 |
+
# device check
|
75 |
+
Prompt_Tokens = nn.Parameter(prompt_state_dict['Prompt_Tokens'].cpu())
|
76 |
+
Prompt_Tokens.to(torch.device(self.Prompt_Tokens.device))
|
77 |
+
|
78 |
+
self.Prompt_Tokens = Prompt_Tokens
|
79 |
+
|
80 |
+
else:
|
81 |
+
print('\n !!! cannot load prompt')
|
82 |
+
print('shape of model req prompt', self.Prompt_Tokens.shape)
|
83 |
+
print('shape of model given prompt', prompt_state_dict['Prompt_Tokens'].shape)
|
84 |
+
print('')
|
85 |
+
|
86 |
+
def forward_features(self, x):
|
87 |
+
x = self.patch_embed(x)
|
88 |
+
# print(x.shape,self.pos_embed.shape)
|
89 |
+
cls_token = self.cls_token.expand(x.shape[0], -1, -1)
|
90 |
+
|
91 |
+
# concatenate CLS token
|
92 |
+
x = torch.cat((cls_token, x), dim=1)
|
93 |
+
x = self.pos_drop(x + self.pos_embed)
|
94 |
+
|
95 |
+
if self.VPT_type == "Deep":
|
96 |
+
|
97 |
+
Prompt_Token_num = self.Prompt_Tokens.shape[1]
|
98 |
+
|
99 |
+
for i in range(len(self.blocks)):
|
100 |
+
# concatenate Prompt_Tokens
|
101 |
+
Prompt_Tokens = self.Prompt_Tokens[i].unsqueeze(0)
|
102 |
+
# firstly concatenate
|
103 |
+
x = torch.cat((x, Prompt_Tokens.expand(x.shape[0], -1, -1)), dim=1)
|
104 |
+
num_tokens = x.shape[1]
|
105 |
+
# lastly remove, a genius trick
|
106 |
+
x = self.blocks[i](x)[:, :num_tokens - Prompt_Token_num]
|
107 |
+
|
108 |
+
else: # self.VPT_type == "Shallow"
|
109 |
+
Prompt_Token_num = self.Prompt_Tokens.shape[1]
|
110 |
+
|
111 |
+
# concatenate Prompt_Tokens
|
112 |
+
Prompt_Tokens = self.Prompt_Tokens.expand(x.shape[0], -1, -1)
|
113 |
+
x = torch.cat((x, Prompt_Tokens), dim=1)
|
114 |
+
num_tokens = x.shape[1]
|
115 |
+
# Sequntially procees
|
116 |
+
x = self.blocks(x)[:, :num_tokens - Prompt_Token_num]
|
117 |
+
|
118 |
+
x = self.norm(x)
|
119 |
+
return x
|
120 |
+
|
121 |
+
def forward(self, x):
|
122 |
+
|
123 |
+
x = self.forward_features(x)
|
124 |
+
|
125 |
+
# use cls token for cls head
|
126 |
+
try:
|
127 |
+
x = self.pre_logits(x[:, 0, :])
|
128 |
+
except:
|
129 |
+
x = self.fc_norm(x[:, 0, :])
|
130 |
+
else:
|
131 |
+
pass
|
132 |
+
x = self.head(x)
|
133 |
+
return x
|
PuzzleTuning/Backbone/attention_modules.py
ADDED
@@ -0,0 +1,303 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
attention modules in ['SimAM', 'CBAM', 'SE', 'GAM'] were applied in the ablation study
|
3 |
+
|
4 |
+
ver: Dec 24th 15:00
|
5 |
+
|
6 |
+
|
7 |
+
ref:
|
8 |
+
https://github.com/xmu-xiaoma666/External-Attention-pytorch
|
9 |
+
"""
|
10 |
+
|
11 |
+
import torch
|
12 |
+
import torch.nn as nn
|
13 |
+
import math
|
14 |
+
import torch.nn.functional as F
|
15 |
+
from torch.nn import init
|
16 |
+
|
17 |
+
|
18 |
+
# help func
|
19 |
+
class BasicConv(nn.Module):
|
20 |
+
def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1, groups=1, relu=True,
|
21 |
+
bn=True, bias=False):
|
22 |
+
super(BasicConv, self).__init__()
|
23 |
+
self.out_channels = out_planes
|
24 |
+
self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding,
|
25 |
+
dilation=dilation, groups=groups, bias=bias)
|
26 |
+
self.bn = nn.BatchNorm2d(out_planes, eps=1e-5, momentum=0.01, affine=True) if bn else None
|
27 |
+
self.relu = nn.ReLU() if relu else None
|
28 |
+
|
29 |
+
def forward(self, x):
|
30 |
+
x = self.conv(x)
|
31 |
+
if self.bn is not None:
|
32 |
+
x = self.bn(x)
|
33 |
+
if self.relu is not None:
|
34 |
+
x = self.relu(x)
|
35 |
+
return x
|
36 |
+
|
37 |
+
|
38 |
+
class Flatten(nn.Module):
|
39 |
+
def forward(self, x):
|
40 |
+
return x.view(x.size(0), -1)
|
41 |
+
|
42 |
+
|
43 |
+
class ChannelGate(nn.Module):
|
44 |
+
def __init__(self, gate_channels, reduction_ratio=16, pool_types=['avg', 'max']):
|
45 |
+
super(ChannelGate, self).__init__()
|
46 |
+
self.gate_channels = gate_channels
|
47 |
+
self.mlp = nn.Sequential(
|
48 |
+
Flatten(),
|
49 |
+
nn.Linear(gate_channels, int(gate_channels // reduction_ratio)),
|
50 |
+
nn.ReLU(),
|
51 |
+
nn.Linear(int(gate_channels // reduction_ratio), gate_channels)
|
52 |
+
)
|
53 |
+
self.pool_types = pool_types
|
54 |
+
|
55 |
+
def forward(self, x):
|
56 |
+
channel_att_sum = None
|
57 |
+
for pool_type in self.pool_types:
|
58 |
+
if pool_type == 'avg':
|
59 |
+
avg_pool = F.avg_pool2d(x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))
|
60 |
+
channel_att_raw = self.mlp(avg_pool)
|
61 |
+
elif pool_type == 'max':
|
62 |
+
max_pool = F.max_pool2d(x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))
|
63 |
+
channel_att_raw = self.mlp(max_pool)
|
64 |
+
elif pool_type == 'lp':
|
65 |
+
lp_pool = F.lp_pool2d(x, 2, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))
|
66 |
+
channel_att_raw = self.mlp(lp_pool)
|
67 |
+
elif pool_type == 'lse':
|
68 |
+
# LSE pool only
|
69 |
+
lse_pool = logsumexp_2d(x)
|
70 |
+
channel_att_raw = self.mlp(lse_pool)
|
71 |
+
|
72 |
+
if channel_att_sum is None:
|
73 |
+
channel_att_sum = channel_att_raw
|
74 |
+
else:
|
75 |
+
channel_att_sum = channel_att_sum + channel_att_raw
|
76 |
+
|
77 |
+
scale = F.sigmoid(channel_att_sum).unsqueeze(2).unsqueeze(3).expand_as(x)
|
78 |
+
return x * scale
|
79 |
+
|
80 |
+
|
81 |
+
def logsumexp_2d(tensor):
|
82 |
+
tensor_flatten = tensor.view(tensor.size(0), tensor.size(1), -1)
|
83 |
+
s, _ = torch.max(tensor_flatten, dim=2, keepdim=True)
|
84 |
+
outputs = s + (tensor_flatten - s).exp().sum(dim=2, keepdim=True).log()
|
85 |
+
return outputs
|
86 |
+
|
87 |
+
|
88 |
+
class ChannelPool(nn.Module):
|
89 |
+
def forward(self, x):
|
90 |
+
return torch.cat((torch.max(x, 1)[0].unsqueeze(1), torch.mean(x, 1).unsqueeze(1)), dim=1)
|
91 |
+
|
92 |
+
|
93 |
+
class SpatialGate(nn.Module):
|
94 |
+
def __init__(self):
|
95 |
+
super(SpatialGate, self).__init__()
|
96 |
+
kernel_size = 7
|
97 |
+
self.compress = ChannelPool()
|
98 |
+
self.spatial = BasicConv(2, 1, kernel_size, stride=1, padding=int((kernel_size - 1) // 2), relu=False)
|
99 |
+
|
100 |
+
def forward(self, x):
|
101 |
+
x_compress = self.compress(x)
|
102 |
+
x_out = self.spatial(x_compress)
|
103 |
+
scale = F.sigmoid(x_out) # broadcasting
|
104 |
+
return x * scale
|
105 |
+
|
106 |
+
|
107 |
+
# attention modules:
|
108 |
+
class cbam_module(nn.Module):
|
109 |
+
"""
|
110 |
+
module:CBAM
|
111 |
+
|
112 |
+
input、output= b, c, h, w
|
113 |
+
|
114 |
+
paper:
|
115 |
+
https://arxiv.org/abs/1807.06521
|
116 |
+
code:
|
117 |
+
https://github.com/ZjjConan/SimAM/blob/master/networks/attentions
|
118 |
+
"""
|
119 |
+
|
120 |
+
def __init__(self, gate_channels, reduction=16, pool_types=['avg', 'max'], no_spatial=False):
|
121 |
+
super(cbam_module, self).__init__()
|
122 |
+
self.ChannelGate = ChannelGate(gate_channels, reduction, pool_types)
|
123 |
+
self.no_spatial = no_spatial
|
124 |
+
if not no_spatial:
|
125 |
+
self.SpatialGate = SpatialGate()
|
126 |
+
|
127 |
+
@staticmethod
|
128 |
+
def get_module_name():
|
129 |
+
return "cbam"
|
130 |
+
|
131 |
+
def forward(self, x):
|
132 |
+
x_out = self.ChannelGate(x)
|
133 |
+
if not self.no_spatial:
|
134 |
+
x_out = self.SpatialGate(x_out)
|
135 |
+
return x_out
|
136 |
+
|
137 |
+
|
138 |
+
class se_module(nn.Module):
|
139 |
+
"""
|
140 |
+
module: SE
|
141 |
+
|
142 |
+
input、output= b, c, h, w
|
143 |
+
|
144 |
+
from paper Squeeze-and-Excitation Networks
|
145 |
+
SE-Net https://arxiv.org/abs/1709.01507
|
146 |
+
code:
|
147 |
+
https://github.com/ZjjConan/SimAM/blob/master/networks/attentions
|
148 |
+
"""
|
149 |
+
|
150 |
+
def __init__(self, channel, reduction=16):
|
151 |
+
super(se_module, self).__init__()
|
152 |
+
self.avg_pool = nn.AdaptiveAvgPool2d(1)
|
153 |
+
self.fc = nn.Sequential(
|
154 |
+
nn.Linear(channel, int(channel // reduction), bias=False),
|
155 |
+
nn.ReLU(inplace=True),
|
156 |
+
nn.Linear(int(channel // reduction), channel, bias=False),
|
157 |
+
nn.Sigmoid()
|
158 |
+
)
|
159 |
+
|
160 |
+
@staticmethod
|
161 |
+
def get_module_name():
|
162 |
+
return "se"
|
163 |
+
|
164 |
+
def forward(self, x):
|
165 |
+
b, c, _, _ = x.size()
|
166 |
+
y = self.avg_pool(x).view(b, c)
|
167 |
+
y = self.fc(y).view(b, c, 1, 1)
|
168 |
+
return x * y
|
169 |
+
|
170 |
+
|
171 |
+
class simam_module(torch.nn.Module):
|
172 |
+
"""
|
173 |
+
module:SimAM
|
174 |
+
|
175 |
+
input、output= b, c, h, w
|
176 |
+
|
177 |
+
paper:(ICML)
|
178 |
+
SimAM: A Simple, Parameter-Free Attention Module for Convolutional Neural Networks
|
179 |
+
code:
|
180 |
+
https://github.com/ZjjConan/SimAM/blob/master/networks/attentions/simam_module.py
|
181 |
+
"""
|
182 |
+
|
183 |
+
def __init__(self, channels=None, e_lambda=1e-4):
|
184 |
+
super(simam_module, self).__init__()
|
185 |
+
|
186 |
+
self.activaton = nn.Sigmoid()
|
187 |
+
self.e_lambda = e_lambda
|
188 |
+
|
189 |
+
def __repr__(self):
|
190 |
+
s = self.__class__.__name__ + '('
|
191 |
+
s += ('lambda=%f)' % self.e_lambda)
|
192 |
+
return s
|
193 |
+
|
194 |
+
@staticmethod
|
195 |
+
def get_module_name():
|
196 |
+
return "simam"
|
197 |
+
|
198 |
+
def forward(self, x):
|
199 |
+
b, c, h, w = x.size()
|
200 |
+
|
201 |
+
n = w * h - 1
|
202 |
+
|
203 |
+
x_minus_mu_square = (x - x.mean(dim=[2, 3], keepdim=True)).pow(2)
|
204 |
+
y = x_minus_mu_square / (4 * (x_minus_mu_square.sum(dim=[2, 3], keepdim=True) / n + self.e_lambda)) + 0.5
|
205 |
+
|
206 |
+
return x * self.activaton(y)
|
207 |
+
|
208 |
+
|
209 |
+
class ResidualAttention(nn.Module):
|
210 |
+
"""
|
211 |
+
module: ResidualAttention
|
212 |
+
|
213 |
+
input、output= b, c, h, w
|
214 |
+
|
215 |
+
Paper:ICCV 2021 Residual Attention: A Simple but Effective Method for Multi-Label Recognition
|
216 |
+
code:https://github.com/xmu-xiaoma666/External-Attention-pytorch/blob/master/attention/ResidualAttention.py
|
217 |
+
"""
|
218 |
+
|
219 |
+
def __init__(self, channel=512, num_class=1000, la=0.2):
|
220 |
+
super().__init__()
|
221 |
+
self.la = la
|
222 |
+
self.fc = nn.Conv2d(in_channels=channel, out_channels=num_class, kernel_size=1, stride=1, bias=False)
|
223 |
+
|
224 |
+
def forward(self, x):
|
225 |
+
b, c, h, w = x.shape
|
226 |
+
y_raw = self.fc(x).flatten(2) # b,num_class,hxw
|
227 |
+
y_avg = torch.mean(y_raw, dim=2) # b,num_class
|
228 |
+
y_max = torch.max(y_raw, dim=2)[0] # b,num_class
|
229 |
+
score = y_avg + self.la * y_max
|
230 |
+
return score
|
231 |
+
|
232 |
+
|
233 |
+
class eca_module(nn.Module):
|
234 |
+
"""Constructs a ECA module.
|
235 |
+
|
236 |
+
Args:
|
237 |
+
channel: Number of channels of the input feature map
|
238 |
+
k_size: Adaptive selection of kernel size
|
239 |
+
"""
|
240 |
+
def __init__(self, channel, k_size=3):
|
241 |
+
super(eca_module, self).__init__()
|
242 |
+
self.avg_pool = nn.AdaptiveAvgPool2d(1)
|
243 |
+
self.conv = nn.Conv1d(1, 1, kernel_size=k_size, padding=(k_size - 1) // 2, bias=False)
|
244 |
+
self.sigmoid = nn.Sigmoid()
|
245 |
+
|
246 |
+
def forward(self, x):
|
247 |
+
# x: input features with shape [b, c, h, w]
|
248 |
+
b, c, h, w = x.size()
|
249 |
+
|
250 |
+
# feature descriptor on the global spatial information
|
251 |
+
y = self.avg_pool(x)
|
252 |
+
|
253 |
+
# Two different branches of ECA module
|
254 |
+
y = self.conv(y.squeeze(-1).transpose(-1, -2)).transpose(-1, -2).unsqueeze(-1)
|
255 |
+
|
256 |
+
# Multi-scale information fusion
|
257 |
+
y = self.sigmoid(y)
|
258 |
+
|
259 |
+
return x * y.expand_as(x)
|
260 |
+
|
261 |
+
|
262 |
+
class GAM_Attention(nn.Module):
|
263 |
+
"""
|
264 |
+
module:GAM
|
265 |
+
|
266 |
+
input= b, in_channels, h, w
|
267 |
+
output= b, out_channels, h, w
|
268 |
+
|
269 |
+
paper:
|
270 |
+
Global Attention Mechanism: Retain Information to Enhance Channel-Spatial Interactions
|
271 |
+
https://arxiv.org/abs/2112.05561
|
272 |
+
code:
|
273 |
+
https://mp.weixin.qq.com/s/VL6rXjyUDmHToYTqM32hUg
|
274 |
+
"""
|
275 |
+
def __init__(self, in_channels, out_channels, rate=4):
|
276 |
+
super(GAM_Attention, self).__init__()
|
277 |
+
|
278 |
+
self.channel_attention = nn.Sequential(
|
279 |
+
nn.Linear(in_channels, int(in_channels / rate)),
|
280 |
+
nn.ReLU(inplace=True),
|
281 |
+
nn.Linear(int(in_channels / rate), in_channels)
|
282 |
+
)
|
283 |
+
|
284 |
+
self.spatial_attention = nn.Sequential(
|
285 |
+
nn.Conv2d(in_channels, int(in_channels / rate), kernel_size=7, padding=3),
|
286 |
+
nn.BatchNorm2d(int(in_channels / rate)),
|
287 |
+
nn.ReLU(inplace=True),
|
288 |
+
nn.Conv2d(int(in_channels / rate), out_channels, kernel_size=7, padding=3),
|
289 |
+
nn.BatchNorm2d(out_channels)
|
290 |
+
)
|
291 |
+
|
292 |
+
def forward(self, x):
|
293 |
+
b, c, h, w = x.shape
|
294 |
+
x_permute = x.permute(0, 2, 3, 1).view(b, -1, c)
|
295 |
+
x_att_permute = self.channel_attention(x_permute).view(b, h, w, c)
|
296 |
+
x_channel_att = x_att_permute.permute(0, 3, 1, 2)
|
297 |
+
|
298 |
+
x = x * x_channel_att
|
299 |
+
|
300 |
+
x_spatial_att = self.spatial_attention(x).sigmoid()
|
301 |
+
out = x * x_spatial_att
|
302 |
+
|
303 |
+
return out
|
PuzzleTuning/Backbone/counterpart_models/README.md
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Recent SOTA works in fine-grained Tasks
|
2 |
+
|
3 |
+
|
4 |
+
CrossFormer
|
5 |
+
|
6 |
+
Paper:
|
7 |
+
https://arxiv.org/pdf/2108.00154.pdf
|
8 |
+
|
9 |
+
|
10 |
+
Code from:
|
11 |
+
https://github.com/cheerss/CrossFormer
|
12 |
+
|
13 |
+
|
14 |
+
|
15 |
+
Conformer
|
16 |
+
Paper:
|
17 |
+
https://arxiv.org/pdf/2105.03889.pdf
|
18 |
+
|
19 |
+
|
20 |
+
Code from:
|
21 |
+
https://github.com/pengzhiliang/Conformer/blob/main/conformer.py
|
22 |
+
|
23 |
+
|
24 |
+
both work will be compared with official pretrained backbone
|
25 |
+
and a new MLP head (classification head).
|
PuzzleTuning/Backbone/counterpart_models/conformer.py
ADDED
@@ -0,0 +1,483 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
From Conformer with alter: conv and trans cls head was changed to volting together
|
3 |
+
ver: DEC 1st 16:00 official release
|
4 |
+
|
5 |
+
ref: https://github.com/pengzhiliang/Conformer/blob/main/conformer.py
|
6 |
+
"""
|
7 |
+
|
8 |
+
import torch
|
9 |
+
import torch.nn as nn
|
10 |
+
import torch.nn.functional as F
|
11 |
+
from functools import partial
|
12 |
+
|
13 |
+
from timm.models.layers import DropPath, trunc_normal_
|
14 |
+
|
15 |
+
|
16 |
+
class Mlp(nn.Module): # FFN
|
17 |
+
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
|
18 |
+
super().__init__()
|
19 |
+
out_features = out_features or in_features
|
20 |
+
hidden_features = hidden_features or in_features
|
21 |
+
self.fc1 = nn.Linear(in_features, hidden_features)
|
22 |
+
self.act = act_layer()
|
23 |
+
self.fc2 = nn.Linear(hidden_features, out_features)
|
24 |
+
self.drop = nn.Dropout(drop)
|
25 |
+
|
26 |
+
def forward(self, x):
|
27 |
+
x = self.fc1(x)
|
28 |
+
x = self.act(x)
|
29 |
+
x = self.drop(x)
|
30 |
+
x = self.fc2(x)
|
31 |
+
x = self.drop(x)
|
32 |
+
return x
|
33 |
+
|
34 |
+
|
35 |
+
class Attention(nn.Module): # MHSA
|
36 |
+
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
|
37 |
+
super().__init__()
|
38 |
+
self.num_heads = num_heads
|
39 |
+
head_dim = dim // num_heads
|
40 |
+
# NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
|
41 |
+
self.scale = qk_scale or head_dim ** -0.5
|
42 |
+
|
43 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
44 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
45 |
+
self.proj = nn.Linear(dim, dim)
|
46 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
47 |
+
|
48 |
+
def forward(self, x):
|
49 |
+
B, N, C = x.shape # N is patch number, C is patch dimension
|
50 |
+
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) # re arrange
|
51 |
+
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
|
52 |
+
|
53 |
+
attn = (q @ k.transpose(-2, -1)) * self.scale
|
54 |
+
attn = attn.softmax(dim=-1)
|
55 |
+
attn = self.attn_drop(attn)
|
56 |
+
|
57 |
+
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
58 |
+
x = self.proj(x)
|
59 |
+
x = self.proj_drop(x)
|
60 |
+
return x
|
61 |
+
|
62 |
+
|
63 |
+
class Block(nn.Module): # Encoder from ViT
|
64 |
+
|
65 |
+
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
|
66 |
+
drop_path=0., act_layer=nn.GELU, norm_layer=partial(nn.LayerNorm, eps=1e-6)):
|
67 |
+
super().__init__()
|
68 |
+
# pre norm 1
|
69 |
+
self.norm1 = norm_layer(dim)
|
70 |
+
# MHSA
|
71 |
+
self.attn = Attention(
|
72 |
+
dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
|
73 |
+
|
74 |
+
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
|
75 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
76 |
+
|
77 |
+
# pre norm 2
|
78 |
+
self.norm2 = norm_layer(dim)
|
79 |
+
|
80 |
+
# FFN(MLP)
|
81 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
82 |
+
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
|
83 |
+
|
84 |
+
def forward(self, x):
|
85 |
+
x = x + self.drop_path(self.attn(self.norm1(x)))
|
86 |
+
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
87 |
+
return x
|
88 |
+
|
89 |
+
|
90 |
+
class ConvBlock(nn.Module): # ResNet bottleneck Convblock actually
|
91 |
+
|
92 |
+
def __init__(self, inplanes, outplanes, stride=1, res_conv=False, act_layer=nn.ReLU, groups=1,
|
93 |
+
norm_layer=partial(nn.BatchNorm2d, eps=1e-6), drop_block=None, drop_path=None):
|
94 |
+
super(ConvBlock, self).__init__()
|
95 |
+
|
96 |
+
expansion = 4
|
97 |
+
med_planes = outplanes // expansion
|
98 |
+
|
99 |
+
self.conv1 = nn.Conv2d(inplanes, med_planes, kernel_size=1, stride=1, padding=0, bias=False)
|
100 |
+
self.bn1 = norm_layer(med_planes)
|
101 |
+
self.act1 = act_layer(inplace=True)
|
102 |
+
|
103 |
+
self.conv2 = nn.Conv2d(med_planes, med_planes, kernel_size=3, stride=stride, groups=groups, padding=1,
|
104 |
+
bias=False)
|
105 |
+
self.bn2 = norm_layer(med_planes)
|
106 |
+
self.act2 = act_layer(inplace=True)
|
107 |
+
|
108 |
+
self.conv3 = nn.Conv2d(med_planes, outplanes, kernel_size=1, stride=1, padding=0, bias=False)
|
109 |
+
self.bn3 = norm_layer(outplanes)
|
110 |
+
self.act3 = act_layer(inplace=True)
|
111 |
+
|
112 |
+
if res_conv:
|
113 |
+
self.residual_conv = nn.Conv2d(inplanes, outplanes, kernel_size=1, stride=stride, padding=0, bias=False)
|
114 |
+
self.residual_bn = norm_layer(outplanes)
|
115 |
+
|
116 |
+
self.res_conv = res_conv
|
117 |
+
self.drop_block = drop_block
|
118 |
+
self.drop_path = drop_path
|
119 |
+
|
120 |
+
def zero_init_last_bn(self):
|
121 |
+
nn.init.zeros_(self.bn3.weight)
|
122 |
+
|
123 |
+
def forward(self, x, x_t=None, return_x_2=True):
|
124 |
+
residual = x
|
125 |
+
|
126 |
+
x = self.conv1(x)
|
127 |
+
x = self.bn1(x)
|
128 |
+
if self.drop_block is not None:
|
129 |
+
x = self.drop_block(x)
|
130 |
+
x = self.act1(x)
|
131 |
+
|
132 |
+
x = self.conv2(x) if x_t is None else self.conv2(x + x_t)
|
133 |
+
x = self.bn2(x)
|
134 |
+
if self.drop_block is not None:
|
135 |
+
x = self.drop_block(x)
|
136 |
+
x2 = self.act2(x)
|
137 |
+
|
138 |
+
x = self.conv3(x2)
|
139 |
+
x = self.bn3(x)
|
140 |
+
if self.drop_block is not None:
|
141 |
+
x = self.drop_block(x)
|
142 |
+
|
143 |
+
if self.drop_path is not None:
|
144 |
+
x = self.drop_path(x)
|
145 |
+
|
146 |
+
if self.res_conv:
|
147 |
+
residual = self.residual_conv(residual)
|
148 |
+
residual = self.residual_bn(residual)
|
149 |
+
|
150 |
+
x += residual
|
151 |
+
x = self.act3(x)
|
152 |
+
|
153 |
+
if return_x_2:
|
154 |
+
return x, x2
|
155 |
+
else:
|
156 |
+
return x
|
157 |
+
|
158 |
+
|
159 |
+
class FCUDown(nn.Module):
|
160 |
+
""" CNN feature maps -> Transformer patch embeddings
|
161 |
+
"""
|
162 |
+
|
163 |
+
def __init__(self, inplanes, outplanes, dw_stride, act_layer=nn.GELU,
|
164 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6)):
|
165 |
+
super(FCUDown, self).__init__()
|
166 |
+
self.dw_stride = dw_stride
|
167 |
+
|
168 |
+
self.conv_project = nn.Conv2d(inplanes, outplanes, kernel_size=1, stride=1, padding=0) # fix dimension
|
169 |
+
self.sample_pooling = nn.AvgPool2d(kernel_size=dw_stride, stride=dw_stride) # fix feature map size
|
170 |
+
|
171 |
+
self.ln = norm_layer(outplanes)
|
172 |
+
self.act = act_layer()
|
173 |
+
|
174 |
+
def forward(self, x, x_t):
|
175 |
+
x = self.conv_project(x) # [N, C, H, W]
|
176 |
+
|
177 |
+
x = self.sample_pooling(x).flatten(2).transpose(1, 2)
|
178 |
+
x = self.ln(x)
|
179 |
+
x = self.act(x)
|
180 |
+
|
181 |
+
x = torch.cat([x_t[:, 0][:, None, :], x], dim=1) # concatenate class token from x_t
|
182 |
+
|
183 |
+
return x
|
184 |
+
|
185 |
+
|
186 |
+
class FCUUp(nn.Module):
|
187 |
+
""" Transformer patch embeddings -> CNN feature maps
|
188 |
+
by interpolate operation
|
189 |
+
"""
|
190 |
+
|
191 |
+
def __init__(self, inplanes, outplanes, up_stride, act_layer=nn.ReLU,
|
192 |
+
norm_layer=partial(nn.BatchNorm2d, eps=1e-6), ):
|
193 |
+
super(FCUUp, self).__init__()
|
194 |
+
|
195 |
+
self.up_stride = up_stride
|
196 |
+
self.conv_project = nn.Conv2d(inplanes, outplanes, kernel_size=1, stride=1, padding=0)
|
197 |
+
self.bn = norm_layer(outplanes)
|
198 |
+
self.act = act_layer()
|
199 |
+
|
200 |
+
def forward(self, x, H, W): # interpolate to
|
201 |
+
B, _, C = x.shape
|
202 |
+
|
203 |
+
# [N, 197, 384] -> [N, 196, 384] -> [N, 384, 196] -> [N, 384, 14, 14]
|
204 |
+
x_r = x[:, 1:].transpose(1, 2).reshape(B, C, H, W) # drop cls token of x_t
|
205 |
+
|
206 |
+
x_r = self.act(self.bn(self.conv_project(x_r)))
|
207 |
+
|
208 |
+
return F.interpolate(x_r, size=(H * self.up_stride, W * self.up_stride)) # interpolate operation
|
209 |
+
|
210 |
+
|
211 |
+
class Med_ConvBlock(nn.Module): # ResNet bottleneck indentity actually
|
212 |
+
""" special case for Convblock without down sampling,
|
213 |
+
"""
|
214 |
+
|
215 |
+
def __init__(self, inplanes, act_layer=nn.ReLU, groups=1, norm_layer=partial(nn.BatchNorm2d, eps=1e-6),
|
216 |
+
drop_block=None, drop_path=None):
|
217 |
+
|
218 |
+
super(Med_ConvBlock, self).__init__()
|
219 |
+
|
220 |
+
expansion = 4
|
221 |
+
med_planes = inplanes // expansion
|
222 |
+
|
223 |
+
self.conv1 = nn.Conv2d(inplanes, med_planes, kernel_size=1, stride=1, padding=0, bias=False)
|
224 |
+
self.bn1 = norm_layer(med_planes)
|
225 |
+
self.act1 = act_layer(inplace=True)
|
226 |
+
|
227 |
+
self.conv2 = nn.Conv2d(med_planes, med_planes, kernel_size=3, stride=1, groups=groups, padding=1, bias=False)
|
228 |
+
self.bn2 = norm_layer(med_planes)
|
229 |
+
self.act2 = act_layer(inplace=True)
|
230 |
+
|
231 |
+
self.conv3 = nn.Conv2d(med_planes, inplanes, kernel_size=1, stride=1, padding=0, bias=False)
|
232 |
+
self.bn3 = norm_layer(inplanes)
|
233 |
+
self.act3 = act_layer(inplace=True)
|
234 |
+
|
235 |
+
self.drop_block = drop_block
|
236 |
+
self.drop_path = drop_path
|
237 |
+
|
238 |
+
def zero_init_last_bn(self):
|
239 |
+
nn.init.zeros_(self.bn3.weight)
|
240 |
+
|
241 |
+
def forward(self, x):
|
242 |
+
residual = x
|
243 |
+
|
244 |
+
x = self.conv1(x)
|
245 |
+
x = self.bn1(x)
|
246 |
+
if self.drop_block is not None:
|
247 |
+
x = self.drop_block(x)
|
248 |
+
x = self.act1(x)
|
249 |
+
|
250 |
+
x = self.conv2(x)
|
251 |
+
x = self.bn2(x)
|
252 |
+
if self.drop_block is not None:
|
253 |
+
x = self.drop_block(x)
|
254 |
+
x = self.act2(x)
|
255 |
+
|
256 |
+
x = self.conv3(x)
|
257 |
+
x = self.bn3(x)
|
258 |
+
if self.drop_block is not None:
|
259 |
+
x = self.drop_block(x)
|
260 |
+
|
261 |
+
if self.drop_path is not None:
|
262 |
+
x = self.drop_path(x)
|
263 |
+
|
264 |
+
x += residual
|
265 |
+
x = self.act3(x)
|
266 |
+
|
267 |
+
return x
|
268 |
+
|
269 |
+
|
270 |
+
class ConvTransBlock(nn.Module):
|
271 |
+
"""
|
272 |
+
Basic module for ConvTransformer, keep feature maps for CNN block and patch embeddings for transformer encoder block
|
273 |
+
"""
|
274 |
+
|
275 |
+
def __init__(self, inplanes, outplanes, res_conv, stride, dw_stride, embed_dim, num_heads=12, mlp_ratio=4.,
|
276 |
+
qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0., drop_path_rate=0.,
|
277 |
+
last_fusion=False, num_med_block=0, groups=1):
|
278 |
+
|
279 |
+
super(ConvTransBlock, self).__init__()
|
280 |
+
expansion = 4
|
281 |
+
# ConvBlock
|
282 |
+
self.cnn_block = ConvBlock(inplanes=inplanes, outplanes=outplanes, res_conv=res_conv, stride=stride,
|
283 |
+
groups=groups)
|
284 |
+
|
285 |
+
if last_fusion:
|
286 |
+
self.fusion_block = ConvBlock(inplanes=outplanes, outplanes=outplanes, stride=2, res_conv=True,
|
287 |
+
groups=groups)
|
288 |
+
else:
|
289 |
+
self.fusion_block = ConvBlock(inplanes=outplanes, outplanes=outplanes, groups=groups)
|
290 |
+
|
291 |
+
# identity block
|
292 |
+
if num_med_block > 0:
|
293 |
+
self.med_block = []
|
294 |
+
for i in range(num_med_block):
|
295 |
+
self.med_block.append(Med_ConvBlock(inplanes=outplanes, groups=groups))
|
296 |
+
|
297 |
+
self.med_block = nn.ModuleList(self.med_block) # nn.ModuleList
|
298 |
+
|
299 |
+
# FCU
|
300 |
+
self.squeeze_block = FCUDown(inplanes=outplanes // expansion, outplanes=embed_dim, dw_stride=dw_stride)
|
301 |
+
|
302 |
+
self.expand_block = FCUUp(inplanes=embed_dim, outplanes=outplanes // expansion, up_stride=dw_stride)
|
303 |
+
|
304 |
+
# Transformer Encoder block
|
305 |
+
self.trans_block = Block(
|
306 |
+
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
|
307 |
+
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=drop_path_rate)
|
308 |
+
|
309 |
+
self.dw_stride = dw_stride
|
310 |
+
self.embed_dim = embed_dim
|
311 |
+
self.num_med_block = num_med_block
|
312 |
+
self.last_fusion = last_fusion
|
313 |
+
|
314 |
+
def forward(self, x, x_t):
|
315 |
+
x, x2 = self.cnn_block(x)
|
316 |
+
|
317 |
+
_, _, H, W = x2.shape
|
318 |
+
|
319 |
+
x_st = self.squeeze_block(x2, x_t)
|
320 |
+
|
321 |
+
x_t = self.trans_block(x_st + x_t)
|
322 |
+
|
323 |
+
if self.num_med_block > 0:
|
324 |
+
for m in self.med_block:
|
325 |
+
x = m(x)
|
326 |
+
|
327 |
+
x_t_r = self.expand_block(x_t, H // self.dw_stride, W // self.dw_stride)
|
328 |
+
x = self.fusion_block(x, x_t_r, return_x_2=False)
|
329 |
+
|
330 |
+
return x, x_t
|
331 |
+
|
332 |
+
|
333 |
+
class Conformer(nn.Module):
|
334 |
+
|
335 |
+
def __init__(self, patch_size=16, in_chans=3, num_classes=1000, base_channel=64, channel_ratio=4, num_med_block=0,
|
336 |
+
embed_dim=768, depth=12, num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None,
|
337 |
+
drop_rate=0., attn_drop_rate=0., drop_path_rate=0.):
|
338 |
+
|
339 |
+
# Transformer
|
340 |
+
super().__init__()
|
341 |
+
self.num_classes = num_classes
|
342 |
+
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
|
343 |
+
assert depth % 3 == 0
|
344 |
+
|
345 |
+
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
346 |
+
self.trans_dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
|
347 |
+
|
348 |
+
# Classifier head
|
349 |
+
self.trans_norm = nn.LayerNorm(embed_dim)
|
350 |
+
self.trans_cls_head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
351 |
+
self.pooling = nn.AdaptiveAvgPool2d(1)
|
352 |
+
self.conv_cls_head = nn.Linear(int(256 * channel_ratio), num_classes)
|
353 |
+
self.cls_head = nn.Linear(int(2 * num_classes), num_classes)
|
354 |
+
|
355 |
+
# Stem stage: get the feature maps by conv block (copied form resnet.py)
|
356 |
+
self.conv1 = nn.Conv2d(in_chans, 64, kernel_size=7, stride=2, padding=3, bias=False) # 1 / 2 [112, 112]
|
357 |
+
self.bn1 = nn.BatchNorm2d(64)
|
358 |
+
self.act1 = nn.ReLU(inplace=True)
|
359 |
+
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) # 1 / 4 [56, 56]
|
360 |
+
|
361 |
+
# 1 stage
|
362 |
+
stage_1_channel = int(base_channel * channel_ratio)
|
363 |
+
trans_dw_stride = patch_size // 4
|
364 |
+
self.conv_1 = ConvBlock(inplanes=64, outplanes=stage_1_channel, res_conv=True, stride=1)
|
365 |
+
# embedding
|
366 |
+
self.trans_patch_conv = nn.Conv2d(64, embed_dim, kernel_size=trans_dw_stride, stride=trans_dw_stride, padding=0)
|
367 |
+
self.trans_1 = Block(dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias,
|
368 |
+
qk_scale=qk_scale, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=self.trans_dpr[0],
|
369 |
+
)
|
370 |
+
|
371 |
+
# 2~4 stage
|
372 |
+
init_stage = 2
|
373 |
+
fin_stage = depth // 3 + 1
|
374 |
+
for i in range(init_stage, fin_stage):
|
375 |
+
self.add_module('conv_trans_' + str(i),
|
376 |
+
ConvTransBlock(
|
377 |
+
stage_1_channel, stage_1_channel, False, 1, dw_stride=trans_dw_stride,
|
378 |
+
embed_dim=embed_dim,
|
379 |
+
num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
|
380 |
+
drop_rate=drop_rate, attn_drop_rate=attn_drop_rate,
|
381 |
+
drop_path_rate=self.trans_dpr[i - 1],
|
382 |
+
num_med_block=num_med_block
|
383 |
+
)
|
384 |
+
)
|
385 |
+
|
386 |
+
stage_2_channel = int(base_channel * channel_ratio * 2)
|
387 |
+
# 5~8 stage
|
388 |
+
init_stage = fin_stage # 5
|
389 |
+
fin_stage = fin_stage + depth // 3 # 9
|
390 |
+
for i in range(init_stage, fin_stage):
|
391 |
+
s = 2 if i == init_stage else 1
|
392 |
+
in_channel = stage_1_channel if i == init_stage else stage_2_channel
|
393 |
+
res_conv = True if i == init_stage else False
|
394 |
+
self.add_module('conv_trans_' + str(i),
|
395 |
+
ConvTransBlock(
|
396 |
+
in_channel, stage_2_channel, res_conv, s, dw_stride=trans_dw_stride // 2,
|
397 |
+
embed_dim=embed_dim,
|
398 |
+
num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
|
399 |
+
drop_rate=drop_rate, attn_drop_rate=attn_drop_rate,
|
400 |
+
drop_path_rate=self.trans_dpr[i - 1],
|
401 |
+
num_med_block=num_med_block
|
402 |
+
)
|
403 |
+
)
|
404 |
+
|
405 |
+
stage_3_channel = int(base_channel * channel_ratio * 2 * 2)
|
406 |
+
# 9~12 stage
|
407 |
+
init_stage = fin_stage # 9
|
408 |
+
fin_stage = fin_stage + depth // 3 # 13
|
409 |
+
for i in range(init_stage, fin_stage):
|
410 |
+
s = 2 if i == init_stage else 1
|
411 |
+
in_channel = stage_2_channel if i == init_stage else stage_3_channel
|
412 |
+
res_conv = True if i == init_stage else False
|
413 |
+
last_fusion = True if i == depth else False
|
414 |
+
self.add_module('conv_trans_' + str(i),
|
415 |
+
ConvTransBlock(
|
416 |
+
in_channel, stage_3_channel, res_conv, s, dw_stride=trans_dw_stride // 4,
|
417 |
+
embed_dim=embed_dim,
|
418 |
+
num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
|
419 |
+
drop_rate=drop_rate, attn_drop_rate=attn_drop_rate,
|
420 |
+
drop_path_rate=self.trans_dpr[i - 1],
|
421 |
+
num_med_block=num_med_block, last_fusion=last_fusion
|
422 |
+
)
|
423 |
+
)
|
424 |
+
self.fin_stage = fin_stage
|
425 |
+
|
426 |
+
trunc_normal_(self.cls_token, std=.02)
|
427 |
+
|
428 |
+
self.apply(self._init_weights)
|
429 |
+
|
430 |
+
def _init_weights(self, m):
|
431 |
+
if isinstance(m, nn.Linear):
|
432 |
+
trunc_normal_(m.weight, std=.02)
|
433 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
434 |
+
nn.init.constant_(m.bias, 0)
|
435 |
+
elif isinstance(m, nn.LayerNorm):
|
436 |
+
nn.init.constant_(m.bias, 0)
|
437 |
+
nn.init.constant_(m.weight, 1.0)
|
438 |
+
elif isinstance(m, nn.Conv2d):
|
439 |
+
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
440 |
+
elif isinstance(m, nn.BatchNorm2d):
|
441 |
+
nn.init.constant_(m.weight, 1.)
|
442 |
+
nn.init.constant_(m.bias, 0.)
|
443 |
+
elif isinstance(m, nn.GroupNorm):
|
444 |
+
nn.init.constant_(m.weight, 1.)
|
445 |
+
nn.init.constant_(m.bias, 0.)
|
446 |
+
|
447 |
+
@torch.jit.ignore
|
448 |
+
def no_weight_decay(self):
|
449 |
+
return {'cls_token'}
|
450 |
+
|
451 |
+
def forward(self, x):
|
452 |
+
B = x.shape[0]
|
453 |
+
cls_tokens = self.cls_token.expand(B, -1, -1)
|
454 |
+
|
455 |
+
# pdb.set_trace()
|
456 |
+
# stem stage [N, 3, 224, 224] -> [N, 64, 56, 56]
|
457 |
+
x_base = self.maxpool(self.act1(self.bn1(self.conv1(x))))
|
458 |
+
|
459 |
+
# 1 stage
|
460 |
+
x = self.conv_1(x_base, return_x_2=False)
|
461 |
+
# embedding: [N, 64, 56, 56] -> [N, d, p, p] -> [N, d, p^2] -> [N, p^2, d] -> [N, p^2 + 1, d]
|
462 |
+
x_t = self.trans_patch_conv(x_base).flatten(2).transpose(1, 2)
|
463 |
+
x_t = torch.cat([cls_tokens, x_t], dim=1)
|
464 |
+
x_t = self.trans_1(x_t)
|
465 |
+
|
466 |
+
# 2 ~ final
|
467 |
+
for i in range(2, self.fin_stage):
|
468 |
+
x, x_t = eval('self.conv_trans_' + str(i))(x, x_t)
|
469 |
+
|
470 |
+
# conv classification
|
471 |
+
x_p = self.pooling(x).flatten(1)
|
472 |
+
conv_cls = self.conv_cls_head(x_p)
|
473 |
+
|
474 |
+
# trans classification
|
475 |
+
x_t = self.trans_norm(x_t)
|
476 |
+
tran_cls = self.trans_cls_head(x_t[:, 0])
|
477 |
+
|
478 |
+
# 加一个类别投票
|
479 |
+
cls = torch.cat([conv_cls, tran_cls], dim=1)
|
480 |
+
cls = self.cls_head(cls)
|
481 |
+
return cls
|
482 |
+
|
483 |
+
# return [conv_cls, tran_cls]
|
PuzzleTuning/Backbone/counterpart_models/crossformer.py
ADDED
@@ -0,0 +1,624 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.utils.checkpoint as checkpoint
|
4 |
+
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
|
5 |
+
|
6 |
+
|
7 |
+
class Mlp(nn.Module):
|
8 |
+
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
|
9 |
+
super().__init__()
|
10 |
+
out_features = out_features or in_features
|
11 |
+
hidden_features = hidden_features or in_features
|
12 |
+
self.fc1 = nn.Linear(in_features, hidden_features)
|
13 |
+
self.act = act_layer()
|
14 |
+
self.fc2 = nn.Linear(hidden_features, out_features)
|
15 |
+
self.drop = nn.Dropout(drop)
|
16 |
+
|
17 |
+
def forward(self, x):
|
18 |
+
x = self.fc1(x)
|
19 |
+
x = self.act(x)
|
20 |
+
x = self.drop(x)
|
21 |
+
x = self.fc2(x)
|
22 |
+
x = self.drop(x)
|
23 |
+
return x
|
24 |
+
|
25 |
+
class DynamicPosBias(nn.Module):
|
26 |
+
def __init__(self, dim, num_heads, residual):
|
27 |
+
super().__init__()
|
28 |
+
self.residual = residual
|
29 |
+
self.num_heads = num_heads
|
30 |
+
self.pos_dim = dim // 4
|
31 |
+
self.pos_proj = nn.Linear(2, self.pos_dim)
|
32 |
+
self.pos1 = nn.Sequential(
|
33 |
+
nn.LayerNorm(self.pos_dim),
|
34 |
+
nn.ReLU(inplace=True),
|
35 |
+
nn.Linear(self.pos_dim, self.pos_dim),
|
36 |
+
)
|
37 |
+
self.pos2 = nn.Sequential(
|
38 |
+
nn.LayerNorm(self.pos_dim),
|
39 |
+
nn.ReLU(inplace=True),
|
40 |
+
nn.Linear(self.pos_dim, self.pos_dim)
|
41 |
+
)
|
42 |
+
self.pos3 = nn.Sequential(
|
43 |
+
nn.LayerNorm(self.pos_dim),
|
44 |
+
nn.ReLU(inplace=True),
|
45 |
+
nn.Linear(self.pos_dim, self.num_heads)
|
46 |
+
)
|
47 |
+
def forward(self, biases):
|
48 |
+
if self.residual:
|
49 |
+
pos = self.pos_proj(biases) # 2Wh-1 * 2Ww-1, heads
|
50 |
+
pos = pos + self.pos1(pos)
|
51 |
+
pos = pos + self.pos2(pos)
|
52 |
+
pos = self.pos3(pos)
|
53 |
+
else:
|
54 |
+
pos = self.pos3(self.pos2(self.pos1(self.pos_proj(biases))))
|
55 |
+
return pos
|
56 |
+
|
57 |
+
def flops(self, N):
|
58 |
+
flops = N * 2 * self.pos_dim
|
59 |
+
flops += N * self.pos_dim * self.pos_dim
|
60 |
+
flops += N * self.pos_dim * self.pos_dim
|
61 |
+
flops += N * self.pos_dim * self.num_heads
|
62 |
+
return flops
|
63 |
+
|
64 |
+
class Attention(nn.Module):
|
65 |
+
r""" Multi-head self attention module with dynamic position bias.
|
66 |
+
|
67 |
+
Args:
|
68 |
+
dim (int): Number of input channels.
|
69 |
+
group_size (tuple[int]): The height and width of the group.
|
70 |
+
num_heads (int): Number of attention heads.
|
71 |
+
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
72 |
+
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
|
73 |
+
attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
|
74 |
+
proj_drop (float, optional): Dropout ratio of output. Default: 0.0
|
75 |
+
"""
|
76 |
+
|
77 |
+
def __init__(self, dim, group_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.,
|
78 |
+
position_bias=True):
|
79 |
+
|
80 |
+
super().__init__()
|
81 |
+
self.dim = dim
|
82 |
+
self.group_size = group_size # Wh, Ww
|
83 |
+
self.num_heads = num_heads
|
84 |
+
head_dim = dim // num_heads
|
85 |
+
self.scale = qk_scale or head_dim ** -0.5
|
86 |
+
self.position_bias = position_bias
|
87 |
+
|
88 |
+
if position_bias:
|
89 |
+
self.pos = DynamicPosBias(self.dim // 4, self.num_heads, residual=False)
|
90 |
+
|
91 |
+
# generate mother-set
|
92 |
+
position_bias_h = torch.arange(1 - self.group_size[0], self.group_size[0])
|
93 |
+
position_bias_w = torch.arange(1 - self.group_size[1], self.group_size[1])
|
94 |
+
biases = torch.stack(torch.meshgrid([position_bias_h, position_bias_w])) # 2, 2Wh-1, 2W2-1
|
95 |
+
biases = biases.flatten(1).transpose(0, 1).float()
|
96 |
+
self.register_buffer("biases", biases)
|
97 |
+
|
98 |
+
# get pair-wise relative position index for each token inside the group
|
99 |
+
coords_h = torch.arange(self.group_size[0])
|
100 |
+
coords_w = torch.arange(self.group_size[1])
|
101 |
+
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
|
102 |
+
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
|
103 |
+
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
|
104 |
+
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
|
105 |
+
relative_coords[:, :, 0] += self.group_size[0] - 1 # shift to start from 0
|
106 |
+
relative_coords[:, :, 1] += self.group_size[1] - 1
|
107 |
+
relative_coords[:, :, 0] *= 2 * self.group_size[1] - 1
|
108 |
+
relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
|
109 |
+
self.register_buffer("relative_position_index", relative_position_index)
|
110 |
+
|
111 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
112 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
113 |
+
self.proj = nn.Linear(dim, dim)
|
114 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
115 |
+
|
116 |
+
self.softmax = nn.Softmax(dim=-1)
|
117 |
+
|
118 |
+
def forward(self, x, mask=None):
|
119 |
+
"""
|
120 |
+
Args:
|
121 |
+
x: input features with shape of (num_groups*B, N, C)
|
122 |
+
mask: (0/-inf) mask with shape of (num_groups, Wh*Ww, Wh*Ww) or None
|
123 |
+
"""
|
124 |
+
B_, N, C = x.shape
|
125 |
+
qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
126 |
+
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
|
127 |
+
|
128 |
+
q = q * self.scale
|
129 |
+
attn = (q @ k.transpose(-2, -1))
|
130 |
+
|
131 |
+
if self.position_bias:
|
132 |
+
pos = self.pos(self.biases) # 2Wh-1 * 2Ww-1, heads
|
133 |
+
# select position bias
|
134 |
+
relative_position_bias = pos[self.relative_position_index.view(-1)].view(
|
135 |
+
self.group_size[0] * self.group_size[1], self.group_size[0] * self.group_size[1], -1) # Wh*Ww,Wh*Ww,nH
|
136 |
+
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
|
137 |
+
attn = attn + relative_position_bias.unsqueeze(0)
|
138 |
+
|
139 |
+
if mask is not None:
|
140 |
+
nW = mask.shape[0]
|
141 |
+
attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
|
142 |
+
attn = attn.view(-1, self.num_heads, N, N)
|
143 |
+
attn = self.softmax(attn)
|
144 |
+
else:
|
145 |
+
attn = self.softmax(attn)
|
146 |
+
|
147 |
+
attn = self.attn_drop(attn)
|
148 |
+
|
149 |
+
x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
|
150 |
+
x = self.proj(x)
|
151 |
+
x = self.proj_drop(x)
|
152 |
+
return x
|
153 |
+
|
154 |
+
def extra_repr(self) -> str:
|
155 |
+
return f'dim={self.dim}, group_size={self.group_size}, num_heads={self.num_heads}'
|
156 |
+
|
157 |
+
def flops(self, N):
|
158 |
+
# calculate flops for 1 group with token length of N
|
159 |
+
flops = 0
|
160 |
+
# qkv = self.qkv(x)
|
161 |
+
flops += N * self.dim * 3 * self.dim
|
162 |
+
# attn = (q @ k.transpose(-2, -1))
|
163 |
+
flops += self.num_heads * N * (self.dim // self.num_heads) * N
|
164 |
+
# x = (attn @ v)
|
165 |
+
flops += self.num_heads * N * N * (self.dim // self.num_heads)
|
166 |
+
# x = self.proj(x)
|
167 |
+
flops += N * self.dim * self.dim
|
168 |
+
if self.position_bias:
|
169 |
+
flops += self.pos.flops(N)
|
170 |
+
return flops
|
171 |
+
|
172 |
+
|
173 |
+
class CrossFormerBlock(nn.Module):
|
174 |
+
r""" CrossFormer Block.
|
175 |
+
|
176 |
+
Args:
|
177 |
+
dim (int): Number of input channels.
|
178 |
+
input_resolution (tuple[int]): Input resulotion.
|
179 |
+
num_heads (int): Number of attention heads.
|
180 |
+
group_size (int): Group size.
|
181 |
+
lsda_flag (int): use SDA or LDA, 0 for SDA and 1 for LDA.
|
182 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
183 |
+
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
184 |
+
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
|
185 |
+
drop (float, optional): Dropout rate. Default: 0.0
|
186 |
+
attn_drop (float, optional): Attention dropout rate. Default: 0.0
|
187 |
+
drop_path (float, optional): Stochastic depth rate. Default: 0.0
|
188 |
+
act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
|
189 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
190 |
+
"""
|
191 |
+
|
192 |
+
def __init__(self, dim, input_resolution, num_heads, group_size=7, lsda_flag=0,
|
193 |
+
mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
|
194 |
+
act_layer=nn.GELU, norm_layer=nn.LayerNorm, num_patch_size=1):
|
195 |
+
super().__init__()
|
196 |
+
self.dim = dim
|
197 |
+
self.input_resolution = input_resolution
|
198 |
+
self.num_heads = num_heads
|
199 |
+
self.group_size = group_size
|
200 |
+
self.lsda_flag = lsda_flag
|
201 |
+
self.mlp_ratio = mlp_ratio
|
202 |
+
self.num_patch_size = num_patch_size
|
203 |
+
if min(self.input_resolution) <= self.group_size:
|
204 |
+
# if group size is larger than input resolution, we don't partition groups
|
205 |
+
self.lsda_flag = 0
|
206 |
+
self.group_size = min(self.input_resolution)
|
207 |
+
|
208 |
+
self.norm1 = norm_layer(dim)
|
209 |
+
|
210 |
+
self.attn = Attention(
|
211 |
+
dim, group_size=to_2tuple(self.group_size), num_heads=num_heads,
|
212 |
+
qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop,
|
213 |
+
position_bias=True)
|
214 |
+
|
215 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
216 |
+
self.norm2 = norm_layer(dim)
|
217 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
218 |
+
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
|
219 |
+
|
220 |
+
attn_mask = None
|
221 |
+
self.register_buffer("attn_mask", attn_mask)
|
222 |
+
|
223 |
+
def forward(self, x):
|
224 |
+
H, W = self.input_resolution
|
225 |
+
B, L, C = x.shape
|
226 |
+
assert L == H * W, "input feature has wrong size %d, %d, %d" % (L, H, W)
|
227 |
+
|
228 |
+
shortcut = x
|
229 |
+
x = self.norm1(x)
|
230 |
+
x = x.view(B, H, W, C)
|
231 |
+
|
232 |
+
# group embeddings
|
233 |
+
G = self.group_size
|
234 |
+
if self.lsda_flag == 0: # 0 for SDA
|
235 |
+
x = x.reshape(B, H // G, G, W // G, G, C).permute(0, 1, 3, 2, 4, 5)
|
236 |
+
else: # 1 for LDA
|
237 |
+
x = x.reshape(B, G, H // G, G, W // G, C).permute(0, 2, 4, 1, 3, 5)
|
238 |
+
x = x.reshape(B * H * W // G**2, G**2, C)
|
239 |
+
|
240 |
+
# multi-head self-attention
|
241 |
+
x = self.attn(x, mask=self.attn_mask) # nW*B, G*G, C
|
242 |
+
|
243 |
+
# ungroup embeddings
|
244 |
+
x = x.reshape(B, H // G, W // G, G, G, C)
|
245 |
+
if self.lsda_flag == 0:
|
246 |
+
x = x.permute(0, 1, 3, 2, 4, 5).reshape(B, H, W, C)
|
247 |
+
else:
|
248 |
+
x = x.permute(0, 3, 1, 4, 2, 5).reshape(B, H, W, C)
|
249 |
+
x = x.view(B, H * W, C)
|
250 |
+
|
251 |
+
# FFN
|
252 |
+
x = shortcut + self.drop_path(x)
|
253 |
+
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
254 |
+
|
255 |
+
return x
|
256 |
+
|
257 |
+
def extra_repr(self) -> str:
|
258 |
+
return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \
|
259 |
+
f"group_size={self.group_size}, lsda_flag={self.lsda_flag}, mlp_ratio={self.mlp_ratio}"
|
260 |
+
|
261 |
+
def flops(self):
|
262 |
+
flops = 0
|
263 |
+
H, W = self.input_resolution
|
264 |
+
# norm1
|
265 |
+
flops += self.dim * H * W
|
266 |
+
# LSDA
|
267 |
+
nW = H * W / self.group_size / self.group_size
|
268 |
+
flops += nW * self.attn.flops(self.group_size * self.group_size)
|
269 |
+
# mlp
|
270 |
+
flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio
|
271 |
+
# norm2
|
272 |
+
flops += self.dim * H * W
|
273 |
+
return flops
|
274 |
+
|
275 |
+
class PatchMerging(nn.Module):
|
276 |
+
r""" Patch Merging Layer.
|
277 |
+
|
278 |
+
Args:
|
279 |
+
input_resolution (tuple[int]): Resolution of input feature.
|
280 |
+
dim (int): Number of input channels.
|
281 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
282 |
+
"""
|
283 |
+
|
284 |
+
def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm, patch_size=[2], num_input_patch_size=1):
|
285 |
+
super().__init__()
|
286 |
+
self.input_resolution = input_resolution
|
287 |
+
self.dim = dim
|
288 |
+
self.reductions = nn.ModuleList()
|
289 |
+
self.patch_size = patch_size
|
290 |
+
self.norm = norm_layer(dim)
|
291 |
+
|
292 |
+
for i, ps in enumerate(patch_size):
|
293 |
+
if i == len(patch_size) - 1:
|
294 |
+
out_dim = 2 * dim // 2 ** i
|
295 |
+
else:
|
296 |
+
out_dim = 2 * dim // 2 ** (i + 1)
|
297 |
+
stride = 2
|
298 |
+
padding = (ps - stride) // 2
|
299 |
+
self.reductions.append(nn.Conv2d(dim, out_dim, kernel_size=ps,
|
300 |
+
stride=stride, padding=padding))
|
301 |
+
|
302 |
+
def forward(self, x):
|
303 |
+
"""
|
304 |
+
x: B, H*W, C
|
305 |
+
"""
|
306 |
+
H, W = self.input_resolution
|
307 |
+
B, L, C = x.shape
|
308 |
+
assert L == H * W, "input feature has wrong size"
|
309 |
+
assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."
|
310 |
+
|
311 |
+
x = self.norm(x)
|
312 |
+
x = x.view(B, H, W, C).permute(0, 3, 1, 2)
|
313 |
+
|
314 |
+
xs = []
|
315 |
+
for i in range(len(self.reductions)):
|
316 |
+
tmp_x = self.reductions[i](x).flatten(2).transpose(1, 2)
|
317 |
+
xs.append(tmp_x)
|
318 |
+
x = torch.cat(xs, dim=2)
|
319 |
+
return x
|
320 |
+
|
321 |
+
def extra_repr(self) -> str:
|
322 |
+
return f"input_resolution={self.input_resolution}, dim={self.dim}"
|
323 |
+
|
324 |
+
def flops(self):
|
325 |
+
H, W = self.input_resolution
|
326 |
+
flops = H * W * self.dim
|
327 |
+
for i, ps in enumerate(self.patch_size):
|
328 |
+
if i == len(self.patch_size) - 1:
|
329 |
+
out_dim = 2 * self.dim // 2 ** i
|
330 |
+
else:
|
331 |
+
out_dim = 2 * self.dim // 2 ** (i + 1)
|
332 |
+
flops += (H // 2) * (W // 2) * ps * ps * out_dim * self.dim
|
333 |
+
return flops
|
334 |
+
|
335 |
+
|
336 |
+
class Stage(nn.Module):
|
337 |
+
""" CrossFormer blocks for one stage.
|
338 |
+
|
339 |
+
Args:
|
340 |
+
dim (int): Number of input channels.
|
341 |
+
input_resolution (tuple[int]): Input resolution.
|
342 |
+
depth (int): Number of blocks.
|
343 |
+
num_heads (int): Number of attention heads.
|
344 |
+
group_size (int): variable G in the paper, one group has GxG embeddings
|
345 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
346 |
+
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
347 |
+
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
|
348 |
+
drop (float, optional): Dropout rate. Default: 0.0
|
349 |
+
attn_drop (float, optional): Attention dropout rate. Default: 0.0
|
350 |
+
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
|
351 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
352 |
+
downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
|
353 |
+
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
|
354 |
+
"""
|
355 |
+
|
356 |
+
def __init__(self, dim, input_resolution, depth, num_heads, group_size,
|
357 |
+
mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
|
358 |
+
drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False,
|
359 |
+
patch_size_end=[4], num_patch_size=None):
|
360 |
+
|
361 |
+
super().__init__()
|
362 |
+
self.dim = dim
|
363 |
+
self.input_resolution = input_resolution
|
364 |
+
self.depth = depth
|
365 |
+
self.use_checkpoint = use_checkpoint
|
366 |
+
|
367 |
+
# build blocks
|
368 |
+
self.blocks = nn.ModuleList()
|
369 |
+
for i in range(depth):
|
370 |
+
lsda_flag = 0 if (i % 2 == 0) else 1
|
371 |
+
self.blocks.append(CrossFormerBlock(dim=dim, input_resolution=input_resolution,
|
372 |
+
num_heads=num_heads, group_size=group_size,
|
373 |
+
lsda_flag=lsda_flag,
|
374 |
+
mlp_ratio=mlp_ratio,
|
375 |
+
qkv_bias=qkv_bias, qk_scale=qk_scale,
|
376 |
+
drop=drop, attn_drop=attn_drop,
|
377 |
+
drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
|
378 |
+
norm_layer=norm_layer,
|
379 |
+
num_patch_size=num_patch_size))
|
380 |
+
|
381 |
+
# patch merging layer
|
382 |
+
if downsample is not None:
|
383 |
+
self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer,
|
384 |
+
patch_size=patch_size_end, num_input_patch_size=num_patch_size)
|
385 |
+
else:
|
386 |
+
self.downsample = None
|
387 |
+
|
388 |
+
def forward(self, x):
|
389 |
+
for blk in self.blocks:
|
390 |
+
if self.use_checkpoint:
|
391 |
+
x = checkpoint.checkpoint(blk, x)
|
392 |
+
else:
|
393 |
+
x = blk(x)
|
394 |
+
if self.downsample is not None:
|
395 |
+
x = self.downsample(x)
|
396 |
+
return x
|
397 |
+
|
398 |
+
def extra_repr(self) -> str:
|
399 |
+
return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"
|
400 |
+
|
401 |
+
def flops(self):
|
402 |
+
flops = 0
|
403 |
+
for blk in self.blocks:
|
404 |
+
flops += blk.flops()
|
405 |
+
if self.downsample is not None:
|
406 |
+
flops += self.downsample.flops()
|
407 |
+
return flops
|
408 |
+
|
409 |
+
|
410 |
+
class PatchEmbed(nn.Module):
|
411 |
+
r""" Image to Patch Embedding
|
412 |
+
|
413 |
+
Args:
|
414 |
+
img_size (int): Image size. Default: 224.
|
415 |
+
patch_size (int): Patch token size. Default: [4].
|
416 |
+
in_chans (int): Number of input image channels. Default: 3.
|
417 |
+
embed_dim (int): Number of linear projection output channels. Default: 96.
|
418 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: None
|
419 |
+
"""
|
420 |
+
|
421 |
+
def __init__(self, img_size=224, patch_size=[4], in_chans=3, embed_dim=96, norm_layer=None):
|
422 |
+
super().__init__()
|
423 |
+
img_size = to_2tuple(img_size)
|
424 |
+
# patch_size = to_2tuple(patch_size)
|
425 |
+
patches_resolution = [img_size[0] // patch_size[0], img_size[0] // patch_size[0]]
|
426 |
+
self.img_size = img_size
|
427 |
+
self.patch_size = patch_size
|
428 |
+
self.patches_resolution = patches_resolution
|
429 |
+
self.num_patches = patches_resolution[0] * patches_resolution[1]
|
430 |
+
|
431 |
+
self.in_chans = in_chans
|
432 |
+
self.embed_dim = embed_dim
|
433 |
+
|
434 |
+
self.projs = nn.ModuleList()
|
435 |
+
for i, ps in enumerate(patch_size):
|
436 |
+
if i == len(patch_size) - 1:
|
437 |
+
dim = embed_dim // 2 ** i
|
438 |
+
else:
|
439 |
+
dim = embed_dim // 2 ** (i + 1)
|
440 |
+
stride = patch_size[0]
|
441 |
+
padding = (ps - patch_size[0]) // 2
|
442 |
+
self.projs.append(nn.Conv2d(in_chans, dim, kernel_size=ps, stride=stride, padding=padding))
|
443 |
+
if norm_layer is not None:
|
444 |
+
self.norm = norm_layer(embed_dim)
|
445 |
+
else:
|
446 |
+
self.norm = None
|
447 |
+
|
448 |
+
def forward(self, x):
|
449 |
+
B, C, H, W = x.shape
|
450 |
+
# FIXME look at relaxing size constraints
|
451 |
+
assert H == self.img_size[0] and W == self.img_size[1], \
|
452 |
+
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
|
453 |
+
xs = []
|
454 |
+
for i in range(len(self.projs)):
|
455 |
+
tx = self.projs[i](x).flatten(2).transpose(1, 2)
|
456 |
+
xs.append(tx) # B Ph*Pw C
|
457 |
+
x = torch.cat(xs, dim=2)
|
458 |
+
if self.norm is not None:
|
459 |
+
x = self.norm(x)
|
460 |
+
return x
|
461 |
+
|
462 |
+
def flops(self):
|
463 |
+
Ho, Wo = self.patches_resolution
|
464 |
+
flops = 0
|
465 |
+
for i, ps in enumerate(self.patch_size):
|
466 |
+
if i == len(self.patch_size) - 1:
|
467 |
+
dim = self.embed_dim // 2 ** i
|
468 |
+
else:
|
469 |
+
dim = self.embed_dim // 2 ** (i + 1)
|
470 |
+
flops += Ho * Wo * dim * self.in_chans * (self.patch_size[i] * self.patch_size[i])
|
471 |
+
if self.norm is not None:
|
472 |
+
flops += Ho * Wo * self.embed_dim
|
473 |
+
return flops
|
474 |
+
|
475 |
+
|
476 |
+
class CrossFormer(nn.Module):
|
477 |
+
r""" CrossFormer
|
478 |
+
A PyTorch impl of : `CrossFormer: A Versatile Vision Transformer Based on Cross-scale Attention` -
|
479 |
+
|
480 |
+
Args:
|
481 |
+
img_size (int | tuple(int)): Input image size. Default 224
|
482 |
+
patch_size (int | tuple(int)): Patch size. Default: 4
|
483 |
+
in_chans (int): Number of input image channels. Default: 3
|
484 |
+
num_classes (int): Number of classes for classification head. Default: 1000
|
485 |
+
embed_dim (int): Patch embedding dimension. Default: 96
|
486 |
+
depths (tuple(int)): Depth of each stage.
|
487 |
+
num_heads (tuple(int)): Number of attention heads in different layers.
|
488 |
+
group_size (int): Group size. Default: 7
|
489 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
|
490 |
+
qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
|
491 |
+
qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None
|
492 |
+
drop_rate (float): Dropout rate. Default: 0
|
493 |
+
attn_drop_rate (float): Attention dropout rate. Default: 0
|
494 |
+
drop_path_rate (float): Stochastic depth rate. Default: 0.1
|
495 |
+
norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
|
496 |
+
ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
|
497 |
+
patch_norm (bool): If True, add normalization after patch embedding. Default: True
|
498 |
+
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
|
499 |
+
"""
|
500 |
+
|
501 |
+
def __init__(self, img_size=224, patch_size=[4], in_chans=3, num_classes=1000,
|
502 |
+
embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24],
|
503 |
+
group_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None,
|
504 |
+
drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
|
505 |
+
norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
|
506 |
+
use_checkpoint=False, merge_size=[[2], [2], [2]], **kwargs):
|
507 |
+
super().__init__()
|
508 |
+
|
509 |
+
self.num_classes = num_classes
|
510 |
+
self.num_layers = len(depths)
|
511 |
+
self.embed_dim = embed_dim
|
512 |
+
self.ape = ape
|
513 |
+
self.patch_norm = patch_norm
|
514 |
+
self.num_features = int(embed_dim * 2 ** (self.num_layers - 1))
|
515 |
+
self.mlp_ratio = mlp_ratio
|
516 |
+
|
517 |
+
# split image into non-overlapping patches
|
518 |
+
self.patch_embed = PatchEmbed(
|
519 |
+
img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim,
|
520 |
+
norm_layer=norm_layer if self.patch_norm else None)
|
521 |
+
num_patches = self.patch_embed.num_patches
|
522 |
+
patches_resolution = self.patch_embed.patches_resolution
|
523 |
+
self.patches_resolution = patches_resolution
|
524 |
+
|
525 |
+
# absolute position embedding
|
526 |
+
if self.ape:
|
527 |
+
self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
|
528 |
+
trunc_normal_(self.absolute_pos_embed, std=.02)
|
529 |
+
|
530 |
+
self.pos_drop = nn.Dropout(p=drop_rate)
|
531 |
+
|
532 |
+
# stochastic depth
|
533 |
+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
|
534 |
+
|
535 |
+
# build layers
|
536 |
+
self.layers = nn.ModuleList()
|
537 |
+
|
538 |
+
num_patch_sizes = [len(patch_size)] + [len(m) for m in merge_size]
|
539 |
+
for i_layer in range(self.num_layers):
|
540 |
+
patch_size_end = merge_size[i_layer] if i_layer < self.num_layers - 1 else None
|
541 |
+
num_patch_size = num_patch_sizes[i_layer]
|
542 |
+
layer = Stage(dim=int(embed_dim * 2 ** i_layer),
|
543 |
+
input_resolution=(patches_resolution[0] // (2 ** i_layer),
|
544 |
+
patches_resolution[1] // (2 ** i_layer)),
|
545 |
+
depth=depths[i_layer],
|
546 |
+
num_heads=num_heads[i_layer],
|
547 |
+
group_size=group_size[i_layer],
|
548 |
+
mlp_ratio=self.mlp_ratio,
|
549 |
+
qkv_bias=qkv_bias, qk_scale=qk_scale,
|
550 |
+
drop=drop_rate, attn_drop=attn_drop_rate,
|
551 |
+
drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
|
552 |
+
norm_layer=norm_layer,
|
553 |
+
downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
|
554 |
+
use_checkpoint=use_checkpoint,
|
555 |
+
patch_size_end=patch_size_end,
|
556 |
+
num_patch_size=num_patch_size)
|
557 |
+
self.layers.append(layer)
|
558 |
+
|
559 |
+
self.norm = norm_layer(self.num_features)
|
560 |
+
self.avgpool = nn.AdaptiveAvgPool1d(1)
|
561 |
+
self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
|
562 |
+
|
563 |
+
self.apply(self._init_weights)
|
564 |
+
|
565 |
+
def _init_weights(self, m):
|
566 |
+
if isinstance(m, nn.Linear):
|
567 |
+
trunc_normal_(m.weight, std=.02)
|
568 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
569 |
+
nn.init.constant_(m.bias, 0)
|
570 |
+
elif isinstance(m, nn.LayerNorm):
|
571 |
+
nn.init.constant_(m.bias, 0)
|
572 |
+
nn.init.constant_(m.weight, 1.0)
|
573 |
+
|
574 |
+
@torch.jit.ignore
|
575 |
+
def no_weight_decay(self):
|
576 |
+
return {'absolute_pos_embed'}
|
577 |
+
|
578 |
+
@torch.jit.ignore
|
579 |
+
def no_weight_decay_keywords(self):
|
580 |
+
return {'relative_position_bias_table'}
|
581 |
+
|
582 |
+
def forward_features(self, x):
|
583 |
+
x = self.patch_embed(x)
|
584 |
+
if self.ape:
|
585 |
+
x = x + self.absolute_pos_embed
|
586 |
+
x = self.pos_drop(x)
|
587 |
+
|
588 |
+
for layer in self.layers:
|
589 |
+
x = layer(x)
|
590 |
+
|
591 |
+
x = self.norm(x) # B L C
|
592 |
+
x = self.avgpool(x.transpose(1, 2)) # B C 1
|
593 |
+
x = torch.flatten(x, 1)
|
594 |
+
return x
|
595 |
+
|
596 |
+
def forward(self, x):
|
597 |
+
x = self.forward_features(x)
|
598 |
+
x = self.head(x)
|
599 |
+
return x
|
600 |
+
|
601 |
+
def flops(self):
|
602 |
+
flops = 0
|
603 |
+
flops += self.patch_embed.flops()
|
604 |
+
for i, layer in enumerate(self.layers):
|
605 |
+
flops += layer.flops()
|
606 |
+
flops += self.num_features * self.patches_resolution[0] * self.patches_resolution[1] // (2 ** self.num_layers)
|
607 |
+
flops += self.num_features * self.num_classes
|
608 |
+
return flops
|
609 |
+
|
610 |
+
|
611 |
+
class cross_former_cls_head_warp(nn.Module):
|
612 |
+
def __init__(self, backbone, num_classes):
|
613 |
+
super().__init__()
|
614 |
+
embed_dim = 96
|
615 |
+
depths = [2, 2, 18, 2]
|
616 |
+
num_layers = len(depths)
|
617 |
+
num_features = int(embed_dim * 2 ** (num_layers - 1))
|
618 |
+
self.backbone = backbone
|
619 |
+
self.head = nn.Linear(num_features, num_classes)
|
620 |
+
|
621 |
+
def forward(self, x):
|
622 |
+
x = self.backbone(x)
|
623 |
+
x = self.head(x)
|
624 |
+
return x
|
PuzzleTuning/Backbone/counterpart_models/crossformer_backbone.py
ADDED
@@ -0,0 +1,659 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
import torch.utils.checkpoint as checkpoint
|
6 |
+
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
|
7 |
+
from mmdet.utils import get_root_logger
|
8 |
+
from mmcv.runner import load_checkpoint
|
9 |
+
|
10 |
+
NEG_INF = -1000000
|
11 |
+
|
12 |
+
class Mlp(nn.Module):
|
13 |
+
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
|
14 |
+
super().__init__()
|
15 |
+
out_features = out_features or in_features
|
16 |
+
hidden_features = hidden_features or in_features
|
17 |
+
self.fc1 = nn.Linear(in_features, hidden_features)
|
18 |
+
self.act = act_layer()
|
19 |
+
self.fc2 = nn.Linear(hidden_features, out_features)
|
20 |
+
self.drop = nn.Dropout(drop)
|
21 |
+
|
22 |
+
def forward(self, x):
|
23 |
+
x = self.fc1(x)
|
24 |
+
x = self.act(x)
|
25 |
+
x = self.drop(x)
|
26 |
+
x = self.fc2(x)
|
27 |
+
x = self.drop(x)
|
28 |
+
return x
|
29 |
+
|
30 |
+
|
31 |
+
class DynamicPosBias(nn.Module):
|
32 |
+
def __init__(self, dim, num_heads, residual):
|
33 |
+
super().__init__()
|
34 |
+
self.residual = residual
|
35 |
+
self.num_heads = num_heads
|
36 |
+
self.pos_dim = dim // 4
|
37 |
+
self.pos_proj = nn.Linear(2, self.pos_dim)
|
38 |
+
self.pos1 = nn.Sequential(
|
39 |
+
nn.LayerNorm(self.pos_dim),
|
40 |
+
nn.ReLU(inplace=True),
|
41 |
+
nn.Linear(self.pos_dim, self.pos_dim),
|
42 |
+
)
|
43 |
+
self.pos2 = nn.Sequential(
|
44 |
+
nn.LayerNorm(self.pos_dim),
|
45 |
+
nn.ReLU(inplace=True),
|
46 |
+
nn.Linear(self.pos_dim, self.pos_dim)
|
47 |
+
)
|
48 |
+
self.pos3 = nn.Sequential(
|
49 |
+
nn.LayerNorm(self.pos_dim),
|
50 |
+
nn.ReLU(inplace=True),
|
51 |
+
nn.Linear(self.pos_dim, self.num_heads)
|
52 |
+
)
|
53 |
+
def forward(self, biases):
|
54 |
+
if self.residual:
|
55 |
+
pos = self.pos_proj(biases) # 2Gh-1 * 2Gw-1, heads
|
56 |
+
pos = pos + self.pos1(pos)
|
57 |
+
pos = pos + self.pos2(pos)
|
58 |
+
pos = self.pos3(pos)
|
59 |
+
else:
|
60 |
+
pos = self.pos3(self.pos2(self.pos1(self.pos_proj(biases))))
|
61 |
+
return pos
|
62 |
+
|
63 |
+
def flops(self, N):
|
64 |
+
flops = N * 2 * self.pos_dim
|
65 |
+
flops += N * self.pos_dim * self.pos_dim
|
66 |
+
flops += N * self.pos_dim * self.pos_dim
|
67 |
+
flops += N * self.pos_dim * self.num_heads
|
68 |
+
return flops
|
69 |
+
|
70 |
+
class Attention(nn.Module):
|
71 |
+
r""" Multi-head self attention module with relative position bias.
|
72 |
+
|
73 |
+
Args:
|
74 |
+
dim (int): Number of input channels.
|
75 |
+
num_heads (int): Number of attention heads.
|
76 |
+
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
77 |
+
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
|
78 |
+
attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
|
79 |
+
proj_drop (float, optional): Dropout ratio of output. Default: 0.0
|
80 |
+
"""
|
81 |
+
|
82 |
+
def __init__(self, dim, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.,
|
83 |
+
position_bias=True):
|
84 |
+
|
85 |
+
super().__init__()
|
86 |
+
self.dim = dim
|
87 |
+
self.num_heads = num_heads
|
88 |
+
head_dim = dim // num_heads
|
89 |
+
self.scale = qk_scale or head_dim ** -0.5
|
90 |
+
self.position_bias = position_bias
|
91 |
+
if self.position_bias:
|
92 |
+
self.pos = DynamicPosBias(self.dim // 4, self.num_heads, residual=False)
|
93 |
+
|
94 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
95 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
96 |
+
self.proj = nn.Linear(dim, dim)
|
97 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
98 |
+
|
99 |
+
self.softmax = nn.Softmax(dim=-1)
|
100 |
+
|
101 |
+
def forward(self, x, H, W, mask=None):
|
102 |
+
"""
|
103 |
+
Args:
|
104 |
+
x: input features with shape of (num_windows*B, N, C)
|
105 |
+
mask: (0/-inf) mask with shape of (num_windows, Gh*Gw, Gh*Gw) or None
|
106 |
+
"""
|
107 |
+
group_size = (H, W)
|
108 |
+
B_, N, C = x.shape
|
109 |
+
assert H*W == N
|
110 |
+
qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4).contiguous()
|
111 |
+
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
|
112 |
+
|
113 |
+
q = q * self.scale
|
114 |
+
attn = (q @ k.transpose(-2, -1)) # (num_windows*B, N, N), N = Gh*Gw
|
115 |
+
|
116 |
+
if self.position_bias:
|
117 |
+
# generate mother-set
|
118 |
+
position_bias_h = torch.arange(1 - group_size[0], group_size[0], device=attn.device)
|
119 |
+
position_bias_w = torch.arange(1 - group_size[1], group_size[1], device=attn.device)
|
120 |
+
biases = torch.stack(torch.meshgrid([position_bias_h, position_bias_w])) # 2, 2Gh-1, 2W2-1
|
121 |
+
biases = biases.flatten(1).transpose(0, 1).contiguous().float()
|
122 |
+
|
123 |
+
# get pair-wise relative position index for each token inside the window
|
124 |
+
coords_h = torch.arange(group_size[0], device=attn.device)
|
125 |
+
coords_w = torch.arange(group_size[1], device=attn.device)
|
126 |
+
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Gh, Gw
|
127 |
+
coords_flatten = torch.flatten(coords, 1) # 2, Gh*Gw
|
128 |
+
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Gh*Gw, Gh*Gw
|
129 |
+
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Gh*Gw, Gh*Gw, 2
|
130 |
+
relative_coords[:, :, 0] += group_size[0] - 1 # shift to start from 0
|
131 |
+
relative_coords[:, :, 1] += group_size[1] - 1
|
132 |
+
relative_coords[:, :, 0] *= 2 * group_size[1] - 1
|
133 |
+
relative_position_index = relative_coords.sum(-1) # Gh*Gw, Gh*Gw
|
134 |
+
|
135 |
+
pos = self.pos(biases) # 2Gh-1 * 2Gw-1, heads
|
136 |
+
# select position bias
|
137 |
+
relative_position_bias = pos[relative_position_index.view(-1)].view(
|
138 |
+
group_size[0] * group_size[1], group_size[0] * group_size[1], -1) # Gh*Gw,Gh*Gw,nH
|
139 |
+
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Gh*Gw, Gh*Gw
|
140 |
+
attn = attn + relative_position_bias.unsqueeze(0)
|
141 |
+
|
142 |
+
if mask is not None:
|
143 |
+
nG = mask.shape[0]
|
144 |
+
attn = attn.view(B_ // nG, nG, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) # (B, nG, nHead, N, N)
|
145 |
+
attn = attn.view(-1, self.num_heads, N, N)
|
146 |
+
attn = self.softmax(attn)
|
147 |
+
else:
|
148 |
+
attn = self.softmax(attn)
|
149 |
+
|
150 |
+
attn = self.attn_drop(attn)
|
151 |
+
|
152 |
+
x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
|
153 |
+
x = self.proj(x)
|
154 |
+
x = self.proj_drop(x)
|
155 |
+
return x
|
156 |
+
|
157 |
+
def extra_repr(self) -> str:
|
158 |
+
return f'dim={self.dim}, num_heads={self.num_heads}'
|
159 |
+
|
160 |
+
def flops(self, N):
|
161 |
+
# calculate flops for 1 window with token length of N
|
162 |
+
flops = 0
|
163 |
+
excluded_flops = 0
|
164 |
+
# qkv = self.qkv(x)
|
165 |
+
flops += N * self.dim * 3 * self.dim
|
166 |
+
# attn = (q @ k.transpose(-2, -1))
|
167 |
+
flops += self.num_heads * N * (self.dim // self.num_heads) * N
|
168 |
+
excluded_flops += self.num_heads * N * (self.dim // self.num_heads) * N
|
169 |
+
# x = (attn @ v)
|
170 |
+
flops += self.num_heads * N * N * (self.dim // self.num_heads)
|
171 |
+
excluded_flops += self.num_heads * N * N * (self.dim // self.num_heads)
|
172 |
+
# x = self.proj(x)
|
173 |
+
flops += N * self.dim * self.dim
|
174 |
+
if self.position_bias:
|
175 |
+
flops += self.pos.flops(N)
|
176 |
+
return flops, excluded_flops
|
177 |
+
|
178 |
+
|
179 |
+
class CrossFormerBlock(nn.Module):
|
180 |
+
r""" CrossFormer Block.
|
181 |
+
|
182 |
+
Args:
|
183 |
+
dim (int): Number of input channels.
|
184 |
+
input_resolution (tuple[int]): Input resulotion.
|
185 |
+
num_heads (int): Number of attention heads.
|
186 |
+
group_size (int): Window size.
|
187 |
+
lsda_flag (int): use SDA or LDA, 0 for SDA and 1 for LDA.
|
188 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
189 |
+
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
190 |
+
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
|
191 |
+
drop (float, optional): Dropout rate. Default: 0.0
|
192 |
+
attn_drop (float, optional): Attention dropout rate. Default: 0.0
|
193 |
+
drop_path (float, optional): Stochastic depth rate. Default: 0.0
|
194 |
+
act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
|
195 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
196 |
+
"""
|
197 |
+
|
198 |
+
def __init__(self, dim, input_resolution, num_heads, group_size=7, interval=8, lsda_flag=0,
|
199 |
+
mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
|
200 |
+
act_layer=nn.GELU, norm_layer=nn.LayerNorm, num_patch_size=1):
|
201 |
+
super().__init__()
|
202 |
+
self.dim = dim
|
203 |
+
self.input_resolution = input_resolution
|
204 |
+
self.num_heads = num_heads
|
205 |
+
self.group_size = group_size
|
206 |
+
self.interval = interval
|
207 |
+
self.lsda_flag = lsda_flag
|
208 |
+
self.mlp_ratio = mlp_ratio
|
209 |
+
self.num_patch_size = num_patch_size
|
210 |
+
|
211 |
+
self.norm1 = norm_layer(dim)
|
212 |
+
|
213 |
+
self.attn = Attention(
|
214 |
+
dim, num_heads=num_heads,
|
215 |
+
qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop,
|
216 |
+
position_bias=True)
|
217 |
+
|
218 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
219 |
+
self.norm2 = norm_layer(dim)
|
220 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
221 |
+
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
|
222 |
+
|
223 |
+
def forward(self, x, H, W):
|
224 |
+
B, L, C = x.shape
|
225 |
+
assert L == H * W, "input feature has wrong size %d, %d, %d" % (L, H, W)
|
226 |
+
|
227 |
+
if min(H, W) <= self.group_size:
|
228 |
+
# if window size is larger than input resolution, we don't partition windows
|
229 |
+
self.lsda_flag = 0
|
230 |
+
self.group_size = min(H, W)
|
231 |
+
|
232 |
+
shortcut = x
|
233 |
+
x = self.norm1(x)
|
234 |
+
x = x.view(B, H, W, C)
|
235 |
+
|
236 |
+
# padding
|
237 |
+
size_div = self.interval if self.lsda_flag == 1 else self.group_size
|
238 |
+
pad_l = pad_t = 0
|
239 |
+
pad_r = (size_div - W % size_div) % size_div
|
240 |
+
pad_b = (size_div - H % size_div) % size_div
|
241 |
+
x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
|
242 |
+
_, Hp, Wp, _ = x.shape
|
243 |
+
|
244 |
+
mask = torch.zeros((1, Hp, Wp, 1), device=x.device)
|
245 |
+
if pad_b > 0:
|
246 |
+
mask[:, -pad_b:, :, :] = -1
|
247 |
+
if pad_r > 0:
|
248 |
+
mask[:, :, -pad_r:, :] = -1
|
249 |
+
|
250 |
+
# group embeddings and generate attn_mask
|
251 |
+
if self.lsda_flag == 0: # SDA
|
252 |
+
G = Gh = Gw = self.group_size
|
253 |
+
x = x.reshape(B, Hp // G, G, Wp // G, G, C).permute(0, 1, 3, 2, 4, 5).contiguous()
|
254 |
+
x = x.reshape(B * Hp * Wp // G**2, G**2, C)
|
255 |
+
nG = Hp * Wp // G**2
|
256 |
+
# attn_mask
|
257 |
+
if pad_r > 0 or pad_b > 0:
|
258 |
+
mask = mask.reshape(1, Hp // G, G, Wp // G, G, 1).permute(0, 1, 3, 2, 4, 5).contiguous()
|
259 |
+
mask = mask.reshape(nG, 1, G * G)
|
260 |
+
attn_mask = torch.zeros((nG, G * G, G * G), device=x.device)
|
261 |
+
attn_mask = attn_mask.masked_fill(mask < 0, NEG_INF)
|
262 |
+
else:
|
263 |
+
attn_mask = None
|
264 |
+
else: # LDA
|
265 |
+
I, Gh, Gw = self.interval, Hp // self.interval, Wp // self.interval
|
266 |
+
x = x.reshape(B, Gh, I, Gw, I, C).permute(0, 2, 4, 1, 3, 5).contiguous()
|
267 |
+
x = x.reshape(B * I * I, Gh * Gw, C)
|
268 |
+
nG = I ** 2
|
269 |
+
# attn_mask
|
270 |
+
if pad_r > 0 or pad_b > 0:
|
271 |
+
mask = mask.reshape(1, Gh, I, Gw, I, 1).permute(0, 2, 4, 1, 3, 5).contiguous()
|
272 |
+
mask = mask.reshape(nG, 1, Gh * Gw)
|
273 |
+
attn_mask = torch.zeros((nG, Gh * Gw, Gh * Gw), device=x.device)
|
274 |
+
attn_mask = attn_mask.masked_fill(mask < 0, NEG_INF)
|
275 |
+
else:
|
276 |
+
attn_mask = None
|
277 |
+
|
278 |
+
# multi-head self-attention
|
279 |
+
x = self.attn(x, Gh, Gw, mask=attn_mask) # nG*B, G*G, C
|
280 |
+
|
281 |
+
# ungroup embeddings
|
282 |
+
if self.lsda_flag == 0:
|
283 |
+
x = x.reshape(B, Hp // G, Wp // G, G, G, C).permute(0, 1, 3, 2, 4, 5).contiguous() # B, Hp//G, G, Wp//G, G, C
|
284 |
+
else:
|
285 |
+
x = x.reshape(B, I, I, Gh, Gw, C).permute(0, 3, 1, 4, 2, 5).contiguous() # B, Gh, I, Gw, I, C
|
286 |
+
x = x.reshape(B, Hp, Wp, C)
|
287 |
+
|
288 |
+
# remove padding
|
289 |
+
if pad_r > 0 or pad_b > 0:
|
290 |
+
x = x[:, :H, :W, :].contiguous()
|
291 |
+
x = x.view(B, H * W, C)
|
292 |
+
|
293 |
+
# FFN
|
294 |
+
x = shortcut + self.drop_path(x)
|
295 |
+
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
296 |
+
|
297 |
+
return x
|
298 |
+
|
299 |
+
def extra_repr(self) -> str:
|
300 |
+
return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \
|
301 |
+
f"group_size={self.group_size}, lsda_flag={self.lsda_flag}, mlp_ratio={self.mlp_ratio}"
|
302 |
+
|
303 |
+
def flops(self):
|
304 |
+
flops = 0
|
305 |
+
H, W = self.input_resolution
|
306 |
+
# norm1
|
307 |
+
flops += self.dim * H * W
|
308 |
+
# Attention
|
309 |
+
size_div = self.interval if self.lsda_flag == 1 else self.group_size
|
310 |
+
Hp = math.ceil(H / size_div) * size_div
|
311 |
+
Wp = math.ceil(W / size_div) * size_div
|
312 |
+
Gh = Hp / size_div if self.lsda_flag == 1 else self.group_size
|
313 |
+
Gw = Wp / size_div if self.lsda_flag == 1 else self.group_size
|
314 |
+
nG = Hp * Wp / Gh / Gw
|
315 |
+
attn_flops, attn_excluded_flops = self.attn.flops(Gh * Gw)
|
316 |
+
flops += nG * attn_flops
|
317 |
+
excluded_flops = nG * attn_excluded_flops
|
318 |
+
# mlp
|
319 |
+
flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio
|
320 |
+
# norm2
|
321 |
+
flops += self.dim * H * W
|
322 |
+
return flops, excluded_flops
|
323 |
+
|
324 |
+
class PatchMerging(nn.Module):
|
325 |
+
r""" Patch Merging Layer.
|
326 |
+
|
327 |
+
Args:
|
328 |
+
input_resolution (tuple[int]): Resolution of input feature.
|
329 |
+
dim (int): Number of input channels.
|
330 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
331 |
+
"""
|
332 |
+
|
333 |
+
def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm, patch_size=[2], num_input_patch_size=1):
|
334 |
+
super().__init__()
|
335 |
+
self.input_resolution = input_resolution
|
336 |
+
self.dim = dim
|
337 |
+
self.reductions = nn.ModuleList()
|
338 |
+
self.patch_size = patch_size
|
339 |
+
self.norm = norm_layer(dim)
|
340 |
+
|
341 |
+
for i, ps in enumerate(patch_size):
|
342 |
+
if i == len(patch_size) - 1:
|
343 |
+
out_dim = 2 * dim // 2 ** i
|
344 |
+
else:
|
345 |
+
out_dim = 2 * dim // 2 ** (i + 1)
|
346 |
+
stride = 2
|
347 |
+
padding = (ps - stride) // 2
|
348 |
+
self.reductions.append(nn.Conv2d(dim, out_dim, kernel_size=ps,
|
349 |
+
stride=stride, padding=padding))
|
350 |
+
|
351 |
+
def forward(self, x, H, W):
|
352 |
+
"""
|
353 |
+
x: B, H*W, C
|
354 |
+
"""
|
355 |
+
B, L, C = x.shape
|
356 |
+
assert L == H * W, "input feature has wrong size"
|
357 |
+
assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."
|
358 |
+
|
359 |
+
x = self.norm(x)
|
360 |
+
x = x.view(B, H, W, C).permute(0, 3, 1, 2).contiguous()
|
361 |
+
|
362 |
+
xs = []
|
363 |
+
for i in range(len(self.reductions)):
|
364 |
+
tmp_x = self.reductions[i](x).flatten(2).transpose(1, 2).contiguous()
|
365 |
+
xs.append(tmp_x)
|
366 |
+
x = torch.cat(xs, dim=2)
|
367 |
+
return x
|
368 |
+
|
369 |
+
def extra_repr(self) -> str:
|
370 |
+
return f"input_resolution={self.input_resolution}, dim={self.dim}"
|
371 |
+
|
372 |
+
def flops(self):
|
373 |
+
H, W = self.input_resolution
|
374 |
+
flops = H * W * self.dim
|
375 |
+
for i, ps in enumerate(self.patch_size):
|
376 |
+
if i == len(self.patch_size) - 1:
|
377 |
+
out_dim = 2 * self.dim // 2 ** i
|
378 |
+
else:
|
379 |
+
out_dim = 2 * self.dim // 2 ** (i + 1)
|
380 |
+
flops += (H // 2) * (W // 2) * ps * ps * out_dim * self.dim
|
381 |
+
return flops
|
382 |
+
|
383 |
+
|
384 |
+
class Stage(nn.Module):
|
385 |
+
""" CrossFormer blocks for one stage.
|
386 |
+
|
387 |
+
Args:
|
388 |
+
dim (int): Number of input channels.
|
389 |
+
input_resolution (tuple[int]): Input resolution.
|
390 |
+
depth (int): Number of blocks.
|
391 |
+
num_heads (int): Number of attention heads.
|
392 |
+
group_size (int): Group size.
|
393 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
394 |
+
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
395 |
+
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
|
396 |
+
drop (float, optional): Dropout rate. Default: 0.0
|
397 |
+
attn_drop (float, optional): Attention dropout rate. Default: 0.0
|
398 |
+
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
|
399 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
400 |
+
downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
|
401 |
+
use_checkpoint (bool): Ghether to use checkpointing to save memory. Default: False.
|
402 |
+
"""
|
403 |
+
|
404 |
+
def __init__(self, dim, input_resolution, depth, num_heads, group_size, interval,
|
405 |
+
mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
|
406 |
+
drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False,
|
407 |
+
patch_size_end=[4], num_patch_size=None):
|
408 |
+
|
409 |
+
super().__init__()
|
410 |
+
self.dim = dim
|
411 |
+
self.depth = depth
|
412 |
+
self.use_checkpoint = use_checkpoint
|
413 |
+
|
414 |
+
# build blocks
|
415 |
+
self.blocks = nn.ModuleList()
|
416 |
+
for i in range(depth):
|
417 |
+
lsda_flag = 0 if (i % 2 == 0) else 1
|
418 |
+
self.blocks.append(CrossFormerBlock(dim=dim, input_resolution=input_resolution,
|
419 |
+
num_heads=num_heads, group_size=group_size, interval=interval,
|
420 |
+
lsda_flag=lsda_flag,
|
421 |
+
mlp_ratio=mlp_ratio,
|
422 |
+
qkv_bias=qkv_bias, qk_scale=qk_scale,
|
423 |
+
drop=drop, attn_drop=attn_drop,
|
424 |
+
drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
|
425 |
+
norm_layer=norm_layer,
|
426 |
+
num_patch_size=num_patch_size))
|
427 |
+
|
428 |
+
# patch merging layer
|
429 |
+
if downsample is not None:
|
430 |
+
self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer,
|
431 |
+
patch_size=patch_size_end, num_input_patch_size=num_patch_size)
|
432 |
+
else:
|
433 |
+
self.downsample = None
|
434 |
+
|
435 |
+
def forward(self, x, H, W):
|
436 |
+
for blk in self.blocks:
|
437 |
+
if self.use_checkpoint:
|
438 |
+
x = checkpoint.checkpoint(blk, x)
|
439 |
+
else:
|
440 |
+
x = blk(x, H, W)
|
441 |
+
|
442 |
+
B, _, C = x.shape
|
443 |
+
feat = x.view(B, H, W, C).permute(0, 3, 1, 2).contiguous()
|
444 |
+
if self.downsample is not None:
|
445 |
+
x = self.downsample(x, H, W)
|
446 |
+
return feat, x
|
447 |
+
|
448 |
+
def extra_repr(self) -> str:
|
449 |
+
return f"dim={self.dim}, depth={self.depth}"
|
450 |
+
|
451 |
+
def flops(self):
|
452 |
+
flops = 0
|
453 |
+
excluded_flops = 0
|
454 |
+
for blk in self.blocks:
|
455 |
+
blk_flops, blk_excluded_flops = blk.flops()
|
456 |
+
flops += blk_flops
|
457 |
+
excluded_flops += blk_excluded_flops
|
458 |
+
if self.downsample is not None:
|
459 |
+
flops += self.downsample.flops()
|
460 |
+
return flops, excluded_flops
|
461 |
+
|
462 |
+
|
463 |
+
class PatchEmbed(nn.Module):
|
464 |
+
r""" Image to Patch Embedding
|
465 |
+
|
466 |
+
Args:
|
467 |
+
img_size (int): Image size. Default: 224.
|
468 |
+
patch_size (int): Patch token size. Default: 4.
|
469 |
+
in_chans (int): Number of input image channels. Default: 3.
|
470 |
+
embed_dim (int): Number of linear projection output channels. Default: 96.
|
471 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: None
|
472 |
+
"""
|
473 |
+
|
474 |
+
def __init__(self, img_size=224, patch_size=[4], in_chans=3, embed_dim=96, norm_layer=None):
|
475 |
+
super().__init__()
|
476 |
+
img_size = to_2tuple(img_size)
|
477 |
+
# patch_size = to_2tuple(patch_size)
|
478 |
+
patches_resolution = [img_size[0] // 4, img_size[1] // 4] # only for flops calculation
|
479 |
+
self.img_size = img_size
|
480 |
+
self.patch_size = patch_size
|
481 |
+
self.patches_resolution = patches_resolution
|
482 |
+
|
483 |
+
self.in_chans = in_chans
|
484 |
+
self.embed_dim = embed_dim
|
485 |
+
|
486 |
+
self.projs = nn.ModuleList()
|
487 |
+
for i, ps in enumerate(patch_size):
|
488 |
+
if i == len(patch_size) - 1:
|
489 |
+
dim = embed_dim // 2 ** i
|
490 |
+
else:
|
491 |
+
dim = embed_dim // 2 ** (i + 1)
|
492 |
+
stride = 4
|
493 |
+
padding = (ps - 4) // 2
|
494 |
+
self.projs.append(nn.Conv2d(in_chans, dim, kernel_size=ps, stride=stride, padding=padding))
|
495 |
+
if norm_layer is not None:
|
496 |
+
self.norm = norm_layer(embed_dim)
|
497 |
+
else:
|
498 |
+
self.norm = None
|
499 |
+
|
500 |
+
def forward(self, x):
|
501 |
+
B, C, H, W = x.shape
|
502 |
+
xs = []
|
503 |
+
for i in range(len(self.projs)):
|
504 |
+
tx = self.projs[i](x).flatten(2).transpose(1, 2)
|
505 |
+
xs.append(tx) # B Ph*Pw C
|
506 |
+
x = torch.cat(xs, dim=2)
|
507 |
+
if self.norm is not None:
|
508 |
+
x = self.norm(x)
|
509 |
+
return x, H, W
|
510 |
+
|
511 |
+
def flops(self):
|
512 |
+
Ho, Wo = self.patches_resolution
|
513 |
+
flops = 0
|
514 |
+
for i, ps in enumerate(self.patch_size):
|
515 |
+
if i == len(self.patch_size) - 1:
|
516 |
+
dim = self.embed_dim // 2 ** i
|
517 |
+
else:
|
518 |
+
dim = self.embed_dim // 2 ** (i + 1)
|
519 |
+
flops += Ho * Wo * dim * self.in_chans * (self.patch_size[i] * self.patch_size[i])
|
520 |
+
if self.norm is not None:
|
521 |
+
flops += Ho * Wo * self.embed_dim
|
522 |
+
return flops
|
523 |
+
|
524 |
+
|
525 |
+
class CrossFormer(nn.Module):
|
526 |
+
r""" CrossFormer
|
527 |
+
A PyTorch impl of : `CrossFormer: A Versatile Vision Transformer Based on Cross-scale Attention` -
|
528 |
+
|
529 |
+
Args:
|
530 |
+
img_size (int | tuple(int)): Input image size. Default 224
|
531 |
+
patch_size (int | tuple(int)): Patch size. Default: 4
|
532 |
+
in_chans (int): Number of input image channels. Default: 3
|
533 |
+
num_classes (int): Number of classes for classification head. Default: 1000
|
534 |
+
embed_dim (int): Patch embedding dimension. Default: 96
|
535 |
+
depths (tuple(int)): Depth of each stage.
|
536 |
+
num_heads (tuple(int)): Number of attention heads in different layers.
|
537 |
+
group_size (int): Group size. Default: 7
|
538 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
|
539 |
+
qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
|
540 |
+
qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None
|
541 |
+
drop_rate (float): Dropout rate. Default: 0
|
542 |
+
attn_drop_rate (float): Attention dropout rate. Default: 0
|
543 |
+
drop_path_rate (float): Stochastic depth rate. Default: 0.1
|
544 |
+
norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
|
545 |
+
ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
|
546 |
+
patch_norm (bool): If True, add normalization after patch embedding. Default: True
|
547 |
+
use_checkpoint (bool): Ghether to use checkpointing to save memory. Default: False
|
548 |
+
"""
|
549 |
+
|
550 |
+
def __init__(self, img_size=224, patch_size=[4], in_chans=3, num_classes=1000,
|
551 |
+
embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24],
|
552 |
+
group_size=7, crs_interval=[8, 4, 2, 1], mlp_ratio=4., qkv_bias=True, qk_scale=None,
|
553 |
+
drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
|
554 |
+
norm_layer=nn.LayerNorm, patch_norm=True,
|
555 |
+
use_checkpoint=False, merge_size=[[2], [2], [2]], **kwargs):
|
556 |
+
super().__init__()
|
557 |
+
|
558 |
+
self.num_classes = num_classes
|
559 |
+
self.num_layers = len(depths)
|
560 |
+
self.embed_dim = embed_dim
|
561 |
+
self.patch_norm = patch_norm
|
562 |
+
self.num_features = int(embed_dim * 2 ** (self.num_layers - 1))
|
563 |
+
self.mlp_ratio = mlp_ratio
|
564 |
+
|
565 |
+
# split image into non-overlapping patches
|
566 |
+
self.patch_embed = PatchEmbed(
|
567 |
+
img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim,
|
568 |
+
norm_layer=norm_layer if self.patch_norm else None)
|
569 |
+
patches_resolution = self.patch_embed.patches_resolution
|
570 |
+
self.patches_resolution = patches_resolution # [H//4, W//4] of original image size
|
571 |
+
|
572 |
+
self.pos_drop = nn.Dropout(p=drop_rate)
|
573 |
+
|
574 |
+
# stochastic depth
|
575 |
+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
|
576 |
+
|
577 |
+
# build layers
|
578 |
+
self.layers = nn.ModuleList()
|
579 |
+
|
580 |
+
num_patch_sizes = [len(patch_size)] + [len(m) for m in merge_size]
|
581 |
+
for i_layer in range(self.num_layers):
|
582 |
+
patch_size_end = merge_size[i_layer] if i_layer < self.num_layers - 1 else None
|
583 |
+
num_patch_size = num_patch_sizes[i_layer]
|
584 |
+
layer = Stage(dim=int(embed_dim * 2 ** i_layer),
|
585 |
+
input_resolution=(patches_resolution[0] // (2 ** i_layer),
|
586 |
+
patches_resolution[1] // (2 ** i_layer)),
|
587 |
+
depth=depths[i_layer],
|
588 |
+
num_heads=num_heads[i_layer],
|
589 |
+
group_size=group_size[i_layer],
|
590 |
+
interval=crs_interval[i_layer],
|
591 |
+
mlp_ratio=self.mlp_ratio,
|
592 |
+
qkv_bias=qkv_bias, qk_scale=qk_scale,
|
593 |
+
drop=drop_rate, attn_drop=attn_drop_rate,
|
594 |
+
drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
|
595 |
+
norm_layer=norm_layer,
|
596 |
+
downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
|
597 |
+
use_checkpoint=use_checkpoint,
|
598 |
+
patch_size_end=patch_size_end,
|
599 |
+
num_patch_size=num_patch_size)
|
600 |
+
self.layers.append(layer)
|
601 |
+
|
602 |
+
# # classification
|
603 |
+
# self.norm = norm_layer(self.num_features)
|
604 |
+
# self.avgpool = nn.AdaptiveAvgPool1d(1)
|
605 |
+
# self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
|
606 |
+
|
607 |
+
self.apply(self._init_weights)
|
608 |
+
|
609 |
+
def init_weights(self, pretrained=None):
|
610 |
+
if isinstance(pretrained, str):
|
611 |
+
logger = get_root_logger()
|
612 |
+
load_checkpoint(self, pretrained, map_location='cpu', strict=False, logger=logger)
|
613 |
+
|
614 |
+
def _init_weights(self, m):
|
615 |
+
if isinstance(m, nn.Linear):
|
616 |
+
trunc_normal_(m.weight, std=.02)
|
617 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
618 |
+
nn.init.constant_(m.bias, 0)
|
619 |
+
elif isinstance(m, nn.LayerNorm):
|
620 |
+
nn.init.constant_(m.bias, 0)
|
621 |
+
nn.init.constant_(m.weight, 1.0)
|
622 |
+
|
623 |
+
@torch.jit.ignore
|
624 |
+
def no_weight_decay(self):
|
625 |
+
return {'absolute_pos_embed'}
|
626 |
+
|
627 |
+
@torch.jit.ignore
|
628 |
+
def no_weight_decay_keywords(self):
|
629 |
+
return {'relative_position_bias_table'}
|
630 |
+
|
631 |
+
def forward(self, x):
|
632 |
+
x, H, W = self.patch_embed(x)
|
633 |
+
x = self.pos_drop(x)
|
634 |
+
|
635 |
+
outs = []
|
636 |
+
for i, layer in enumerate(self.layers):
|
637 |
+
feat, x = layer(x, H //4 //(2 ** i), W //4 //(2 ** i))
|
638 |
+
outs.append(feat)
|
639 |
+
|
640 |
+
# # classification
|
641 |
+
# x = self.norm(x) # B L C
|
642 |
+
# x = self.avgpool(x.transpose(1, 2)) # B C 1
|
643 |
+
# x = torch.flatten(x, 1)
|
644 |
+
# x = self.head(x)
|
645 |
+
# return x
|
646 |
+
|
647 |
+
return outs
|
648 |
+
|
649 |
+
def flops(self):
|
650 |
+
flops = 0
|
651 |
+
excluded_flops = 0
|
652 |
+
flops += self.patch_embed.flops()
|
653 |
+
for i, layer in enumerate(self.layers):
|
654 |
+
layer_flops, layer_excluded_flops = layer.flops()
|
655 |
+
flops += layer_flops
|
656 |
+
excluded_flops += layer_excluded_flops
|
657 |
+
# flops += self.num_features * self.patches_resolution[0] * self.patches_resolution[1] // (2 ** self.num_layers)
|
658 |
+
# flops += self.num_features * self.num_classes
|
659 |
+
return flops, excluded_flops
|
PuzzleTuning/Backbone/getmodel.py
ADDED
@@ -0,0 +1,392 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
get model func Script ver: Dec 5th 14:20
|
3 |
+
"""
|
4 |
+
import os
|
5 |
+
import sys
|
6 |
+
sys.path.append(os.path.realpath('.'))
|
7 |
+
|
8 |
+
import torch
|
9 |
+
import torch.nn as nn
|
10 |
+
from torchvision import models
|
11 |
+
from Backbone import ResHybrid
|
12 |
+
|
13 |
+
|
14 |
+
# get model
|
15 |
+
def get_model(num_classes=1000, edge_size=224, model_idx=None, drop_rate=0.0, attn_drop_rate=0.0, drop_path_rate=0.0,
|
16 |
+
pretrained_backbone=True, use_cls_token=True, use_pos_embedding=True, use_att_module='SimAM'):
|
17 |
+
"""
|
18 |
+
:param num_classes: classification required number of your dataset
|
19 |
+
:param edge_size: the input edge size of the dataloder
|
20 |
+
:param model_idx: the model we are going to use. by the format of Model_size_other_info
|
21 |
+
|
22 |
+
:param drop_rate: The dropout layer's probility of proposed models
|
23 |
+
:param attn_drop_rate: The dropout layer(right after the MHSA block or MHGA block)'s probility of proposed models
|
24 |
+
:param drop_path_rate: The probility of stochastic depth
|
25 |
+
|
26 |
+
:param pretrained_backbone: The backbone CNN is initiate randomly or by its official Pretrained models
|
27 |
+
|
28 |
+
:param use_cls_token: To use the class token
|
29 |
+
:param use_pos_embedding: To use the positional enbedding
|
30 |
+
:param use_att_module: To use which attention module in the FGD Focus block
|
31 |
+
|
32 |
+
:return: prepared model
|
33 |
+
"""
|
34 |
+
if model_idx[0:5] == 'ViT_h':
|
35 |
+
# Transfer learning for ViT
|
36 |
+
import timm
|
37 |
+
from pprint import pprint
|
38 |
+
model_names = timm.list_models('*vit*')
|
39 |
+
pprint(model_names)
|
40 |
+
if edge_size == 224:
|
41 |
+
model = timm.create_model('vit_huge_patch14_224_in21k', pretrained=pretrained_backbone, num_classes=num_classes)
|
42 |
+
else:
|
43 |
+
print('not a avaliable image size with', model_idx)
|
44 |
+
|
45 |
+
elif model_idx[0:5] == 'ViT_l':
|
46 |
+
# Transfer learning for ViT
|
47 |
+
import timm
|
48 |
+
from pprint import pprint
|
49 |
+
model_names = timm.list_models('*vit*')
|
50 |
+
pprint(model_names)
|
51 |
+
if edge_size == 224:
|
52 |
+
model = timm.create_model('vit_large_patch16_224', pretrained=pretrained_backbone, num_classes=num_classes)
|
53 |
+
elif edge_size == 384:
|
54 |
+
model = timm.create_model('vit_large_patch16_384', pretrained=pretrained_backbone, num_classes=num_classes)
|
55 |
+
else:
|
56 |
+
print('not a avaliable image size with', model_idx)
|
57 |
+
|
58 |
+
elif model_idx[0:5] == 'ViT_s':
|
59 |
+
# Transfer learning for ViT
|
60 |
+
import timm
|
61 |
+
from pprint import pprint
|
62 |
+
model_names = timm.list_models('*vit*')
|
63 |
+
pprint(model_names)
|
64 |
+
if edge_size == 224:
|
65 |
+
model = timm.create_model('vit_small_patch16_224', pretrained=pretrained_backbone, num_classes=num_classes)
|
66 |
+
elif edge_size == 384:
|
67 |
+
model = timm.create_model('vit_small_patch16_384', pretrained=pretrained_backbone, num_classes=num_classes)
|
68 |
+
else:
|
69 |
+
print('not a avaliable image size with', model_idx)
|
70 |
+
|
71 |
+
elif model_idx[0:5] == 'ViT_t':
|
72 |
+
# Transfer learning for ViT
|
73 |
+
import timm
|
74 |
+
from pprint import pprint
|
75 |
+
model_names = timm.list_models('*vit*')
|
76 |
+
pprint(model_names)
|
77 |
+
if edge_size == 224:
|
78 |
+
model = timm.create_model('vit_tiny_patch16_224', pretrained=pretrained_backbone, num_classes=num_classes)
|
79 |
+
elif edge_size == 384:
|
80 |
+
model = timm.create_model('vit_tiny_patch16_384', pretrained=pretrained_backbone, num_classes=num_classes)
|
81 |
+
else:
|
82 |
+
print('not a avaliable image size with', model_idx)
|
83 |
+
|
84 |
+
elif model_idx[0:5] == 'ViT_b' or model_idx[0:3] == 'ViT': # vit_base
|
85 |
+
# Transfer learning for ViT
|
86 |
+
import timm
|
87 |
+
from pprint import pprint
|
88 |
+
model_names = timm.list_models('*vit*')
|
89 |
+
pprint(model_names)
|
90 |
+
if edge_size == 224:
|
91 |
+
model = timm.create_model('vit_base_patch16_224', pretrained=pretrained_backbone, num_classes=num_classes)
|
92 |
+
elif edge_size == 384:
|
93 |
+
model = timm.create_model('vit_base_patch16_384', pretrained=pretrained_backbone, num_classes=num_classes)
|
94 |
+
else:
|
95 |
+
print('not a avaliable image size with', model_idx)
|
96 |
+
|
97 |
+
elif model_idx[0:3] == 'vgg':
|
98 |
+
# Transfer learning for vgg16_bn
|
99 |
+
import timm
|
100 |
+
from pprint import pprint
|
101 |
+
model_names = timm.list_models('*vgg*')
|
102 |
+
pprint(model_names)
|
103 |
+
if model_idx[0:8] == 'vgg16_bn':
|
104 |
+
model = timm.create_model('vgg16_bn', pretrained=pretrained_backbone, num_classes=num_classes)
|
105 |
+
elif model_idx[0:5] == 'vgg16':
|
106 |
+
model = timm.create_model('vgg16', pretrained=pretrained_backbone, num_classes=num_classes)
|
107 |
+
elif model_idx[0:8] == 'vgg19_bn':
|
108 |
+
model = timm.create_model('vgg19_bn', pretrained=pretrained_backbone, num_classes=num_classes)
|
109 |
+
elif model_idx[0:5] == 'vgg19':
|
110 |
+
model = timm.create_model('vgg19', pretrained=pretrained_backbone, num_classes=num_classes)
|
111 |
+
|
112 |
+
elif model_idx[0:4] == 'deit': # Transfer learning for DeiT
|
113 |
+
import timm
|
114 |
+
from pprint import pprint
|
115 |
+
model_names = timm.list_models('*deit*')
|
116 |
+
pprint(model_names)
|
117 |
+
if edge_size == 384:
|
118 |
+
model = timm.create_model('deit_base_patch16_384', pretrained=pretrained_backbone, num_classes=2)
|
119 |
+
elif edge_size == 224:
|
120 |
+
model = timm.create_model('deit_base_patch16_224', pretrained=pretrained_backbone, num_classes=2)
|
121 |
+
else:
|
122 |
+
pass
|
123 |
+
|
124 |
+
elif model_idx[0:5] == 'twins': # Transfer learning for twins
|
125 |
+
import timm
|
126 |
+
from pprint import pprint
|
127 |
+
|
128 |
+
model_names = timm.list_models('*twins*')
|
129 |
+
pprint(model_names)
|
130 |
+
model = timm.create_model('twins_pcpvt_base', pretrained=pretrained_backbone, num_classes=num_classes)
|
131 |
+
|
132 |
+
elif model_idx[0:5] == 'pit_b' and edge_size == 224: # Transfer learning for PiT
|
133 |
+
import timm
|
134 |
+
from pprint import pprint
|
135 |
+
|
136 |
+
model_names = timm.list_models('*pit*')
|
137 |
+
pprint(model_names)
|
138 |
+
model = timm.create_model('pit_b_224', pretrained=pretrained_backbone, num_classes=num_classes)
|
139 |
+
|
140 |
+
elif model_idx[0:5] == 'gcvit' and edge_size == 224: # Transfer learning for gcvit
|
141 |
+
import timm
|
142 |
+
from pprint import pprint
|
143 |
+
|
144 |
+
model_names = timm.list_models('*gcvit*')
|
145 |
+
pprint(model_names)
|
146 |
+
model = timm.create_model('gcvit_base', pretrained=pretrained_backbone, num_classes=num_classes)
|
147 |
+
|
148 |
+
elif model_idx[0:6] == 'xcit_s': # Transfer learning for XCiT
|
149 |
+
import timm
|
150 |
+
from pprint import pprint
|
151 |
+
model_names = timm.list_models('*xcit*')
|
152 |
+
pprint(model_names)
|
153 |
+
if edge_size == 384:
|
154 |
+
model = timm.create_model('xcit_small_12_p16_384_dist', pretrained=pretrained_backbone,
|
155 |
+
num_classes=num_classes)
|
156 |
+
elif edge_size == 224:
|
157 |
+
model = timm.create_model('xcit_small_12_p16_224_dist', pretrained=pretrained_backbone,
|
158 |
+
num_classes=num_classes)
|
159 |
+
else:
|
160 |
+
pass
|
161 |
+
|
162 |
+
elif model_idx[0:6] == 'xcit_m': # Transfer learning for XCiT
|
163 |
+
import timm
|
164 |
+
from pprint import pprint
|
165 |
+
model_names = timm.list_models('*xcit*')
|
166 |
+
pprint(model_names)
|
167 |
+
if edge_size == 384:
|
168 |
+
model = timm.create_model('xcit_medium_24_p16_384_dist', pretrained=pretrained_backbone,
|
169 |
+
num_classes=num_classes)
|
170 |
+
elif edge_size == 224:
|
171 |
+
model = timm.create_model('xcit_medium_24_p16_224_dist', pretrained=pretrained_backbone,
|
172 |
+
num_classes=num_classes)
|
173 |
+
else:
|
174 |
+
pass
|
175 |
+
|
176 |
+
elif model_idx[0:6] == 'mvitv2': # Transfer learning for MViT v2 small fixme bug in model!
|
177 |
+
import timm
|
178 |
+
from pprint import pprint
|
179 |
+
model_names = timm.list_models('*mvitv2*')
|
180 |
+
pprint(model_names)
|
181 |
+
model = timm.create_model('mvitv2_small_cls', pretrained=pretrained_backbone, num_classes=num_classes)
|
182 |
+
|
183 |
+
elif model_idx[0:6] == 'convit' and edge_size == 224: # Transfer learning for ConViT fixme bug in model!
|
184 |
+
import timm
|
185 |
+
from pprint import pprint
|
186 |
+
|
187 |
+
model_names = timm.list_models('*convit*')
|
188 |
+
pprint(model_names)
|
189 |
+
model = timm.create_model('convit_base', pretrained=pretrained_backbone, num_classes=num_classes)
|
190 |
+
|
191 |
+
elif model_idx[0:6] == 'ResNet': # Transfer learning for the ResNets
|
192 |
+
if model_idx[0:8] == 'ResNet34':
|
193 |
+
model = models.resnet34(pretrained=pretrained_backbone)
|
194 |
+
elif model_idx[0:8] == 'ResNet50':
|
195 |
+
model = models.resnet50(pretrained=pretrained_backbone)
|
196 |
+
elif model_idx[0:9] == 'ResNet101':
|
197 |
+
model = models.resnet101(pretrained=pretrained_backbone)
|
198 |
+
else:
|
199 |
+
print('this model is not defined in get model')
|
200 |
+
return -1
|
201 |
+
num_ftrs = model.fc.in_features
|
202 |
+
model.fc = nn.Linear(num_ftrs, num_classes)
|
203 |
+
|
204 |
+
elif model_idx[0:6] == 'Backbone': # ours: MSHT
|
205 |
+
# NOTICE: HERE 'pretrained' controls only The backbone CNN is initiate randomly
|
206 |
+
# or by its official Pretrained models
|
207 |
+
model = ResHybrid.create_model(model_idx, edge_size, pretrained=pretrained_backbone, num_classes=num_classes,
|
208 |
+
drop_rate=drop_rate, attn_drop_rate=attn_drop_rate,
|
209 |
+
drop_path_rate=drop_path_rate, use_cls_token=use_cls_token,
|
210 |
+
use_pos_embedding=use_pos_embedding, use_att_module=use_att_module)
|
211 |
+
|
212 |
+
elif model_idx[0:7] == 'bot_256' and edge_size == 256: # Model: BoT
|
213 |
+
import timm
|
214 |
+
from pprint import pprint
|
215 |
+
model_names = timm.list_models('*bot*')
|
216 |
+
pprint(model_names)
|
217 |
+
# NOTICE: we find no weight for BoT in timm
|
218 |
+
# ['botnet26t_256', 'botnet50ts_256', 'eca_botnext26ts_256']
|
219 |
+
model = timm.create_model('botnet26t_256', pretrained=pretrained_backbone, num_classes=num_classes)
|
220 |
+
|
221 |
+
elif model_idx[0:8] == 'densenet': # Transfer learning for densenet
|
222 |
+
import timm
|
223 |
+
from pprint import pprint
|
224 |
+
|
225 |
+
model_names = timm.list_models('*densenet*')
|
226 |
+
pprint(model_names)
|
227 |
+
model = timm.create_model('densenet121', pretrained=pretrained_backbone, num_classes=num_classes)
|
228 |
+
|
229 |
+
elif model_idx[0:8] == 'xception': # Transfer learning for Xception
|
230 |
+
import timm
|
231 |
+
from pprint import pprint
|
232 |
+
model_names = timm.list_models('*xception*')
|
233 |
+
pprint(model_names)
|
234 |
+
model = timm.create_model('xception', pretrained=pretrained_backbone, num_classes=num_classes)
|
235 |
+
|
236 |
+
elif model_idx[0:9] == 'pvt_v2_b0': # Transfer learning for PVT v2 (todo not okey with torch summary)
|
237 |
+
import timm
|
238 |
+
from pprint import pprint
|
239 |
+
model_names = timm.list_models('*pvt_v2*')
|
240 |
+
pprint(model_names)
|
241 |
+
model = timm.create_model('pvt_v2_b0', pretrained=pretrained_backbone, num_classes=num_classes)
|
242 |
+
|
243 |
+
elif model_idx[0:9] == 'visformer' and edge_size == 224: # Transfer learning for Visformer
|
244 |
+
import timm
|
245 |
+
from pprint import pprint
|
246 |
+
model_names = timm.list_models('*visformer*')
|
247 |
+
pprint(model_names)
|
248 |
+
model = timm.create_model('visformer_small', pretrained=pretrained_backbone, num_classes=num_classes)
|
249 |
+
|
250 |
+
elif model_idx[0:9] == 'conformer': # Transfer learning for Conformer base
|
251 |
+
from Backbone.counterpart_models import conformer
|
252 |
+
|
253 |
+
embed_dim = 576
|
254 |
+
channel_ratio = 6
|
255 |
+
|
256 |
+
if pretrained_backbone:
|
257 |
+
model = conformer.Conformer(num_classes=1000, patch_size=16, channel_ratio=channel_ratio,
|
258 |
+
embed_dim=embed_dim, depth=12, num_heads=9, mlp_ratio=4, qkv_bias=True)
|
259 |
+
# this is the related path to <code>, not <Backbone>
|
260 |
+
save_model_path = '../saved_models/Conformer_base_patch16.pth' # fixme model is downloaded at this path
|
261 |
+
# downloaded from official model state at https://github.com/pengzhiliang/Conformer
|
262 |
+
model.load_state_dict(torch.load(save_model_path), False)
|
263 |
+
|
264 |
+
model.trans_cls_head = nn.Linear(embed_dim, num_classes)
|
265 |
+
model.conv_cls_head = nn.Linear(int(256 * channel_ratio), num_classes)
|
266 |
+
model.cls_head = nn.Linear(int(2 * num_classes), num_classes)
|
267 |
+
|
268 |
+
else:
|
269 |
+
model = conformer.Conformer(num_classes=num_classes, patch_size=16, channel_ratio=channel_ratio,
|
270 |
+
embed_dim=embed_dim, depth=12, num_heads=9, mlp_ratio=4, qkv_bias=True)
|
271 |
+
|
272 |
+
elif model_idx[0:9] == 'coat_mini' and edge_size == 224: # Transfer learning for coat_mini
|
273 |
+
import timm
|
274 |
+
from pprint import pprint
|
275 |
+
|
276 |
+
model_names = timm.list_models('*coat*')
|
277 |
+
pprint(model_names)
|
278 |
+
model = timm.create_model('coat_mini', pretrained=pretrained_backbone, num_classes=num_classes)
|
279 |
+
|
280 |
+
elif model_idx[0:10] == 'swin_b_384' and edge_size == 384: # Transfer learning for Swin Transformer (swin_b_384)
|
281 |
+
import timm
|
282 |
+
from pprint import pprint
|
283 |
+
model_names = timm.list_models('*swin*')
|
284 |
+
pprint(model_names) # swin_base_patch4_window12_384 swin_base_patch4_window12_384_in22k
|
285 |
+
model = timm.create_model('swin_base_patch4_window12_384', pretrained=pretrained_backbone,
|
286 |
+
num_classes=num_classes)
|
287 |
+
|
288 |
+
elif model_idx[0:10] == 'swin_b_224' and edge_size == 224: # Transfer learning for Swin Transformer (swin_b_384)
|
289 |
+
import timm
|
290 |
+
from pprint import pprint
|
291 |
+
model_names = timm.list_models('*swin*')
|
292 |
+
pprint(model_names) # swin_base_patch4_window7_224 swin_base_patch4_window7_224_in22k
|
293 |
+
model = timm.create_model('swin_base_patch4_window7_224', pretrained=pretrained_backbone,
|
294 |
+
num_classes=num_classes)
|
295 |
+
|
296 |
+
elif model_idx[0:11] == 'mobilenetv3': # Transfer learning for mobilenetv3
|
297 |
+
import timm
|
298 |
+
from pprint import pprint
|
299 |
+
model_names = timm.list_models('*mobilenet*')
|
300 |
+
pprint(model_names)
|
301 |
+
model = timm.create_model('mobilenetv3_large_100', pretrained=pretrained_backbone, num_classes=num_classes)
|
302 |
+
|
303 |
+
elif model_idx[0:11] == 'mobilevit_s': # Transfer learning for mobilevit_s
|
304 |
+
import timm
|
305 |
+
from pprint import pprint
|
306 |
+
model_names = timm.list_models('*mobilevit*')
|
307 |
+
pprint(model_names)
|
308 |
+
model = timm.create_model('mobilevit_s', pretrained=pretrained_backbone, num_classes=num_classes)
|
309 |
+
|
310 |
+
elif model_idx[0:11] == 'inceptionv3': # Transfer learning for Inception v3
|
311 |
+
import timm
|
312 |
+
from pprint import pprint
|
313 |
+
model_names = timm.list_models('*inception*')
|
314 |
+
pprint(model_names)
|
315 |
+
model = timm.create_model('inception_v3', pretrained=pretrained_backbone, num_classes=num_classes)
|
316 |
+
|
317 |
+
elif model_idx[0:12] == 'cross_former' and edge_size == 224: # Transfer learning for crossformer base
|
318 |
+
from Backbone.counterpart_models import crossformer
|
319 |
+
backbone = crossformer.CrossFormer(img_size=edge_size,
|
320 |
+
patch_size=[4, 8, 16, 32],
|
321 |
+
in_chans=3,
|
322 |
+
num_classes=0, # get backbone only
|
323 |
+
embed_dim=96,
|
324 |
+
depths=[2, 2, 18, 2],
|
325 |
+
num_heads=[3, 6, 12, 24],
|
326 |
+
group_size=[7, 7, 7, 7],
|
327 |
+
mlp_ratio=4.,
|
328 |
+
qkv_bias=True,
|
329 |
+
qk_scale=None,
|
330 |
+
drop_rate=0.0,
|
331 |
+
drop_path_rate=0.3,
|
332 |
+
ape=False,
|
333 |
+
patch_norm=True,
|
334 |
+
use_checkpoint=False,
|
335 |
+
merge_size=[[2, 4], [2, 4], [2, 4]], )
|
336 |
+
if pretrained_backbone:
|
337 |
+
save_model_path = '../saved_models/crossformer-b.pth' # fixme model is downloaded at this path
|
338 |
+
# downloaded from official model state at https://github.com/cheerss/CrossFormer
|
339 |
+
backbone.load_state_dict(torch.load(save_model_path)['model'], False)
|
340 |
+
model = crossformer.cross_former_cls_head_warp(backbone, num_classes)
|
341 |
+
|
342 |
+
elif model_idx[0:13] == 'crossvit_base': # Transfer learning for crossvit_base (todo not okey with torch summary)
|
343 |
+
import timm
|
344 |
+
from pprint import pprint
|
345 |
+
model_names = timm.list_models('*crossvit_base*')
|
346 |
+
pprint(model_names)
|
347 |
+
model = timm.create_model('crossvit_base_240', pretrained=pretrained_backbone, num_classes=num_classes)
|
348 |
+
|
349 |
+
elif model_idx[0:14] == 'efficientnet_b': # Transfer learning for efficientnet_b3,4
|
350 |
+
import timm
|
351 |
+
from pprint import pprint
|
352 |
+
model_names = timm.list_models('*efficientnet*')
|
353 |
+
pprint(model_names)
|
354 |
+
model = timm.create_model(model_idx[0:15], pretrained=pretrained_backbone, num_classes=num_classes)
|
355 |
+
|
356 |
+
elif model_idx[0:14] == 'ResN50_ViT_384': # ResNet+ViT融合模型384
|
357 |
+
import timm
|
358 |
+
from pprint import pprint
|
359 |
+
model_names = timm.list_models('*vit_base_resnet*')
|
360 |
+
pprint(model_names)
|
361 |
+
model = timm.create_model('vit_base_resnet50_384', pretrained=pretrained_backbone, num_classes=num_classes)
|
362 |
+
|
363 |
+
elif model_idx[0:15] == 'coat_lite_small' and edge_size == 224: # Transfer learning for coat_lite_small
|
364 |
+
import timm
|
365 |
+
from pprint import pprint
|
366 |
+
|
367 |
+
model_names = timm.list_models('*coat*')
|
368 |
+
pprint(model_names)
|
369 |
+
model = timm.create_model('coat_lite_small', pretrained=pretrained_backbone, num_classes=num_classes)
|
370 |
+
|
371 |
+
elif model_idx[0:17] == 'efficientformer_l' and edge_size == 224: # Transfer learning for efficientnet_b3,4
|
372 |
+
import timm
|
373 |
+
from pprint import pprint
|
374 |
+
model_names = timm.list_models('*efficientformer*')
|
375 |
+
pprint(model_names)
|
376 |
+
model = timm.create_model(model_idx[0:18], pretrained=pretrained_backbone, num_classes=num_classes)
|
377 |
+
|
378 |
+
else:
|
379 |
+
print('\nThe model', model_idx, 'with the edge size of', edge_size)
|
380 |
+
print("is not defined in the script!!", '\n')
|
381 |
+
return -1
|
382 |
+
|
383 |
+
try:
|
384 |
+
img = torch.randn(1, 3, edge_size, edge_size)
|
385 |
+
preds = model(img) # (1, class_number)
|
386 |
+
print('test model output:', preds)
|
387 |
+
except:
|
388 |
+
print("Problem exist in the model defining process!!")
|
389 |
+
return -1
|
390 |
+
else:
|
391 |
+
print('model is ready now!')
|
392 |
+
return model
|
PuzzleTuning/Counterpart PreTrain Methods/ReadMe.md
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
we have used MOCO-V3, Dino, MAE to pretrain the ViT-base-224 model.
|
2 |
+
The official codes are implemented here.
|
PuzzleTuning/Counterpart PreTrain Methods/dino-main/LICENSE
ADDED
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Apache License
|
2 |
+
Version 2.0, January 2004
|
3 |
+
http://www.apache.org/licenses/
|
4 |
+
|
5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
6 |
+
|
7 |
+
1. Definitions.
|
8 |
+
|
9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
11 |
+
|
12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
13 |
+
the copyright owner that is granting the License.
|
14 |
+
|
15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
16 |
+
other entities that control, are controlled by, or are under common
|
17 |
+
control with that entity. For the purposes of this definition,
|
18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
19 |
+
direction or management of such entity, whether by contract or
|
20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
22 |
+
|
23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
24 |
+
exercising permissions granted by this License.
|
25 |
+
|
26 |
+
"Source" form shall mean the preferred form for making modifications,
|
27 |
+
including but not limited to software source code, documentation
|
28 |
+
source, and configuration files.
|
29 |
+
|
30 |
+
"Object" form shall mean any form resulting from mechanical
|
31 |
+
transformation or translation of a Source form, including but
|
32 |
+
not limited to compiled object code, generated documentation,
|
33 |
+
and conversions to other media types.
|
34 |
+
|
35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
36 |
+
Object form, made available under the License, as indicated by a
|
37 |
+
copyright notice that is included in or attached to the work
|
38 |
+
(an example is provided in the Appendix below).
|
39 |
+
|
40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
41 |
+
form, that is based on (or derived from) the Work and for which the
|
42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
44 |
+
of this License, Derivative Works shall not include works that remain
|
45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
46 |
+
the Work and Derivative Works thereof.
|
47 |
+
|
48 |
+
"Contribution" shall mean any work of authorship, including
|
49 |
+
the original version of the Work and any modifications or additions
|
50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
54 |
+
means any form of electronic, verbal, or written communication sent
|
55 |
+
to the Licensor or its representatives, including but not limited to
|
56 |
+
communication on electronic mailing lists, source code control systems,
|
57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
59 |
+
excluding communication that is conspicuously marked or otherwise
|
60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
61 |
+
|
62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
64 |
+
subsequently incorporated within the Work.
|
65 |
+
|
66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
71 |
+
Work and such Derivative Works in Source or Object form.
|
72 |
+
|
73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
76 |
+
(except as stated in this section) patent license to make, have made,
|
77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
78 |
+
where such license applies only to those patent claims licensable
|
79 |
+
by such Contributor that are necessarily infringed by their
|
80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
82 |
+
institute patent litigation against any entity (including a
|
83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
84 |
+
or a Contribution incorporated within the Work constitutes direct
|
85 |
+
or contributory patent infringement, then any patent licenses
|
86 |
+
granted to You under this License for that Work shall terminate
|
87 |
+
as of the date such litigation is filed.
|
88 |
+
|
89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
90 |
+
Work or Derivative Works thereof in any medium, with or without
|
91 |
+
modifications, and in Source or Object form, provided that You
|
92 |
+
meet the following conditions:
|
93 |
+
|
94 |
+
(a) You must give any other recipients of the Work or
|
95 |
+
Derivative Works a copy of this License; and
|
96 |
+
|
97 |
+
(b) You must cause any modified files to carry prominent notices
|
98 |
+
stating that You changed the files; and
|
99 |
+
|
100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
101 |
+
that You distribute, all copyright, patent, trademark, and
|
102 |
+
attribution notices from the Source form of the Work,
|
103 |
+
excluding those notices that do not pertain to any part of
|
104 |
+
the Derivative Works; and
|
105 |
+
|
106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
107 |
+
distribution, then any Derivative Works that You distribute must
|
108 |
+
include a readable copy of the attribution notices contained
|
109 |
+
within such NOTICE file, excluding those notices that do not
|
110 |
+
pertain to any part of the Derivative Works, in at least one
|
111 |
+
of the following places: within a NOTICE text file distributed
|
112 |
+
as part of the Derivative Works; within the Source form or
|
113 |
+
documentation, if provided along with the Derivative Works; or,
|
114 |
+
within a display generated by the Derivative Works, if and
|
115 |
+
wherever such third-party notices normally appear. The contents
|
116 |
+
of the NOTICE file are for informational purposes only and
|
117 |
+
do not modify the License. You may add Your own attribution
|
118 |
+
notices within Derivative Works that You distribute, alongside
|
119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
120 |
+
that such additional attribution notices cannot be construed
|
121 |
+
as modifying the License.
|
122 |
+
|
123 |
+
You may add Your own copyright statement to Your modifications and
|
124 |
+
may provide additional or different license terms and conditions
|
125 |
+
for use, reproduction, or distribution of Your modifications, or
|
126 |
+
for any such Derivative Works as a whole, provided Your use,
|
127 |
+
reproduction, and distribution of the Work otherwise complies with
|
128 |
+
the conditions stated in this License.
|
129 |
+
|
130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
132 |
+
by You to the Licensor shall be under the terms and conditions of
|
133 |
+
this License, without any additional terms or conditions.
|
134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
135 |
+
the terms of any separate license agreement you may have executed
|
136 |
+
with Licensor regarding such Contributions.
|
137 |
+
|
138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
140 |
+
except as required for reasonable and customary use in describing the
|
141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
142 |
+
|
143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
144 |
+
agreed to in writing, Licensor provides the Work (and each
|
145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
147 |
+
implied, including, without limitation, any warranties or conditions
|
148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
150 |
+
appropriateness of using or redistributing the Work and assume any
|
151 |
+
risks associated with Your exercise of permissions under this License.
|
152 |
+
|
153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
154 |
+
whether in tort (including negligence), contract, or otherwise,
|
155 |
+
unless required by applicable law (such as deliberate and grossly
|
156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
157 |
+
liable to You for damages, including any direct, indirect, special,
|
158 |
+
incidental, or consequential damages of any character arising as a
|
159 |
+
result of this License or out of the use or inability to use the
|
160 |
+
Work (including but not limited to damages for loss of goodwill,
|
161 |
+
work stoppage, computer failure or malfunction, or any and all
|
162 |
+
other commercial damages or losses), even if such Contributor
|
163 |
+
has been advised of the possibility of such damages.
|
164 |
+
|
165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
168 |
+
or other liability obligations and/or rights consistent with this
|
169 |
+
License. However, in accepting such obligations, You may act only
|
170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
171 |
+
of any other Contributor, and only if You agree to indemnify,
|
172 |
+
defend, and hold each Contributor harmless for any liability
|
173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
174 |
+
of your accepting any such warranty or additional liability.
|
175 |
+
|
176 |
+
END OF TERMS AND CONDITIONS
|
177 |
+
|
178 |
+
APPENDIX: How to apply the Apache License to your work.
|
179 |
+
|
180 |
+
To apply the Apache License to your work, attach the following
|
181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
182 |
+
replaced with your own identifying information. (Don't include
|
183 |
+
the brackets!) The text should be enclosed in the appropriate
|
184 |
+
comment syntax for the file format. We also recommend that a
|
185 |
+
file or class name and description of purpose be included on the
|
186 |
+
same "printed page" as the copyright notice for easier
|
187 |
+
identification within third-party archives.
|
188 |
+
|
189 |
+
Copyright [yyyy] [name of copyright owner]
|
190 |
+
|
191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
192 |
+
you may not use this file except in compliance with the License.
|
193 |
+
You may obtain a copy of the License at
|
194 |
+
|
195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
196 |
+
|
197 |
+
Unless required by applicable law or agreed to in writing, software
|
198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
200 |
+
See the License for the specific language governing permissions and
|
201 |
+
limitations under the License.
|
PuzzleTuning/Counterpart PreTrain Methods/dino-main/README.md
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Self-Supervised Vision Transformers with DINO
|
2 |
+
|
3 |
+
The original repo of DINO could be found [here](https://github.com/facebookresearch/dino "DINO")
|
4 |
+
|
5 |
+
Pip requirements: timm == 0.4.9, PyTorch == 1.7.1, Torchvision == 0.8.2, Cuda == 11.0
|
6 |
+
|
7 |
+
Typical BASH:
|
8 |
+
```console
|
9 |
+
python -m torch.distributed.launch \
|
10 |
+
--nproc_per_node=2 main_dino.py --arch vit_base --batch_size_per_gpu 128 \
|
11 |
+
--lr 1.5e-4 --epochs 100 --data_path /root/autodl-tmp/All \
|
12 |
+
--basic_state_dict /root/autodl-tmp/ViT_b16_224_Imagenet.pth \
|
13 |
+
--num_workers 32 --output_dir the/path/of/CPIA
|
14 |
+
```
|
PuzzleTuning/Counterpart PreTrain Methods/dino-main/eval_copy_detection.py
ADDED
@@ -0,0 +1,301 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
import os
|
15 |
+
import sys
|
16 |
+
import pickle
|
17 |
+
import argparse
|
18 |
+
|
19 |
+
import torch
|
20 |
+
from torch import nn
|
21 |
+
import torch.distributed as dist
|
22 |
+
import torch.backends.cudnn as cudnn
|
23 |
+
from torchvision import models as torchvision_models
|
24 |
+
from torchvision import transforms as pth_transforms
|
25 |
+
from PIL import Image, ImageFile
|
26 |
+
import numpy as np
|
27 |
+
|
28 |
+
import utils
|
29 |
+
import vision_transformer as vits
|
30 |
+
from eval_knn import extract_features
|
31 |
+
|
32 |
+
|
33 |
+
class CopydaysDataset():
|
34 |
+
def __init__(self, basedir):
|
35 |
+
self.basedir = basedir
|
36 |
+
self.block_names = (
|
37 |
+
['original', 'strong'] +
|
38 |
+
['jpegqual/%d' % i for i in
|
39 |
+
[3, 5, 8, 10, 15, 20, 30, 50, 75]] +
|
40 |
+
['crops/%d' % i for i in
|
41 |
+
[10, 15, 20, 30, 40, 50, 60, 70, 80]])
|
42 |
+
self.nblocks = len(self.block_names)
|
43 |
+
|
44 |
+
self.query_blocks = range(self.nblocks)
|
45 |
+
self.q_block_sizes = np.ones(self.nblocks, dtype=int) * 157
|
46 |
+
self.q_block_sizes[1] = 229
|
47 |
+
# search only among originals
|
48 |
+
self.database_blocks = [0]
|
49 |
+
|
50 |
+
def get_block(self, i):
|
51 |
+
dirname = self.basedir + '/' + self.block_names[i]
|
52 |
+
fnames = [dirname + '/' + fname
|
53 |
+
for fname in sorted(os.listdir(dirname))
|
54 |
+
if fname.endswith('.jpg')]
|
55 |
+
return fnames
|
56 |
+
|
57 |
+
def get_block_filenames(self, subdir_name):
|
58 |
+
dirname = self.basedir + '/' + subdir_name
|
59 |
+
return [fname
|
60 |
+
for fname in sorted(os.listdir(dirname))
|
61 |
+
if fname.endswith('.jpg')]
|
62 |
+
|
63 |
+
def eval_result(self, ids, distances):
|
64 |
+
j0 = 0
|
65 |
+
for i in range(self.nblocks):
|
66 |
+
j1 = j0 + self.q_block_sizes[i]
|
67 |
+
block_name = self.block_names[i]
|
68 |
+
I = ids[j0:j1] # block size
|
69 |
+
sum_AP = 0
|
70 |
+
if block_name != 'strong':
|
71 |
+
# 1:1 mapping of files to names
|
72 |
+
positives_per_query = [[i] for i in range(j1 - j0)]
|
73 |
+
else:
|
74 |
+
originals = self.get_block_filenames('original')
|
75 |
+
strongs = self.get_block_filenames('strong')
|
76 |
+
|
77 |
+
# check if prefixes match
|
78 |
+
positives_per_query = [
|
79 |
+
[j for j, bname in enumerate(originals)
|
80 |
+
if bname[:4] == qname[:4]]
|
81 |
+
for qname in strongs]
|
82 |
+
|
83 |
+
for qno, Iline in enumerate(I):
|
84 |
+
positives = positives_per_query[qno]
|
85 |
+
ranks = []
|
86 |
+
for rank, bno in enumerate(Iline):
|
87 |
+
if bno in positives:
|
88 |
+
ranks.append(rank)
|
89 |
+
sum_AP += score_ap_from_ranks_1(ranks, len(positives))
|
90 |
+
|
91 |
+
print("eval on %s mAP=%.3f" % (
|
92 |
+
block_name, sum_AP / (j1 - j0)))
|
93 |
+
j0 = j1
|
94 |
+
|
95 |
+
|
96 |
+
# from the Holidays evaluation package
|
97 |
+
def score_ap_from_ranks_1(ranks, nres):
|
98 |
+
""" Compute the average precision of one search.
|
99 |
+
ranks = ordered list of ranks of true positives
|
100 |
+
nres = total number of positives in dataset
|
101 |
+
"""
|
102 |
+
|
103 |
+
# accumulate trapezoids in PR-plot
|
104 |
+
ap = 0.0
|
105 |
+
|
106 |
+
# All have an x-size of:
|
107 |
+
recall_step = 1.0 / nres
|
108 |
+
|
109 |
+
for ntp, rank in enumerate(ranks):
|
110 |
+
|
111 |
+
# y-size on left side of trapezoid:
|
112 |
+
# ntp = nb of true positives so far
|
113 |
+
# rank = nb of retrieved items so far
|
114 |
+
if rank == 0:
|
115 |
+
precision_0 = 1.0
|
116 |
+
else:
|
117 |
+
precision_0 = ntp / float(rank)
|
118 |
+
|
119 |
+
# y-size on right side of trapezoid:
|
120 |
+
# ntp and rank are increased by one
|
121 |
+
precision_1 = (ntp + 1) / float(rank + 1)
|
122 |
+
|
123 |
+
ap += (precision_1 + precision_0) * recall_step / 2.0
|
124 |
+
|
125 |
+
return ap
|
126 |
+
|
127 |
+
|
128 |
+
class ImgListDataset(torch.utils.data.Dataset):
|
129 |
+
def __init__(self, img_list, transform=None):
|
130 |
+
self.samples = img_list
|
131 |
+
self.transform = transform
|
132 |
+
|
133 |
+
def __getitem__(self, i):
|
134 |
+
with open(self.samples[i], 'rb') as f:
|
135 |
+
img = Image.open(f)
|
136 |
+
img = img.convert('RGB')
|
137 |
+
if self.transform is not None:
|
138 |
+
img = self.transform(img)
|
139 |
+
return img, i
|
140 |
+
|
141 |
+
def __len__(self):
|
142 |
+
return len(self.samples)
|
143 |
+
|
144 |
+
|
145 |
+
def is_image_file(s):
|
146 |
+
ext = s.split(".")[-1]
|
147 |
+
if ext in ['jpg', 'jpeg', 'png', 'ppm', 'bmp', 'pgm', 'tif', 'tiff', 'webp']:
|
148 |
+
return True
|
149 |
+
return False
|
150 |
+
|
151 |
+
|
152 |
+
@torch.no_grad()
|
153 |
+
def extract_features(image_list, model, args):
|
154 |
+
transform = pth_transforms.Compose([
|
155 |
+
pth_transforms.Resize((args.imsize, args.imsize), interpolation=3),
|
156 |
+
pth_transforms.ToTensor(),
|
157 |
+
pth_transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
|
158 |
+
])
|
159 |
+
tempdataset = ImgListDataset(image_list, transform=transform)
|
160 |
+
data_loader = torch.utils.data.DataLoader(tempdataset, batch_size=args.batch_size_per_gpu,
|
161 |
+
num_workers=args.num_workers, drop_last=False,
|
162 |
+
sampler=torch.utils.data.DistributedSampler(tempdataset, shuffle=False))
|
163 |
+
features = None
|
164 |
+
for samples, index in utils.MetricLogger(delimiter=" ").log_every(data_loader, 10):
|
165 |
+
samples, index = samples.cuda(non_blocking=True), index.cuda(non_blocking=True)
|
166 |
+
feats = model.get_intermediate_layers(samples, n=1)[0].clone()
|
167 |
+
|
168 |
+
cls_output_token = feats[:, 0, :] # [CLS] token
|
169 |
+
# GeM with exponent 4 for output patch tokens
|
170 |
+
b, h, w, d = len(samples), int(samples.shape[-2] / model.patch_embed.patch_size), int(samples.shape[-1] / model.patch_embed.patch_size), feats.shape[-1]
|
171 |
+
feats = feats[:, 1:, :].reshape(b, h, w, d)
|
172 |
+
feats = feats.clamp(min=1e-6).permute(0, 3, 1, 2)
|
173 |
+
feats = nn.functional.avg_pool2d(feats.pow(4), (h, w)).pow(1. / 4).reshape(b, -1)
|
174 |
+
# concatenate [CLS] token and GeM pooled patch tokens
|
175 |
+
feats = torch.cat((cls_output_token, feats), dim=1)
|
176 |
+
|
177 |
+
# init storage feature matrix
|
178 |
+
if dist.get_rank() == 0 and features is None:
|
179 |
+
features = torch.zeros(len(data_loader.dataset), feats.shape[-1])
|
180 |
+
if args.use_cuda:
|
181 |
+
features = features.cuda(non_blocking=True)
|
182 |
+
|
183 |
+
# get indexes from all processes
|
184 |
+
y_all = torch.empty(dist.get_world_size(), index.size(0), dtype=index.dtype, device=index.device)
|
185 |
+
y_l = list(y_all.unbind(0))
|
186 |
+
y_all_reduce = torch.distributed.all_gather(y_l, index, async_op=True)
|
187 |
+
y_all_reduce.wait()
|
188 |
+
index_all = torch.cat(y_l)
|
189 |
+
|
190 |
+
# share features between processes
|
191 |
+
feats_all = torch.empty(dist.get_world_size(), feats.size(0), feats.size(1),
|
192 |
+
dtype=feats.dtype, device=feats.device)
|
193 |
+
output_l = list(feats_all.unbind(0))
|
194 |
+
output_all_reduce = torch.distributed.all_gather(output_l, feats, async_op=True)
|
195 |
+
output_all_reduce.wait()
|
196 |
+
|
197 |
+
# update storage feature matrix
|
198 |
+
if dist.get_rank() == 0:
|
199 |
+
if args.use_cuda:
|
200 |
+
features.index_copy_(0, index_all, torch.cat(output_l))
|
201 |
+
else:
|
202 |
+
features.index_copy_(0, index_all.cpu(), torch.cat(output_l).cpu())
|
203 |
+
return features # features is still None for every rank which is not 0 (main)
|
204 |
+
|
205 |
+
|
206 |
+
if __name__ == '__main__':
|
207 |
+
parser = argparse.ArgumentParser('Copy detection on Copydays')
|
208 |
+
parser.add_argument('--data_path', default='/path/to/copydays/', type=str,
|
209 |
+
help="See https://lear.inrialpes.fr/~jegou/data.php#copydays")
|
210 |
+
parser.add_argument('--whitening_path', default='/path/to/whitening_data/', type=str,
|
211 |
+
help="""Path to directory with images used for computing the whitening operator.
|
212 |
+
In our paper, we use 20k random images from YFCC100M.""")
|
213 |
+
parser.add_argument('--distractors_path', default='/path/to/distractors/', type=str,
|
214 |
+
help="Path to directory with distractors images. In our paper, we use 10k random images from YFCC100M.")
|
215 |
+
parser.add_argument('--imsize', default=320, type=int, help='Image size (square image)')
|
216 |
+
parser.add_argument('--batch_size_per_gpu', default=16, type=int, help='Per-GPU batch-size')
|
217 |
+
parser.add_argument('--pretrained_weights', default='', type=str, help="Path to pretrained weights to evaluate.")
|
218 |
+
parser.add_argument('--use_cuda', default=True, type=utils.bool_flag)
|
219 |
+
parser.add_argument('--arch', default='vit_base', type=str, help='Architecture')
|
220 |
+
parser.add_argument('--patch_size', default=8, type=int, help='Patch resolution of the model.')
|
221 |
+
parser.add_argument("--checkpoint_key", default="teacher", type=str,
|
222 |
+
help='Key to use in the checkpoint (example: "teacher")')
|
223 |
+
parser.add_argument('--num_workers', default=10, type=int, help='Number of data loading workers per GPU.')
|
224 |
+
parser.add_argument("--dist_url", default="env://", type=str, help="""url used to set up
|
225 |
+
distributed training; see https://pytorch.org/docs/stable/distributed.html""")
|
226 |
+
parser.add_argument("--local_rank", default=0, type=int, help="Please ignore and do not set this argument.")
|
227 |
+
args = parser.parse_args()
|
228 |
+
|
229 |
+
utils.init_distributed_mode(args)
|
230 |
+
print("git:\n {}\n".format(utils.get_sha()))
|
231 |
+
print("\n".join("%s: %s" % (k, str(v)) for k, v in sorted(dict(vars(args)).items())))
|
232 |
+
cudnn.benchmark = True
|
233 |
+
|
234 |
+
# ============ building network ... ============
|
235 |
+
if "vit" in args.arch:
|
236 |
+
model = vits.__dict__[args.arch](patch_size=args.patch_size, num_classes=0)
|
237 |
+
print(f"Model {args.arch} {args.patch_size}x{args.patch_size} built.")
|
238 |
+
else:
|
239 |
+
print(f"Architecture {args.arch} non supported")
|
240 |
+
sys.exit(1)
|
241 |
+
if args.use_cuda:
|
242 |
+
model.cuda()
|
243 |
+
model.eval()
|
244 |
+
utils.load_pretrained_weights(model, args.pretrained_weights, args.checkpoint_key, args.arch, args.patch_size)
|
245 |
+
|
246 |
+
dataset = CopydaysDataset(args.data_path)
|
247 |
+
|
248 |
+
# ============ Extract features ... ============
|
249 |
+
# extract features for queries
|
250 |
+
queries = []
|
251 |
+
for q in dataset.query_blocks:
|
252 |
+
queries.append(extract_features(dataset.get_block(q), model, args))
|
253 |
+
if utils.get_rank() == 0:
|
254 |
+
queries = torch.cat(queries)
|
255 |
+
print(f"Extraction of queries features done. Shape: {queries.shape}")
|
256 |
+
|
257 |
+
# extract features for database
|
258 |
+
database = []
|
259 |
+
for b in dataset.database_blocks:
|
260 |
+
database.append(extract_features(dataset.get_block(b), model, args))
|
261 |
+
|
262 |
+
# extract features for distractors
|
263 |
+
if os.path.isdir(args.distractors_path):
|
264 |
+
print("Using distractors...")
|
265 |
+
list_distractors = [os.path.join(args.distractors_path, s) for s in os.listdir(args.distractors_path) if is_image_file(s)]
|
266 |
+
database.append(extract_features(list_distractors, model, args))
|
267 |
+
if utils.get_rank() == 0:
|
268 |
+
database = torch.cat(database)
|
269 |
+
print(f"Extraction of database and distractors features done. Shape: {database.shape}")
|
270 |
+
|
271 |
+
# ============ Whitening ... ============
|
272 |
+
if os.path.isdir(args.whitening_path):
|
273 |
+
print(f"Extracting features on images from {args.whitening_path} for learning the whitening operator.")
|
274 |
+
list_whit = [os.path.join(args.whitening_path, s) for s in os.listdir(args.whitening_path) if is_image_file(s)]
|
275 |
+
features_for_whitening = extract_features(list_whit, model, args)
|
276 |
+
if utils.get_rank() == 0:
|
277 |
+
# center
|
278 |
+
mean_feature = torch.mean(features_for_whitening, dim=0)
|
279 |
+
database -= mean_feature
|
280 |
+
queries -= mean_feature
|
281 |
+
pca = utils.PCA(dim=database.shape[-1], whit=0.5)
|
282 |
+
# compute covariance
|
283 |
+
cov = torch.mm(features_for_whitening.T, features_for_whitening) / features_for_whitening.shape[0]
|
284 |
+
pca.train_pca(cov.cpu().numpy())
|
285 |
+
database = pca.apply(database)
|
286 |
+
queries = pca.apply(queries)
|
287 |
+
|
288 |
+
# ============ Copy detection ... ============
|
289 |
+
if utils.get_rank() == 0:
|
290 |
+
# l2 normalize the features
|
291 |
+
database = nn.functional.normalize(database, dim=1, p=2)
|
292 |
+
queries = nn.functional.normalize(queries, dim=1, p=2)
|
293 |
+
|
294 |
+
# similarity
|
295 |
+
similarity = torch.mm(queries, database.T)
|
296 |
+
distances, indices = similarity.topk(20, largest=True, sorted=True)
|
297 |
+
|
298 |
+
# evaluate
|
299 |
+
retrieved = dataset.eval_result(indices, distances)
|
300 |
+
dist.barrier()
|
301 |
+
|
PuzzleTuning/Counterpart PreTrain Methods/dino-main/eval_image_retrieval.py
ADDED
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
import os
|
15 |
+
import sys
|
16 |
+
import pickle
|
17 |
+
import argparse
|
18 |
+
|
19 |
+
import torch
|
20 |
+
from torch import nn
|
21 |
+
import torch.distributed as dist
|
22 |
+
import torch.backends.cudnn as cudnn
|
23 |
+
from torchvision import models as torchvision_models
|
24 |
+
from torchvision import transforms as pth_transforms
|
25 |
+
from PIL import Image, ImageFile
|
26 |
+
import numpy as np
|
27 |
+
|
28 |
+
import utils
|
29 |
+
import vision_transformer as vits
|
30 |
+
from eval_knn import extract_features
|
31 |
+
|
32 |
+
|
33 |
+
class OxfordParisDataset(torch.utils.data.Dataset):
|
34 |
+
def __init__(self, dir_main, dataset, split, transform=None, imsize=None):
|
35 |
+
if dataset not in ['roxford5k', 'rparis6k']:
|
36 |
+
raise ValueError('Unknown dataset: {}!'.format(dataset))
|
37 |
+
|
38 |
+
# loading imlist, qimlist, and gnd, in cfg as a dict
|
39 |
+
gnd_fname = os.path.join(dir_main, dataset, 'gnd_{}.pkl'.format(dataset))
|
40 |
+
with open(gnd_fname, 'rb') as f:
|
41 |
+
cfg = pickle.load(f)
|
42 |
+
cfg['gnd_fname'] = gnd_fname
|
43 |
+
cfg['ext'] = '.jpg'
|
44 |
+
cfg['qext'] = '.jpg'
|
45 |
+
cfg['dir_data'] = os.path.join(dir_main, dataset)
|
46 |
+
cfg['dir_images'] = os.path.join(cfg['dir_data'], 'jpg')
|
47 |
+
cfg['n'] = len(cfg['imlist'])
|
48 |
+
cfg['nq'] = len(cfg['qimlist'])
|
49 |
+
cfg['im_fname'] = config_imname
|
50 |
+
cfg['qim_fname'] = config_qimname
|
51 |
+
cfg['dataset'] = dataset
|
52 |
+
self.cfg = cfg
|
53 |
+
|
54 |
+
self.samples = cfg["qimlist"] if split == "query" else cfg["imlist"]
|
55 |
+
self.transform = transform
|
56 |
+
self.imsize = imsize
|
57 |
+
|
58 |
+
def __len__(self):
|
59 |
+
return len(self.samples)
|
60 |
+
|
61 |
+
def __getitem__(self, index):
|
62 |
+
path = os.path.join(self.cfg["dir_images"], self.samples[index] + ".jpg")
|
63 |
+
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
64 |
+
with open(path, 'rb') as f:
|
65 |
+
img = Image.open(f)
|
66 |
+
img = img.convert('RGB')
|
67 |
+
if self.imsize is not None:
|
68 |
+
img.thumbnail((self.imsize, self.imsize), Image.ANTIALIAS)
|
69 |
+
if self.transform is not None:
|
70 |
+
img = self.transform(img)
|
71 |
+
return img, index
|
72 |
+
|
73 |
+
|
74 |
+
def config_imname(cfg, i):
|
75 |
+
return os.path.join(cfg['dir_images'], cfg['imlist'][i] + cfg['ext'])
|
76 |
+
|
77 |
+
|
78 |
+
def config_qimname(cfg, i):
|
79 |
+
return os.path.join(cfg['dir_images'], cfg['qimlist'][i] + cfg['qext'])
|
80 |
+
|
81 |
+
|
82 |
+
if __name__ == '__main__':
|
83 |
+
parser = argparse.ArgumentParser('Image Retrieval on revisited Paris and Oxford')
|
84 |
+
parser.add_argument('--data_path', default='/path/to/revisited_paris_oxford/', type=str)
|
85 |
+
parser.add_argument('--dataset', default='roxford5k', type=str, choices=['roxford5k', 'rparis6k'])
|
86 |
+
parser.add_argument('--multiscale', default=False, type=utils.bool_flag)
|
87 |
+
parser.add_argument('--imsize', default=224, type=int, help='Image size')
|
88 |
+
parser.add_argument('--pretrained_weights', default='', type=str, help="Path to pretrained weights to evaluate.")
|
89 |
+
parser.add_argument('--use_cuda', default=True, type=utils.bool_flag)
|
90 |
+
parser.add_argument('--arch', default='vit_small', type=str, help='Architecture')
|
91 |
+
parser.add_argument('--patch_size', default=16, type=int, help='Patch resolution of the model.')
|
92 |
+
parser.add_argument("--checkpoint_key", default="teacher", type=str,
|
93 |
+
help='Key to use in the checkpoint (example: "teacher")')
|
94 |
+
parser.add_argument('--num_workers', default=10, type=int, help='Number of data loading workers per GPU.')
|
95 |
+
parser.add_argument("--dist_url", default="env://", type=str, help="""url used to set up
|
96 |
+
distributed training; see https://pytorch.org/docs/stable/distributed.html""")
|
97 |
+
parser.add_argument("--local_rank", default=0, type=int, help="Please ignore and do not set this argument.")
|
98 |
+
args = parser.parse_args()
|
99 |
+
|
100 |
+
utils.init_distributed_mode(args)
|
101 |
+
print("git:\n {}\n".format(utils.get_sha()))
|
102 |
+
print("\n".join("%s: %s" % (k, str(v)) for k, v in sorted(dict(vars(args)).items())))
|
103 |
+
cudnn.benchmark = True
|
104 |
+
|
105 |
+
# ============ preparing data ... ============
|
106 |
+
transform = pth_transforms.Compose([
|
107 |
+
pth_transforms.ToTensor(),
|
108 |
+
pth_transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
|
109 |
+
])
|
110 |
+
dataset_train = OxfordParisDataset(args.data_path, args.dataset, split="train", transform=transform, imsize=args.imsize)
|
111 |
+
dataset_query = OxfordParisDataset(args.data_path, args.dataset, split="query", transform=transform, imsize=args.imsize)
|
112 |
+
sampler = torch.utils.data.DistributedSampler(dataset_train, shuffle=False)
|
113 |
+
data_loader_train = torch.utils.data.DataLoader(
|
114 |
+
dataset_train,
|
115 |
+
sampler=sampler,
|
116 |
+
batch_size=1,
|
117 |
+
num_workers=args.num_workers,
|
118 |
+
pin_memory=True,
|
119 |
+
drop_last=False,
|
120 |
+
)
|
121 |
+
data_loader_query = torch.utils.data.DataLoader(
|
122 |
+
dataset_query,
|
123 |
+
batch_size=1,
|
124 |
+
num_workers=args.num_workers,
|
125 |
+
pin_memory=True,
|
126 |
+
drop_last=False,
|
127 |
+
)
|
128 |
+
print(f"train: {len(dataset_train)} imgs / query: {len(dataset_query)} imgs")
|
129 |
+
|
130 |
+
# ============ building network ... ============
|
131 |
+
if "vit" in args.arch:
|
132 |
+
model = vits.__dict__[args.arch](patch_size=args.patch_size, num_classes=0)
|
133 |
+
print(f"Model {args.arch} {args.patch_size}x{args.patch_size} built.")
|
134 |
+
elif "xcit" in args.arch:
|
135 |
+
model = torch.hub.load('facebookresearch/xcit:main', args.arch, num_classes=0)
|
136 |
+
elif args.arch in torchvision_models.__dict__.keys():
|
137 |
+
model = torchvision_models.__dict__[args.arch](num_classes=0)
|
138 |
+
else:
|
139 |
+
print(f"Architecture {args.arch} non supported")
|
140 |
+
sys.exit(1)
|
141 |
+
if args.use_cuda:
|
142 |
+
model.cuda()
|
143 |
+
model.eval()
|
144 |
+
|
145 |
+
# load pretrained weights
|
146 |
+
if os.path.isfile(args.pretrained_weights):
|
147 |
+
state_dict = torch.load(args.pretrained_weights, map_location="cpu")
|
148 |
+
if args.checkpoint_key is not None and args.checkpoint_key in state_dict:
|
149 |
+
print(f"Take key {args.checkpoint_key} in provided checkpoint dict")
|
150 |
+
state_dict = state_dict[args.checkpoint_key]
|
151 |
+
# remove `module.` prefix
|
152 |
+
state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
|
153 |
+
# remove `backbone.` prefix induced by multicrop wrapper
|
154 |
+
state_dict = {k.replace("backbone.", ""): v for k, v in state_dict.items()}
|
155 |
+
msg = model.load_state_dict(state_dict, strict=False)
|
156 |
+
print('Pretrained weights found at {} and loaded with msg: {}'.format(args.pretrained_weights, msg))
|
157 |
+
elif args.arch == "vit_small" and args.patch_size == 16:
|
158 |
+
print("Since no pretrained weights have been provided, we load pretrained DINO weights on Google Landmark v2.")
|
159 |
+
model.load_state_dict(torch.hub.load_state_dict_from_url(url="https://dl.fbaipublicfiles.com/dino/dino_vitsmall16_googlelandmark_pretrain/dino_vitsmall16_googlelandmark_pretrain.pth"))
|
160 |
+
else:
|
161 |
+
print("Warning: We use random weights.")
|
162 |
+
|
163 |
+
############################################################################
|
164 |
+
# Step 1: extract features
|
165 |
+
train_features = extract_features(model, data_loader_train, args.use_cuda, multiscale=args.multiscale)
|
166 |
+
query_features = extract_features(model, data_loader_query, args.use_cuda, multiscale=args.multiscale)
|
167 |
+
|
168 |
+
if utils.get_rank() == 0: # only rank 0 will work from now on
|
169 |
+
# normalize features
|
170 |
+
train_features = nn.functional.normalize(train_features, dim=1, p=2)
|
171 |
+
query_features = nn.functional.normalize(query_features, dim=1, p=2)
|
172 |
+
|
173 |
+
############################################################################
|
174 |
+
# Step 2: similarity
|
175 |
+
sim = torch.mm(train_features, query_features.T)
|
176 |
+
ranks = torch.argsort(-sim, dim=0).cpu().numpy()
|
177 |
+
|
178 |
+
############################################################################
|
179 |
+
# Step 3: evaluate
|
180 |
+
gnd = dataset_train.cfg['gnd']
|
181 |
+
# evaluate ranks
|
182 |
+
ks = [1, 5, 10]
|
183 |
+
# search for easy & hard
|
184 |
+
gnd_t = []
|
185 |
+
for i in range(len(gnd)):
|
186 |
+
g = {}
|
187 |
+
g['ok'] = np.concatenate([gnd[i]['easy'], gnd[i]['hard']])
|
188 |
+
g['junk'] = np.concatenate([gnd[i]['junk']])
|
189 |
+
gnd_t.append(g)
|
190 |
+
mapM, apsM, mprM, prsM = utils.compute_map(ranks, gnd_t, ks)
|
191 |
+
# search for hard
|
192 |
+
gnd_t = []
|
193 |
+
for i in range(len(gnd)):
|
194 |
+
g = {}
|
195 |
+
g['ok'] = np.concatenate([gnd[i]['hard']])
|
196 |
+
g['junk'] = np.concatenate([gnd[i]['junk'], gnd[i]['easy']])
|
197 |
+
gnd_t.append(g)
|
198 |
+
mapH, apsH, mprH, prsH = utils.compute_map(ranks, gnd_t, ks)
|
199 |
+
print('>> {}: mAP M: {}, H: {}'.format(args.dataset, np.around(mapM*100, decimals=2), np.around(mapH*100, decimals=2)))
|
200 |
+
print('>> {}: mP@k{} M: {}, H: {}'.format(args.dataset, np.array(ks), np.around(mprM*100, decimals=2), np.around(mprH*100, decimals=2)))
|
201 |
+
dist.barrier()
|
PuzzleTuning/Counterpart PreTrain Methods/dino-main/eval_knn.py
ADDED
@@ -0,0 +1,242 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
import os
|
15 |
+
import sys
|
16 |
+
import argparse
|
17 |
+
|
18 |
+
import torch
|
19 |
+
from torch import nn
|
20 |
+
import torch.distributed as dist
|
21 |
+
import torch.backends.cudnn as cudnn
|
22 |
+
from torchvision import datasets
|
23 |
+
from torchvision import transforms as pth_transforms
|
24 |
+
from torchvision import models as torchvision_models
|
25 |
+
|
26 |
+
import utils
|
27 |
+
import vision_transformer as vits
|
28 |
+
|
29 |
+
|
30 |
+
def extract_feature_pipeline(args):
|
31 |
+
# ============ preparing data ... ============
|
32 |
+
transform = pth_transforms.Compose([
|
33 |
+
pth_transforms.Resize(256, interpolation=3),
|
34 |
+
pth_transforms.CenterCrop(224),
|
35 |
+
pth_transforms.ToTensor(),
|
36 |
+
pth_transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
|
37 |
+
])
|
38 |
+
dataset_train = ReturnIndexDataset(os.path.join(args.data_path, "train"), transform=transform)
|
39 |
+
dataset_val = ReturnIndexDataset(os.path.join(args.data_path, "val"), transform=transform)
|
40 |
+
sampler = torch.utils.data.DistributedSampler(dataset_train, shuffle=False)
|
41 |
+
data_loader_train = torch.utils.data.DataLoader(
|
42 |
+
dataset_train,
|
43 |
+
sampler=sampler,
|
44 |
+
batch_size=args.batch_size_per_gpu,
|
45 |
+
num_workers=args.num_workers,
|
46 |
+
pin_memory=True,
|
47 |
+
drop_last=False,
|
48 |
+
)
|
49 |
+
data_loader_val = torch.utils.data.DataLoader(
|
50 |
+
dataset_val,
|
51 |
+
batch_size=args.batch_size_per_gpu,
|
52 |
+
num_workers=args.num_workers,
|
53 |
+
pin_memory=True,
|
54 |
+
drop_last=False,
|
55 |
+
)
|
56 |
+
print(f"Data loaded with {len(dataset_train)} train and {len(dataset_val)} val imgs.")
|
57 |
+
|
58 |
+
# ============ building network ... ============
|
59 |
+
if "vit" in args.arch:
|
60 |
+
model = vits.__dict__[args.arch](patch_size=args.patch_size, num_classes=0)
|
61 |
+
print(f"Model {args.arch} {args.patch_size}x{args.patch_size} built.")
|
62 |
+
elif "xcit" in args.arch:
|
63 |
+
model = torch.hub.load('facebookresearch/xcit:main', args.arch, num_classes=0)
|
64 |
+
elif args.arch in torchvision_models.__dict__.keys():
|
65 |
+
model = torchvision_models.__dict__[args.arch](num_classes=0)
|
66 |
+
model.fc = nn.Identity()
|
67 |
+
else:
|
68 |
+
print(f"Architecture {args.arch} non supported")
|
69 |
+
sys.exit(1)
|
70 |
+
model.cuda()
|
71 |
+
utils.load_pretrained_weights(model, args.pretrained_weights, args.checkpoint_key, args.arch, args.patch_size)
|
72 |
+
model.eval()
|
73 |
+
|
74 |
+
# ============ extract features ... ============
|
75 |
+
print("Extracting features for train set...")
|
76 |
+
train_features = extract_features(model, data_loader_train, args.use_cuda)
|
77 |
+
print("Extracting features for val set...")
|
78 |
+
test_features = extract_features(model, data_loader_val, args.use_cuda)
|
79 |
+
|
80 |
+
if utils.get_rank() == 0:
|
81 |
+
train_features = nn.functional.normalize(train_features, dim=1, p=2)
|
82 |
+
test_features = nn.functional.normalize(test_features, dim=1, p=2)
|
83 |
+
|
84 |
+
train_labels = torch.tensor([s[-1] for s in dataset_train.samples]).long()
|
85 |
+
test_labels = torch.tensor([s[-1] for s in dataset_val.samples]).long()
|
86 |
+
# save features and labels
|
87 |
+
if args.dump_features and dist.get_rank() == 0:
|
88 |
+
torch.save(train_features.cpu(), os.path.join(args.dump_features, "trainfeat.pth"))
|
89 |
+
torch.save(test_features.cpu(), os.path.join(args.dump_features, "testfeat.pth"))
|
90 |
+
torch.save(train_labels.cpu(), os.path.join(args.dump_features, "trainlabels.pth"))
|
91 |
+
torch.save(test_labels.cpu(), os.path.join(args.dump_features, "testlabels.pth"))
|
92 |
+
return train_features, test_features, train_labels, test_labels
|
93 |
+
|
94 |
+
|
95 |
+
@torch.no_grad()
|
96 |
+
def extract_features(model, data_loader, use_cuda=True, multiscale=False):
|
97 |
+
metric_logger = utils.MetricLogger(delimiter=" ")
|
98 |
+
features = None
|
99 |
+
for samples, index in metric_logger.log_every(data_loader, 10):
|
100 |
+
samples = samples.cuda(non_blocking=True)
|
101 |
+
index = index.cuda(non_blocking=True)
|
102 |
+
if multiscale:
|
103 |
+
feats = utils.multi_scale(samples, model)
|
104 |
+
else:
|
105 |
+
feats = model(samples).clone()
|
106 |
+
|
107 |
+
# init storage feature matrix
|
108 |
+
if dist.get_rank() == 0 and features is None:
|
109 |
+
features = torch.zeros(len(data_loader.dataset), feats.shape[-1])
|
110 |
+
if use_cuda:
|
111 |
+
features = features.cuda(non_blocking=True)
|
112 |
+
print(f"Storing features into tensor of shape {features.shape}")
|
113 |
+
|
114 |
+
# get indexes from all processes
|
115 |
+
y_all = torch.empty(dist.get_world_size(), index.size(0), dtype=index.dtype, device=index.device)
|
116 |
+
y_l = list(y_all.unbind(0))
|
117 |
+
y_all_reduce = torch.distributed.all_gather(y_l, index, async_op=True)
|
118 |
+
y_all_reduce.wait()
|
119 |
+
index_all = torch.cat(y_l)
|
120 |
+
|
121 |
+
# share features between processes
|
122 |
+
feats_all = torch.empty(
|
123 |
+
dist.get_world_size(),
|
124 |
+
feats.size(0),
|
125 |
+
feats.size(1),
|
126 |
+
dtype=feats.dtype,
|
127 |
+
device=feats.device,
|
128 |
+
)
|
129 |
+
output_l = list(feats_all.unbind(0))
|
130 |
+
output_all_reduce = torch.distributed.all_gather(output_l, feats, async_op=True)
|
131 |
+
output_all_reduce.wait()
|
132 |
+
|
133 |
+
# update storage feature matrix
|
134 |
+
if dist.get_rank() == 0:
|
135 |
+
if use_cuda:
|
136 |
+
features.index_copy_(0, index_all, torch.cat(output_l))
|
137 |
+
else:
|
138 |
+
features.index_copy_(0, index_all.cpu(), torch.cat(output_l).cpu())
|
139 |
+
return features
|
140 |
+
|
141 |
+
|
142 |
+
@torch.no_grad()
|
143 |
+
def knn_classifier(train_features, train_labels, test_features, test_labels, k, T, num_classes=1000):
|
144 |
+
top1, top5, total = 0.0, 0.0, 0
|
145 |
+
train_features = train_features.t()
|
146 |
+
num_test_images, num_chunks = test_labels.shape[0], 100
|
147 |
+
imgs_per_chunk = num_test_images // num_chunks
|
148 |
+
retrieval_one_hot = torch.zeros(k, num_classes).to(train_features.device)
|
149 |
+
for idx in range(0, num_test_images, imgs_per_chunk):
|
150 |
+
# get the features for test images
|
151 |
+
features = test_features[
|
152 |
+
idx : min((idx + imgs_per_chunk), num_test_images), :
|
153 |
+
]
|
154 |
+
targets = test_labels[idx : min((idx + imgs_per_chunk), num_test_images)]
|
155 |
+
batch_size = targets.shape[0]
|
156 |
+
|
157 |
+
# calculate the dot product and compute top-k neighbors
|
158 |
+
similarity = torch.mm(features, train_features)
|
159 |
+
distances, indices = similarity.topk(k, largest=True, sorted=True)
|
160 |
+
candidates = train_labels.view(1, -1).expand(batch_size, -1)
|
161 |
+
retrieved_neighbors = torch.gather(candidates, 1, indices)
|
162 |
+
|
163 |
+
retrieval_one_hot.resize_(batch_size * k, num_classes).zero_()
|
164 |
+
retrieval_one_hot.scatter_(1, retrieved_neighbors.view(-1, 1), 1)
|
165 |
+
distances_transform = distances.clone().div_(T).exp_()
|
166 |
+
probs = torch.sum(
|
167 |
+
torch.mul(
|
168 |
+
retrieval_one_hot.view(batch_size, -1, num_classes),
|
169 |
+
distances_transform.view(batch_size, -1, 1),
|
170 |
+
),
|
171 |
+
1,
|
172 |
+
)
|
173 |
+
_, predictions = probs.sort(1, True)
|
174 |
+
|
175 |
+
# find the predictions that match the target
|
176 |
+
correct = predictions.eq(targets.data.view(-1, 1))
|
177 |
+
top1 = top1 + correct.narrow(1, 0, 1).sum().item()
|
178 |
+
top5 = top5 + correct.narrow(1, 0, min(5, k)).sum().item() # top5 does not make sense if k < 5
|
179 |
+
total += targets.size(0)
|
180 |
+
top1 = top1 * 100.0 / total
|
181 |
+
top5 = top5 * 100.0 / total
|
182 |
+
return top1, top5
|
183 |
+
|
184 |
+
|
185 |
+
class ReturnIndexDataset(datasets.ImageFolder):
|
186 |
+
def __getitem__(self, idx):
|
187 |
+
img, lab = super(ReturnIndexDataset, self).__getitem__(idx)
|
188 |
+
return img, idx
|
189 |
+
|
190 |
+
|
191 |
+
if __name__ == '__main__':
|
192 |
+
parser = argparse.ArgumentParser('Evaluation with weighted k-NN on ImageNet')
|
193 |
+
parser.add_argument('--batch_size_per_gpu', default=128, type=int, help='Per-GPU batch-size')
|
194 |
+
parser.add_argument('--nb_knn', default=[10, 20, 100, 200], nargs='+', type=int,
|
195 |
+
help='Number of NN to use. 20 is usually working the best.')
|
196 |
+
parser.add_argument('--temperature', default=0.07, type=float,
|
197 |
+
help='Temperature used in the voting coefficient')
|
198 |
+
parser.add_argument('--pretrained_weights', default='', type=str, help="Path to pretrained weights to evaluate.")
|
199 |
+
parser.add_argument('--use_cuda', default=True, type=utils.bool_flag,
|
200 |
+
help="Should we store the features on GPU? We recommend setting this to False if you encounter OOM")
|
201 |
+
parser.add_argument('--arch', default='vit_small', type=str, help='Architecture')
|
202 |
+
parser.add_argument('--patch_size', default=16, type=int, help='Patch resolution of the model.')
|
203 |
+
parser.add_argument("--checkpoint_key", default="teacher", type=str,
|
204 |
+
help='Key to use in the checkpoint (example: "teacher")')
|
205 |
+
parser.add_argument('--dump_features', default=None,
|
206 |
+
help='Path where to save computed features, empty for no saving')
|
207 |
+
parser.add_argument('--load_features', default=None, help="""If the features have
|
208 |
+
already been computed, where to find them.""")
|
209 |
+
parser.add_argument('--num_workers', default=10, type=int, help='Number of data loading workers per GPU.')
|
210 |
+
parser.add_argument("--dist_url", default="env://", type=str, help="""url used to set up
|
211 |
+
distributed training; see https://pytorch.org/docs/stable/distributed.html""")
|
212 |
+
parser.add_argument("--local_rank", default=0, type=int, help="Please ignore and do not set this argument.")
|
213 |
+
parser.add_argument('--data_path', default='/path/to/imagenet/', type=str)
|
214 |
+
args = parser.parse_args()
|
215 |
+
|
216 |
+
utils.init_distributed_mode(args)
|
217 |
+
print("git:\n {}\n".format(utils.get_sha()))
|
218 |
+
print("\n".join("%s: %s" % (k, str(v)) for k, v in sorted(dict(vars(args)).items())))
|
219 |
+
cudnn.benchmark = True
|
220 |
+
|
221 |
+
if args.load_features:
|
222 |
+
train_features = torch.load(os.path.join(args.load_features, "trainfeat.pth"))
|
223 |
+
test_features = torch.load(os.path.join(args.load_features, "testfeat.pth"))
|
224 |
+
train_labels = torch.load(os.path.join(args.load_features, "trainlabels.pth"))
|
225 |
+
test_labels = torch.load(os.path.join(args.load_features, "testlabels.pth"))
|
226 |
+
else:
|
227 |
+
# need to extract features !
|
228 |
+
train_features, test_features, train_labels, test_labels = extract_feature_pipeline(args)
|
229 |
+
|
230 |
+
if utils.get_rank() == 0:
|
231 |
+
if args.use_cuda:
|
232 |
+
train_features = train_features.cuda()
|
233 |
+
test_features = test_features.cuda()
|
234 |
+
train_labels = train_labels.cuda()
|
235 |
+
test_labels = test_labels.cuda()
|
236 |
+
|
237 |
+
print("Features are ready!\nStart the k-NN classification.")
|
238 |
+
for k in args.nb_knn:
|
239 |
+
top1, top5 = knn_classifier(train_features, train_labels,
|
240 |
+
test_features, test_labels, k, args.temperature)
|
241 |
+
print(f"{k}-NN classifier result: Top1: {top1}, Top5: {top5}")
|
242 |
+
dist.barrier()
|
PuzzleTuning/Counterpart PreTrain Methods/dino-main/eval_linear.py
ADDED
@@ -0,0 +1,281 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
import os
|
15 |
+
import argparse
|
16 |
+
import json
|
17 |
+
from pathlib import Path
|
18 |
+
|
19 |
+
import torch
|
20 |
+
from torch import nn
|
21 |
+
import torch.distributed as dist
|
22 |
+
import torch.backends.cudnn as cudnn
|
23 |
+
from torchvision import datasets
|
24 |
+
from torchvision import transforms as pth_transforms
|
25 |
+
from torchvision import models as torchvision_models
|
26 |
+
|
27 |
+
import utils
|
28 |
+
import vision_transformer as vits
|
29 |
+
|
30 |
+
|
31 |
+
def eval_linear(args):
|
32 |
+
utils.init_distributed_mode(args)
|
33 |
+
print("git:\n {}\n".format(utils.get_sha()))
|
34 |
+
print("\n".join("%s: %s" % (k, str(v)) for k, v in sorted(dict(vars(args)).items())))
|
35 |
+
cudnn.benchmark = True
|
36 |
+
|
37 |
+
# ============ building network ... ============
|
38 |
+
# if the network is a Vision Transformer (i.e. vit_tiny, vit_small, vit_base)
|
39 |
+
if args.arch in vits.__dict__.keys():
|
40 |
+
model = vits.__dict__[args.arch](patch_size=args.patch_size, num_classes=0)
|
41 |
+
embed_dim = model.embed_dim * (args.n_last_blocks + int(args.avgpool_patchtokens))
|
42 |
+
# if the network is a XCiT
|
43 |
+
elif "xcit" in args.arch:
|
44 |
+
model = torch.hub.load('facebookresearch/xcit:main', args.arch, num_classes=0)
|
45 |
+
embed_dim = model.embed_dim
|
46 |
+
# otherwise, we check if the architecture is in torchvision models
|
47 |
+
elif args.arch in torchvision_models.__dict__.keys():
|
48 |
+
model = torchvision_models.__dict__[args.arch]()
|
49 |
+
embed_dim = model.fc.weight.shape[1]
|
50 |
+
model.fc = nn.Identity()
|
51 |
+
else:
|
52 |
+
print(f"Unknow architecture: {args.arch}")
|
53 |
+
sys.exit(1)
|
54 |
+
model.cuda()
|
55 |
+
model.eval()
|
56 |
+
# load weights to evaluate
|
57 |
+
utils.load_pretrained_weights(model, args.pretrained_weights, args.checkpoint_key, args.arch, args.patch_size)
|
58 |
+
print(f"Model {args.arch} built.")
|
59 |
+
|
60 |
+
linear_classifier = LinearClassifier(embed_dim, num_labels=args.num_labels)
|
61 |
+
linear_classifier = linear_classifier.cuda()
|
62 |
+
linear_classifier = nn.parallel.DistributedDataParallel(linear_classifier, device_ids=[args.gpu])
|
63 |
+
|
64 |
+
# ============ preparing data ... ============
|
65 |
+
val_transform = pth_transforms.Compose([
|
66 |
+
pth_transforms.Resize(256, interpolation=3),
|
67 |
+
pth_transforms.CenterCrop(224),
|
68 |
+
pth_transforms.ToTensor(),
|
69 |
+
pth_transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
|
70 |
+
])
|
71 |
+
dataset_val = datasets.ImageFolder(os.path.join(args.data_path, "val"), transform=val_transform)
|
72 |
+
val_loader = torch.utils.data.DataLoader(
|
73 |
+
dataset_val,
|
74 |
+
batch_size=args.batch_size_per_gpu,
|
75 |
+
num_workers=args.num_workers,
|
76 |
+
pin_memory=True,
|
77 |
+
)
|
78 |
+
|
79 |
+
if args.evaluate:
|
80 |
+
utils.load_pretrained_linear_weights(linear_classifier, args.arch, args.patch_size)
|
81 |
+
test_stats = validate_network(val_loader, model, linear_classifier, args.n_last_blocks, args.avgpool_patchtokens)
|
82 |
+
print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%")
|
83 |
+
return
|
84 |
+
|
85 |
+
train_transform = pth_transforms.Compose([
|
86 |
+
pth_transforms.RandomResizedCrop(224),
|
87 |
+
pth_transforms.RandomHorizontalFlip(),
|
88 |
+
pth_transforms.ToTensor(),
|
89 |
+
pth_transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
|
90 |
+
])
|
91 |
+
dataset_train = datasets.ImageFolder(os.path.join(args.data_path, "train"), transform=train_transform)
|
92 |
+
sampler = torch.utils.data.distributed.DistributedSampler(dataset_train)
|
93 |
+
train_loader = torch.utils.data.DataLoader(
|
94 |
+
dataset_train,
|
95 |
+
sampler=sampler,
|
96 |
+
batch_size=args.batch_size_per_gpu,
|
97 |
+
num_workers=args.num_workers,
|
98 |
+
pin_memory=True,
|
99 |
+
)
|
100 |
+
print(f"Data loaded with {len(dataset_train)} train and {len(dataset_val)} val imgs.")
|
101 |
+
|
102 |
+
# set optimizer
|
103 |
+
optimizer = torch.optim.SGD(
|
104 |
+
linear_classifier.parameters(),
|
105 |
+
args.lr * (args.batch_size_per_gpu * utils.get_world_size()) / 256., # linear scaling rule
|
106 |
+
momentum=0.9,
|
107 |
+
weight_decay=0, # we do not apply weight decay
|
108 |
+
)
|
109 |
+
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, args.epochs, eta_min=0)
|
110 |
+
|
111 |
+
# Optionally resume from a checkpoint
|
112 |
+
to_restore = {"epoch": 0, "best_acc": 0.}
|
113 |
+
utils.restart_from_checkpoint(
|
114 |
+
os.path.join(args.output_dir, "checkpoint.pth.tar"),
|
115 |
+
run_variables=to_restore,
|
116 |
+
state_dict=linear_classifier,
|
117 |
+
optimizer=optimizer,
|
118 |
+
scheduler=scheduler,
|
119 |
+
)
|
120 |
+
start_epoch = to_restore["epoch"]
|
121 |
+
best_acc = to_restore["best_acc"]
|
122 |
+
|
123 |
+
for epoch in range(start_epoch, args.epochs):
|
124 |
+
train_loader.sampler.set_epoch(epoch)
|
125 |
+
|
126 |
+
train_stats = train(model, linear_classifier, optimizer, train_loader, epoch, args.n_last_blocks, args.avgpool_patchtokens)
|
127 |
+
scheduler.step()
|
128 |
+
|
129 |
+
log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
|
130 |
+
'epoch': epoch}
|
131 |
+
if epoch % args.val_freq == 0 or epoch == args.epochs - 1:
|
132 |
+
test_stats = validate_network(val_loader, model, linear_classifier, args.n_last_blocks, args.avgpool_patchtokens)
|
133 |
+
print(f"Accuracy at epoch {epoch} of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%")
|
134 |
+
best_acc = max(best_acc, test_stats["acc1"])
|
135 |
+
print(f'Max accuracy so far: {best_acc:.2f}%')
|
136 |
+
log_stats = {**{k: v for k, v in log_stats.items()},
|
137 |
+
**{f'test_{k}': v for k, v in test_stats.items()}}
|
138 |
+
if utils.is_main_process():
|
139 |
+
with (Path(args.output_dir) / "log.txt").open("a") as f:
|
140 |
+
f.write(json.dumps(log_stats) + "\n")
|
141 |
+
save_dict = {
|
142 |
+
"epoch": epoch + 1,
|
143 |
+
"state_dict": linear_classifier.state_dict(),
|
144 |
+
"optimizer": optimizer.state_dict(),
|
145 |
+
"scheduler": scheduler.state_dict(),
|
146 |
+
"best_acc": best_acc,
|
147 |
+
}
|
148 |
+
torch.save(save_dict, os.path.join(args.output_dir, "checkpoint.pth.tar"))
|
149 |
+
print("Training of the supervised linear classifier on frozen features completed.\n"
|
150 |
+
"Top-1 test accuracy: {acc:.1f}".format(acc=best_acc))
|
151 |
+
|
152 |
+
|
153 |
+
def train(model, linear_classifier, optimizer, loader, epoch, n, avgpool):
|
154 |
+
linear_classifier.train()
|
155 |
+
metric_logger = utils.MetricLogger(delimiter=" ")
|
156 |
+
metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
|
157 |
+
header = 'Epoch: [{}]'.format(epoch)
|
158 |
+
for (inp, target) in metric_logger.log_every(loader, 20, header):
|
159 |
+
# move to gpu
|
160 |
+
inp = inp.cuda(non_blocking=True)
|
161 |
+
target = target.cuda(non_blocking=True)
|
162 |
+
|
163 |
+
# forward
|
164 |
+
with torch.no_grad():
|
165 |
+
if "vit" in args.arch:
|
166 |
+
intermediate_output = model.get_intermediate_layers(inp, n)
|
167 |
+
output = torch.cat([x[:, 0] for x in intermediate_output], dim=-1)
|
168 |
+
if avgpool:
|
169 |
+
output = torch.cat((output.unsqueeze(-1), torch.mean(intermediate_output[-1][:, 1:], dim=1).unsqueeze(-1)), dim=-1)
|
170 |
+
output = output.reshape(output.shape[0], -1)
|
171 |
+
else:
|
172 |
+
output = model(inp)
|
173 |
+
output = linear_classifier(output)
|
174 |
+
|
175 |
+
# compute cross entropy loss
|
176 |
+
loss = nn.CrossEntropyLoss()(output, target)
|
177 |
+
|
178 |
+
# compute the gradients
|
179 |
+
optimizer.zero_grad()
|
180 |
+
loss.backward()
|
181 |
+
|
182 |
+
# step
|
183 |
+
optimizer.step()
|
184 |
+
|
185 |
+
# log
|
186 |
+
torch.cuda.synchronize()
|
187 |
+
metric_logger.update(loss=loss.item())
|
188 |
+
metric_logger.update(lr=optimizer.param_groups[0]["lr"])
|
189 |
+
# gather the stats from all processes
|
190 |
+
metric_logger.synchronize_between_processes()
|
191 |
+
print("Averaged stats:", metric_logger)
|
192 |
+
return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
|
193 |
+
|
194 |
+
|
195 |
+
@torch.no_grad()
|
196 |
+
def validate_network(val_loader, model, linear_classifier, n, avgpool):
|
197 |
+
linear_classifier.eval()
|
198 |
+
metric_logger = utils.MetricLogger(delimiter=" ")
|
199 |
+
header = 'Test:'
|
200 |
+
for inp, target in metric_logger.log_every(val_loader, 20, header):
|
201 |
+
# move to gpu
|
202 |
+
inp = inp.cuda(non_blocking=True)
|
203 |
+
target = target.cuda(non_blocking=True)
|
204 |
+
|
205 |
+
# forward
|
206 |
+
with torch.no_grad():
|
207 |
+
if "vit" in args.arch:
|
208 |
+
intermediate_output = model.get_intermediate_layers(inp, n)
|
209 |
+
output = torch.cat([x[:, 0] for x in intermediate_output], dim=-1)
|
210 |
+
if avgpool:
|
211 |
+
output = torch.cat((output.unsqueeze(-1), torch.mean(intermediate_output[-1][:, 1:], dim=1).unsqueeze(-1)), dim=-1)
|
212 |
+
output = output.reshape(output.shape[0], -1)
|
213 |
+
else:
|
214 |
+
output = model(inp)
|
215 |
+
output = linear_classifier(output)
|
216 |
+
loss = nn.CrossEntropyLoss()(output, target)
|
217 |
+
|
218 |
+
if linear_classifier.module.num_labels >= 5:
|
219 |
+
acc1, acc5 = utils.accuracy(output, target, topk=(1, 5))
|
220 |
+
else:
|
221 |
+
acc1, = utils.accuracy(output, target, topk=(1,))
|
222 |
+
|
223 |
+
batch_size = inp.shape[0]
|
224 |
+
metric_logger.update(loss=loss.item())
|
225 |
+
metric_logger.meters['acc1'].update(acc1.item(), n=batch_size)
|
226 |
+
if linear_classifier.module.num_labels >= 5:
|
227 |
+
metric_logger.meters['acc5'].update(acc5.item(), n=batch_size)
|
228 |
+
if linear_classifier.module.num_labels >= 5:
|
229 |
+
print('* Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f} loss {losses.global_avg:.3f}'
|
230 |
+
.format(top1=metric_logger.acc1, top5=metric_logger.acc5, losses=metric_logger.loss))
|
231 |
+
else:
|
232 |
+
print('* Acc@1 {top1.global_avg:.3f} loss {losses.global_avg:.3f}'
|
233 |
+
.format(top1=metric_logger.acc1, losses=metric_logger.loss))
|
234 |
+
return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
|
235 |
+
|
236 |
+
|
237 |
+
class LinearClassifier(nn.Module):
|
238 |
+
"""Linear layer to train on top of frozen features"""
|
239 |
+
def __init__(self, dim, num_labels=1000):
|
240 |
+
super(LinearClassifier, self).__init__()
|
241 |
+
self.num_labels = num_labels
|
242 |
+
self.linear = nn.Linear(dim, num_labels)
|
243 |
+
self.linear.weight.data.normal_(mean=0.0, std=0.01)
|
244 |
+
self.linear.bias.data.zero_()
|
245 |
+
|
246 |
+
def forward(self, x):
|
247 |
+
# flatten
|
248 |
+
x = x.view(x.size(0), -1)
|
249 |
+
|
250 |
+
# linear layer
|
251 |
+
return self.linear(x)
|
252 |
+
|
253 |
+
|
254 |
+
if __name__ == '__main__':
|
255 |
+
parser = argparse.ArgumentParser('Evaluation with linear classification on ImageNet')
|
256 |
+
parser.add_argument('--n_last_blocks', default=4, type=int, help="""Concatenate [CLS] tokens
|
257 |
+
for the `n` last blocks. We use `n=4` when evaluating ViT-Small and `n=1` with ViT-Base.""")
|
258 |
+
parser.add_argument('--avgpool_patchtokens', default=False, type=utils.bool_flag,
|
259 |
+
help="""Whether ot not to concatenate the global average pooled features to the [CLS] token.
|
260 |
+
We typically set this to False for ViT-Small and to True with ViT-Base.""")
|
261 |
+
parser.add_argument('--arch', default='vit_small', type=str, help='Architecture')
|
262 |
+
parser.add_argument('--patch_size', default=16, type=int, help='Patch resolution of the model.')
|
263 |
+
parser.add_argument('--pretrained_weights', default='', type=str, help="Path to pretrained weights to evaluate.")
|
264 |
+
parser.add_argument("--checkpoint_key", default="teacher", type=str, help='Key to use in the checkpoint (example: "teacher")')
|
265 |
+
parser.add_argument('--epochs', default=100, type=int, help='Number of epochs of training.')
|
266 |
+
parser.add_argument("--lr", default=0.001, type=float, help="""Learning rate at the beginning of
|
267 |
+
training (highest LR used during training). The learning rate is linearly scaled
|
268 |
+
with the batch size, and specified here for a reference batch size of 256.
|
269 |
+
We recommend tweaking the LR depending on the checkpoint evaluated.""")
|
270 |
+
parser.add_argument('--batch_size_per_gpu', default=128, type=int, help='Per-GPU batch-size')
|
271 |
+
parser.add_argument("--dist_url", default="env://", type=str, help="""url used to set up
|
272 |
+
distributed training; see https://pytorch.org/docs/stable/distributed.html""")
|
273 |
+
parser.add_argument("--local_rank", default=0, type=int, help="Please ignore and do not set this argument.")
|
274 |
+
parser.add_argument('--data_path', default='/path/to/imagenet/', type=str)
|
275 |
+
parser.add_argument('--num_workers', default=10, type=int, help='Number of data loading workers per GPU.')
|
276 |
+
parser.add_argument('--val_freq', default=1, type=int, help="Epoch frequency for validation.")
|
277 |
+
parser.add_argument('--output_dir', default=".", help='Path to save logs and checkpoints')
|
278 |
+
parser.add_argument('--num_labels', default=1000, type=int, help='Number of labels for linear classifier')
|
279 |
+
parser.add_argument('--evaluate', dest='evaluate', action='store_true', help='evaluate model on validation set')
|
280 |
+
args = parser.parse_args()
|
281 |
+
eval_linear(args)
|
PuzzleTuning/Counterpart PreTrain Methods/dino-main/eval_video_segmentation.py
ADDED
@@ -0,0 +1,292 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
"""
|
15 |
+
Some parts are taken from https://github.com/Liusifei/UVC
|
16 |
+
"""
|
17 |
+
import os
|
18 |
+
import copy
|
19 |
+
import glob
|
20 |
+
import queue
|
21 |
+
from urllib.request import urlopen
|
22 |
+
import argparse
|
23 |
+
import numpy as np
|
24 |
+
from tqdm import tqdm
|
25 |
+
|
26 |
+
import cv2
|
27 |
+
import torch
|
28 |
+
import torch.nn as nn
|
29 |
+
from torch.nn import functional as F
|
30 |
+
from PIL import Image
|
31 |
+
from torchvision import transforms
|
32 |
+
|
33 |
+
import utils
|
34 |
+
import vision_transformer as vits
|
35 |
+
|
36 |
+
|
37 |
+
@torch.no_grad()
|
38 |
+
def eval_video_tracking_davis(args, model, frame_list, video_dir, first_seg, seg_ori, color_palette):
|
39 |
+
"""
|
40 |
+
Evaluate tracking on a video given first frame & segmentation
|
41 |
+
"""
|
42 |
+
video_folder = os.path.join(args.output_dir, video_dir.split('/')[-1])
|
43 |
+
os.makedirs(video_folder, exist_ok=True)
|
44 |
+
|
45 |
+
# The queue stores the n preceeding frames
|
46 |
+
que = queue.Queue(args.n_last_frames)
|
47 |
+
|
48 |
+
# first frame
|
49 |
+
frame1, ori_h, ori_w = read_frame(frame_list[0])
|
50 |
+
# extract first frame feature
|
51 |
+
frame1_feat = extract_feature(model, frame1).T # dim x h*w
|
52 |
+
|
53 |
+
# saving first segmentation
|
54 |
+
out_path = os.path.join(video_folder, "00000.png")
|
55 |
+
imwrite_indexed(out_path, seg_ori, color_palette)
|
56 |
+
mask_neighborhood = None
|
57 |
+
for cnt in tqdm(range(1, len(frame_list))):
|
58 |
+
frame_tar = read_frame(frame_list[cnt])[0]
|
59 |
+
|
60 |
+
# we use the first segmentation and the n previous ones
|
61 |
+
used_frame_feats = [frame1_feat] + [pair[0] for pair in list(que.queue)]
|
62 |
+
used_segs = [first_seg] + [pair[1] for pair in list(que.queue)]
|
63 |
+
|
64 |
+
frame_tar_avg, feat_tar, mask_neighborhood = label_propagation(args, model, frame_tar, used_frame_feats, used_segs, mask_neighborhood)
|
65 |
+
|
66 |
+
# pop out oldest frame if neccessary
|
67 |
+
if que.qsize() == args.n_last_frames:
|
68 |
+
que.get()
|
69 |
+
# push current results into queue
|
70 |
+
seg = copy.deepcopy(frame_tar_avg)
|
71 |
+
que.put([feat_tar, seg])
|
72 |
+
|
73 |
+
# upsampling & argmax
|
74 |
+
frame_tar_avg = F.interpolate(frame_tar_avg, scale_factor=args.patch_size, mode='bilinear', align_corners=False, recompute_scale_factor=False)[0]
|
75 |
+
frame_tar_avg = norm_mask(frame_tar_avg)
|
76 |
+
_, frame_tar_seg = torch.max(frame_tar_avg, dim=0)
|
77 |
+
|
78 |
+
# saving to disk
|
79 |
+
frame_tar_seg = np.array(frame_tar_seg.squeeze().cpu(), dtype=np.uint8)
|
80 |
+
frame_tar_seg = np.array(Image.fromarray(frame_tar_seg).resize((ori_w, ori_h), 0))
|
81 |
+
frame_nm = frame_list[cnt].split('/')[-1].replace(".jpg", ".png")
|
82 |
+
imwrite_indexed(os.path.join(video_folder, frame_nm), frame_tar_seg, color_palette)
|
83 |
+
|
84 |
+
|
85 |
+
def restrict_neighborhood(h, w):
|
86 |
+
# We restrict the set of source nodes considered to a spatial neighborhood of the query node (i.e. ``local attention'')
|
87 |
+
mask = torch.zeros(h, w, h, w)
|
88 |
+
for i in range(h):
|
89 |
+
for j in range(w):
|
90 |
+
for p in range(2 * args.size_mask_neighborhood + 1):
|
91 |
+
for q in range(2 * args.size_mask_neighborhood + 1):
|
92 |
+
if i - args.size_mask_neighborhood + p < 0 or i - args.size_mask_neighborhood + p >= h:
|
93 |
+
continue
|
94 |
+
if j - args.size_mask_neighborhood + q < 0 or j - args.size_mask_neighborhood + q >= w:
|
95 |
+
continue
|
96 |
+
mask[i, j, i - args.size_mask_neighborhood + p, j - args.size_mask_neighborhood + q] = 1
|
97 |
+
|
98 |
+
mask = mask.reshape(h * w, h * w)
|
99 |
+
return mask.cuda(non_blocking=True)
|
100 |
+
|
101 |
+
|
102 |
+
def norm_mask(mask):
|
103 |
+
c, h, w = mask.size()
|
104 |
+
for cnt in range(c):
|
105 |
+
mask_cnt = mask[cnt,:,:]
|
106 |
+
if(mask_cnt.max() > 0):
|
107 |
+
mask_cnt = (mask_cnt - mask_cnt.min())
|
108 |
+
mask_cnt = mask_cnt/mask_cnt.max()
|
109 |
+
mask[cnt,:,:] = mask_cnt
|
110 |
+
return mask
|
111 |
+
|
112 |
+
|
113 |
+
def label_propagation(args, model, frame_tar, list_frame_feats, list_segs, mask_neighborhood=None):
|
114 |
+
"""
|
115 |
+
propagate segs of frames in list_frames to frame_tar
|
116 |
+
"""
|
117 |
+
## we only need to extract feature of the target frame
|
118 |
+
feat_tar, h, w = extract_feature(model, frame_tar, return_h_w=True)
|
119 |
+
|
120 |
+
return_feat_tar = feat_tar.T # dim x h*w
|
121 |
+
|
122 |
+
ncontext = len(list_frame_feats)
|
123 |
+
feat_sources = torch.stack(list_frame_feats) # nmb_context x dim x h*w
|
124 |
+
|
125 |
+
feat_tar = F.normalize(feat_tar, dim=1, p=2)
|
126 |
+
feat_sources = F.normalize(feat_sources, dim=1, p=2)
|
127 |
+
|
128 |
+
feat_tar = feat_tar.unsqueeze(0).repeat(ncontext, 1, 1)
|
129 |
+
aff = torch.exp(torch.bmm(feat_tar, feat_sources) / 0.1) # nmb_context x h*w (tar: query) x h*w (source: keys)
|
130 |
+
|
131 |
+
if args.size_mask_neighborhood > 0:
|
132 |
+
if mask_neighborhood is None:
|
133 |
+
mask_neighborhood = restrict_neighborhood(h, w)
|
134 |
+
mask_neighborhood = mask_neighborhood.unsqueeze(0).repeat(ncontext, 1, 1)
|
135 |
+
aff *= mask_neighborhood
|
136 |
+
|
137 |
+
aff = aff.transpose(2, 1).reshape(-1, h * w) # nmb_context*h*w (source: keys) x h*w (tar: queries)
|
138 |
+
tk_val, _ = torch.topk(aff, dim=0, k=args.topk)
|
139 |
+
tk_val_min, _ = torch.min(tk_val, dim=0)
|
140 |
+
aff[aff < tk_val_min] = 0
|
141 |
+
|
142 |
+
aff = aff / torch.sum(aff, keepdim=True, axis=0)
|
143 |
+
|
144 |
+
list_segs = [s.cuda() for s in list_segs]
|
145 |
+
segs = torch.cat(list_segs)
|
146 |
+
nmb_context, C, h, w = segs.shape
|
147 |
+
segs = segs.reshape(nmb_context, C, -1).transpose(2, 1).reshape(-1, C).T # C x nmb_context*h*w
|
148 |
+
seg_tar = torch.mm(segs, aff)
|
149 |
+
seg_tar = seg_tar.reshape(1, C, h, w)
|
150 |
+
return seg_tar, return_feat_tar, mask_neighborhood
|
151 |
+
|
152 |
+
|
153 |
+
def extract_feature(model, frame, return_h_w=False):
|
154 |
+
"""Extract one frame feature everytime."""
|
155 |
+
out = model.get_intermediate_layers(frame.unsqueeze(0).cuda(), n=1)[0]
|
156 |
+
out = out[:, 1:, :] # we discard the [CLS] token
|
157 |
+
h, w = int(frame.shape[1] / model.patch_embed.patch_size), int(frame.shape[2] / model.patch_embed.patch_size)
|
158 |
+
dim = out.shape[-1]
|
159 |
+
out = out[0].reshape(h, w, dim)
|
160 |
+
out = out.reshape(-1, dim)
|
161 |
+
if return_h_w:
|
162 |
+
return out, h, w
|
163 |
+
return out
|
164 |
+
|
165 |
+
|
166 |
+
def imwrite_indexed(filename, array, color_palette):
|
167 |
+
""" Save indexed png for DAVIS."""
|
168 |
+
if np.atleast_3d(array).shape[2] != 1:
|
169 |
+
raise Exception("Saving indexed PNGs requires 2D array.")
|
170 |
+
|
171 |
+
im = Image.fromarray(array)
|
172 |
+
im.putpalette(color_palette.ravel())
|
173 |
+
im.save(filename, format='PNG')
|
174 |
+
|
175 |
+
|
176 |
+
def to_one_hot(y_tensor, n_dims=None):
|
177 |
+
"""
|
178 |
+
Take integer y (tensor or variable) with n dims &
|
179 |
+
convert it to 1-hot representation with n+1 dims.
|
180 |
+
"""
|
181 |
+
if(n_dims is None):
|
182 |
+
n_dims = int(y_tensor.max()+ 1)
|
183 |
+
_,h,w = y_tensor.size()
|
184 |
+
y_tensor = y_tensor.type(torch.LongTensor).view(-1, 1)
|
185 |
+
n_dims = n_dims if n_dims is not None else int(torch.max(y_tensor)) + 1
|
186 |
+
y_one_hot = torch.zeros(y_tensor.size()[0], n_dims).scatter_(1, y_tensor, 1)
|
187 |
+
y_one_hot = y_one_hot.view(h,w,n_dims)
|
188 |
+
return y_one_hot.permute(2, 0, 1).unsqueeze(0)
|
189 |
+
|
190 |
+
|
191 |
+
def read_frame_list(video_dir):
|
192 |
+
frame_list = [img for img in glob.glob(os.path.join(video_dir,"*.jpg"))]
|
193 |
+
frame_list = sorted(frame_list)
|
194 |
+
return frame_list
|
195 |
+
|
196 |
+
|
197 |
+
def read_frame(frame_dir, scale_size=[480]):
|
198 |
+
"""
|
199 |
+
read a single frame & preprocess
|
200 |
+
"""
|
201 |
+
img = cv2.imread(frame_dir)
|
202 |
+
ori_h, ori_w, _ = img.shape
|
203 |
+
if len(scale_size) == 1:
|
204 |
+
if(ori_h > ori_w):
|
205 |
+
tw = scale_size[0]
|
206 |
+
th = (tw * ori_h) / ori_w
|
207 |
+
th = int((th // 64) * 64)
|
208 |
+
else:
|
209 |
+
th = scale_size[0]
|
210 |
+
tw = (th * ori_w) / ori_h
|
211 |
+
tw = int((tw // 64) * 64)
|
212 |
+
else:
|
213 |
+
th, tw = scale_size
|
214 |
+
img = cv2.resize(img, (tw, th))
|
215 |
+
img = img.astype(np.float32)
|
216 |
+
img = img / 255.0
|
217 |
+
img = img[:, :, ::-1]
|
218 |
+
img = np.transpose(img.copy(), (2, 0, 1))
|
219 |
+
img = torch.from_numpy(img).float()
|
220 |
+
img = color_normalize(img)
|
221 |
+
return img, ori_h, ori_w
|
222 |
+
|
223 |
+
|
224 |
+
def read_seg(seg_dir, factor, scale_size=[480]):
|
225 |
+
seg = Image.open(seg_dir)
|
226 |
+
_w, _h = seg.size # note PIL.Image.Image's size is (w, h)
|
227 |
+
if len(scale_size) == 1:
|
228 |
+
if(_w > _h):
|
229 |
+
_th = scale_size[0]
|
230 |
+
_tw = (_th * _w) / _h
|
231 |
+
_tw = int((_tw // 64) * 64)
|
232 |
+
else:
|
233 |
+
_tw = scale_size[0]
|
234 |
+
_th = (_tw * _h) / _w
|
235 |
+
_th = int((_th // 64) * 64)
|
236 |
+
else:
|
237 |
+
_th = scale_size[1]
|
238 |
+
_tw = scale_size[0]
|
239 |
+
small_seg = np.array(seg.resize((_tw // factor, _th // factor), 0))
|
240 |
+
small_seg = torch.from_numpy(small_seg.copy()).contiguous().float().unsqueeze(0)
|
241 |
+
return to_one_hot(small_seg), np.asarray(seg)
|
242 |
+
|
243 |
+
|
244 |
+
def color_normalize(x, mean=[0.485, 0.456, 0.406], std=[0.228, 0.224, 0.225]):
|
245 |
+
for t, m, s in zip(x, mean, std):
|
246 |
+
t.sub_(m)
|
247 |
+
t.div_(s)
|
248 |
+
return x
|
249 |
+
|
250 |
+
|
251 |
+
if __name__ == '__main__':
|
252 |
+
parser = argparse.ArgumentParser('Evaluation with video object segmentation on DAVIS 2017')
|
253 |
+
parser.add_argument('--pretrained_weights', default='', type=str, help="Path to pretrained weights to evaluate.")
|
254 |
+
parser.add_argument('--arch', default='vit_small', type=str,
|
255 |
+
choices=['vit_tiny', 'vit_small', 'vit_base'], help='Architecture (support only ViT atm).')
|
256 |
+
parser.add_argument('--patch_size', default=16, type=int, help='Patch resolution of the model.')
|
257 |
+
parser.add_argument("--checkpoint_key", default="teacher", type=str, help='Key to use in the checkpoint (example: "teacher")')
|
258 |
+
parser.add_argument('--output_dir', default=".", help='Path where to save segmentations')
|
259 |
+
parser.add_argument('--data_path', default='/path/to/davis/', type=str)
|
260 |
+
parser.add_argument("--n_last_frames", type=int, default=7, help="number of preceeding frames")
|
261 |
+
parser.add_argument("--size_mask_neighborhood", default=12, type=int,
|
262 |
+
help="We restrict the set of source nodes considered to a spatial neighborhood of the query node")
|
263 |
+
parser.add_argument("--topk", type=int, default=5, help="accumulate label from top k neighbors")
|
264 |
+
parser.add_argument("--bs", type=int, default=6, help="Batch size, try to reduce if OOM")
|
265 |
+
args = parser.parse_args()
|
266 |
+
|
267 |
+
print("git:\n {}\n".format(utils.get_sha()))
|
268 |
+
print("\n".join("%s: %s" % (k, str(v)) for k, v in sorted(dict(vars(args)).items())))
|
269 |
+
|
270 |
+
# building network
|
271 |
+
model = vits.__dict__[args.arch](patch_size=args.patch_size, num_classes=0)
|
272 |
+
print(f"Model {args.arch} {args.patch_size}x{args.patch_size} built.")
|
273 |
+
model.cuda()
|
274 |
+
utils.load_pretrained_weights(model, args.pretrained_weights, args.checkpoint_key, args.arch, args.patch_size)
|
275 |
+
for param in model.parameters():
|
276 |
+
param.requires_grad = False
|
277 |
+
model.eval()
|
278 |
+
|
279 |
+
color_palette = []
|
280 |
+
for line in urlopen("https://raw.githubusercontent.com/Liusifei/UVC/master/libs/data/palette.txt"):
|
281 |
+
color_palette.append([int(i) for i in line.decode("utf-8").split('\n')[0].split(" ")])
|
282 |
+
color_palette = np.asarray(color_palette, dtype=np.uint8).reshape(-1,3)
|
283 |
+
|
284 |
+
video_list = open(os.path.join(args.data_path, "ImageSets/2017/val.txt")).readlines()
|
285 |
+
for i, video_name in enumerate(video_list):
|
286 |
+
video_name = video_name.strip()
|
287 |
+
print(f'[{i}/{len(video_list)}] Begin to segmentate video {video_name}.')
|
288 |
+
video_dir = os.path.join(args.data_path, "JPEGImages/480p/", video_name)
|
289 |
+
frame_list = read_frame_list(video_dir)
|
290 |
+
seg_path = frame_list[0].replace("JPEGImages", "Annotations").replace("jpg", "png")
|
291 |
+
first_seg, seg_ori = read_seg(seg_path, args.patch_size)
|
292 |
+
eval_video_tracking_davis(args, model, frame_list, video_dir, first_seg, seg_ori, color_palette)
|
PuzzleTuning/Counterpart PreTrain Methods/dino-main/hubconf.py
ADDED
@@ -0,0 +1,151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
import torch
|
15 |
+
from torchvision.models.resnet import resnet50
|
16 |
+
|
17 |
+
import vision_transformer as vits
|
18 |
+
|
19 |
+
dependencies = ["torch", "torchvision"]
|
20 |
+
|
21 |
+
|
22 |
+
def dino_vits16(pretrained=True, **kwargs):
|
23 |
+
"""
|
24 |
+
ViT-Small/16x16 pre-trained with DINO.
|
25 |
+
Achieves 74.5% top-1 accuracy on ImageNet with k-NN classification.
|
26 |
+
"""
|
27 |
+
model = vits.__dict__["vit_small"](patch_size=16, num_classes=0, **kwargs)
|
28 |
+
if pretrained:
|
29 |
+
state_dict = torch.hub.load_state_dict_from_url(
|
30 |
+
url="https://dl.fbaipublicfiles.com/dino/dino_deitsmall16_pretrain/dino_deitsmall16_pretrain.pth",
|
31 |
+
map_location="cpu",
|
32 |
+
)
|
33 |
+
model.load_state_dict(state_dict, strict=True)
|
34 |
+
return model
|
35 |
+
|
36 |
+
|
37 |
+
def dino_vits8(pretrained=True, **kwargs):
|
38 |
+
"""
|
39 |
+
ViT-Small/8x8 pre-trained with DINO.
|
40 |
+
Achieves 78.3% top-1 accuracy on ImageNet with k-NN classification.
|
41 |
+
"""
|
42 |
+
model = vits.__dict__["vit_small"](patch_size=8, num_classes=0, **kwargs)
|
43 |
+
if pretrained:
|
44 |
+
state_dict = torch.hub.load_state_dict_from_url(
|
45 |
+
url="https://dl.fbaipublicfiles.com/dino/dino_deitsmall8_pretrain/dino_deitsmall8_pretrain.pth",
|
46 |
+
map_location="cpu",
|
47 |
+
)
|
48 |
+
model.load_state_dict(state_dict, strict=True)
|
49 |
+
return model
|
50 |
+
|
51 |
+
|
52 |
+
def dino_vitb16(pretrained=True, **kwargs):
|
53 |
+
"""
|
54 |
+
ViT-Base/16x16 pre-trained with DINO.
|
55 |
+
Achieves 76.1% top-1 accuracy on ImageNet with k-NN classification.
|
56 |
+
"""
|
57 |
+
model = vits.__dict__["vit_base"](patch_size=16, num_classes=0, **kwargs)
|
58 |
+
if pretrained:
|
59 |
+
state_dict = torch.hub.load_state_dict_from_url(
|
60 |
+
url="https://dl.fbaipublicfiles.com/dino/dino_vitbase16_pretrain/dino_vitbase16_pretrain.pth",
|
61 |
+
map_location="cpu",
|
62 |
+
)
|
63 |
+
model.load_state_dict(state_dict, strict=True)
|
64 |
+
return model
|
65 |
+
|
66 |
+
|
67 |
+
def dino_vitb8(pretrained=True, **kwargs):
|
68 |
+
"""
|
69 |
+
ViT-Base/8x8 pre-trained with DINO.
|
70 |
+
Achieves 77.4% top-1 accuracy on ImageNet with k-NN classification.
|
71 |
+
"""
|
72 |
+
model = vits.__dict__["vit_base"](patch_size=8, num_classes=0, **kwargs)
|
73 |
+
if pretrained:
|
74 |
+
state_dict = torch.hub.load_state_dict_from_url(
|
75 |
+
url="https://dl.fbaipublicfiles.com/dino/dino_vitbase8_pretrain/dino_vitbase8_pretrain.pth",
|
76 |
+
map_location="cpu",
|
77 |
+
)
|
78 |
+
model.load_state_dict(state_dict, strict=True)
|
79 |
+
return model
|
80 |
+
|
81 |
+
|
82 |
+
def dino_resnet50(pretrained=True, **kwargs):
|
83 |
+
"""
|
84 |
+
ResNet-50 pre-trained with DINO.
|
85 |
+
Achieves 75.3% top-1 accuracy on ImageNet linear evaluation benchmark (requires to train `fc`).
|
86 |
+
"""
|
87 |
+
model = resnet50(pretrained=False, **kwargs)
|
88 |
+
model.fc = torch.nn.Identity()
|
89 |
+
if pretrained:
|
90 |
+
state_dict = torch.hub.load_state_dict_from_url(
|
91 |
+
url="https://dl.fbaipublicfiles.com/dino/dino_resnet50_pretrain/dino_resnet50_pretrain.pth",
|
92 |
+
map_location="cpu",
|
93 |
+
)
|
94 |
+
model.load_state_dict(state_dict, strict=False)
|
95 |
+
return model
|
96 |
+
|
97 |
+
|
98 |
+
def dino_xcit_small_12_p16(pretrained=True, **kwargs):
|
99 |
+
"""
|
100 |
+
XCiT-Small-12/16 pre-trained with DINO.
|
101 |
+
"""
|
102 |
+
model = torch.hub.load('facebookresearch/xcit:main', "xcit_small_12_p16", num_classes=0, **kwargs)
|
103 |
+
if pretrained:
|
104 |
+
state_dict = torch.hub.load_state_dict_from_url(
|
105 |
+
url="https://dl.fbaipublicfiles.com/dino/dino_xcit_small_12_p16_pretrain/dino_xcit_small_12_p16_pretrain.pth",
|
106 |
+
map_location="cpu",
|
107 |
+
)
|
108 |
+
model.load_state_dict(state_dict, strict=True)
|
109 |
+
return model
|
110 |
+
|
111 |
+
|
112 |
+
def dino_xcit_small_12_p8(pretrained=True, **kwargs):
|
113 |
+
"""
|
114 |
+
XCiT-Small-12/8 pre-trained with DINO.
|
115 |
+
"""
|
116 |
+
model = torch.hub.load('facebookresearch/xcit:main', "xcit_small_12_p8", num_classes=0, **kwargs)
|
117 |
+
if pretrained:
|
118 |
+
state_dict = torch.hub.load_state_dict_from_url(
|
119 |
+
url="https://dl.fbaipublicfiles.com/dino/dino_xcit_small_12_p8_pretrain/dino_xcit_small_12_p8_pretrain.pth",
|
120 |
+
map_location="cpu",
|
121 |
+
)
|
122 |
+
model.load_state_dict(state_dict, strict=True)
|
123 |
+
return model
|
124 |
+
|
125 |
+
|
126 |
+
def dino_xcit_medium_24_p16(pretrained=True, **kwargs):
|
127 |
+
"""
|
128 |
+
XCiT-Medium-24/16 pre-trained with DINO.
|
129 |
+
"""
|
130 |
+
model = torch.hub.load('facebookresearch/xcit:main', "xcit_medium_24_p16", num_classes=0, **kwargs)
|
131 |
+
if pretrained:
|
132 |
+
state_dict = torch.hub.load_state_dict_from_url(
|
133 |
+
url="https://dl.fbaipublicfiles.com/dino/dino_xcit_medium_24_p16_pretrain/dino_xcit_medium_24_p16_pretrain.pth",
|
134 |
+
map_location="cpu",
|
135 |
+
)
|
136 |
+
model.load_state_dict(state_dict, strict=True)
|
137 |
+
return model
|
138 |
+
|
139 |
+
|
140 |
+
def dino_xcit_medium_24_p8(pretrained=True, **kwargs):
|
141 |
+
"""
|
142 |
+
XCiT-Medium-24/8 pre-trained with DINO.
|
143 |
+
"""
|
144 |
+
model = torch.hub.load('facebookresearch/xcit:main', "xcit_medium_24_p8", num_classes=0, **kwargs)
|
145 |
+
if pretrained:
|
146 |
+
state_dict = torch.hub.load_state_dict_from_url(
|
147 |
+
url="https://dl.fbaipublicfiles.com/dino/dino_xcit_medium_24_p8_pretrain/dino_xcit_medium_24_p8_pretrain.pth",
|
148 |
+
map_location="cpu",
|
149 |
+
)
|
150 |
+
model.load_state_dict(state_dict, strict=True)
|
151 |
+
return model
|
PuzzleTuning/Counterpart PreTrain Methods/dino-main/main_dino.py
ADDED
@@ -0,0 +1,515 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
import argparse
|
15 |
+
import os
|
16 |
+
import sys
|
17 |
+
import datetime
|
18 |
+
import time
|
19 |
+
import math
|
20 |
+
import json
|
21 |
+
from pathlib import Path
|
22 |
+
|
23 |
+
import numpy as np
|
24 |
+
from PIL import Image
|
25 |
+
import torch
|
26 |
+
import torch.nn as nn
|
27 |
+
import torch.distributed as dist
|
28 |
+
import torch.backends.cudnn as cudnn
|
29 |
+
import torch.nn.functional as F
|
30 |
+
from torchvision import datasets, transforms
|
31 |
+
from torchvision import models as torchvision_models
|
32 |
+
|
33 |
+
import utils
|
34 |
+
import vision_transformer as vits
|
35 |
+
from vision_transformer import DINOHead
|
36 |
+
|
37 |
+
torchvision_archs = sorted(name for name in torchvision_models.__dict__
|
38 |
+
if name.islower() and not name.startswith("__")
|
39 |
+
and callable(torchvision_models.__dict__[name]))
|
40 |
+
|
41 |
+
def get_args_parser():
|
42 |
+
parser = argparse.ArgumentParser('DINO', add_help=False)
|
43 |
+
|
44 |
+
# Model parameters
|
45 |
+
parser.add_argument('--arch', default='vit_base', type=str,
|
46 |
+
help="""Name of architecture to train. For quick experiments with ViTs,
|
47 |
+
we recommend using vit_tiny or vit_small.""")
|
48 |
+
parser.add_argument('--patch_size', default=16, type=int, help="""Size in pixels
|
49 |
+
of input square patches - default 16 (for 16x16 patches). Using smaller
|
50 |
+
values leads to better performance but requires more memory. Applies only
|
51 |
+
for ViTs (vit_tiny, vit_small and vit_base). If <16, we recommend disabling
|
52 |
+
mixed precision training (--use_fp16 false) to avoid unstabilities.""")
|
53 |
+
parser.add_argument('--input_size', default=224, type=int)
|
54 |
+
parser.add_argument('--out_dim', default=65536, type=int, help="""Dimensionality of
|
55 |
+
the DINO head output. For complex and large datasets large values (like 65k) work well.""")
|
56 |
+
parser.add_argument('--norm_last_layer', default=True, type=utils.bool_flag,
|
57 |
+
help="""Whether or not to weight normalize the last layer of the DINO head.
|
58 |
+
Not normalizing leads to better performance but can make the training unstable.
|
59 |
+
In our experiments, we typically set this paramater to False with vit_small and True with vit_base.""")
|
60 |
+
parser.add_argument('--momentum_teacher', default=0.996, type=float, help="""Base EMA
|
61 |
+
parameter for teacher update. The value is increased to 1 during training with cosine schedule.
|
62 |
+
We recommend setting a higher value with small batches: for example use 0.9995 with batch size of 256.""")
|
63 |
+
parser.add_argument('--use_bn_in_head', default=False, type=utils.bool_flag,
|
64 |
+
help="Whether to use batch normalizations in projection head (Default: False)")
|
65 |
+
|
66 |
+
# Temperature teacher parameters
|
67 |
+
parser.add_argument('--warmup_teacher_temp', default=0.04, type=float,
|
68 |
+
help="""Initial value for the teacher temperature: 0.04 works well in most cases.
|
69 |
+
Try decreasing it if the training loss does not decrease.""")
|
70 |
+
parser.add_argument('--teacher_temp', default=0.04, type=float, help="""Final value (after linear warmup)
|
71 |
+
of the teacher temperature. For most experiments, anything above 0.07 is unstable. We recommend
|
72 |
+
starting with the default value of 0.04 and increase this slightly if needed.""")
|
73 |
+
parser.add_argument('--warmup_teacher_temp_epochs', default=0, type=int,
|
74 |
+
help='Number of warmup epochs for the teacher temperature (Default: 30).')
|
75 |
+
|
76 |
+
# Training/Optimization parameters
|
77 |
+
parser.add_argument('--use_fp16', type=utils.bool_flag, default=True, help="""Whether or not
|
78 |
+
to use half precision for training. Improves training time and memory requirements,
|
79 |
+
but can provoke instability and slight decay of performance. We recommend disabling
|
80 |
+
mixed precision if the loss is unstable, if reducing the patch size or if training with bigger ViTs.""")
|
81 |
+
parser.add_argument('--weight_decay', type=float, default=0.04, help="""Initial value of the
|
82 |
+
weight decay. With ViT, a smaller value at the beginning of training works well.""")
|
83 |
+
parser.add_argument('--weight_decay_end', type=float, default=0.4, help="""Final value of the
|
84 |
+
weight decay. We use a cosine schedule for WD and using a larger decay by
|
85 |
+
the end of training improves performance for ViTs.""")
|
86 |
+
parser.add_argument('--clip_grad', type=float, default=3.0, help="""Maximal parameter
|
87 |
+
gradient norm if using gradient clipping. Clipping with norm .3 ~ 1.0 can
|
88 |
+
help optimization for larger ViT architectures. 0 for disabling.""")
|
89 |
+
parser.add_argument('--batch_size_per_gpu', default=512, type=int,
|
90 |
+
help='Per-GPU batch-size : number of distinct images loaded on one GPU.')
|
91 |
+
parser.add_argument('--epochs', default=150, type=int, help='Number of epochs of training.')
|
92 |
+
parser.add_argument('--freeze_last_layer', default=1, type=int, help="""Number of epochs
|
93 |
+
during which we keep the output layer fixed. Typically doing so during
|
94 |
+
the first epoch helps training. Try increasing this value if the loss does not decrease.""")
|
95 |
+
parser.add_argument("--lr", default=1.5e-4, type=float, help="""Learning rate at the end of
|
96 |
+
linear warmup (highest LR used during training). The learning rate is linearly scaled
|
97 |
+
with the batch size, and specified here for a reference batch size of 512.""")
|
98 |
+
parser.add_argument("--warmup_epochs", default=20, type=int,
|
99 |
+
help="Number of epochs for the linear learning-rate warm up.")
|
100 |
+
parser.add_argument('--min_lr', type=float, default=1e-6, help="""Target LR at the
|
101 |
+
end of optimization. We use a cosine LR schedule with linear warmup.""")
|
102 |
+
parser.add_argument('--optimizer', default='adamw', type=str,
|
103 |
+
choices=['adamw', 'sgd', 'lars'], help="""Type of optimizer. We recommend using adamw with ViTs.""")
|
104 |
+
parser.add_argument('--drop_path_rate', type=float, default=0.1, help="stochastic depth rate")
|
105 |
+
|
106 |
+
# Multi-crop parameters
|
107 |
+
parser.add_argument('--global_crops_scale', type=float, nargs='+', default=(0.4, 1.),
|
108 |
+
help="""Scale range of the cropped image before resizing, relatively to the origin image.
|
109 |
+
Used for large global view cropping. When disabling multi-crop (--local_crops_number 0), we
|
110 |
+
recommand using a wider range of scale ("--global_crops_scale 0.14 1." for example)""")
|
111 |
+
parser.add_argument('--local_crops_number', type=int, default=8, help="""Number of small
|
112 |
+
local views to generate. Set this parameter to 0 to disable multi-crop training.
|
113 |
+
When disabling multi-crop we recommend to use "--global_crops_scale 0.14 1." """)
|
114 |
+
parser.add_argument('--local_crops_scale', type=float, nargs='+', default=(0.05, 0.4),
|
115 |
+
help="""Scale range of the cropped image before resizing, relatively to the origin image.
|
116 |
+
Used for small local view cropping of multi-crop.""")
|
117 |
+
|
118 |
+
# Misc
|
119 |
+
parser.add_argument('--data_path', default='/root/autodl-tmp/All', type=str,
|
120 |
+
help='Please specify path to the ImageNet training data.')
|
121 |
+
parser.add_argument('--basic_state_dict', default='/root/autodl-tmp/ViT_b16_224_Imagenet.pth', type=str,
|
122 |
+
help='Load in pretrained or un-pretrained model pth')
|
123 |
+
parser.add_argument('--output_dir', default="/home/CPIA/saved_models/DINO", type=str, help='Path to save logs and checkpoints.')
|
124 |
+
parser.add_argument('--saveckp_freq', default=50, type=int, help='Save checkpoint every x epochs.')
|
125 |
+
parser.add_argument('--seed', default=0, type=int, help='Random seed.')
|
126 |
+
parser.add_argument('--num_workers', default=32, type=int, help='Number of data loading workers per GPU.')
|
127 |
+
parser.add_argument("--dist_url", default="env://", type=str, help="""url used to set up
|
128 |
+
distributed training; see https://pytorch.org/docs/stable/distributed.html""")
|
129 |
+
parser.add_argument("--local_rank", default=0, type=int, help="Please ignore and do not set this argument.")
|
130 |
+
return parser
|
131 |
+
|
132 |
+
|
133 |
+
def train_dino(args):
|
134 |
+
utils.init_distributed_mode(args)
|
135 |
+
utils.fix_random_seeds(args.seed)
|
136 |
+
print('ok')
|
137 |
+
print("git:\n {}\n".format(utils.get_sha()))
|
138 |
+
print("\n".join("%s: %s" % (k, str(v)) for k, v in sorted(dict(vars(args)).items())))
|
139 |
+
cudnn.benchmark = True
|
140 |
+
|
141 |
+
# ============ preparing data ... ============
|
142 |
+
transform = DataAugmentationDINO(
|
143 |
+
args.global_crops_scale,
|
144 |
+
args.local_crops_scale,
|
145 |
+
args.local_crops_number,
|
146 |
+
)
|
147 |
+
dataset = datasets.ImageFolder(args.data_path, transform=transform)
|
148 |
+
sampler = torch.utils.data.DistributedSampler(dataset, shuffle=True)
|
149 |
+
data_loader = torch.utils.data.DataLoader(
|
150 |
+
dataset,
|
151 |
+
sampler=sampler,
|
152 |
+
batch_size=args.batch_size_per_gpu,
|
153 |
+
num_workers=args.num_workers,
|
154 |
+
pin_memory=True,
|
155 |
+
drop_last=True,
|
156 |
+
)
|
157 |
+
print(f"Data loaded: there are {len(dataset)} images.")
|
158 |
+
|
159 |
+
# ============ building student and teacher networks ... ============
|
160 |
+
# we changed the name DeiT-S for ViT-S to avoid confusions
|
161 |
+
args.arch = args.arch.replace("deit", "vit")
|
162 |
+
# if the network is a Vision Transformer (i.e. vit_tiny, vit_small, vit_base)
|
163 |
+
if args.arch in vits.__dict__.keys():
|
164 |
+
student = vits.__dict__[args.arch](
|
165 |
+
patch_size=args.patch_size,
|
166 |
+
drop_path_rate=args.drop_path_rate, # stochastic depth
|
167 |
+
)
|
168 |
+
if args.basic_state_dict is not None: # Transfer-learning
|
169 |
+
try:
|
170 |
+
basic_state_dict = torch.load(args.basic_state_dict)
|
171 |
+
if 'model' in basic_state_dict:
|
172 |
+
basic_state_dict = basic_state_dict['model']
|
173 |
+
else:
|
174 |
+
pass
|
175 |
+
student.load_state_dict(basic_state_dict, False)
|
176 |
+
|
177 |
+
except:
|
178 |
+
print('erro in args.basic_state_dict:', args.basic_state_dict)
|
179 |
+
print('Student PreTuning Restart') # 没倒进去
|
180 |
+
|
181 |
+
else:
|
182 |
+
print('Student PreTuning with Transfer-learning with:', args.basic_state_dict)
|
183 |
+
|
184 |
+
else:
|
185 |
+
print('Student PreTuning Restart')
|
186 |
+
teacher = vits.__dict__[args.arch](patch_size=args.patch_size)
|
187 |
+
if args.basic_state_dict is not None: # Transfer-learning
|
188 |
+
try:
|
189 |
+
basic_state_dict = torch.load(args.basic_state_dict)
|
190 |
+
if 'model' in basic_state_dict:
|
191 |
+
basic_state_dict = basic_state_dict['model']
|
192 |
+
else:
|
193 |
+
pass
|
194 |
+
teacher.load_state_dict(basic_state_dict, False)
|
195 |
+
|
196 |
+
except:
|
197 |
+
print('erro in args.basic_state_dict:', args.basic_state_dict)
|
198 |
+
print('Teacher PreTuning Restart') # 没倒进去
|
199 |
+
|
200 |
+
else:
|
201 |
+
print('Teacher PreTuning with Transfer-learning with:', args.basic_state_dict)
|
202 |
+
|
203 |
+
else:
|
204 |
+
print('Teacher PreTuning Restart')
|
205 |
+
embed_dim = student.embed_dim
|
206 |
+
# if the network is a XCiT
|
207 |
+
"""elif args.arch in torch.hub.list("facebookresearch/xcit:main"):
|
208 |
+
student = torch.hub.load('facebookresearch/xcit:main', args.arch,
|
209 |
+
pretrained=False, drop_path_rate=args.drop_path_rate)
|
210 |
+
teacher = torch.hub.load('facebookresearch/xcit:main', args.arch, pretrained=False)
|
211 |
+
embed_dim = student.embed_dim
|
212 |
+
# otherwise, we check if the architecture is in torchvision models
|
213 |
+
elif args.arch in torchvision_models.__dict__.keys():
|
214 |
+
student = torchvision_models.__dict__[args.arch]()
|
215 |
+
teacher = torchvision_models.__dict__[args.arch]()
|
216 |
+
embed_dim = student.fc.weight.shape[1]
|
217 |
+
else:
|
218 |
+
print(f"Unknow architecture: {args.arch}")"""
|
219 |
+
|
220 |
+
# multi-crop wrapper handles forward with inputs of different resolutions
|
221 |
+
student = utils.MultiCropWrapper(student, DINOHead(
|
222 |
+
embed_dim,
|
223 |
+
args.out_dim,
|
224 |
+
use_bn=args.use_bn_in_head,
|
225 |
+
norm_last_layer=args.norm_last_layer,
|
226 |
+
))
|
227 |
+
teacher = utils.MultiCropWrapper(
|
228 |
+
teacher,
|
229 |
+
DINOHead(embed_dim, args.out_dim, args.use_bn_in_head),
|
230 |
+
)
|
231 |
+
# move networks to gpu
|
232 |
+
student, teacher = student.cuda(), teacher.cuda()
|
233 |
+
# synchronize batch norms (if any)
|
234 |
+
if utils.has_batchnorms(student):
|
235 |
+
student = nn.SyncBatchNorm.convert_sync_batchnorm(student)
|
236 |
+
teacher = nn.SyncBatchNorm.convert_sync_batchnorm(teacher)
|
237 |
+
|
238 |
+
# we need DDP wrapper to have synchro batch norms working...
|
239 |
+
teacher = nn.parallel.DistributedDataParallel(teacher, device_ids=[args.gpu])
|
240 |
+
teacher_without_ddp = teacher.module
|
241 |
+
else:
|
242 |
+
# teacher_without_ddp and teacher are the same thing
|
243 |
+
teacher_without_ddp = teacher
|
244 |
+
student = nn.parallel.DistributedDataParallel(student, device_ids=[args.gpu])
|
245 |
+
# teacher and student start with the same weights
|
246 |
+
teacher_without_ddp.load_state_dict(student.module.state_dict())
|
247 |
+
# there is no backpropagation through the teacher, so no need for gradients
|
248 |
+
for p in teacher.parameters():
|
249 |
+
p.requires_grad = False
|
250 |
+
print(f"Student and Teacher are built: they are both {args.arch} network.")
|
251 |
+
|
252 |
+
# ============ preparing loss ... ============
|
253 |
+
dino_loss = DINOLoss(
|
254 |
+
args.out_dim,
|
255 |
+
args.local_crops_number + 2, # total number of crops = 2 global crops + local_crops_number
|
256 |
+
args.warmup_teacher_temp,
|
257 |
+
args.teacher_temp,
|
258 |
+
args.warmup_teacher_temp_epochs,
|
259 |
+
args.epochs,
|
260 |
+
).cuda()
|
261 |
+
|
262 |
+
# ============ preparing optimizer ... ============
|
263 |
+
params_groups = utils.get_params_groups(student)
|
264 |
+
if args.optimizer == "adamw":
|
265 |
+
optimizer = torch.optim.AdamW(params_groups) # to use with ViTs
|
266 |
+
elif args.optimizer == "sgd":
|
267 |
+
optimizer = torch.optim.SGD(params_groups, lr=0, momentum=0.9) # lr is set by scheduler
|
268 |
+
elif args.optimizer == "lars":
|
269 |
+
optimizer = utils.LARS(params_groups) # to use with convnet and large batches
|
270 |
+
# for mixed precision training
|
271 |
+
fp16_scaler = None
|
272 |
+
if args.use_fp16:
|
273 |
+
fp16_scaler = torch.cuda.amp.GradScaler()
|
274 |
+
|
275 |
+
# ============ init schedulers ... ============
|
276 |
+
lr_schedule = utils.cosine_scheduler(
|
277 |
+
args.lr * (args.batch_size_per_gpu * utils.get_world_size()) / 256., # linear scaling rule
|
278 |
+
args.min_lr,
|
279 |
+
args.epochs, len(data_loader),
|
280 |
+
warmup_epochs=args.warmup_epochs,
|
281 |
+
)
|
282 |
+
wd_schedule = utils.cosine_scheduler(
|
283 |
+
args.weight_decay,
|
284 |
+
args.weight_decay_end,
|
285 |
+
args.epochs, len(data_loader),
|
286 |
+
)
|
287 |
+
# momentum parameter is increased to 1. during training with a cosine schedule
|
288 |
+
momentum_schedule = utils.cosine_scheduler(args.momentum_teacher, 1,
|
289 |
+
args.epochs, len(data_loader))
|
290 |
+
print(f"Loss, optimizer and schedulers ready.")
|
291 |
+
|
292 |
+
# ============ optionally resume training ... ============
|
293 |
+
to_restore = {"epoch": 0}
|
294 |
+
utils.restart_from_checkpoint(
|
295 |
+
os.path.join(args.output_dir, "checkpoint.pth"),
|
296 |
+
run_variables=to_restore,
|
297 |
+
student=student,
|
298 |
+
teacher=teacher,
|
299 |
+
optimizer=optimizer,
|
300 |
+
fp16_scaler=fp16_scaler,
|
301 |
+
dino_loss=dino_loss,
|
302 |
+
)
|
303 |
+
start_epoch = to_restore["epoch"]
|
304 |
+
|
305 |
+
start_time = time.time()
|
306 |
+
print("Starting DINO training !")
|
307 |
+
for epoch in range(start_epoch, args.epochs):
|
308 |
+
data_loader.sampler.set_epoch(epoch)
|
309 |
+
|
310 |
+
# ============ training one epoch of DINO ... ============
|
311 |
+
train_stats = train_one_epoch(student, teacher, teacher_without_ddp, dino_loss,
|
312 |
+
data_loader, optimizer, lr_schedule, wd_schedule, momentum_schedule,
|
313 |
+
epoch, fp16_scaler, args)
|
314 |
+
|
315 |
+
# ============ writing logs ... ============
|
316 |
+
save_dict = {
|
317 |
+
'student': student.state_dict(),
|
318 |
+
'teacher': teacher.state_dict(),
|
319 |
+
'optimizer': optimizer.state_dict(),
|
320 |
+
'epoch': epoch + 1,
|
321 |
+
'args': args,
|
322 |
+
'dino_loss': dino_loss.state_dict(),
|
323 |
+
}
|
324 |
+
if fp16_scaler is not None:
|
325 |
+
save_dict['fp16_scaler'] = fp16_scaler.state_dict()
|
326 |
+
utils.save_on_master(save_dict, os.path.join(args.output_dir, 'checkpoint.pth'))
|
327 |
+
if args.saveckp_freq and (epoch % args.saveckp_freq == 0 or epoch + 1 == args.epochs):
|
328 |
+
initial_setting = os.path.split(args.basic_state_dict)[1].split('.')[0] # 'ViT_b16_224_Imagenet'
|
329 |
+
dataset_using = os.path.split(args.data_path)[1]
|
330 |
+
utils.save_on_master(save_dict, os.path.join(args.output_dir, f'dino_'+initial_setting + '_' + dataset_using+f'_checkpoint{epoch:04}.pth'))
|
331 |
+
|
332 |
+
|
333 |
+
|
334 |
+
|
335 |
+
log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
|
336 |
+
'epoch': epoch}
|
337 |
+
if utils.is_main_process():
|
338 |
+
with (Path(args.output_dir) / "log.txt").open("a") as f:
|
339 |
+
f.write(json.dumps(log_stats) + "\n")
|
340 |
+
total_time = time.time() - start_time
|
341 |
+
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
|
342 |
+
print('Training time {}'.format(total_time_str))
|
343 |
+
|
344 |
+
|
345 |
+
def train_one_epoch(student, teacher, teacher_without_ddp, dino_loss, data_loader,
|
346 |
+
optimizer, lr_schedule, wd_schedule, momentum_schedule,epoch,
|
347 |
+
fp16_scaler, args):
|
348 |
+
metric_logger = utils.MetricLogger(delimiter=" ")
|
349 |
+
header = 'Epoch: [{}/{}]'.format(epoch, args.epochs)
|
350 |
+
for it, (images, _) in enumerate(metric_logger.log_every(data_loader, 10, header)):
|
351 |
+
# update weight decay and learning rate according to their schedule
|
352 |
+
it = len(data_loader) * epoch + it # global training iteration
|
353 |
+
for i, param_group in enumerate(optimizer.param_groups):
|
354 |
+
param_group["lr"] = lr_schedule[it]
|
355 |
+
if i == 0: # only the first group is regularized
|
356 |
+
param_group["weight_decay"] = wd_schedule[it]
|
357 |
+
|
358 |
+
# move images to gpu
|
359 |
+
images = [im.cuda(non_blocking=True) for im in images]
|
360 |
+
# teacher and student forward passes + compute dino loss
|
361 |
+
with torch.cuda.amp.autocast(fp16_scaler is not None):
|
362 |
+
teacher_output = teacher(images[:2]) # only the 2 global views pass through the teacher
|
363 |
+
student_output = student(images)
|
364 |
+
loss = dino_loss(student_output, teacher_output, epoch)
|
365 |
+
|
366 |
+
if not math.isfinite(loss.item()):
|
367 |
+
print("Loss is {}, stopping training".format(loss.item()), force=True)
|
368 |
+
sys.exit(1)
|
369 |
+
|
370 |
+
# student update
|
371 |
+
optimizer.zero_grad()
|
372 |
+
param_norms = None
|
373 |
+
if fp16_scaler is None:
|
374 |
+
loss.backward()
|
375 |
+
if args.clip_grad:
|
376 |
+
param_norms = utils.clip_gradients(student, args.clip_grad)
|
377 |
+
utils.cancel_gradients_last_layer(epoch, student,
|
378 |
+
args.freeze_last_layer)
|
379 |
+
optimizer.step()
|
380 |
+
else:
|
381 |
+
fp16_scaler.scale(loss).backward()
|
382 |
+
if args.clip_grad:
|
383 |
+
fp16_scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place
|
384 |
+
param_norms = utils.clip_gradients(student, args.clip_grad)
|
385 |
+
utils.cancel_gradients_last_layer(epoch, student,
|
386 |
+
args.freeze_last_layer)
|
387 |
+
fp16_scaler.step(optimizer)
|
388 |
+
fp16_scaler.update()
|
389 |
+
|
390 |
+
# EMA update for the teacher
|
391 |
+
with torch.no_grad():
|
392 |
+
m = momentum_schedule[it] # momentum parameter
|
393 |
+
for param_q, param_k in zip(student.module.parameters(), teacher_without_ddp.parameters()):
|
394 |
+
param_k.data.mul_(m).add_((1 - m) * param_q.detach().data)
|
395 |
+
|
396 |
+
# logging
|
397 |
+
torch.cuda.synchronize()
|
398 |
+
metric_logger.update(loss=loss.item())
|
399 |
+
metric_logger.update(lr=optimizer.param_groups[0]["lr"])
|
400 |
+
metric_logger.update(wd=optimizer.param_groups[0]["weight_decay"])
|
401 |
+
# gather the stats from all processes
|
402 |
+
metric_logger.synchronize_between_processes()
|
403 |
+
print("Averaged stats:", metric_logger)
|
404 |
+
return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
|
405 |
+
|
406 |
+
|
407 |
+
class DINOLoss(nn.Module):
|
408 |
+
def __init__(self, out_dim, ncrops, warmup_teacher_temp, teacher_temp,
|
409 |
+
warmup_teacher_temp_epochs, nepochs, student_temp=0.1,
|
410 |
+
center_momentum=0.9):
|
411 |
+
super().__init__()
|
412 |
+
self.student_temp = student_temp
|
413 |
+
self.center_momentum = center_momentum
|
414 |
+
self.ncrops = ncrops
|
415 |
+
self.register_buffer("center", torch.zeros(1, out_dim))
|
416 |
+
# we apply a warm up for the teacher temperature because
|
417 |
+
# a too high temperature makes the training instable at the beginning
|
418 |
+
self.teacher_temp_schedule = np.concatenate((
|
419 |
+
np.linspace(warmup_teacher_temp,
|
420 |
+
teacher_temp, warmup_teacher_temp_epochs),
|
421 |
+
np.ones(nepochs - warmup_teacher_temp_epochs) * teacher_temp
|
422 |
+
))
|
423 |
+
|
424 |
+
def forward(self, student_output, teacher_output, epoch):
|
425 |
+
"""
|
426 |
+
Cross-entropy between softmax outputs of the teacher and student networks.
|
427 |
+
"""
|
428 |
+
student_out = student_output / self.student_temp
|
429 |
+
student_out = student_out.chunk(self.ncrops)
|
430 |
+
|
431 |
+
# teacher centering and sharpening
|
432 |
+
temp = self.teacher_temp_schedule[epoch]
|
433 |
+
teacher_out = F.softmax((teacher_output - self.center) / temp, dim=-1)
|
434 |
+
teacher_out = teacher_out.detach().chunk(2)
|
435 |
+
|
436 |
+
total_loss = 0
|
437 |
+
n_loss_terms = 0
|
438 |
+
for iq, q in enumerate(teacher_out):
|
439 |
+
for v in range(len(student_out)):
|
440 |
+
if v == iq:
|
441 |
+
# we skip cases where student and teacher operate on the same view
|
442 |
+
continue
|
443 |
+
loss = torch.sum(-q * F.log_softmax(student_out[v], dim=-1), dim=-1)
|
444 |
+
total_loss += loss.mean()
|
445 |
+
n_loss_terms += 1
|
446 |
+
total_loss /= n_loss_terms
|
447 |
+
self.update_center(teacher_output)
|
448 |
+
return total_loss
|
449 |
+
|
450 |
+
@torch.no_grad()
|
451 |
+
def update_center(self, teacher_output):
|
452 |
+
"""
|
453 |
+
Update center used for teacher output.
|
454 |
+
"""
|
455 |
+
batch_center = torch.sum(teacher_output, dim=0, keepdim=True)
|
456 |
+
dist.all_reduce(batch_center)
|
457 |
+
batch_center = batch_center / (len(teacher_output) * dist.get_world_size())
|
458 |
+
|
459 |
+
# ema update
|
460 |
+
self.center = self.center * self.center_momentum + batch_center * (1 - self.center_momentum)
|
461 |
+
|
462 |
+
|
463 |
+
class DataAugmentationDINO(object):
|
464 |
+
def __init__(self, global_crops_scale, local_crops_scale, local_crops_number):
|
465 |
+
flip_and_color_jitter = transforms.Compose([
|
466 |
+
transforms.RandomHorizontalFlip(p=0.5),
|
467 |
+
transforms.RandomApply(
|
468 |
+
[transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.2, hue=0.1)],
|
469 |
+
p=0.8
|
470 |
+
),
|
471 |
+
transforms.RandomGrayscale(p=0.2),
|
472 |
+
])
|
473 |
+
normalize = transforms.Compose([
|
474 |
+
transforms.ToTensor(),
|
475 |
+
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
|
476 |
+
])
|
477 |
+
|
478 |
+
# first global crop
|
479 |
+
self.global_transfo1 = transforms.Compose([
|
480 |
+
transforms.RandomResizedCrop(224, scale=global_crops_scale, interpolation=Image.BICUBIC),
|
481 |
+
flip_and_color_jitter,
|
482 |
+
utils.GaussianBlur(1.0),
|
483 |
+
normalize,
|
484 |
+
])
|
485 |
+
# second global crop
|
486 |
+
self.global_transfo2 = transforms.Compose([
|
487 |
+
transforms.RandomResizedCrop(224, scale=global_crops_scale, interpolation=Image.BICUBIC),
|
488 |
+
flip_and_color_jitter,
|
489 |
+
utils.GaussianBlur(0.1),
|
490 |
+
utils.Solarization(0.2),
|
491 |
+
normalize,
|
492 |
+
])
|
493 |
+
# transformation for the local small crops
|
494 |
+
self.local_crops_number = local_crops_number
|
495 |
+
self.local_transfo = transforms.Compose([
|
496 |
+
transforms.RandomResizedCrop(96, scale=local_crops_scale, interpolation=Image.BICUBIC),
|
497 |
+
flip_and_color_jitter,
|
498 |
+
utils.GaussianBlur(p=0.5),
|
499 |
+
normalize,
|
500 |
+
])
|
501 |
+
|
502 |
+
def __call__(self, image):
|
503 |
+
crops = []
|
504 |
+
crops.append(self.global_transfo1(image))
|
505 |
+
crops.append(self.global_transfo2(image))
|
506 |
+
for _ in range(self.local_crops_number):
|
507 |
+
crops.append(self.local_transfo(image))
|
508 |
+
return crops
|
509 |
+
|
510 |
+
|
511 |
+
if __name__ == '__main__':
|
512 |
+
parser = argparse.ArgumentParser('DINO', parents=[get_args_parser()])
|
513 |
+
args = parser.parse_args()
|
514 |
+
Path(args.output_dir).mkdir(parents=True, exist_ok=True)
|
515 |
+
train_dino(args)
|
PuzzleTuning/Counterpart PreTrain Methods/dino-main/run_with_submitit.py
ADDED
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
"""
|
15 |
+
A script to run multinode training with submitit.
|
16 |
+
Almost copy-paste from https://github.com/facebookresearch/deit/blob/main/run_with_submitit.py
|
17 |
+
"""
|
18 |
+
import argparse
|
19 |
+
import os
|
20 |
+
import uuid
|
21 |
+
from pathlib import Path
|
22 |
+
|
23 |
+
import main_dino
|
24 |
+
import submitit
|
25 |
+
|
26 |
+
|
27 |
+
def parse_args():
|
28 |
+
parser = argparse.ArgumentParser("Submitit for DINO", parents=[main_dino.get_args_parser()])
|
29 |
+
parser.add_argument("--ngpus", default=8, type=int, help="Number of gpus to request on each node")
|
30 |
+
parser.add_argument("--nodes", default=2, type=int, help="Number of nodes to request")
|
31 |
+
parser.add_argument("--timeout", default=2800, type=int, help="Duration of the job")
|
32 |
+
|
33 |
+
parser.add_argument("--partition", default="learnfair", type=str, help="Partition where to submit")
|
34 |
+
parser.add_argument("--use_volta32", action='store_true', help="Big models? Use this")
|
35 |
+
parser.add_argument('--comment', default="", type=str,
|
36 |
+
help='Comment to pass to scheduler, e.g. priority message')
|
37 |
+
return parser.parse_args()
|
38 |
+
|
39 |
+
|
40 |
+
def get_shared_folder() -> Path:
|
41 |
+
user = os.getenv("USER")
|
42 |
+
if Path("/checkpoint/").is_dir():
|
43 |
+
p = Path(f"/checkpoint/{user}/experiments")
|
44 |
+
p.mkdir(exist_ok=True)
|
45 |
+
return p
|
46 |
+
raise RuntimeError("No shared folder available")
|
47 |
+
|
48 |
+
|
49 |
+
def get_init_file():
|
50 |
+
# Init file must not exist, but it's parent dir must exist.
|
51 |
+
os.makedirs(str(get_shared_folder()), exist_ok=True)
|
52 |
+
init_file = get_shared_folder() / f"{uuid.uuid4().hex}_init"
|
53 |
+
if init_file.exists():
|
54 |
+
os.remove(str(init_file))
|
55 |
+
return init_file
|
56 |
+
|
57 |
+
|
58 |
+
class Trainer(object):
|
59 |
+
def __init__(self, args):
|
60 |
+
self.args = args
|
61 |
+
|
62 |
+
def __call__(self):
|
63 |
+
import main_dino
|
64 |
+
|
65 |
+
self._setup_gpu_args()
|
66 |
+
main_dino.train_dino(self.args)
|
67 |
+
|
68 |
+
def checkpoint(self):
|
69 |
+
import os
|
70 |
+
import submitit
|
71 |
+
|
72 |
+
self.args.dist_url = get_init_file().as_uri()
|
73 |
+
print("Requeuing ", self.args)
|
74 |
+
empty_trainer = type(self)(self.args)
|
75 |
+
return submitit.helpers.DelayedSubmission(empty_trainer)
|
76 |
+
|
77 |
+
def _setup_gpu_args(self):
|
78 |
+
import submitit
|
79 |
+
from pathlib import Path
|
80 |
+
|
81 |
+
job_env = submitit.JobEnvironment()
|
82 |
+
self.args.output_dir = Path(str(self.args.output_dir).replace("%j", str(job_env.job_id)))
|
83 |
+
self.args.gpu = job_env.local_rank
|
84 |
+
self.args.rank = job_env.global_rank
|
85 |
+
self.args.world_size = job_env.num_tasks
|
86 |
+
print(f"Process group: {job_env.num_tasks} tasks, rank: {job_env.global_rank}")
|
87 |
+
|
88 |
+
|
89 |
+
def main():
|
90 |
+
args = parse_args()
|
91 |
+
if args.output_dir == "":
|
92 |
+
args.output_dir = get_shared_folder() / "%j"
|
93 |
+
Path(args.output_dir).mkdir(parents=True, exist_ok=True)
|
94 |
+
executor = submitit.AutoExecutor(folder=args.output_dir, slurm_max_num_timeout=30)
|
95 |
+
|
96 |
+
num_gpus_per_node = args.ngpus
|
97 |
+
nodes = args.nodes
|
98 |
+
timeout_min = args.timeout
|
99 |
+
|
100 |
+
partition = args.partition
|
101 |
+
kwargs = {}
|
102 |
+
if args.use_volta32:
|
103 |
+
kwargs['slurm_constraint'] = 'volta32gb'
|
104 |
+
if args.comment:
|
105 |
+
kwargs['slurm_comment'] = args.comment
|
106 |
+
|
107 |
+
executor.update_parameters(
|
108 |
+
mem_gb=40 * num_gpus_per_node,
|
109 |
+
gpus_per_node=num_gpus_per_node,
|
110 |
+
tasks_per_node=num_gpus_per_node, # one task per GPU
|
111 |
+
cpus_per_task=10,
|
112 |
+
nodes=nodes,
|
113 |
+
timeout_min=timeout_min, # max is 60 * 72
|
114 |
+
# Below are cluster dependent parameters
|
115 |
+
slurm_partition=partition,
|
116 |
+
slurm_signal_delay_s=120,
|
117 |
+
**kwargs
|
118 |
+
)
|
119 |
+
|
120 |
+
executor.update_parameters(name="dino")
|
121 |
+
|
122 |
+
args.dist_url = get_init_file().as_uri()
|
123 |
+
|
124 |
+
trainer = Trainer(args)
|
125 |
+
job = executor.submit(trainer)
|
126 |
+
|
127 |
+
print(f"Submitted job_id: {job.job_id}")
|
128 |
+
print(f"Logs and checkpoints will be saved at: {args.output_dir}")
|
129 |
+
|
130 |
+
|
131 |
+
if __name__ == "__main__":
|
132 |
+
main()
|
PuzzleTuning/Counterpart PreTrain Methods/dino-main/utils.py
ADDED
@@ -0,0 +1,829 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
"""
|
15 |
+
Misc functions.
|
16 |
+
|
17 |
+
Mostly copy-paste from torchvision references or other public repos like DETR:
|
18 |
+
https://github.com/facebookresearch/detr/blob/master/util/misc.py
|
19 |
+
"""
|
20 |
+
import os
|
21 |
+
import sys
|
22 |
+
import time
|
23 |
+
import math
|
24 |
+
import random
|
25 |
+
import datetime
|
26 |
+
import subprocess
|
27 |
+
from collections import defaultdict, deque
|
28 |
+
|
29 |
+
import numpy as np
|
30 |
+
import torch
|
31 |
+
from torch import nn
|
32 |
+
import torch.distributed as dist
|
33 |
+
from PIL import ImageFilter, ImageOps
|
34 |
+
|
35 |
+
|
36 |
+
class GaussianBlur(object):
|
37 |
+
"""
|
38 |
+
Apply Gaussian Blur to the PIL image.
|
39 |
+
"""
|
40 |
+
def __init__(self, p=0.5, radius_min=0.1, radius_max=2.):
|
41 |
+
self.prob = p
|
42 |
+
self.radius_min = radius_min
|
43 |
+
self.radius_max = radius_max
|
44 |
+
|
45 |
+
def __call__(self, img):
|
46 |
+
do_it = random.random() <= self.prob
|
47 |
+
if not do_it:
|
48 |
+
return img
|
49 |
+
|
50 |
+
return img.filter(
|
51 |
+
ImageFilter.GaussianBlur(
|
52 |
+
radius=random.uniform(self.radius_min, self.radius_max)
|
53 |
+
)
|
54 |
+
)
|
55 |
+
|
56 |
+
|
57 |
+
class Solarization(object):
|
58 |
+
"""
|
59 |
+
Apply Solarization to the PIL image.
|
60 |
+
"""
|
61 |
+
def __init__(self, p):
|
62 |
+
self.p = p
|
63 |
+
|
64 |
+
def __call__(self, img):
|
65 |
+
if random.random() < self.p:
|
66 |
+
return ImageOps.solarize(img)
|
67 |
+
else:
|
68 |
+
return img
|
69 |
+
|
70 |
+
|
71 |
+
def load_pretrained_weights(model, pretrained_weights, checkpoint_key, model_name, patch_size):
|
72 |
+
if os.path.isfile(pretrained_weights):
|
73 |
+
state_dict = torch.load(pretrained_weights, map_location="cpu")
|
74 |
+
if checkpoint_key is not None and checkpoint_key in state_dict:
|
75 |
+
print(f"Take key {checkpoint_key} in provided checkpoint dict")
|
76 |
+
state_dict = state_dict[checkpoint_key]
|
77 |
+
# remove `module.` prefix
|
78 |
+
state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
|
79 |
+
# remove `backbone.` prefix induced by multicrop wrapper
|
80 |
+
state_dict = {k.replace("backbone.", ""): v for k, v in state_dict.items()}
|
81 |
+
msg = model.load_state_dict(state_dict, strict=False)
|
82 |
+
print('Pretrained weights found at {} and loaded with msg: {}'.format(pretrained_weights, msg))
|
83 |
+
else:
|
84 |
+
print("Please use the `--pretrained_weights` argument to indicate the path of the checkpoint to evaluate.")
|
85 |
+
url = None
|
86 |
+
if model_name == "vit_small" and patch_size == 16:
|
87 |
+
url = "dino_deitsmall16_pretrain/dino_deitsmall16_pretrain.pth"
|
88 |
+
elif model_name == "vit_small" and patch_size == 8:
|
89 |
+
url = "dino_deitsmall8_pretrain/dino_deitsmall8_pretrain.pth"
|
90 |
+
elif model_name == "vit_base" and patch_size == 16:
|
91 |
+
url = "dino_vitbase16_pretrain/dino_vitbase16_pretrain.pth"
|
92 |
+
elif model_name == "vit_base" and patch_size == 8:
|
93 |
+
url = "dino_vitbase8_pretrain/dino_vitbase8_pretrain.pth"
|
94 |
+
elif model_name == "xcit_small_12_p16":
|
95 |
+
url = "dino_xcit_small_12_p16_pretrain/dino_xcit_small_12_p16_pretrain.pth"
|
96 |
+
elif model_name == "xcit_small_12_p8":
|
97 |
+
url = "dino_xcit_small_12_p8_pretrain/dino_xcit_small_12_p8_pretrain.pth"
|
98 |
+
elif model_name == "xcit_medium_24_p16":
|
99 |
+
url = "dino_xcit_medium_24_p16_pretrain/dino_xcit_medium_24_p16_pretrain.pth"
|
100 |
+
elif model_name == "xcit_medium_24_p8":
|
101 |
+
url = "dino_xcit_medium_24_p8_pretrain/dino_xcit_medium_24_p8_pretrain.pth"
|
102 |
+
elif model_name == "resnet50":
|
103 |
+
url = "dino_resnet50_pretrain/dino_resnet50_pretrain.pth"
|
104 |
+
if url is not None:
|
105 |
+
print("Since no pretrained weights have been provided, we load the reference pretrained DINO weights.")
|
106 |
+
state_dict = torch.hub.load_state_dict_from_url(url="https://dl.fbaipublicfiles.com/dino/" + url)
|
107 |
+
model.load_state_dict(state_dict, strict=True)
|
108 |
+
else:
|
109 |
+
print("There is no reference weights available for this model => We use random weights.")
|
110 |
+
|
111 |
+
|
112 |
+
def load_pretrained_linear_weights(linear_classifier, model_name, patch_size):
|
113 |
+
url = None
|
114 |
+
if model_name == "vit_small" and patch_size == 16:
|
115 |
+
url = "dino_deitsmall16_pretrain/dino_deitsmall16_linearweights.pth"
|
116 |
+
elif model_name == "vit_small" and patch_size == 8:
|
117 |
+
url = "dino_deitsmall8_pretrain/dino_deitsmall8_linearweights.pth"
|
118 |
+
elif model_name == "vit_base" and patch_size == 16:
|
119 |
+
url = "dino_vitbase16_pretrain/dino_vitbase16_linearweights.pth"
|
120 |
+
elif model_name == "vit_base" and patch_size == 8:
|
121 |
+
url = "dino_vitbase8_pretrain/dino_vitbase8_linearweights.pth"
|
122 |
+
elif model_name == "resnet50":
|
123 |
+
url = "dino_resnet50_pretrain/dino_resnet50_linearweights.pth"
|
124 |
+
if url is not None:
|
125 |
+
print("We load the reference pretrained linear weights.")
|
126 |
+
state_dict = torch.hub.load_state_dict_from_url(url="https://dl.fbaipublicfiles.com/dino/" + url)["state_dict"]
|
127 |
+
linear_classifier.load_state_dict(state_dict, strict=True)
|
128 |
+
else:
|
129 |
+
print("We use random linear weights.")
|
130 |
+
|
131 |
+
|
132 |
+
def clip_gradients(model, clip):
|
133 |
+
norms = []
|
134 |
+
for name, p in model.named_parameters():
|
135 |
+
if p.grad is not None:
|
136 |
+
param_norm = p.grad.data.norm(2)
|
137 |
+
norms.append(param_norm.item())
|
138 |
+
clip_coef = clip / (param_norm + 1e-6)
|
139 |
+
if clip_coef < 1:
|
140 |
+
p.grad.data.mul_(clip_coef)
|
141 |
+
return norms
|
142 |
+
|
143 |
+
|
144 |
+
def cancel_gradients_last_layer(epoch, model, freeze_last_layer):
|
145 |
+
if epoch >= freeze_last_layer:
|
146 |
+
return
|
147 |
+
for n, p in model.named_parameters():
|
148 |
+
if "last_layer" in n:
|
149 |
+
p.grad = None
|
150 |
+
|
151 |
+
|
152 |
+
def restart_from_checkpoint(ckp_path, run_variables=None, **kwargs):
|
153 |
+
"""
|
154 |
+
Re-start from checkpoint
|
155 |
+
"""
|
156 |
+
if not os.path.isfile(ckp_path):
|
157 |
+
return
|
158 |
+
print("Found checkpoint at {}".format(ckp_path))
|
159 |
+
|
160 |
+
# open checkpoint file
|
161 |
+
checkpoint = torch.load(ckp_path, map_location="cpu")
|
162 |
+
|
163 |
+
# key is what to look for in the checkpoint file
|
164 |
+
# value is the object to load
|
165 |
+
# example: {'state_dict': model}
|
166 |
+
for key, value in kwargs.items():
|
167 |
+
if key in checkpoint and value is not None:
|
168 |
+
try:
|
169 |
+
msg = value.load_state_dict(checkpoint[key], strict=False)
|
170 |
+
print("=> loaded '{}' from checkpoint '{}' with msg {}".format(key, ckp_path, msg))
|
171 |
+
except TypeError:
|
172 |
+
try:
|
173 |
+
msg = value.load_state_dict(checkpoint[key])
|
174 |
+
print("=> loaded '{}' from checkpoint: '{}'".format(key, ckp_path))
|
175 |
+
except ValueError:
|
176 |
+
print("=> failed to load '{}' from checkpoint: '{}'".format(key, ckp_path))
|
177 |
+
else:
|
178 |
+
print("=> key '{}' not found in checkpoint: '{}'".format(key, ckp_path))
|
179 |
+
|
180 |
+
# re load variable important for the run
|
181 |
+
if run_variables is not None:
|
182 |
+
for var_name in run_variables:
|
183 |
+
if var_name in checkpoint:
|
184 |
+
run_variables[var_name] = checkpoint[var_name]
|
185 |
+
|
186 |
+
|
187 |
+
def cosine_scheduler(base_value, final_value, epochs, niter_per_ep, warmup_epochs=0, start_warmup_value=0):
|
188 |
+
warmup_schedule = np.array([])
|
189 |
+
warmup_iters = warmup_epochs * niter_per_ep
|
190 |
+
if warmup_epochs > 0:
|
191 |
+
warmup_schedule = np.linspace(start_warmup_value, base_value, warmup_iters)
|
192 |
+
|
193 |
+
iters = np.arange(epochs * niter_per_ep - warmup_iters)
|
194 |
+
schedule = final_value + 0.5 * (base_value - final_value) * (1 + np.cos(np.pi * iters / len(iters)))
|
195 |
+
|
196 |
+
schedule = np.concatenate((warmup_schedule, schedule))
|
197 |
+
assert len(schedule) == epochs * niter_per_ep
|
198 |
+
return schedule
|
199 |
+
|
200 |
+
|
201 |
+
def bool_flag(s):
|
202 |
+
"""
|
203 |
+
Parse boolean arguments from the command line.
|
204 |
+
"""
|
205 |
+
FALSY_STRINGS = {"off", "false", "0"}
|
206 |
+
TRUTHY_STRINGS = {"on", "true", "1"}
|
207 |
+
if s.lower() in FALSY_STRINGS:
|
208 |
+
return False
|
209 |
+
elif s.lower() in TRUTHY_STRINGS:
|
210 |
+
return True
|
211 |
+
else:
|
212 |
+
raise argparse.ArgumentTypeError("invalid value for a boolean flag")
|
213 |
+
|
214 |
+
|
215 |
+
def fix_random_seeds(seed=31):
|
216 |
+
"""
|
217 |
+
Fix random seeds.
|
218 |
+
"""
|
219 |
+
torch.manual_seed(seed)
|
220 |
+
torch.cuda.manual_seed_all(seed)
|
221 |
+
np.random.seed(seed)
|
222 |
+
|
223 |
+
|
224 |
+
class SmoothedValue(object):
|
225 |
+
"""Track a series of values and provide access to smoothed values over a
|
226 |
+
window or the global series average.
|
227 |
+
"""
|
228 |
+
|
229 |
+
def __init__(self, window_size=20, fmt=None):
|
230 |
+
if fmt is None:
|
231 |
+
fmt = "{median:.6f} ({global_avg:.6f})"
|
232 |
+
self.deque = deque(maxlen=window_size)
|
233 |
+
self.total = 0.0
|
234 |
+
self.count = 0
|
235 |
+
self.fmt = fmt
|
236 |
+
|
237 |
+
def update(self, value, n=1):
|
238 |
+
self.deque.append(value)
|
239 |
+
self.count += n
|
240 |
+
self.total += value * n
|
241 |
+
|
242 |
+
def synchronize_between_processes(self):
|
243 |
+
"""
|
244 |
+
Warning: does not synchronize the deque!
|
245 |
+
"""
|
246 |
+
if not is_dist_avail_and_initialized():
|
247 |
+
return
|
248 |
+
t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
|
249 |
+
dist.barrier()
|
250 |
+
dist.all_reduce(t)
|
251 |
+
t = t.tolist()
|
252 |
+
self.count = int(t[0])
|
253 |
+
self.total = t[1]
|
254 |
+
|
255 |
+
@property
|
256 |
+
def median(self):
|
257 |
+
d = torch.tensor(list(self.deque))
|
258 |
+
return d.median().item()
|
259 |
+
|
260 |
+
@property
|
261 |
+
def avg(self):
|
262 |
+
d = torch.tensor(list(self.deque), dtype=torch.float32)
|
263 |
+
return d.mean().item()
|
264 |
+
|
265 |
+
@property
|
266 |
+
def global_avg(self):
|
267 |
+
return self.total / self.count
|
268 |
+
|
269 |
+
@property
|
270 |
+
def max(self):
|
271 |
+
return max(self.deque)
|
272 |
+
|
273 |
+
@property
|
274 |
+
def value(self):
|
275 |
+
return self.deque[-1]
|
276 |
+
|
277 |
+
def __str__(self):
|
278 |
+
return self.fmt.format(
|
279 |
+
median=self.median,
|
280 |
+
avg=self.avg,
|
281 |
+
global_avg=self.global_avg,
|
282 |
+
max=self.max,
|
283 |
+
value=self.value)
|
284 |
+
|
285 |
+
|
286 |
+
def reduce_dict(input_dict, average=True):
|
287 |
+
"""
|
288 |
+
Args:
|
289 |
+
input_dict (dict): all the values will be reduced
|
290 |
+
average (bool): whether to do average or sum
|
291 |
+
Reduce the values in the dictionary from all processes so that all processes
|
292 |
+
have the averaged results. Returns a dict with the same fields as
|
293 |
+
input_dict, after reduction.
|
294 |
+
"""
|
295 |
+
world_size = get_world_size()
|
296 |
+
if world_size < 2:
|
297 |
+
return input_dict
|
298 |
+
with torch.no_grad():
|
299 |
+
names = []
|
300 |
+
values = []
|
301 |
+
# sort the keys so that they are consistent across processes
|
302 |
+
for k in sorted(input_dict.keys()):
|
303 |
+
names.append(k)
|
304 |
+
values.append(input_dict[k])
|
305 |
+
values = torch.stack(values, dim=0)
|
306 |
+
dist.all_reduce(values)
|
307 |
+
if average:
|
308 |
+
values /= world_size
|
309 |
+
reduced_dict = {k: v for k, v in zip(names, values)}
|
310 |
+
return reduced_dict
|
311 |
+
|
312 |
+
|
313 |
+
class MetricLogger(object):
|
314 |
+
def __init__(self, delimiter="\t"):
|
315 |
+
self.meters = defaultdict(SmoothedValue)
|
316 |
+
self.delimiter = delimiter
|
317 |
+
|
318 |
+
def update(self, **kwargs):
|
319 |
+
for k, v in kwargs.items():
|
320 |
+
if isinstance(v, torch.Tensor):
|
321 |
+
v = v.item()
|
322 |
+
assert isinstance(v, (float, int))
|
323 |
+
self.meters[k].update(v)
|
324 |
+
|
325 |
+
def __getattr__(self, attr):
|
326 |
+
if attr in self.meters:
|
327 |
+
return self.meters[attr]
|
328 |
+
if attr in self.__dict__:
|
329 |
+
return self.__dict__[attr]
|
330 |
+
raise AttributeError("'{}' object has no attribute '{}'".format(
|
331 |
+
type(self).__name__, attr))
|
332 |
+
|
333 |
+
def __str__(self):
|
334 |
+
loss_str = []
|
335 |
+
for name, meter in self.meters.items():
|
336 |
+
loss_str.append(
|
337 |
+
"{}: {}".format(name, str(meter))
|
338 |
+
)
|
339 |
+
return self.delimiter.join(loss_str)
|
340 |
+
|
341 |
+
def synchronize_between_processes(self):
|
342 |
+
for meter in self.meters.values():
|
343 |
+
meter.synchronize_between_processes()
|
344 |
+
|
345 |
+
def add_meter(self, name, meter):
|
346 |
+
self.meters[name] = meter
|
347 |
+
|
348 |
+
def log_every(self, iterable, print_freq, header=None):
|
349 |
+
i = 0
|
350 |
+
if not header:
|
351 |
+
header = ''
|
352 |
+
start_time = time.time()
|
353 |
+
end = time.time()
|
354 |
+
iter_time = SmoothedValue(fmt='{avg:.6f}')
|
355 |
+
data_time = SmoothedValue(fmt='{avg:.6f}')
|
356 |
+
space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
|
357 |
+
if torch.cuda.is_available():
|
358 |
+
log_msg = self.delimiter.join([
|
359 |
+
header,
|
360 |
+
'[{0' + space_fmt + '}/{1}]',
|
361 |
+
'eta: {eta}',
|
362 |
+
'{meters}',
|
363 |
+
'time: {time}',
|
364 |
+
'data: {data}',
|
365 |
+
'max mem: {memory:.0f}'
|
366 |
+
])
|
367 |
+
else:
|
368 |
+
log_msg = self.delimiter.join([
|
369 |
+
header,
|
370 |
+
'[{0' + space_fmt + '}/{1}]',
|
371 |
+
'eta: {eta}',
|
372 |
+
'{meters}',
|
373 |
+
'time: {time}',
|
374 |
+
'data: {data}'
|
375 |
+
])
|
376 |
+
MB = 1024.0 * 1024.0
|
377 |
+
for obj in iterable:
|
378 |
+
data_time.update(time.time() - end)
|
379 |
+
yield obj
|
380 |
+
iter_time.update(time.time() - end)
|
381 |
+
if i % print_freq == 0 or i == len(iterable) - 1:
|
382 |
+
eta_seconds = iter_time.global_avg * (len(iterable) - i)
|
383 |
+
eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
|
384 |
+
if torch.cuda.is_available():
|
385 |
+
print(log_msg.format(
|
386 |
+
i, len(iterable), eta=eta_string,
|
387 |
+
meters=str(self),
|
388 |
+
time=str(iter_time), data=str(data_time),
|
389 |
+
memory=torch.cuda.max_memory_allocated() / MB))
|
390 |
+
else:
|
391 |
+
print(log_msg.format(
|
392 |
+
i, len(iterable), eta=eta_string,
|
393 |
+
meters=str(self),
|
394 |
+
time=str(iter_time), data=str(data_time)))
|
395 |
+
i += 1
|
396 |
+
end = time.time()
|
397 |
+
total_time = time.time() - start_time
|
398 |
+
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
|
399 |
+
print('{} Total time: {} ({:.6f} s / it)'.format(
|
400 |
+
header, total_time_str, total_time / len(iterable)))
|
401 |
+
|
402 |
+
|
403 |
+
def get_sha():
|
404 |
+
cwd = os.path.dirname(os.path.abspath(__file__))
|
405 |
+
|
406 |
+
def _run(command):
|
407 |
+
return subprocess.check_output(command, cwd=cwd).decode('ascii').strip()
|
408 |
+
sha = 'N/A'
|
409 |
+
diff = "clean"
|
410 |
+
branch = 'N/A'
|
411 |
+
try:
|
412 |
+
sha = _run(['git', 'rev-parse', 'HEAD'])
|
413 |
+
subprocess.check_output(['git', 'diff'], cwd=cwd)
|
414 |
+
diff = _run(['git', 'diff-index', 'HEAD'])
|
415 |
+
diff = "has uncommited changes" if diff else "clean"
|
416 |
+
branch = _run(['git', 'rev-parse', '--abbrev-ref', 'HEAD'])
|
417 |
+
except Exception:
|
418 |
+
pass
|
419 |
+
message = f"sha: {sha}, status: {diff}, branch: {branch}"
|
420 |
+
return message
|
421 |
+
|
422 |
+
|
423 |
+
def is_dist_avail_and_initialized():
|
424 |
+
if not dist.is_available():
|
425 |
+
return False
|
426 |
+
if not dist.is_initialized():
|
427 |
+
return False
|
428 |
+
return True
|
429 |
+
|
430 |
+
|
431 |
+
def get_world_size():
|
432 |
+
if not is_dist_avail_and_initialized():
|
433 |
+
return 1
|
434 |
+
return dist.get_world_size()
|
435 |
+
|
436 |
+
|
437 |
+
def get_rank():
|
438 |
+
if not is_dist_avail_and_initialized():
|
439 |
+
return 0
|
440 |
+
return dist.get_rank()
|
441 |
+
|
442 |
+
|
443 |
+
def is_main_process():
|
444 |
+
return get_rank() == 0
|
445 |
+
|
446 |
+
|
447 |
+
def save_on_master(*args, **kwargs):
|
448 |
+
if is_main_process():
|
449 |
+
torch.save(*args, **kwargs)
|
450 |
+
|
451 |
+
|
452 |
+
def setup_for_distributed(is_master):
|
453 |
+
"""
|
454 |
+
This function disables printing when not in master process
|
455 |
+
"""
|
456 |
+
import builtins as __builtin__
|
457 |
+
builtin_print = __builtin__.print
|
458 |
+
|
459 |
+
def print(*args, **kwargs):
|
460 |
+
force = kwargs.pop('force', False)
|
461 |
+
if is_master or force:
|
462 |
+
builtin_print(*args, **kwargs)
|
463 |
+
|
464 |
+
__builtin__.print = print
|
465 |
+
|
466 |
+
|
467 |
+
def init_distributed_mode(args):
|
468 |
+
# launched with torch.distributed.launch
|
469 |
+
if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
|
470 |
+
args.rank = int(os.environ["RANK"])
|
471 |
+
args.world_size = int(os.environ['WORLD_SIZE'])
|
472 |
+
args.gpu = int(os.environ['LOCAL_RANK'])
|
473 |
+
# launched with submitit on a slurm cluster
|
474 |
+
elif 'SLURM_PROCID' in os.environ:
|
475 |
+
args.rank = int(os.environ['SLURM_PROCID'])
|
476 |
+
args.gpu = args.rank % torch.cuda.device_count()
|
477 |
+
# launched naively with `python main_dino.py`
|
478 |
+
# we manually add MASTER_ADDR and MASTER_PORT to env variables
|
479 |
+
elif torch.cuda.is_available():
|
480 |
+
print('Will run the code on one GPU.')
|
481 |
+
args.rank, args.gpu, args.world_size = 0, 0, 1
|
482 |
+
os.environ['MASTER_ADDR'] = '127.0.0.1'
|
483 |
+
os.environ['MASTER_PORT'] = '29500'
|
484 |
+
else:
|
485 |
+
print('Does not support training without GPU.')
|
486 |
+
sys.exit(1)
|
487 |
+
|
488 |
+
dist.init_process_group(
|
489 |
+
backend="nccl",
|
490 |
+
init_method=args.dist_url,
|
491 |
+
world_size=args.world_size,
|
492 |
+
rank=args.rank,
|
493 |
+
)
|
494 |
+
|
495 |
+
torch.cuda.set_device(args.gpu)
|
496 |
+
print('| distributed init (rank {}): {}'.format(
|
497 |
+
args.rank, args.dist_url), flush=True)
|
498 |
+
dist.barrier()
|
499 |
+
setup_for_distributed(args.rank == 0)
|
500 |
+
|
501 |
+
|
502 |
+
def accuracy(output, target, topk=(1,)):
|
503 |
+
"""Computes the accuracy over the k top predictions for the specified values of k"""
|
504 |
+
maxk = max(topk)
|
505 |
+
batch_size = target.size(0)
|
506 |
+
_, pred = output.topk(maxk, 1, True, True)
|
507 |
+
pred = pred.t()
|
508 |
+
correct = pred.eq(target.reshape(1, -1).expand_as(pred))
|
509 |
+
return [correct[:k].reshape(-1).float().sum(0) * 100. / batch_size for k in topk]
|
510 |
+
|
511 |
+
|
512 |
+
def _no_grad_trunc_normal_(tensor, mean, std, a, b):
|
513 |
+
# Cut & paste from PyTorch official master until it's in a few official releases - RW
|
514 |
+
# Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
|
515 |
+
def norm_cdf(x):
|
516 |
+
# Computes standard normal cumulative distribution function
|
517 |
+
return (1. + math.erf(x / math.sqrt(2.))) / 2.
|
518 |
+
|
519 |
+
if (mean < a - 2 * std) or (mean > b + 2 * std):
|
520 |
+
warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
|
521 |
+
"The distribution of values may be incorrect.",
|
522 |
+
stacklevel=2)
|
523 |
+
|
524 |
+
with torch.no_grad():
|
525 |
+
# Values are generated by using a truncated uniform distribution and
|
526 |
+
# then using the inverse CDF for the normal distribution.
|
527 |
+
# Get upper and lower cdf values
|
528 |
+
l = norm_cdf((a - mean) / std)
|
529 |
+
u = norm_cdf((b - mean) / std)
|
530 |
+
|
531 |
+
# Uniformly fill tensor with values from [l, u], then translate to
|
532 |
+
# [2l-1, 2u-1].
|
533 |
+
tensor.uniform_(2 * l - 1, 2 * u - 1)
|
534 |
+
|
535 |
+
# Use inverse cdf transform for normal distribution to get truncated
|
536 |
+
# standard normal
|
537 |
+
tensor.erfinv_()
|
538 |
+
|
539 |
+
# Transform to proper mean, std
|
540 |
+
tensor.mul_(std * math.sqrt(2.))
|
541 |
+
tensor.add_(mean)
|
542 |
+
|
543 |
+
# Clamp to ensure it's in the proper range
|
544 |
+
tensor.clamp_(min=a, max=b)
|
545 |
+
return tensor
|
546 |
+
|
547 |
+
|
548 |
+
def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
|
549 |
+
# type: (Tensor, float, float, float, float) -> Tensor
|
550 |
+
return _no_grad_trunc_normal_(tensor, mean, std, a, b)
|
551 |
+
|
552 |
+
|
553 |
+
class LARS(torch.optim.Optimizer):
|
554 |
+
"""
|
555 |
+
Almost copy-paste from https://github.com/facebookresearch/barlowtwins/blob/main/main.py
|
556 |
+
"""
|
557 |
+
def __init__(self, params, lr=0, weight_decay=0, momentum=0.9, eta=0.001,
|
558 |
+
weight_decay_filter=None, lars_adaptation_filter=None):
|
559 |
+
defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum,
|
560 |
+
eta=eta, weight_decay_filter=weight_decay_filter,
|
561 |
+
lars_adaptation_filter=lars_adaptation_filter)
|
562 |
+
super().__init__(params, defaults)
|
563 |
+
|
564 |
+
@torch.no_grad()
|
565 |
+
def step(self):
|
566 |
+
for g in self.param_groups:
|
567 |
+
for p in g['params']:
|
568 |
+
dp = p.grad
|
569 |
+
|
570 |
+
if dp is None:
|
571 |
+
continue
|
572 |
+
|
573 |
+
if p.ndim != 1:
|
574 |
+
dp = dp.add(p, alpha=g['weight_decay'])
|
575 |
+
|
576 |
+
if p.ndim != 1:
|
577 |
+
param_norm = torch.norm(p)
|
578 |
+
update_norm = torch.norm(dp)
|
579 |
+
one = torch.ones_like(param_norm)
|
580 |
+
q = torch.where(param_norm > 0.,
|
581 |
+
torch.where(update_norm > 0,
|
582 |
+
(g['eta'] * param_norm / update_norm), one), one)
|
583 |
+
dp = dp.mul(q)
|
584 |
+
|
585 |
+
param_state = self.state[p]
|
586 |
+
if 'mu' not in param_state:
|
587 |
+
param_state['mu'] = torch.zeros_like(p)
|
588 |
+
mu = param_state['mu']
|
589 |
+
mu.mul_(g['momentum']).add_(dp)
|
590 |
+
|
591 |
+
p.add_(mu, alpha=-g['lr'])
|
592 |
+
|
593 |
+
|
594 |
+
class MultiCropWrapper(nn.Module):
|
595 |
+
"""
|
596 |
+
Perform forward pass separately on each resolution input.
|
597 |
+
The inputs corresponding to a single resolution are clubbed and single
|
598 |
+
forward is run on the same resolution inputs. Hence we do several
|
599 |
+
forward passes = number of different resolutions used. We then
|
600 |
+
concatenate all the output features and run the head forward on these
|
601 |
+
concatenated features.
|
602 |
+
"""
|
603 |
+
def __init__(self, backbone, head):
|
604 |
+
super(MultiCropWrapper, self).__init__()
|
605 |
+
# disable layers dedicated to ImageNet labels classification
|
606 |
+
backbone.fc, backbone.head = nn.Identity(), nn.Identity()
|
607 |
+
self.backbone = backbone
|
608 |
+
self.head = head
|
609 |
+
|
610 |
+
def forward(self, x):
|
611 |
+
# convert to list
|
612 |
+
if not isinstance(x, list):
|
613 |
+
x = [x]
|
614 |
+
idx_crops = torch.cumsum(torch.unique_consecutive(
|
615 |
+
torch.tensor([inp.shape[-1] for inp in x]),
|
616 |
+
return_counts=True,
|
617 |
+
)[1], 0)
|
618 |
+
start_idx, output = 0, torch.empty(0).to(x[0].device)
|
619 |
+
for end_idx in idx_crops:
|
620 |
+
_out = self.backbone(torch.cat(x[start_idx: end_idx]))
|
621 |
+
# The output is a tuple with XCiT model. See:
|
622 |
+
# https://github.com/facebookresearch/xcit/blob/master/xcit.py#L404-L405
|
623 |
+
if isinstance(_out, tuple):
|
624 |
+
_out = _out[0]
|
625 |
+
# accumulate outputs
|
626 |
+
output = torch.cat((output, _out))
|
627 |
+
start_idx = end_idx
|
628 |
+
# Run the head forward on the concatenated features.
|
629 |
+
return self.head(output)
|
630 |
+
|
631 |
+
|
632 |
+
def get_params_groups(model):
|
633 |
+
regularized = []
|
634 |
+
not_regularized = []
|
635 |
+
for name, param in model.named_parameters():
|
636 |
+
if not param.requires_grad:
|
637 |
+
continue
|
638 |
+
# we do not regularize biases nor Norm parameters
|
639 |
+
if name.endswith(".bias") or len(param.shape) == 1:
|
640 |
+
not_regularized.append(param)
|
641 |
+
else:
|
642 |
+
regularized.append(param)
|
643 |
+
return [{'params': regularized}, {'params': not_regularized, 'weight_decay': 0.}]
|
644 |
+
|
645 |
+
|
646 |
+
def has_batchnorms(model):
|
647 |
+
bn_types = (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, nn.SyncBatchNorm)
|
648 |
+
for name, module in model.named_modules():
|
649 |
+
if isinstance(module, bn_types):
|
650 |
+
return True
|
651 |
+
return False
|
652 |
+
|
653 |
+
|
654 |
+
class PCA():
|
655 |
+
"""
|
656 |
+
Class to compute and apply PCA.
|
657 |
+
"""
|
658 |
+
def __init__(self, dim=256, whit=0.5):
|
659 |
+
self.dim = dim
|
660 |
+
self.whit = whit
|
661 |
+
self.mean = None
|
662 |
+
|
663 |
+
def train_pca(self, cov):
|
664 |
+
"""
|
665 |
+
Takes a covariance matrix (np.ndarray) as input.
|
666 |
+
"""
|
667 |
+
d, v = np.linalg.eigh(cov)
|
668 |
+
eps = d.max() * 1e-5
|
669 |
+
n_0 = (d < eps).sum()
|
670 |
+
if n_0 > 0:
|
671 |
+
d[d < eps] = eps
|
672 |
+
|
673 |
+
# total energy
|
674 |
+
totenergy = d.sum()
|
675 |
+
|
676 |
+
# sort eigenvectors with eigenvalues order
|
677 |
+
idx = np.argsort(d)[::-1][:self.dim]
|
678 |
+
d = d[idx]
|
679 |
+
v = v[:, idx]
|
680 |
+
|
681 |
+
print("keeping %.2f %% of the energy" % (d.sum() / totenergy * 100.0))
|
682 |
+
|
683 |
+
# for the whitening
|
684 |
+
d = np.diag(1. / d**self.whit)
|
685 |
+
|
686 |
+
# principal components
|
687 |
+
self.dvt = np.dot(d, v.T)
|
688 |
+
|
689 |
+
def apply(self, x):
|
690 |
+
# input is from numpy
|
691 |
+
if isinstance(x, np.ndarray):
|
692 |
+
if self.mean is not None:
|
693 |
+
x -= self.mean
|
694 |
+
return np.dot(self.dvt, x.T).T
|
695 |
+
|
696 |
+
# input is from torch and is on GPU
|
697 |
+
if x.is_cuda:
|
698 |
+
if self.mean is not None:
|
699 |
+
x -= torch.cuda.FloatTensor(self.mean)
|
700 |
+
return torch.mm(torch.cuda.FloatTensor(self.dvt), x.transpose(0, 1)).transpose(0, 1)
|
701 |
+
|
702 |
+
# input if from torch, on CPU
|
703 |
+
if self.mean is not None:
|
704 |
+
x -= torch.FloatTensor(self.mean)
|
705 |
+
return torch.mm(torch.FloatTensor(self.dvt), x.transpose(0, 1)).transpose(0, 1)
|
706 |
+
|
707 |
+
|
708 |
+
def compute_ap(ranks, nres):
|
709 |
+
"""
|
710 |
+
Computes average precision for given ranked indexes.
|
711 |
+
Arguments
|
712 |
+
---------
|
713 |
+
ranks : zerro-based ranks of positive images
|
714 |
+
nres : number of positive images
|
715 |
+
Returns
|
716 |
+
-------
|
717 |
+
ap : average precision
|
718 |
+
"""
|
719 |
+
|
720 |
+
# number of images ranked by the system
|
721 |
+
nimgranks = len(ranks)
|
722 |
+
|
723 |
+
# accumulate trapezoids in PR-plot
|
724 |
+
ap = 0
|
725 |
+
|
726 |
+
recall_step = 1. / nres
|
727 |
+
|
728 |
+
for j in np.arange(nimgranks):
|
729 |
+
rank = ranks[j]
|
730 |
+
|
731 |
+
if rank == 0:
|
732 |
+
precision_0 = 1.
|
733 |
+
else:
|
734 |
+
precision_0 = float(j) / rank
|
735 |
+
|
736 |
+
precision_1 = float(j + 1) / (rank + 1)
|
737 |
+
|
738 |
+
ap += (precision_0 + precision_1) * recall_step / 2.
|
739 |
+
|
740 |
+
return ap
|
741 |
+
|
742 |
+
|
743 |
+
def compute_map(ranks, gnd, kappas=[]):
|
744 |
+
"""
|
745 |
+
Computes the mAP for a given set of returned results.
|
746 |
+
Usage:
|
747 |
+
map = compute_map (ranks, gnd)
|
748 |
+
computes mean average precsion (map) only
|
749 |
+
map, aps, pr, prs = compute_map (ranks, gnd, kappas)
|
750 |
+
computes mean average precision (map), average precision (aps) for each query
|
751 |
+
computes mean precision at kappas (pr), precision at kappas (prs) for each query
|
752 |
+
Notes:
|
753 |
+
1) ranks starts from 0, ranks.shape = db_size X #queries
|
754 |
+
2) The junk results (e.g., the query itself) should be declared in the gnd stuct array
|
755 |
+
3) If there are no positive images for some query, that query is excluded from the evaluation
|
756 |
+
"""
|
757 |
+
|
758 |
+
map = 0.
|
759 |
+
nq = len(gnd) # number of queries
|
760 |
+
aps = np.zeros(nq)
|
761 |
+
pr = np.zeros(len(kappas))
|
762 |
+
prs = np.zeros((nq, len(kappas)))
|
763 |
+
nempty = 0
|
764 |
+
|
765 |
+
for i in np.arange(nq):
|
766 |
+
qgnd = np.array(gnd[i]['ok'])
|
767 |
+
|
768 |
+
# no positive images, skip from the average
|
769 |
+
if qgnd.shape[0] == 0:
|
770 |
+
aps[i] = float('nan')
|
771 |
+
prs[i, :] = float('nan')
|
772 |
+
nempty += 1
|
773 |
+
continue
|
774 |
+
|
775 |
+
try:
|
776 |
+
qgndj = np.array(gnd[i]['junk'])
|
777 |
+
except:
|
778 |
+
qgndj = np.empty(0)
|
779 |
+
|
780 |
+
# sorted positions of positive and junk images (0 based)
|
781 |
+
pos = np.arange(ranks.shape[0])[np.in1d(ranks[:,i], qgnd)]
|
782 |
+
junk = np.arange(ranks.shape[0])[np.in1d(ranks[:,i], qgndj)]
|
783 |
+
|
784 |
+
k = 0;
|
785 |
+
ij = 0;
|
786 |
+
if len(junk):
|
787 |
+
# decrease positions of positives based on the number of
|
788 |
+
# junk images appearing before them
|
789 |
+
ip = 0
|
790 |
+
while (ip < len(pos)):
|
791 |
+
while (ij < len(junk) and pos[ip] > junk[ij]):
|
792 |
+
k += 1
|
793 |
+
ij += 1
|
794 |
+
pos[ip] = pos[ip] - k
|
795 |
+
ip += 1
|
796 |
+
|
797 |
+
# compute ap
|
798 |
+
ap = compute_ap(pos, len(qgnd))
|
799 |
+
map = map + ap
|
800 |
+
aps[i] = ap
|
801 |
+
|
802 |
+
# compute precision @ k
|
803 |
+
pos += 1 # get it to 1-based
|
804 |
+
for j in np.arange(len(kappas)):
|
805 |
+
kq = min(max(pos), kappas[j]);
|
806 |
+
prs[i, j] = (pos <= kq).sum() / kq
|
807 |
+
pr = pr + prs[i, :]
|
808 |
+
|
809 |
+
map = map / (nq - nempty)
|
810 |
+
pr = pr / (nq - nempty)
|
811 |
+
|
812 |
+
return map, aps, pr, prs
|
813 |
+
|
814 |
+
|
815 |
+
def multi_scale(samples, model):
|
816 |
+
v = None
|
817 |
+
for s in [1, 1/2**(1/2), 1/2]: # we use 3 different scales
|
818 |
+
if s == 1:
|
819 |
+
inp = samples.clone()
|
820 |
+
else:
|
821 |
+
inp = nn.functional.interpolate(samples, scale_factor=s, mode='bilinear', align_corners=False)
|
822 |
+
feats = model(inp).clone()
|
823 |
+
if v is None:
|
824 |
+
v = feats
|
825 |
+
else:
|
826 |
+
v += feats
|
827 |
+
v /= 3
|
828 |
+
v /= v.norm()
|
829 |
+
return v
|
PuzzleTuning/Counterpart PreTrain Methods/dino-main/video_generation.py
ADDED
@@ -0,0 +1,378 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
import os
|
15 |
+
import glob
|
16 |
+
import sys
|
17 |
+
import argparse
|
18 |
+
import cv2
|
19 |
+
|
20 |
+
from tqdm import tqdm
|
21 |
+
import matplotlib.pyplot as plt
|
22 |
+
import torch
|
23 |
+
import torch.nn as nn
|
24 |
+
import torchvision
|
25 |
+
from torchvision import transforms as pth_transforms
|
26 |
+
import numpy as np
|
27 |
+
from PIL import Image
|
28 |
+
|
29 |
+
import utils
|
30 |
+
import vision_transformer as vits
|
31 |
+
|
32 |
+
|
33 |
+
FOURCC = {
|
34 |
+
"mp4": cv2.VideoWriter_fourcc(*"MP4V"),
|
35 |
+
"avi": cv2.VideoWriter_fourcc(*"XVID"),
|
36 |
+
}
|
37 |
+
DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
38 |
+
|
39 |
+
|
40 |
+
class VideoGenerator:
|
41 |
+
def __init__(self, args):
|
42 |
+
self.args = args
|
43 |
+
# self.model = None
|
44 |
+
# Don't need to load model if you only want a video
|
45 |
+
if not self.args.video_only:
|
46 |
+
self.model = self.__load_model()
|
47 |
+
|
48 |
+
def run(self):
|
49 |
+
if self.args.input_path is None:
|
50 |
+
print(f"Provided input path {self.args.input_path} is non valid.")
|
51 |
+
sys.exit(1)
|
52 |
+
else:
|
53 |
+
if self.args.video_only:
|
54 |
+
self._generate_video_from_images(
|
55 |
+
self.args.input_path, self.args.output_path
|
56 |
+
)
|
57 |
+
else:
|
58 |
+
# If input path exists
|
59 |
+
if os.path.exists(self.args.input_path):
|
60 |
+
# If input is a video file
|
61 |
+
if os.path.isfile(self.args.input_path):
|
62 |
+
frames_folder = os.path.join(self.args.output_path, "frames")
|
63 |
+
attention_folder = os.path.join(
|
64 |
+
self.args.output_path, "attention"
|
65 |
+
)
|
66 |
+
|
67 |
+
os.makedirs(frames_folder, exist_ok=True)
|
68 |
+
os.makedirs(attention_folder, exist_ok=True)
|
69 |
+
|
70 |
+
self._extract_frames_from_video(
|
71 |
+
self.args.input_path, frames_folder
|
72 |
+
)
|
73 |
+
|
74 |
+
self._inference(
|
75 |
+
frames_folder,
|
76 |
+
attention_folder,
|
77 |
+
)
|
78 |
+
|
79 |
+
self._generate_video_from_images(
|
80 |
+
attention_folder, self.args.output_path
|
81 |
+
)
|
82 |
+
|
83 |
+
# If input is a folder of already extracted frames
|
84 |
+
if os.path.isdir(self.args.input_path):
|
85 |
+
attention_folder = os.path.join(
|
86 |
+
self.args.output_path, "attention"
|
87 |
+
)
|
88 |
+
|
89 |
+
os.makedirs(attention_folder, exist_ok=True)
|
90 |
+
|
91 |
+
self._inference(self.args.input_path, attention_folder)
|
92 |
+
|
93 |
+
self._generate_video_from_images(
|
94 |
+
attention_folder, self.args.output_path
|
95 |
+
)
|
96 |
+
|
97 |
+
# If input path doesn't exists
|
98 |
+
else:
|
99 |
+
print(f"Provided input path {self.args.input_path} doesn't exists.")
|
100 |
+
sys.exit(1)
|
101 |
+
|
102 |
+
def _extract_frames_from_video(self, inp: str, out: str):
|
103 |
+
vidcap = cv2.VideoCapture(inp)
|
104 |
+
self.args.fps = vidcap.get(cv2.CAP_PROP_FPS)
|
105 |
+
|
106 |
+
print(f"Video: {inp} ({self.args.fps} fps)")
|
107 |
+
print(f"Extracting frames to {out}")
|
108 |
+
|
109 |
+
success, image = vidcap.read()
|
110 |
+
count = 0
|
111 |
+
while success:
|
112 |
+
cv2.imwrite(
|
113 |
+
os.path.join(out, f"frame-{count:04}.jpg"),
|
114 |
+
image,
|
115 |
+
)
|
116 |
+
success, image = vidcap.read()
|
117 |
+
count += 1
|
118 |
+
|
119 |
+
def _generate_video_from_images(self, inp: str, out: str):
|
120 |
+
img_array = []
|
121 |
+
attention_images_list = sorted(glob.glob(os.path.join(inp, "attn-*.jpg")))
|
122 |
+
|
123 |
+
# Get size of the first image
|
124 |
+
with open(attention_images_list[0], "rb") as f:
|
125 |
+
img = Image.open(f)
|
126 |
+
img = img.convert("RGB")
|
127 |
+
size = (img.width, img.height)
|
128 |
+
img_array.append(cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR))
|
129 |
+
|
130 |
+
print(f"Generating video {size} to {out}")
|
131 |
+
|
132 |
+
for filename in tqdm(attention_images_list[1:]):
|
133 |
+
with open(filename, "rb") as f:
|
134 |
+
img = Image.open(f)
|
135 |
+
img = img.convert("RGB")
|
136 |
+
img_array.append(cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR))
|
137 |
+
|
138 |
+
out = cv2.VideoWriter(
|
139 |
+
os.path.join(out, "video." + self.args.video_format),
|
140 |
+
FOURCC[self.args.video_format],
|
141 |
+
self.args.fps,
|
142 |
+
size,
|
143 |
+
)
|
144 |
+
|
145 |
+
for i in range(len(img_array)):
|
146 |
+
out.write(img_array[i])
|
147 |
+
out.release()
|
148 |
+
print("Done")
|
149 |
+
|
150 |
+
def _inference(self, inp: str, out: str):
|
151 |
+
print(f"Generating attention images to {out}")
|
152 |
+
|
153 |
+
for img_path in tqdm(sorted(glob.glob(os.path.join(inp, "*.jpg")))):
|
154 |
+
with open(img_path, "rb") as f:
|
155 |
+
img = Image.open(f)
|
156 |
+
img = img.convert("RGB")
|
157 |
+
|
158 |
+
if self.args.resize is not None:
|
159 |
+
transform = pth_transforms.Compose(
|
160 |
+
[
|
161 |
+
pth_transforms.ToTensor(),
|
162 |
+
pth_transforms.Resize(self.args.resize),
|
163 |
+
pth_transforms.Normalize(
|
164 |
+
(0.485, 0.456, 0.406), (0.229, 0.224, 0.225)
|
165 |
+
),
|
166 |
+
]
|
167 |
+
)
|
168 |
+
else:
|
169 |
+
transform = pth_transforms.Compose(
|
170 |
+
[
|
171 |
+
pth_transforms.ToTensor(),
|
172 |
+
pth_transforms.Normalize(
|
173 |
+
(0.485, 0.456, 0.406), (0.229, 0.224, 0.225)
|
174 |
+
),
|
175 |
+
]
|
176 |
+
)
|
177 |
+
|
178 |
+
img = transform(img)
|
179 |
+
|
180 |
+
# make the image divisible by the patch size
|
181 |
+
w, h = (
|
182 |
+
img.shape[1] - img.shape[1] % self.args.patch_size,
|
183 |
+
img.shape[2] - img.shape[2] % self.args.patch_size,
|
184 |
+
)
|
185 |
+
img = img[:, :w, :h].unsqueeze(0)
|
186 |
+
|
187 |
+
w_featmap = img.shape[-2] // self.args.patch_size
|
188 |
+
h_featmap = img.shape[-1] // self.args.patch_size
|
189 |
+
|
190 |
+
attentions = self.model.get_last_selfattention(img.to(DEVICE))
|
191 |
+
|
192 |
+
nh = attentions.shape[1] # number of head
|
193 |
+
|
194 |
+
# we keep only the output patch attention
|
195 |
+
attentions = attentions[0, :, 0, 1:].reshape(nh, -1)
|
196 |
+
|
197 |
+
# we keep only a certain percentage of the mass
|
198 |
+
val, idx = torch.sort(attentions)
|
199 |
+
val /= torch.sum(val, dim=1, keepdim=True)
|
200 |
+
cumval = torch.cumsum(val, dim=1)
|
201 |
+
th_attn = cumval > (1 - self.args.threshold)
|
202 |
+
idx2 = torch.argsort(idx)
|
203 |
+
for head in range(nh):
|
204 |
+
th_attn[head] = th_attn[head][idx2[head]]
|
205 |
+
th_attn = th_attn.reshape(nh, w_featmap, h_featmap).float()
|
206 |
+
# interpolate
|
207 |
+
th_attn = (
|
208 |
+
nn.functional.interpolate(
|
209 |
+
th_attn.unsqueeze(0),
|
210 |
+
scale_factor=self.args.patch_size,
|
211 |
+
mode="nearest",
|
212 |
+
)[0]
|
213 |
+
.cpu()
|
214 |
+
.numpy()
|
215 |
+
)
|
216 |
+
|
217 |
+
attentions = attentions.reshape(nh, w_featmap, h_featmap)
|
218 |
+
attentions = (
|
219 |
+
nn.functional.interpolate(
|
220 |
+
attentions.unsqueeze(0),
|
221 |
+
scale_factor=self.args.patch_size,
|
222 |
+
mode="nearest",
|
223 |
+
)[0]
|
224 |
+
.cpu()
|
225 |
+
.numpy()
|
226 |
+
)
|
227 |
+
|
228 |
+
# save attentions heatmaps
|
229 |
+
fname = os.path.join(out, "attn-" + os.path.basename(img_path))
|
230 |
+
plt.imsave(
|
231 |
+
fname=fname,
|
232 |
+
arr=sum(
|
233 |
+
attentions[i] * 1 / attentions.shape[0]
|
234 |
+
for i in range(attentions.shape[0])
|
235 |
+
),
|
236 |
+
cmap="inferno",
|
237 |
+
format="jpg",
|
238 |
+
)
|
239 |
+
|
240 |
+
def __load_model(self):
|
241 |
+
# build model
|
242 |
+
model = vits.__dict__[self.args.arch](
|
243 |
+
patch_size=self.args.patch_size, num_classes=0
|
244 |
+
)
|
245 |
+
for p in model.parameters():
|
246 |
+
p.requires_grad = False
|
247 |
+
model.eval()
|
248 |
+
model.to(DEVICE)
|
249 |
+
|
250 |
+
if os.path.isfile(self.args.pretrained_weights):
|
251 |
+
state_dict = torch.load(self.args.pretrained_weights, map_location="cpu")
|
252 |
+
if (
|
253 |
+
self.args.checkpoint_key is not None
|
254 |
+
and self.args.checkpoint_key in state_dict
|
255 |
+
):
|
256 |
+
print(
|
257 |
+
f"Take key {self.args.checkpoint_key} in provided checkpoint dict"
|
258 |
+
)
|
259 |
+
state_dict = state_dict[self.args.checkpoint_key]
|
260 |
+
state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
|
261 |
+
# remove `backbone.` prefix induced by multicrop wrapper
|
262 |
+
state_dict = {k.replace("backbone.", ""): v for k, v in state_dict.items()}
|
263 |
+
msg = model.load_state_dict(state_dict, strict=False)
|
264 |
+
print(
|
265 |
+
"Pretrained weights found at {} and loaded with msg: {}".format(
|
266 |
+
self.args.pretrained_weights, msg
|
267 |
+
)
|
268 |
+
)
|
269 |
+
else:
|
270 |
+
print(
|
271 |
+
"Please use the `--pretrained_weights` argument to indicate the path of the checkpoint to evaluate."
|
272 |
+
)
|
273 |
+
url = None
|
274 |
+
if self.args.arch == "vit_small" and self.args.patch_size == 16:
|
275 |
+
url = "dino_deitsmall16_pretrain/dino_deitsmall16_pretrain.pth"
|
276 |
+
elif self.args.arch == "vit_small" and self.args.patch_size == 8:
|
277 |
+
url = "dino_deitsmall8_300ep_pretrain/dino_deitsmall8_300ep_pretrain.pth" # model used for visualizations in our paper
|
278 |
+
elif self.args.arch == "vit_base" and self.args.patch_size == 16:
|
279 |
+
url = "dino_vitbase16_pretrain/dino_vitbase16_pretrain.pth"
|
280 |
+
elif self.args.arch == "vit_base" and self.args.patch_size == 8:
|
281 |
+
url = "dino_vitbase8_pretrain/dino_vitbase8_pretrain.pth"
|
282 |
+
if url is not None:
|
283 |
+
print(
|
284 |
+
"Since no pretrained weights have been provided, we load the reference pretrained DINO weights."
|
285 |
+
)
|
286 |
+
state_dict = torch.hub.load_state_dict_from_url(
|
287 |
+
url="https://dl.fbaipublicfiles.com/dino/" + url
|
288 |
+
)
|
289 |
+
model.load_state_dict(state_dict, strict=True)
|
290 |
+
else:
|
291 |
+
print(
|
292 |
+
"There is no reference weights available for this model => We use random weights."
|
293 |
+
)
|
294 |
+
return model
|
295 |
+
|
296 |
+
|
297 |
+
def parse_args():
|
298 |
+
parser = argparse.ArgumentParser("Generation self-attention video")
|
299 |
+
parser.add_argument(
|
300 |
+
"--arch",
|
301 |
+
default="vit_small",
|
302 |
+
type=str,
|
303 |
+
choices=["vit_tiny", "vit_small", "vit_base"],
|
304 |
+
help="Architecture (support only ViT atm).",
|
305 |
+
)
|
306 |
+
parser.add_argument(
|
307 |
+
"--patch_size", default=8, type=int, help="Patch resolution of the self.model."
|
308 |
+
)
|
309 |
+
parser.add_argument(
|
310 |
+
"--pretrained_weights",
|
311 |
+
default="",
|
312 |
+
type=str,
|
313 |
+
help="Path to pretrained weights to load.",
|
314 |
+
)
|
315 |
+
parser.add_argument(
|
316 |
+
"--checkpoint_key",
|
317 |
+
default="teacher",
|
318 |
+
type=str,
|
319 |
+
help='Key to use in the checkpoint (example: "teacher")',
|
320 |
+
)
|
321 |
+
parser.add_argument(
|
322 |
+
"--input_path",
|
323 |
+
required=True,
|
324 |
+
type=str,
|
325 |
+
help="""Path to a video file if you want to extract frames
|
326 |
+
or to a folder of images already extracted by yourself.
|
327 |
+
or to a folder of attention images.""",
|
328 |
+
)
|
329 |
+
parser.add_argument(
|
330 |
+
"--output_path",
|
331 |
+
default="./",
|
332 |
+
type=str,
|
333 |
+
help="""Path to store a folder of frames and / or a folder of attention images.
|
334 |
+
and / or a final video. Default to current directory.""",
|
335 |
+
)
|
336 |
+
parser.add_argument(
|
337 |
+
"--threshold",
|
338 |
+
type=float,
|
339 |
+
default=0.6,
|
340 |
+
help="""We visualize masks
|
341 |
+
obtained by thresholding the self-attention maps to keep xx percent of the mass.""",
|
342 |
+
)
|
343 |
+
parser.add_argument(
|
344 |
+
"--resize",
|
345 |
+
default=None,
|
346 |
+
type=int,
|
347 |
+
nargs="+",
|
348 |
+
help="""Apply a resize transformation to input image(s). Use if OOM error.
|
349 |
+
Usage (single or W H): --resize 512, --resize 720 1280""",
|
350 |
+
)
|
351 |
+
parser.add_argument(
|
352 |
+
"--video_only",
|
353 |
+
action="store_true",
|
354 |
+
help="""Use this flag if you only want to generate a video and not all attention images.
|
355 |
+
If used, --input_path must be set to the folder of attention images. Ex: ./attention/""",
|
356 |
+
)
|
357 |
+
parser.add_argument(
|
358 |
+
"--fps",
|
359 |
+
default=30.0,
|
360 |
+
type=float,
|
361 |
+
help="FPS of input / output video. Automatically set if you extract frames from a video.",
|
362 |
+
)
|
363 |
+
parser.add_argument(
|
364 |
+
"--video_format",
|
365 |
+
default="mp4",
|
366 |
+
type=str,
|
367 |
+
choices=["mp4", "avi"],
|
368 |
+
help="Format of generated video (mp4 or avi).",
|
369 |
+
)
|
370 |
+
|
371 |
+
return parser.parse_args()
|
372 |
+
|
373 |
+
|
374 |
+
if __name__ == "__main__":
|
375 |
+
args = parse_args()
|
376 |
+
|
377 |
+
vg = VideoGenerator(args)
|
378 |
+
vg.run()
|
PuzzleTuning/Counterpart PreTrain Methods/dino-main/vision_transformer.py
ADDED
@@ -0,0 +1,291 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
"""
|
15 |
+
Mostly copy-paste from timm library.
|
16 |
+
https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
|
17 |
+
"""
|
18 |
+
import math
|
19 |
+
from functools import partial
|
20 |
+
|
21 |
+
import torch
|
22 |
+
import torch.nn as nn
|
23 |
+
|
24 |
+
from utils import trunc_normal_
|
25 |
+
|
26 |
+
|
27 |
+
def drop_path(x, drop_prob: float = 0., training: bool = False):
|
28 |
+
if drop_prob == 0. or not training:
|
29 |
+
return x
|
30 |
+
keep_prob = 1 - drop_prob
|
31 |
+
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
|
32 |
+
random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
|
33 |
+
random_tensor.floor_() # binarize
|
34 |
+
output = x.div(keep_prob) * random_tensor
|
35 |
+
return output
|
36 |
+
|
37 |
+
|
38 |
+
class DropPath(nn.Module):
|
39 |
+
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
40 |
+
"""
|
41 |
+
def __init__(self, drop_prob=None):
|
42 |
+
super(DropPath, self).__init__()
|
43 |
+
self.drop_prob = drop_prob
|
44 |
+
|
45 |
+
def forward(self, x):
|
46 |
+
return drop_path(x, self.drop_prob, self.training)
|
47 |
+
|
48 |
+
|
49 |
+
class Mlp(nn.Module):
|
50 |
+
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
|
51 |
+
super().__init__()
|
52 |
+
out_features = out_features or in_features
|
53 |
+
hidden_features = hidden_features or in_features
|
54 |
+
self.fc1 = nn.Linear(in_features, hidden_features)
|
55 |
+
self.act = act_layer()
|
56 |
+
self.fc2 = nn.Linear(hidden_features, out_features)
|
57 |
+
self.drop = nn.Dropout(drop)
|
58 |
+
|
59 |
+
def forward(self, x):
|
60 |
+
x = self.fc1(x)
|
61 |
+
x = self.act(x)
|
62 |
+
x = self.drop(x)
|
63 |
+
x = self.fc2(x)
|
64 |
+
x = self.drop(x)
|
65 |
+
return x
|
66 |
+
|
67 |
+
|
68 |
+
class Attention(nn.Module):
|
69 |
+
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
|
70 |
+
super().__init__()
|
71 |
+
self.num_heads = num_heads
|
72 |
+
head_dim = dim // num_heads
|
73 |
+
self.scale = qk_scale or head_dim ** -0.5
|
74 |
+
|
75 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
76 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
77 |
+
self.proj = nn.Linear(dim, dim)
|
78 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
79 |
+
|
80 |
+
def forward(self, x):
|
81 |
+
B, N, C = x.shape
|
82 |
+
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
83 |
+
q, k, v = qkv[0], qkv[1], qkv[2]
|
84 |
+
|
85 |
+
attn = (q @ k.transpose(-2, -1)) * self.scale
|
86 |
+
attn = attn.softmax(dim=-1)
|
87 |
+
attn = self.attn_drop(attn)
|
88 |
+
|
89 |
+
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
90 |
+
x = self.proj(x)
|
91 |
+
x = self.proj_drop(x)
|
92 |
+
return x, attn
|
93 |
+
|
94 |
+
|
95 |
+
class Block(nn.Module):
|
96 |
+
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
|
97 |
+
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
|
98 |
+
super().__init__()
|
99 |
+
self.norm1 = norm_layer(dim)
|
100 |
+
self.attn = Attention(
|
101 |
+
dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
|
102 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
103 |
+
self.norm2 = norm_layer(dim)
|
104 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
105 |
+
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
|
106 |
+
|
107 |
+
def forward(self, x, return_attention=False):
|
108 |
+
y, attn = self.attn(self.norm1(x))
|
109 |
+
if return_attention:
|
110 |
+
return attn
|
111 |
+
x = x + self.drop_path(y)
|
112 |
+
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
113 |
+
return x
|
114 |
+
|
115 |
+
|
116 |
+
class PatchEmbed(nn.Module):
|
117 |
+
""" Image to Patch Embedding
|
118 |
+
"""
|
119 |
+
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
|
120 |
+
super().__init__()
|
121 |
+
num_patches = (img_size // patch_size) * (img_size // patch_size)
|
122 |
+
self.img_size = img_size
|
123 |
+
self.patch_size = patch_size
|
124 |
+
self.num_patches = num_patches
|
125 |
+
|
126 |
+
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
|
127 |
+
|
128 |
+
def forward(self, x):
|
129 |
+
B, C, H, W = x.shape
|
130 |
+
x = self.proj(x).flatten(2).transpose(1, 2)
|
131 |
+
return x
|
132 |
+
|
133 |
+
|
134 |
+
class VisionTransformer(nn.Module):
|
135 |
+
""" Vision Transformer """
|
136 |
+
def __init__(self, img_size=[224], patch_size=16, in_chans=3, num_classes=0, embed_dim=768, depth=12,
|
137 |
+
num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,
|
138 |
+
drop_path_rate=0., norm_layer=nn.LayerNorm, **kwargs):
|
139 |
+
super().__init__()
|
140 |
+
self.num_features = self.embed_dim = embed_dim
|
141 |
+
|
142 |
+
self.patch_embed = PatchEmbed(
|
143 |
+
img_size=img_size[0], patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
|
144 |
+
num_patches = self.patch_embed.num_patches
|
145 |
+
|
146 |
+
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
147 |
+
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
|
148 |
+
self.pos_drop = nn.Dropout(p=drop_rate)
|
149 |
+
|
150 |
+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
|
151 |
+
self.blocks = nn.ModuleList([
|
152 |
+
Block(
|
153 |
+
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
|
154 |
+
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer)
|
155 |
+
for i in range(depth)])
|
156 |
+
self.norm = norm_layer(embed_dim)
|
157 |
+
|
158 |
+
# Classifier head
|
159 |
+
self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
160 |
+
|
161 |
+
trunc_normal_(self.pos_embed, std=.02)
|
162 |
+
trunc_normal_(self.cls_token, std=.02)
|
163 |
+
self.apply(self._init_weights)
|
164 |
+
|
165 |
+
def _init_weights(self, m):
|
166 |
+
if isinstance(m, nn.Linear):
|
167 |
+
trunc_normal_(m.weight, std=.02)
|
168 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
169 |
+
nn.init.constant_(m.bias, 0)
|
170 |
+
elif isinstance(m, nn.LayerNorm):
|
171 |
+
nn.init.constant_(m.bias, 0)
|
172 |
+
nn.init.constant_(m.weight, 1.0)
|
173 |
+
|
174 |
+
def interpolate_pos_encoding(self, x, w, h):
|
175 |
+
npatch = x.shape[1] - 1
|
176 |
+
N = self.pos_embed.shape[1] - 1
|
177 |
+
if npatch == N and w == h:
|
178 |
+
return self.pos_embed
|
179 |
+
class_pos_embed = self.pos_embed[:, 0]
|
180 |
+
patch_pos_embed = self.pos_embed[:, 1:]
|
181 |
+
dim = x.shape[-1]
|
182 |
+
w0 = w // self.patch_embed.patch_size
|
183 |
+
h0 = h // self.patch_embed.patch_size
|
184 |
+
# we add a small number to avoid floating point error in the interpolation
|
185 |
+
# see discussion at https://github.com/facebookresearch/dino/issues/8
|
186 |
+
w0, h0 = w0 + 0.1, h0 + 0.1
|
187 |
+
patch_pos_embed = nn.functional.interpolate(
|
188 |
+
patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2),
|
189 |
+
scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)),
|
190 |
+
mode='bicubic',
|
191 |
+
)
|
192 |
+
assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1]
|
193 |
+
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
|
194 |
+
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
|
195 |
+
|
196 |
+
def prepare_tokens(self, x):
|
197 |
+
B, nc, w, h = x.shape
|
198 |
+
x = self.patch_embed(x) # patch linear embedding
|
199 |
+
|
200 |
+
# add the [CLS] token to the embed patch tokens
|
201 |
+
cls_tokens = self.cls_token.expand(B, -1, -1)
|
202 |
+
x = torch.cat((cls_tokens, x), dim=1)
|
203 |
+
|
204 |
+
# add positional encoding to each token
|
205 |
+
x = x + self.interpolate_pos_encoding(x, w, h)
|
206 |
+
|
207 |
+
return self.pos_drop(x)
|
208 |
+
|
209 |
+
def forward(self, x):
|
210 |
+
x = self.prepare_tokens(x)
|
211 |
+
for blk in self.blocks:
|
212 |
+
x = blk(x)
|
213 |
+
x = self.norm(x)
|
214 |
+
return x[:, 0]
|
215 |
+
|
216 |
+
def get_last_selfattention(self, x):
|
217 |
+
x = self.prepare_tokens(x)
|
218 |
+
for i, blk in enumerate(self.blocks):
|
219 |
+
if i < len(self.blocks) - 1:
|
220 |
+
x = blk(x)
|
221 |
+
else:
|
222 |
+
# return attention of the last block
|
223 |
+
return blk(x, return_attention=True)
|
224 |
+
|
225 |
+
def get_intermediate_layers(self, x, n=1):
|
226 |
+
x = self.prepare_tokens(x)
|
227 |
+
# we return the output tokens from the `n` last blocks
|
228 |
+
output = []
|
229 |
+
for i, blk in enumerate(self.blocks):
|
230 |
+
x = blk(x)
|
231 |
+
if len(self.blocks) - i <= n:
|
232 |
+
output.append(self.norm(x))
|
233 |
+
return output
|
234 |
+
|
235 |
+
|
236 |
+
def vit_tiny(patch_size=16, **kwargs):
|
237 |
+
model = VisionTransformer(
|
238 |
+
patch_size=patch_size, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4,
|
239 |
+
qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
240 |
+
return model
|
241 |
+
|
242 |
+
|
243 |
+
def vit_small(patch_size=16, **kwargs):
|
244 |
+
model = VisionTransformer(
|
245 |
+
patch_size=patch_size, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4,
|
246 |
+
qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
247 |
+
return model
|
248 |
+
|
249 |
+
|
250 |
+
def vit_base(patch_size=16, **kwargs):
|
251 |
+
model = VisionTransformer(
|
252 |
+
patch_size=patch_size, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4,
|
253 |
+
qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
254 |
+
return model
|
255 |
+
|
256 |
+
|
257 |
+
class DINOHead(nn.Module):
|
258 |
+
def __init__(self, in_dim, out_dim, use_bn=False, norm_last_layer=True, nlayers=3, hidden_dim=2048, bottleneck_dim=256):
|
259 |
+
super().__init__()
|
260 |
+
nlayers = max(nlayers, 1)
|
261 |
+
if nlayers == 1:
|
262 |
+
self.mlp = nn.Linear(in_dim, bottleneck_dim)
|
263 |
+
else:
|
264 |
+
layers = [nn.Linear(in_dim, hidden_dim)]
|
265 |
+
if use_bn:
|
266 |
+
layers.append(nn.BatchNorm1d(hidden_dim))
|
267 |
+
layers.append(nn.GELU())
|
268 |
+
for _ in range(nlayers - 2):
|
269 |
+
layers.append(nn.Linear(hidden_dim, hidden_dim))
|
270 |
+
if use_bn:
|
271 |
+
layers.append(nn.BatchNorm1d(hidden_dim))
|
272 |
+
layers.append(nn.GELU())
|
273 |
+
layers.append(nn.Linear(hidden_dim, bottleneck_dim))
|
274 |
+
self.mlp = nn.Sequential(*layers)
|
275 |
+
self.apply(self._init_weights)
|
276 |
+
self.last_layer = nn.utils.weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False))
|
277 |
+
self.last_layer.weight_g.data.fill_(1)
|
278 |
+
if norm_last_layer:
|
279 |
+
self.last_layer.weight_g.requires_grad = False
|
280 |
+
|
281 |
+
def _init_weights(self, m):
|
282 |
+
if isinstance(m, nn.Linear):
|
283 |
+
trunc_normal_(m.weight, std=.02)
|
284 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
285 |
+
nn.init.constant_(m.bias, 0)
|
286 |
+
|
287 |
+
def forward(self, x):
|
288 |
+
x = self.mlp(x)
|
289 |
+
x = nn.functional.normalize(x, dim=-1, p=2)
|
290 |
+
x = self.last_layer(x)
|
291 |
+
return x
|
PuzzleTuning/Counterpart PreTrain Methods/dino-main/visualize_attention.py
ADDED
@@ -0,0 +1,213 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
import os
|
15 |
+
import sys
|
16 |
+
import argparse
|
17 |
+
import cv2
|
18 |
+
import random
|
19 |
+
import colorsys
|
20 |
+
import requests
|
21 |
+
from io import BytesIO
|
22 |
+
|
23 |
+
import skimage.io
|
24 |
+
from skimage.measure import find_contours
|
25 |
+
import matplotlib.pyplot as plt
|
26 |
+
from matplotlib.patches import Polygon
|
27 |
+
import torch
|
28 |
+
import torch.nn as nn
|
29 |
+
import torchvision
|
30 |
+
from torchvision import transforms as pth_transforms
|
31 |
+
import numpy as np
|
32 |
+
from PIL import Image
|
33 |
+
|
34 |
+
import utils
|
35 |
+
import vision_transformer as vits
|
36 |
+
|
37 |
+
|
38 |
+
def apply_mask(image, mask, color, alpha=0.5):
|
39 |
+
for c in range(3):
|
40 |
+
image[:, :, c] = image[:, :, c] * (1 - alpha * mask) + alpha * mask * color[c] * 255
|
41 |
+
return image
|
42 |
+
|
43 |
+
|
44 |
+
def random_colors(N, bright=True):
|
45 |
+
"""
|
46 |
+
Generate random colors.
|
47 |
+
"""
|
48 |
+
brightness = 1.0 if bright else 0.7
|
49 |
+
hsv = [(i / N, 1, brightness) for i in range(N)]
|
50 |
+
colors = list(map(lambda c: colorsys.hsv_to_rgb(*c), hsv))
|
51 |
+
random.shuffle(colors)
|
52 |
+
return colors
|
53 |
+
|
54 |
+
|
55 |
+
def display_instances(image, mask, fname="test", figsize=(5, 5), blur=False, contour=True, alpha=0.5):
|
56 |
+
fig = plt.figure(figsize=figsize, frameon=False)
|
57 |
+
ax = plt.Axes(fig, [0., 0., 1., 1.])
|
58 |
+
ax.set_axis_off()
|
59 |
+
fig.add_axes(ax)
|
60 |
+
ax = plt.gca()
|
61 |
+
|
62 |
+
N = 1
|
63 |
+
mask = mask[None, :, :]
|
64 |
+
# Generate random colors
|
65 |
+
colors = random_colors(N)
|
66 |
+
|
67 |
+
# Show area outside image boundaries.
|
68 |
+
height, width = image.shape[:2]
|
69 |
+
margin = 0
|
70 |
+
ax.set_ylim(height + margin, -margin)
|
71 |
+
ax.set_xlim(-margin, width + margin)
|
72 |
+
ax.axis('off')
|
73 |
+
masked_image = image.astype(np.uint32).copy()
|
74 |
+
for i in range(N):
|
75 |
+
color = colors[i]
|
76 |
+
_mask = mask[i]
|
77 |
+
if blur:
|
78 |
+
_mask = cv2.blur(_mask,(10,10))
|
79 |
+
# Mask
|
80 |
+
masked_image = apply_mask(masked_image, _mask, color, alpha)
|
81 |
+
# Mask Polygon
|
82 |
+
# Pad to ensure proper polygons for masks that touch image edges.
|
83 |
+
if contour:
|
84 |
+
padded_mask = np.zeros((_mask.shape[0] + 2, _mask.shape[1] + 2))
|
85 |
+
padded_mask[1:-1, 1:-1] = _mask
|
86 |
+
contours = find_contours(padded_mask, 0.5)
|
87 |
+
for verts in contours:
|
88 |
+
# Subtract the padding and flip (y, x) to (x, y)
|
89 |
+
verts = np.fliplr(verts) - 1
|
90 |
+
p = Polygon(verts, facecolor="none", edgecolor=color)
|
91 |
+
ax.add_patch(p)
|
92 |
+
ax.imshow(masked_image.astype(np.uint8), aspect='auto')
|
93 |
+
fig.savefig(fname)
|
94 |
+
print(f"{fname} saved.")
|
95 |
+
return
|
96 |
+
|
97 |
+
|
98 |
+
if __name__ == '__main__':
|
99 |
+
parser = argparse.ArgumentParser('Visualize Self-Attention maps')
|
100 |
+
parser.add_argument('--arch', default='vit_small', type=str,
|
101 |
+
choices=['vit_tiny', 'vit_small', 'vit_base'], help='Architecture (support only ViT atm).')
|
102 |
+
parser.add_argument('--patch_size', default=8, type=int, help='Patch resolution of the model.')
|
103 |
+
parser.add_argument('--pretrained_weights', default='', type=str,
|
104 |
+
help="Path to pretrained weights to load.")
|
105 |
+
parser.add_argument("--checkpoint_key", default="teacher", type=str,
|
106 |
+
help='Key to use in the checkpoint (example: "teacher")')
|
107 |
+
parser.add_argument("--image_path", default=None, type=str, help="Path of the image to load.")
|
108 |
+
parser.add_argument("--image_size", default=(480, 480), type=int, nargs="+", help="Resize image.")
|
109 |
+
parser.add_argument('--output_dir', default='.', help='Path where to save visualizations.')
|
110 |
+
parser.add_argument("--threshold", type=float, default=None, help="""We visualize masks
|
111 |
+
obtained by thresholding the self-attention maps to keep xx% of the mass.""")
|
112 |
+
args = parser.parse_args()
|
113 |
+
|
114 |
+
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
115 |
+
# build model
|
116 |
+
model = vits.__dict__[args.arch](patch_size=args.patch_size, num_classes=0)
|
117 |
+
for p in model.parameters():
|
118 |
+
p.requires_grad = False
|
119 |
+
model.eval()
|
120 |
+
model.to(device)
|
121 |
+
if os.path.isfile(args.pretrained_weights):
|
122 |
+
state_dict = torch.load(args.pretrained_weights, map_location="cpu")
|
123 |
+
if args.checkpoint_key is not None and args.checkpoint_key in state_dict:
|
124 |
+
print(f"Take key {args.checkpoint_key} in provided checkpoint dict")
|
125 |
+
state_dict = state_dict[args.checkpoint_key]
|
126 |
+
# remove `module.` prefix
|
127 |
+
state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
|
128 |
+
# remove `backbone.` prefix induced by multicrop wrapper
|
129 |
+
state_dict = {k.replace("backbone.", ""): v for k, v in state_dict.items()}
|
130 |
+
msg = model.load_state_dict(state_dict, strict=False)
|
131 |
+
print('Pretrained weights found at {} and loaded with msg: {}'.format(args.pretrained_weights, msg))
|
132 |
+
else:
|
133 |
+
print("Please use the `--pretrained_weights` argument to indicate the path of the checkpoint to evaluate.")
|
134 |
+
url = None
|
135 |
+
if args.arch == "vit_small" and args.patch_size == 16:
|
136 |
+
url = "dino_deitsmall16_pretrain/dino_deitsmall16_pretrain.pth"
|
137 |
+
elif args.arch == "vit_small" and args.patch_size == 8:
|
138 |
+
url = "dino_deitsmall8_300ep_pretrain/dino_deitsmall8_300ep_pretrain.pth" # model used for visualizations in our paper
|
139 |
+
elif args.arch == "vit_base" and args.patch_size == 16:
|
140 |
+
url = "dino_vitbase16_pretrain/dino_vitbase16_pretrain.pth"
|
141 |
+
elif args.arch == "vit_base" and args.patch_size == 8:
|
142 |
+
url = "dino_vitbase8_pretrain/dino_vitbase8_pretrain.pth"
|
143 |
+
if url is not None:
|
144 |
+
print("Since no pretrained weights have been provided, we load the reference pretrained DINO weights.")
|
145 |
+
state_dict = torch.hub.load_state_dict_from_url(url="https://dl.fbaipublicfiles.com/dino/" + url)
|
146 |
+
model.load_state_dict(state_dict, strict=True)
|
147 |
+
else:
|
148 |
+
print("There is no reference weights available for this model => We use random weights.")
|
149 |
+
|
150 |
+
# open image
|
151 |
+
if args.image_path is None:
|
152 |
+
# user has not specified any image - we use our own image
|
153 |
+
print("Please use the `--image_path` argument to indicate the path of the image you wish to visualize.")
|
154 |
+
print("Since no image path have been provided, we take the first image in our paper.")
|
155 |
+
response = requests.get("https://dl.fbaipublicfiles.com/dino/img.png")
|
156 |
+
img = Image.open(BytesIO(response.content))
|
157 |
+
img = img.convert('RGB')
|
158 |
+
elif os.path.isfile(args.image_path):
|
159 |
+
with open(args.image_path, 'rb') as f:
|
160 |
+
img = Image.open(f)
|
161 |
+
img = img.convert('RGB')
|
162 |
+
else:
|
163 |
+
print(f"Provided image path {args.image_path} is non valid.")
|
164 |
+
sys.exit(1)
|
165 |
+
transform = pth_transforms.Compose([
|
166 |
+
pth_transforms.Resize(args.image_size),
|
167 |
+
pth_transforms.ToTensor(),
|
168 |
+
pth_transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
|
169 |
+
])
|
170 |
+
img = transform(img)
|
171 |
+
|
172 |
+
# make the image divisible by the patch size
|
173 |
+
w, h = img.shape[1] - img.shape[1] % args.patch_size, img.shape[2] - img.shape[2] % args.patch_size
|
174 |
+
img = img[:, :w, :h].unsqueeze(0)
|
175 |
+
|
176 |
+
w_featmap = img.shape[-2] // args.patch_size
|
177 |
+
h_featmap = img.shape[-1] // args.patch_size
|
178 |
+
|
179 |
+
attentions = model.get_last_selfattention(img.to(device))
|
180 |
+
|
181 |
+
nh = attentions.shape[1] # number of head
|
182 |
+
|
183 |
+
# we keep only the output patch attention
|
184 |
+
attentions = attentions[0, :, 0, 1:].reshape(nh, -1)
|
185 |
+
|
186 |
+
if args.threshold is not None:
|
187 |
+
# we keep only a certain percentage of the mass
|
188 |
+
val, idx = torch.sort(attentions)
|
189 |
+
val /= torch.sum(val, dim=1, keepdim=True)
|
190 |
+
cumval = torch.cumsum(val, dim=1)
|
191 |
+
th_attn = cumval > (1 - args.threshold)
|
192 |
+
idx2 = torch.argsort(idx)
|
193 |
+
for head in range(nh):
|
194 |
+
th_attn[head] = th_attn[head][idx2[head]]
|
195 |
+
th_attn = th_attn.reshape(nh, w_featmap, h_featmap).float()
|
196 |
+
# interpolate
|
197 |
+
th_attn = nn.functional.interpolate(th_attn.unsqueeze(0), scale_factor=args.patch_size, mode="nearest")[0].cpu().numpy()
|
198 |
+
|
199 |
+
attentions = attentions.reshape(nh, w_featmap, h_featmap)
|
200 |
+
attentions = nn.functional.interpolate(attentions.unsqueeze(0), scale_factor=args.patch_size, mode="nearest")[0].cpu().numpy()
|
201 |
+
|
202 |
+
# save attentions heatmaps
|
203 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
204 |
+
torchvision.utils.save_image(torchvision.utils.make_grid(img, normalize=True, scale_each=True), os.path.join(args.output_dir, "img.png"))
|
205 |
+
for j in range(nh):
|
206 |
+
fname = os.path.join(args.output_dir, "attn-head" + str(j) + ".png")
|
207 |
+
plt.imsave(fname=fname, arr=attentions[j], format='png')
|
208 |
+
print(f"{fname} saved.")
|
209 |
+
|
210 |
+
if args.threshold is not None:
|
211 |
+
image = skimage.io.imread(os.path.join(args.output_dir, "img.png"))
|
212 |
+
for j in range(nh):
|
213 |
+
display_instances(image, th_attn[j], fname=os.path.join(args.output_dir, "mask_th" + str(args.threshold) + "_head" + str(j) +".png"), blur=False)
|
PuzzleTuning/Counterpart PreTrain Methods/gcmae/FINETUNE.md
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
## Fine tune GCMAE
|
2 |
+
|
3 |
+
```
|
4 |
+
python main_finetune.py \
|
5 |
+
--data_path path/to/data \
|
6 |
+
--nb_classes 9 \
|
7 |
+
--output_dir path/to/ouput/dir \
|
8 |
+
--log_dir path/to/log/dir \
|
9 |
+
--batch_size 128 \
|
10 |
+
--model vit_base_patch16 \
|
11 |
+
--epochs 50 \
|
12 |
+
--finetune path/to/pth/path \
|
13 |
+
```
|
PuzzleTuning/Counterpart PreTrain Methods/gcmae/GCMAE.png
ADDED
PuzzleTuning/Counterpart PreTrain Methods/gcmae/LICENSE
ADDED
@@ -0,0 +1,399 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Attribution-NonCommercial 4.0 International
|
2 |
+
|
3 |
+
=======================================================================
|
4 |
+
|
5 |
+
Creative Commons Corporation ("Creative Commons") is not a law firm and
|
6 |
+
does not provide legal services or legal advice. Distribution of
|
7 |
+
Creative Commons public licenses does not create a lawyer-client or
|
8 |
+
other relationship. Creative Commons makes its licenses and related
|
9 |
+
information available on an "as-is" basis. Creative Commons gives no
|
10 |
+
warranties regarding its licenses, any material licensed under their
|
11 |
+
terms and conditions, or any related information. Creative Commons
|
12 |
+
disclaims all liability for damages resulting from their use to the
|
13 |
+
fullest extent possible.
|
14 |
+
|
15 |
+
Using Creative Commons Public Licenses
|
16 |
+
|
17 |
+
Creative Commons public licenses provide a standard set of terms and
|
18 |
+
conditions that creators and other rights holders may use to share
|
19 |
+
original works of authorship and other material subject to copyright
|
20 |
+
and certain other rights specified in the public license below. The
|
21 |
+
following considerations are for informational purposes only, are not
|
22 |
+
exhaustive, and do not form part of our licenses.
|
23 |
+
|
24 |
+
Considerations for licensors: Our public licenses are
|
25 |
+
intended for use by those authorized to give the public
|
26 |
+
permission to use material in ways otherwise restricted by
|
27 |
+
copyright and certain other rights. Our licenses are
|
28 |
+
irrevocable. Licensors should read and understand the terms
|
29 |
+
and conditions of the license they choose before applying it.
|
30 |
+
Licensors should also secure all rights necessary before
|
31 |
+
applying our licenses so that the public can reuse the
|
32 |
+
material as expected. Licensors should clearly mark any
|
33 |
+
material not subject to the license. This includes other CC-
|
34 |
+
licensed material, or material used under an exception or
|
35 |
+
limitation to copyright. More considerations for licensors:
|
36 |
+
wiki.creativecommons.org/Considerations_for_licensors
|
37 |
+
|
38 |
+
Considerations for the public: By using one of our public
|
39 |
+
licenses, a licensor grants the public permission to use the
|
40 |
+
licensed material under specified terms and conditions. If
|
41 |
+
the licensor's permission is not necessary for any reason--for
|
42 |
+
example, because of any applicable exception or limitation to
|
43 |
+
copyright--then that use is not regulated by the license. Our
|
44 |
+
licenses grant only permissions under copyright and certain
|
45 |
+
other rights that a licensor has authority to grant. Use of
|
46 |
+
the licensed material may still be restricted for other
|
47 |
+
reasons, including because others have copyright or other
|
48 |
+
rights in the material. A licensor may make special requests,
|
49 |
+
such as asking that all changes be marked or described.
|
50 |
+
Although not required by our licenses, you are encouraged to
|
51 |
+
respect those requests where reasonable. More_considerations
|
52 |
+
for the public:
|
53 |
+
wiki.creativecommons.org/Considerations_for_licensees
|
54 |
+
|
55 |
+
=======================================================================
|
56 |
+
|
57 |
+
Creative Commons Attribution-NonCommercial 4.0 International Public
|
58 |
+
License
|
59 |
+
|
60 |
+
By exercising the Licensed Rights (defined below), You accept and agree
|
61 |
+
to be bound by the terms and conditions of this Creative Commons
|
62 |
+
Attribution-NonCommercial 4.0 International Public License ("Public
|
63 |
+
License"). To the extent this Public License may be interpreted as a
|
64 |
+
contract, You are granted the Licensed Rights in consideration of Your
|
65 |
+
acceptance of these terms and conditions, and the Licensor grants You
|
66 |
+
such rights in consideration of benefits the Licensor receives from
|
67 |
+
making the Licensed Material available under these terms and
|
68 |
+
conditions.
|
69 |
+
|
70 |
+
Section 1 -- Definitions.
|
71 |
+
|
72 |
+
a. Adapted Material means material subject to Copyright and Similar
|
73 |
+
Rights that is derived from or based upon the Licensed Material
|
74 |
+
and in which the Licensed Material is translated, altered,
|
75 |
+
arranged, transformed, or otherwise modified in a manner requiring
|
76 |
+
permission under the Copyright and Similar Rights held by the
|
77 |
+
Licensor. For purposes of this Public License, where the Licensed
|
78 |
+
Material is a musical work, performance, or sound recording,
|
79 |
+
Adapted Material is always produced where the Licensed Material is
|
80 |
+
synched in timed relation with a moving image.
|
81 |
+
|
82 |
+
b. Adapter's License means the license You apply to Your Copyright
|
83 |
+
and Similar Rights in Your contributions to Adapted Material in
|
84 |
+
accordance with the terms and conditions of this Public License.
|
85 |
+
|
86 |
+
c. Copyright and Similar Rights means copyright and/or similar rights
|
87 |
+
closely related to copyright including, without limitation,
|
88 |
+
performance, broadcast, sound recording, and Sui Generis Database
|
89 |
+
Rights, without regard to how the rights are labeled or
|
90 |
+
categorized. For purposes of this Public License, the rights
|
91 |
+
specified in Section 2(b)(1)-(2) are not Copyright and Similar
|
92 |
+
Rights.
|
93 |
+
d. Effective Technological Measures means those measures that, in the
|
94 |
+
absence of proper authority, may not be circumvented under laws
|
95 |
+
fulfilling obligations under Article 11 of the WIPO Copyright
|
96 |
+
Treaty adopted on December 20, 1996, and/or similar international
|
97 |
+
agreements.
|
98 |
+
|
99 |
+
e. Exceptions and Limitations means fair use, fair dealing, and/or
|
100 |
+
any other exception or limitation to Copyright and Similar Rights
|
101 |
+
that applies to Your use of the Licensed Material.
|
102 |
+
|
103 |
+
f. Licensed Material means the artistic or literary work, database,
|
104 |
+
or other material to which the Licensor applied this Public
|
105 |
+
License.
|
106 |
+
|
107 |
+
g. Licensed Rights means the rights granted to You subject to the
|
108 |
+
terms and conditions of this Public License, which are limited to
|
109 |
+
all Copyright and Similar Rights that apply to Your use of the
|
110 |
+
Licensed Material and that the Licensor has authority to license.
|
111 |
+
|
112 |
+
h. Licensor means the individual(s) or entity(ies) granting rights
|
113 |
+
under this Public License.
|
114 |
+
|
115 |
+
i. NonCommercial means not primarily intended for or directed towards
|
116 |
+
commercial advantage or monetary compensation. For purposes of
|
117 |
+
this Public License, the exchange of the Licensed Material for
|
118 |
+
other material subject to Copyright and Similar Rights by digital
|
119 |
+
file-sharing or similar means is NonCommercial provided there is
|
120 |
+
no payment of monetary compensation in connection with the
|
121 |
+
exchange.
|
122 |
+
|
123 |
+
j. Share means to provide material to the public by any means or
|
124 |
+
process that requires permission under the Licensed Rights, such
|
125 |
+
as reproduction, public display, public performance, distribution,
|
126 |
+
dissemination, communication, or importation, and to make material
|
127 |
+
available to the public including in ways that members of the
|
128 |
+
public may access the material from a place and at a time
|
129 |
+
individually chosen by them.
|
130 |
+
|
131 |
+
k. Sui Generis Database Rights means rights other than copyright
|
132 |
+
resulting from Directive 96/9/EC of the European Parliament and of
|
133 |
+
the Council of 11 March 1996 on the legal protection of databases,
|
134 |
+
as amended and/or succeeded, as well as other essentially
|
135 |
+
equivalent rights anywhere in the world.
|
136 |
+
|
137 |
+
l. You means the individual or entity exercising the Licensed Rights
|
138 |
+
under this Public License. Your has a corresponding meaning.
|
139 |
+
|
140 |
+
Section 2 -- Scope.
|
141 |
+
|
142 |
+
a. License grant.
|
143 |
+
|
144 |
+
1. Subject to the terms and conditions of this Public License,
|
145 |
+
the Licensor hereby grants You a worldwide, royalty-free,
|
146 |
+
non-sublicensable, non-exclusive, irrevocable license to
|
147 |
+
exercise the Licensed Rights in the Licensed Material to:
|
148 |
+
|
149 |
+
a. reproduce and Share the Licensed Material, in whole or
|
150 |
+
in part, for NonCommercial purposes only; and
|
151 |
+
|
152 |
+
b. produce, reproduce, and Share Adapted Material for
|
153 |
+
NonCommercial purposes only.
|
154 |
+
|
155 |
+
2. Exceptions and Limitations. For the avoidance of doubt, where
|
156 |
+
Exceptions and Limitations apply to Your use, this Public
|
157 |
+
License does not apply, and You do not need to comply with
|
158 |
+
its terms and conditions.
|
159 |
+
|
160 |
+
3. Term. The term of this Public License is specified in Section
|
161 |
+
6(a).
|
162 |
+
|
163 |
+
4. Media and formats; technical modifications allowed. The
|
164 |
+
Licensor authorizes You to exercise the Licensed Rights in
|
165 |
+
all media and formats whether now known or hereafter created,
|
166 |
+
and to make technical modifications necessary to do so. The
|
167 |
+
Licensor waives and/or agrees not to assert any right or
|
168 |
+
authority to forbid You from making technical modifications
|
169 |
+
necessary to exercise the Licensed Rights, including
|
170 |
+
technical modifications necessary to circumvent Effective
|
171 |
+
Technological Measures. For purposes of this Public License,
|
172 |
+
simply making modifications authorized by this Section 2(a)
|
173 |
+
(4) never produces Adapted Material.
|
174 |
+
|
175 |
+
5. Downstream recipients.
|
176 |
+
|
177 |
+
a. Offer from the Licensor -- Licensed Material. Every
|
178 |
+
recipient of the Licensed Material automatically
|
179 |
+
receives an offer from the Licensor to exercise the
|
180 |
+
Licensed Rights under the terms and conditions of this
|
181 |
+
Public License.
|
182 |
+
|
183 |
+
b. No downstream restrictions. You may not offer or impose
|
184 |
+
any additional or different terms or conditions on, or
|
185 |
+
apply any Effective Technological Measures to, the
|
186 |
+
Licensed Material if doing so restricts exercise of the
|
187 |
+
Licensed Rights by any recipient of the Licensed
|
188 |
+
Material.
|
189 |
+
|
190 |
+
6. No endorsement. Nothing in this Public License constitutes or
|
191 |
+
may be construed as permission to assert or imply that You
|
192 |
+
are, or that Your use of the Licensed Material is, connected
|
193 |
+
with, or sponsored, endorsed, or granted official status by,
|
194 |
+
the Licensor or others designated to receive attribution as
|
195 |
+
provided in Section 3(a)(1)(A)(i).
|
196 |
+
|
197 |
+
b. Other rights.
|
198 |
+
|
199 |
+
1. Moral rights, such as the right of integrity, are not
|
200 |
+
licensed under this Public License, nor are publicity,
|
201 |
+
privacy, and/or other similar personality rights; however, to
|
202 |
+
the extent possible, the Licensor waives and/or agrees not to
|
203 |
+
assert any such rights held by the Licensor to the limited
|
204 |
+
extent necessary to allow You to exercise the Licensed
|
205 |
+
Rights, but not otherwise.
|
206 |
+
|
207 |
+
2. Patent and trademark rights are not licensed under this
|
208 |
+
Public License.
|
209 |
+
|
210 |
+
3. To the extent possible, the Licensor waives any right to
|
211 |
+
collect royalties from You for the exercise of the Licensed
|
212 |
+
Rights, whether directly or through a collecting society
|
213 |
+
under any voluntary or waivable statutory or compulsory
|
214 |
+
licensing scheme. In all other cases the Licensor expressly
|
215 |
+
reserves any right to collect such royalties, including when
|
216 |
+
the Licensed Material is used other than for NonCommercial
|
217 |
+
purposes.
|
218 |
+
|
219 |
+
Section 3 -- License Conditions.
|
220 |
+
|
221 |
+
Your exercise of the Licensed Rights is expressly made subject to the
|
222 |
+
following conditions.
|
223 |
+
|
224 |
+
a. Attribution.
|
225 |
+
|
226 |
+
1. If You Share the Licensed Material (including in modified
|
227 |
+
form), You must:
|
228 |
+
|
229 |
+
a. retain the following if it is supplied by the Licensor
|
230 |
+
with the Licensed Material:
|
231 |
+
|
232 |
+
i. identification of the creator(s) of the Licensed
|
233 |
+
Material and any others designated to receive
|
234 |
+
attribution, in any reasonable manner requested by
|
235 |
+
the Licensor (including by pseudonym if
|
236 |
+
designated);
|
237 |
+
|
238 |
+
ii. a copyright notice;
|
239 |
+
|
240 |
+
iii. a notice that refers to this Public License;
|
241 |
+
|
242 |
+
iv. a notice that refers to the disclaimer of
|
243 |
+
warranties;
|
244 |
+
|
245 |
+
v. a URI or hyperlink to the Licensed Material to the
|
246 |
+
extent reasonably practicable;
|
247 |
+
|
248 |
+
b. indicate if You modified the Licensed Material and
|
249 |
+
retain an indication of any previous modifications; and
|
250 |
+
|
251 |
+
c. indicate the Licensed Material is licensed under this
|
252 |
+
Public License, and include the text of, or the URI or
|
253 |
+
hyperlink to, this Public License.
|
254 |
+
|
255 |
+
2. You may satisfy the conditions in Section 3(a)(1) in any
|
256 |
+
reasonable manner based on the medium, means, and context in
|
257 |
+
which You Share the Licensed Material. For example, it may be
|
258 |
+
reasonable to satisfy the conditions by providing a URI or
|
259 |
+
hyperlink to a resource that includes the required
|
260 |
+
information.
|
261 |
+
|
262 |
+
3. If requested by the Licensor, You must remove any of the
|
263 |
+
information required by Section 3(a)(1)(A) to the extent
|
264 |
+
reasonably practicable.
|
265 |
+
|
266 |
+
4. If You Share Adapted Material You produce, the Adapter's
|
267 |
+
License You apply must not prevent recipients of the Adapted
|
268 |
+
Material from complying with this Public License.
|
269 |
+
|
270 |
+
Section 4 -- Sui Generis Database Rights.
|
271 |
+
|
272 |
+
Where the Licensed Rights include Sui Generis Database Rights that
|
273 |
+
apply to Your use of the Licensed Material:
|
274 |
+
|
275 |
+
a. for the avoidance of doubt, Section 2(a)(1) grants You the right
|
276 |
+
to extract, reuse, reproduce, and Share all or a substantial
|
277 |
+
portion of the contents of the database for NonCommercial purposes
|
278 |
+
only;
|
279 |
+
|
280 |
+
b. if You include all or a substantial portion of the database
|
281 |
+
contents in a database in which You have Sui Generis Database
|
282 |
+
Rights, then the database in which You have Sui Generis Database
|
283 |
+
Rights (but not its individual contents) is Adapted Material; and
|
284 |
+
|
285 |
+
c. You must comply with the conditions in Section 3(a) if You Share
|
286 |
+
all or a substantial portion of the contents of the database.
|
287 |
+
|
288 |
+
For the avoidance of doubt, this Section 4 supplements and does not
|
289 |
+
replace Your obligations under this Public License where the Licensed
|
290 |
+
Rights include other Copyright and Similar Rights.
|
291 |
+
|
292 |
+
Section 5 -- Disclaimer of Warranties and Limitation of Liability.
|
293 |
+
|
294 |
+
a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE
|
295 |
+
EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS
|
296 |
+
AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF
|
297 |
+
ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS,
|
298 |
+
IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION,
|
299 |
+
WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR
|
300 |
+
PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS,
|
301 |
+
ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT
|
302 |
+
KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT
|
303 |
+
ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU.
|
304 |
+
|
305 |
+
b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE
|
306 |
+
TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION,
|
307 |
+
NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT,
|
308 |
+
INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES,
|
309 |
+
COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR
|
310 |
+
USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN
|
311 |
+
ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR
|
312 |
+
DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR
|
313 |
+
IN PART, THIS LIMITATION MAY NOT APPLY TO YOU.
|
314 |
+
|
315 |
+
c. The disclaimer of warranties and limitation of liability provided
|
316 |
+
above shall be interpreted in a manner that, to the extent
|
317 |
+
possible, most closely approximates an absolute disclaimer and
|
318 |
+
waiver of all liability.
|
319 |
+
|
320 |
+
Section 6 -- Term and Termination.
|
321 |
+
|
322 |
+
a. This Public License applies for the term of the Copyright and
|
323 |
+
Similar Rights licensed here. However, if You fail to comply with
|
324 |
+
this Public License, then Your rights under this Public License
|
325 |
+
terminate automatically.
|
326 |
+
|
327 |
+
b. Where Your right to use the Licensed Material has terminated under
|
328 |
+
Section 6(a), it reinstates:
|
329 |
+
|
330 |
+
1. automatically as of the date the violation is cured, provided
|
331 |
+
it is cured within 30 days of Your discovery of the
|
332 |
+
violation; or
|
333 |
+
|
334 |
+
2. upon express reinstatement by the Licensor.
|
335 |
+
|
336 |
+
For the avoidance of doubt, this Section 6(b) does not affect any
|
337 |
+
right the Licensor may have to seek remedies for Your violations
|
338 |
+
of this Public License.
|
339 |
+
|
340 |
+
c. For the avoidance of doubt, the Licensor may also offer the
|
341 |
+
Licensed Material under separate terms or conditions or stop
|
342 |
+
distributing the Licensed Material at any time; however, doing so
|
343 |
+
will not terminate this Public License.
|
344 |
+
|
345 |
+
d. Sections 1, 5, 6, 7, and 8 survive termination of this Public
|
346 |
+
License.
|
347 |
+
|
348 |
+
Section 7 -- Other Terms and Conditions.
|
349 |
+
|
350 |
+
a. The Licensor shall not be bound by any additional or different
|
351 |
+
terms or conditions communicated by You unless expressly agreed.
|
352 |
+
|
353 |
+
b. Any arrangements, understandings, or agreements regarding the
|
354 |
+
Licensed Material not stated herein are separate from and
|
355 |
+
independent of the terms and conditions of this Public License.
|
356 |
+
|
357 |
+
Section 8 -- Interpretation.
|
358 |
+
|
359 |
+
a. For the avoidance of doubt, this Public License does not, and
|
360 |
+
shall not be interpreted to, reduce, limit, restrict, or impose
|
361 |
+
conditions on any use of the Licensed Material that could lawfully
|
362 |
+
be made without permission under this Public License.
|
363 |
+
|
364 |
+
b. To the extent possible, if any provision of this Public License is
|
365 |
+
deemed unenforceable, it shall be automatically reformed to the
|
366 |
+
minimum extent necessary to make it enforceable. If the provision
|
367 |
+
cannot be reformed, it shall be severed from this Public License
|
368 |
+
without affecting the enforceability of the remaining terms and
|
369 |
+
conditions.
|
370 |
+
|
371 |
+
c. No term or condition of this Public License will be waived and no
|
372 |
+
failure to comply consented to unless expressly agreed to by the
|
373 |
+
Licensor.
|
374 |
+
|
375 |
+
d. Nothing in this Public License constitutes or may be interpreted
|
376 |
+
as a limitation upon, or waiver of, any privileges and immunities
|
377 |
+
that apply to the Licensor or You, including from the legal
|
378 |
+
processes of any jurisdiction or authority.
|
379 |
+
|
380 |
+
=======================================================================
|
381 |
+
|
382 |
+
Creative Commons is not a party to its public
|
383 |
+
licenses. Notwithstanding, Creative Commons may elect to apply one of
|
384 |
+
its public licenses to material it publishes and in those instances
|
385 |
+
will be considered the “Licensor.” The text of the Creative Commons
|
386 |
+
public licenses is dedicated to the public domain under the CC0 Public
|
387 |
+
Domain Dedication. Except for the limited purpose of indicating that
|
388 |
+
material is shared under a Creative Commons public license or as
|
389 |
+
otherwise permitted by the Creative Commons policies published at
|
390 |
+
creativecommons.org/policies, Creative Commons does not authorize the
|
391 |
+
use of the trademark "Creative Commons" or any other trademark or logo
|
392 |
+
of Creative Commons without its prior written consent including,
|
393 |
+
without limitation, in connection with any unauthorized modifications
|
394 |
+
to any of its public licenses or any other arrangements,
|
395 |
+
understandings, or agreements concerning use of licensed material. For
|
396 |
+
the avoidance of doubt, this paragraph does not form part of the
|
397 |
+
public licenses.
|
398 |
+
|
399 |
+
Creative Commons may be contacted at creativecommons.org.
|
PuzzleTuning/Counterpart PreTrain Methods/gcmae/LINPROBE.md
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
## Linear probe GCMAE
|
2 |
+
|
3 |
+
```
|
4 |
+
python main_linprobe.py \
|
5 |
+
--data_path_train path/to/train/data \
|
6 |
+
--data_path_val path/to/val/data \
|
7 |
+
--nb_classes 2 \
|
8 |
+
--output_dir path/to/ouput/dir \
|
9 |
+
--log_dir path/to/log/dir \
|
10 |
+
--batch_size 512 \
|
11 |
+
--model vit_base_patch16 \
|
12 |
+
--epochs 90 \
|
13 |
+
--finetune path/to/pth/path
|
14 |
+
```
|
15 |
+
|
PuzzleTuning/Counterpart PreTrain Methods/gcmae/PRETRAIN.md
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
## Pre-training GCMAE
|
2 |
+
|
3 |
+
To pre-train ViT-Base (recommended default)
|
4 |
+
```
|
5 |
+
python main_pretrain.py \
|
6 |
+
--data_path path/to/data \
|
7 |
+
--data_val_path path/to/data \
|
8 |
+
--output_dir path/to/ouput/dir \
|
9 |
+
--log_dir path/to/log/dir \
|
10 |
+
--batch_size 128 \
|
11 |
+
--model gcmae_vit_base_patch16 \
|
12 |
+
--norm_pix_loss \
|
13 |
+
--mask_ratio 0.5 \
|
14 |
+
--epochs 80 \
|
15 |
+
--warmup_epochs 40 \
|
16 |
+
--blr 1e-3 --weight_decay 0.05 \
|
17 |
+
--low_dim 768 \
|
18 |
+
--nce_k 8192 \
|
19 |
+
--nce_t 0.07 \
|
20 |
+
--nce_m 0.5 \
|
21 |
+
```
|
PuzzleTuning/Counterpart PreTrain Methods/gcmae/README.md
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# GCMAE
|
2 |
+
|
3 |
+
The original repo of GCMAE could be found [here](https://github.com/StarUniversus/gcmae)
|
4 |
+
|
5 |
+
To install environments:
|
6 |
+
```bash
|
7 |
+
pip install -r requirements.txt
|
8 |
+
```
|
9 |
+
|
10 |
+
|
11 |
+
To start pretraining:
|
12 |
+
```bash
|
13 |
+
# You need to alter the script according to your directories
|
14 |
+
bash pretrain.sh
|
15 |
+
```
|
PuzzleTuning/Counterpart PreTrain Methods/gcmae/README_origin.md
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# GCMAE
|
2 |
+
|
3 |
+
<p align="center">
|
4 |
+
<img src="./GCMAE.png" width="800">
|
5 |
+
</p>
|
6 |
+
|
7 |
+
The official implementation of the paper [Global Contrast Masked Autoencoders Are Powerful Pathological Representation Learners](https://arxiv.org/abs/2205.09048)
|
8 |
+
|
9 |
+
```
|
10 |
+
@article{li2022gcmae,
|
11 |
+
author = {Quan, Hao and Li, Xingyu and Chen, Weixing and Bai, Qun and Zou, Mingchen and Yang, Ruijie and Zheng, Tingting and Qi, Ruiqun and Gao, Xinghua and Cui, Xiaoyu},
|
12 |
+
title = {Global Contrast Masked Autoencoders Are Powerful Pathological Representation Learners},
|
13 |
+
journal={arXiv:2205.09048},
|
14 |
+
year = {2022},
|
15 |
+
}
|
16 |
+
```
|
17 |
+
## Abstract
|
18 |
+
Based on digital whole slide scanning technique, artificial intelligence algorithms represented by deep learning have achieved remarkable results in the field of computational pathology. Compared with other medical images such as Computed Tomography (CT) or Magnetic Resonance Imaging (MRI), pathological images are more difficult to annotate, thus there is an extreme lack of data sets that can be used for supervised learning. In this study, a self-supervised learning (SSL) model, Global Contrast Masked Autoencoders (GCMAE), is proposed, which has the ability to represent both global and local domain-specific features of whole slide image (WSI), as well as excellent cross-data transfer ability. The Camelyon16 and NCTCRC datasets are used to evaluate the performance of our model. When dealing with transfer learning tasks with different data sets, the experimental results show that GCMAE has better linear classification accuracy than MAE, which can reach **81.10%** and **89.22%** respectively. Our method outperforms the previous state of-the-art algorithm and even surpass supervised learning (improved by **3.86%** on NCTCRC data sets).
|
19 |
+
|
20 |
+
## Installation
|
21 |
+
This repo is a modification on the [mae repo](https://github.com/facebookresearch/mae). Installation and preparation follow that repo.
|
22 |
+
|
23 |
+
## Usage
|
24 |
+
|
25 |
+
* [PRETRAIN](PRETRAIN.md)
|
26 |
+
|
27 |
+
* [LINPROBE](LINPROBE.md)
|
28 |
+
|
29 |
+
* [FINETUNE](FINETUNE.md)
|
30 |
+
|
31 |
+
* [ Visual GCMAE feature representation](VISUAL.md)
|
32 |
+
|
33 |
+
## Dataset
|
34 |
+
* [Camelyon16](https://pan.baidu.com/s/1N0fqJR9u8yq-y6ZY0mSoUw?pwd=noms)
|
35 |
+
* [NCT-CRC-HE-100K](https://zenodo.org/record/1214456)
|
36 |
+
* [BreakHis](https://web.inf.ufpr.br/vri/databases/breast-cancer-histopathological-database-breakhis/)
|
37 |
+
|
38 |
+
## License
|
39 |
+
Distributed under the CC-BY-NC 4.0 License. See [LICENSE](LICENSE) for more information.
|
PuzzleTuning/Counterpart PreTrain Methods/gcmae/VISUAL.md
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
## Visual GCMAE feature representation
|
2 |
+
|
3 |
+
```
|
4 |
+
python tsne.py \
|
5 |
+
--batch_size 128 \
|
6 |
+
--model vit_base_patch16 \
|
7 |
+
--finetune path/to/pth \
|
8 |
+
--save_path path/to/png \
|
9 |
+
--data_path_val path/to/data \
|
10 |
+
```
|
PuzzleTuning/Counterpart PreTrain Methods/gcmae/engine_finetune.py
ADDED
@@ -0,0 +1,183 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
# --------------------------------------------------------
|
7 |
+
# References:
|
8 |
+
# DeiT: https://github.com/facebookresearch/deit
|
9 |
+
# BEiT: https://github.com/microsoft/unilm/tree/master/beit
|
10 |
+
# MAE: https://github.com/facebookresearch/mae
|
11 |
+
# --------------------------------------------------------
|
12 |
+
|
13 |
+
import math
|
14 |
+
import sys
|
15 |
+
from typing import Iterable, Optional
|
16 |
+
|
17 |
+
import torch
|
18 |
+
|
19 |
+
from timm.data import Mixup
|
20 |
+
from timm.utils import accuracy
|
21 |
+
|
22 |
+
import util.misc as misc
|
23 |
+
import util.lr_sched as lr_sched
|
24 |
+
from sklearn.metrics import roc_auc_score
|
25 |
+
from sklearn.metrics import precision_score, recall_score, f1_score, confusion_matrix
|
26 |
+
import torch.nn.functional as F
|
27 |
+
import numpy
|
28 |
+
import numpy as np
|
29 |
+
from torchmetrics import Specificity, AUROC
|
30 |
+
|
31 |
+
|
32 |
+
|
33 |
+
def train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module,
|
34 |
+
data_loader: Iterable, optimizer: torch.optim.Optimizer,
|
35 |
+
device: torch.device, epoch: int, loss_scaler, max_norm: float = 0,
|
36 |
+
mixup_fn: Optional[Mixup] = None, log_writer=None,
|
37 |
+
args=None):
|
38 |
+
model.train(True)
|
39 |
+
metric_logger = misc.MetricLogger(delimiter=" ")
|
40 |
+
metric_logger.add_meter('lr', misc.SmoothedValue(window_size=1, fmt='{value:.6f}'))
|
41 |
+
header = 'Epoch: [{}]'.format(epoch)
|
42 |
+
print_freq = 20
|
43 |
+
|
44 |
+
accum_iter = args.accum_iter
|
45 |
+
|
46 |
+
optimizer.zero_grad()
|
47 |
+
|
48 |
+
if log_writer is not None:
|
49 |
+
print('log_dir: {}'.format(log_writer.log_dir))
|
50 |
+
|
51 |
+
for data_iter_step, (samples, targets) in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
|
52 |
+
|
53 |
+
# we use a per iteration (instead of per epoch) lr scheduler
|
54 |
+
if data_iter_step % accum_iter == 0:
|
55 |
+
lr_sched.adjust_learning_rate(optimizer, data_iter_step / len(data_loader) + epoch, args)
|
56 |
+
|
57 |
+
samples = samples.to(device, non_blocking=True)
|
58 |
+
targets = targets.to(device, non_blocking=True)
|
59 |
+
|
60 |
+
if mixup_fn is not None:
|
61 |
+
samples, targets = mixup_fn(samples, targets)
|
62 |
+
|
63 |
+
with torch.cuda.amp.autocast():
|
64 |
+
outputs = model(samples)
|
65 |
+
loss = criterion(outputs, targets)
|
66 |
+
|
67 |
+
loss_value = loss.item()
|
68 |
+
|
69 |
+
if not math.isfinite(loss_value):
|
70 |
+
print("Loss is {}, stopping training".format(loss_value))
|
71 |
+
sys.exit(1)
|
72 |
+
|
73 |
+
loss /= accum_iter
|
74 |
+
loss_scaler(loss, optimizer, clip_grad=max_norm,
|
75 |
+
parameters=model.parameters(), create_graph=False,
|
76 |
+
update_grad=(data_iter_step + 1) % accum_iter == 0)
|
77 |
+
if (data_iter_step + 1) % accum_iter == 0:
|
78 |
+
optimizer.zero_grad()
|
79 |
+
|
80 |
+
torch.cuda.synchronize()
|
81 |
+
|
82 |
+
metric_logger.update(loss=loss_value)
|
83 |
+
min_lr = 10.
|
84 |
+
max_lr = 0.
|
85 |
+
for group in optimizer.param_groups:
|
86 |
+
min_lr = min(min_lr, group["lr"])
|
87 |
+
max_lr = max(max_lr, group["lr"])
|
88 |
+
|
89 |
+
metric_logger.update(lr=max_lr)
|
90 |
+
|
91 |
+
loss_value_reduce = misc.all_reduce_mean(loss_value)
|
92 |
+
if log_writer is not None and (data_iter_step + 1) % accum_iter == 0:
|
93 |
+
""" We use epoch_1000x as the x-axis in tensorboard.
|
94 |
+
This calibrates different curves when batch size changes.
|
95 |
+
"""
|
96 |
+
epoch_1000x = int((data_iter_step / len(data_loader) + epoch) * 1000)
|
97 |
+
log_writer.add_scalar('loss', loss_value_reduce, epoch_1000x)
|
98 |
+
log_writer.add_scalar('lr', max_lr, epoch_1000x)
|
99 |
+
|
100 |
+
# gather the stats from all processes
|
101 |
+
metric_logger.synchronize_between_processes()
|
102 |
+
print("Averaged stats:", metric_logger)
|
103 |
+
return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
|
104 |
+
|
105 |
+
def confusion_m(y_true, y_pred):
|
106 |
+
tn, fp, fn, tp = confusion_matrix(y_true, y_pred).ravel()
|
107 |
+
return tn, fp, fn, tp
|
108 |
+
|
109 |
+
def compute_metrics_binary(probs, preds, targets):
|
110 |
+
auc = roc_auc_score(targets, probs) * 100
|
111 |
+
precision = precision_score(targets, preds) * 100
|
112 |
+
recall = recall_score(targets, preds) * 100
|
113 |
+
f1 = f1_score(targets, preds) * 100
|
114 |
+
tn, fp, fn, tp = confusion_m(targets, preds)
|
115 |
+
specificity = (tn / float(tn+fp)) * 100
|
116 |
+
return auc, precision, recall, f1, specificity
|
117 |
+
|
118 |
+
def compute_metrics_multiclass(probs, preds, targets, nb_classes):
|
119 |
+
preds_tensor, probs_tensor, targets_tensor = torch.tensor(preds), torch.tensor(probs), torch.tensor(targets)
|
120 |
+
auroc = AUROC(average='macro', num_classes=nb_classes)
|
121 |
+
auc = auroc(probs_tensor, targets_tensor) * 100
|
122 |
+
precision = precision_score(targets, preds, average='macro') * 100
|
123 |
+
recall = recall_score(targets, preds, average='macro') * 100
|
124 |
+
f1 = f1_score(targets, preds, average='macro') * 100
|
125 |
+
speci = Specificity(average='macro', num_classes=nb_classes)
|
126 |
+
specificity = speci(preds_tensor, targets_tensor) * 100
|
127 |
+
return auc, precision, recall, f1, specificity
|
128 |
+
|
129 |
+
@torch.no_grad()
|
130 |
+
def evaluate(data_loader, model, device, nb_classes):
|
131 |
+
criterion = torch.nn.CrossEntropyLoss()
|
132 |
+
m = torch.nn.Softmax(dim=1)
|
133 |
+
metric_logger = misc.MetricLogger(delimiter=" ")
|
134 |
+
header = 'Test:'
|
135 |
+
|
136 |
+
# switch to evaluation mode
|
137 |
+
model.eval()
|
138 |
+
probs = []
|
139 |
+
targets = []
|
140 |
+
preds = []
|
141 |
+
|
142 |
+
for batch in metric_logger.log_every(data_loader, 10, header):
|
143 |
+
images = batch[0]
|
144 |
+
target = batch[-1]
|
145 |
+
images = images.to(device, non_blocking=True)
|
146 |
+
target = target.to(device, non_blocking=True)
|
147 |
+
|
148 |
+
# compute output
|
149 |
+
with torch.cuda.amp.autocast():
|
150 |
+
output = model(images)
|
151 |
+
loss = criterion(output, target)
|
152 |
+
output = m(output)
|
153 |
+
score, pred = output.topk(1, 1, True, True)
|
154 |
+
if nb_classes == 2:
|
155 |
+
prob = output[:, 1]
|
156 |
+
elif nb_classes > 2:
|
157 |
+
prob = output
|
158 |
+
|
159 |
+
probs.extend(prob.detach().cpu().numpy())
|
160 |
+
targets.extend(target.detach().cpu().numpy())
|
161 |
+
preds.extend(pred.tolist())
|
162 |
+
if nb_classes < 5:
|
163 |
+
acc1, acc5 = accuracy(output, target, topk=(1, 1))
|
164 |
+
else:
|
165 |
+
acc1, acc5 = accuracy(output, target, topk=(1, 5))
|
166 |
+
|
167 |
+
batch_size = images.shape[0]
|
168 |
+
metric_logger.update(loss=loss.item())
|
169 |
+
metric_logger.meters['acc1'].update(acc1.item(), n=batch_size)
|
170 |
+
metric_logger.meters['acc5'].update(acc5.item(), n=batch_size)
|
171 |
+
# gather the stats from all processes
|
172 |
+
metric_logger.synchronize_between_processes()
|
173 |
+
print('* Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f} loss {losses.global_avg:.3f}'
|
174 |
+
.format(top1=metric_logger.acc1, top5=metric_logger.acc5, losses=metric_logger.loss))
|
175 |
+
if nb_classes == 2:
|
176 |
+
print("binary class metrics!")
|
177 |
+
auc, precision, recall, f1, specificity = compute_metrics_binary(probs, preds, targets)
|
178 |
+
|
179 |
+
elif nb_classes > 2:
|
180 |
+
print("multi_class metrics!")
|
181 |
+
auc, precision, recall, f1, specificity = compute_metrics_multiclass(probs, preds, targets, nb_classes)
|
182 |
+
|
183 |
+
return {k: meter.global_avg for k, meter in metric_logger.meters.items()}, auc, precision, recall, f1, specificity
|
PuzzleTuning/Counterpart PreTrain Methods/gcmae/engine_pretrain.py
ADDED
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
# --------------------------------------------------------
|
7 |
+
# References:
|
8 |
+
# DeiT: https://github.com/facebookresearch/deit
|
9 |
+
# BEiT: https://github.com/microsoft/unilm/tree/master/beit
|
10 |
+
# --------------------------------------------------------
|
11 |
+
import math
|
12 |
+
import sys
|
13 |
+
from typing import Iterable
|
14 |
+
|
15 |
+
import torch
|
16 |
+
|
17 |
+
import util.misc as misc
|
18 |
+
import util.lr_sched as lr_sched
|
19 |
+
from test_npid import NN, kNN
|
20 |
+
|
21 |
+
def train_one_epoch(model: torch.nn.Module,
|
22 |
+
data_loader: Iterable,
|
23 |
+
optimizer: torch.optim.Optimizer,
|
24 |
+
device: torch.device, epoch: int, loss_scaler,
|
25 |
+
log_writer=None,
|
26 |
+
args=None,
|
27 |
+
lemniscate=None,
|
28 |
+
):
|
29 |
+
model.train(True)
|
30 |
+
metric_logger = misc.MetricLogger(delimiter=" ")
|
31 |
+
metric_logger.add_meter('lr', misc.SmoothedValue(window_size=1, fmt='{value:.6f}'))
|
32 |
+
header = 'Epoch: [{}]'.format(epoch)
|
33 |
+
print_freq = 20
|
34 |
+
|
35 |
+
accum_iter = args.accum_iter
|
36 |
+
|
37 |
+
optimizer.zero_grad()
|
38 |
+
|
39 |
+
if log_writer is not None:
|
40 |
+
print('log_dir: {}'.format(log_writer.log_dir))
|
41 |
+
|
42 |
+
for data_iter_step, (samples, _, index) in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
|
43 |
+
|
44 |
+
# we use a per iteration (instead of per epoch) lr scheduler
|
45 |
+
if data_iter_step % accum_iter == 0:
|
46 |
+
lr_sched.adjust_learning_rate(optimizer, data_iter_step / len(data_loader) + epoch, args)
|
47 |
+
|
48 |
+
samples = samples.to(device, non_blocking=True)
|
49 |
+
index = index.to(device, non_blocking=True)
|
50 |
+
with torch.cuda.amp.autocast():
|
51 |
+
loss_mae, _, _, loss_npid, _= model(samples, mask_ratio=args.mask_ratio, index = index, is_train=True)
|
52 |
+
loss = loss_mae + 0.1 * loss_npid
|
53 |
+
|
54 |
+
loss_value = loss.item()
|
55 |
+
|
56 |
+
if not math.isfinite(loss_value):
|
57 |
+
print("Loss is {}, stopping training".format(loss_value))
|
58 |
+
sys.exit(1)
|
59 |
+
|
60 |
+
loss /= accum_iter
|
61 |
+
loss_scaler(loss, optimizer, parameters=model.parameters(),
|
62 |
+
update_grad=(data_iter_step + 1) % accum_iter == 0)
|
63 |
+
if (data_iter_step + 1) % accum_iter == 0:
|
64 |
+
optimizer.zero_grad()
|
65 |
+
|
66 |
+
torch.cuda.synchronize()
|
67 |
+
|
68 |
+
|
69 |
+
metric_logger.update(loss_all=loss_value)
|
70 |
+
metric_logger.update(loss_mae=loss_mae.item())
|
71 |
+
metric_logger.update(loss_npid=loss_npid.item())
|
72 |
+
|
73 |
+
|
74 |
+
lr = optimizer.param_groups[0]["lr"]
|
75 |
+
metric_logger.update(lr=lr)
|
76 |
+
|
77 |
+
loss_value_reduce = misc.all_reduce_mean(loss_value)
|
78 |
+
if log_writer is not None and (data_iter_step + 1) % accum_iter == 0:
|
79 |
+
""" We use epoch_1000x as the x-axis in tensorboard.
|
80 |
+
This calibrates different curves when batch size changes.
|
81 |
+
"""
|
82 |
+
|
83 |
+
epoch_1000x = int((data_iter_step / len(data_loader) + epoch) * 1000)
|
84 |
+
log_writer.add_scalar('train_loss', loss_value_reduce, epoch_1000x)
|
85 |
+
log_writer.add_scalar('origin_loss/train_loss_mae', loss_mae.item(), epoch_1000x)
|
86 |
+
log_writer.add_scalar('origin_loss/train_loss_npid', loss_npid.item(), epoch_1000x)
|
87 |
+
log_writer.add_scalar('lr', lr, epoch_1000x)
|
88 |
+
|
89 |
+
# pred1 = NN(epoch, model, lemniscate, data_loader, data_loader_val)
|
90 |
+
# log_writer.add_scalar('NN_ac', pred1, epoch)
|
91 |
+
# if args.output_dir and (epoch % 20 == 0 or epoch + 1 == args.epochs):
|
92 |
+
|
93 |
+
# top1 = kNN(0, model, lemniscate, data_loader, data_loader_val, 200, args.nce_t)
|
94 |
+
# log_writer.add_scalar('KNN_top1', top1, epoch)
|
95 |
+
|
96 |
+
metric_logger.synchronize_between_processes()
|
97 |
+
print("Averaged stats:", metric_logger)
|
98 |
+
# return {k: meter.global_avg for k, meter in metric_logger.meters.items()}, pred1
|
99 |
+
return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
|
PuzzleTuning/Counterpart PreTrain Methods/gcmae/lib/NCEAverage.py
ADDED
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch.autograd import Function
|
3 |
+
from torch import nn
|
4 |
+
from .alias_multinomial import AliasMethod
|
5 |
+
import math
|
6 |
+
|
7 |
+
class NCEFunction(Function):
|
8 |
+
@staticmethod
|
9 |
+
def forward(self, x, y, memory, idx, params):
|
10 |
+
K = int(params[0].item())
|
11 |
+
T = params[1].item()
|
12 |
+
Z = params[2].item()
|
13 |
+
|
14 |
+
momentum = params[3].item()
|
15 |
+
batchSize = x.size(0)
|
16 |
+
outputSize = memory.size(0)
|
17 |
+
inputSize = memory.size(1)
|
18 |
+
|
19 |
+
# sample positives & negatives
|
20 |
+
idx.select(1,0).copy_(y.detach())
|
21 |
+
|
22 |
+
# sample correspoinding weights
|
23 |
+
weight = torch.index_select(memory, 0, idx.view(-1))
|
24 |
+
weight.resize_(batchSize, K+1, inputSize)
|
25 |
+
|
26 |
+
# inner product
|
27 |
+
with torch.cuda.amp.autocast(enabled = False):
|
28 |
+
out = torch.bmm(weight, x.detach().reshape(batchSize, inputSize, 1))
|
29 |
+
#print(out)
|
30 |
+
out.div_(T).exp_() # batchSize * self.K+1
|
31 |
+
#x.detach().resize_(batchSize, inputSize)
|
32 |
+
#print(out)
|
33 |
+
if Z < 0:
|
34 |
+
params[2] = out.mean() * outputSize
|
35 |
+
Z = params[2].item()
|
36 |
+
print("normalization constant Z is set to {:.1f}".format(Z))
|
37 |
+
|
38 |
+
out.div_(Z).resize_(batchSize, K+1)
|
39 |
+
|
40 |
+
self.save_for_backward(x, memory, y, weight, out, params)
|
41 |
+
|
42 |
+
return out
|
43 |
+
|
44 |
+
@staticmethod
|
45 |
+
def backward(self, gradOutput):
|
46 |
+
x, memory, y, weight, out, params = self.saved_tensors
|
47 |
+
K = int(params[0].item())
|
48 |
+
T = params[1].item()
|
49 |
+
Z = params[2].item()
|
50 |
+
momentum = params[3].item()
|
51 |
+
batchSize = gradOutput.size(0)
|
52 |
+
|
53 |
+
# gradients d Pm / d linear = exp(linear) / Z
|
54 |
+
gradOutput.detach().mul_(out.detach())
|
55 |
+
# add temperature
|
56 |
+
gradOutput.detach().div_(T)
|
57 |
+
|
58 |
+
|
59 |
+
|
60 |
+
# gradient of linear
|
61 |
+
with torch.cuda.amp.autocast(enabled = False):
|
62 |
+
gradInput = torch.bmm(gradOutput.detach().reshape(batchSize, 1, K+1), weight)
|
63 |
+
gradInput.resize_as_(x)
|
64 |
+
|
65 |
+
# update the non-parametric detach()
|
66 |
+
weight_pos = weight.select(1, 0).resize_as_(x)
|
67 |
+
weight_pos.mul_(momentum)
|
68 |
+
weight_pos.add_(torch.mul(x.detach(), 1-momentum))
|
69 |
+
w_norm = weight_pos.pow(2).sum(1, keepdim=True).pow(0.5)
|
70 |
+
updated_weight = weight_pos.div(w_norm)
|
71 |
+
memory.index_copy_(0, y, updated_weight)
|
72 |
+
|
73 |
+
return gradInput, None, None, None, None
|
74 |
+
|
75 |
+
class NCEAverage(nn.Module):
|
76 |
+
|
77 |
+
def __init__(self, inputSize, outputSize, K, T=0.07, momentum=0.5, Z=None):
|
78 |
+
super(NCEAverage, self).__init__()
|
79 |
+
self.nLem = outputSize
|
80 |
+
self.unigrams = torch.ones(self.nLem)
|
81 |
+
self.multinomial = AliasMethod(self.unigrams)
|
82 |
+
self.multinomial.cuda()
|
83 |
+
self.K = K
|
84 |
+
|
85 |
+
self.register_buffer('params',torch.tensor([K, T, -1, momentum]))
|
86 |
+
stdv = 1. / math.sqrt(inputSize/3)
|
87 |
+
self.register_buffer('memory', torch.rand(outputSize, inputSize).mul_(2*stdv).add_(-stdv))
|
88 |
+
|
89 |
+
def forward(self, x, y):
|
90 |
+
batchSize = x.size(0)
|
91 |
+
idx = self.multinomial.draw(batchSize * (self.K+1)).view(batchSize, -1)
|
92 |
+
out = NCEFunction.apply(x, y, self.memory, idx, self.params)
|
93 |
+
return out
|
94 |
+
|
PuzzleTuning/Counterpart PreTrain Methods/gcmae/lib/NCECriterion.py
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn
|
3 |
+
|
4 |
+
eps = 1e-7
|
5 |
+
|
6 |
+
class NCECriterion(nn.Module):
|
7 |
+
|
8 |
+
def __init__(self, nLem):
|
9 |
+
super(NCECriterion, self).__init__()
|
10 |
+
self.nLem = nLem
|
11 |
+
|
12 |
+
def forward(self, x, targets):
|
13 |
+
batchSize = x.size(0)
|
14 |
+
K = x.size(1)-1
|
15 |
+
Pnt = 1 / float(self.nLem)
|
16 |
+
Pns = 1 / float(self.nLem)
|
17 |
+
|
18 |
+
# eq 5.1 : P(origin=model) = Pmt / (Pmt + k*Pnt)
|
19 |
+
Pmt = x.select(1,0)
|
20 |
+
Pmt_div = Pmt.add(K * Pnt + eps)
|
21 |
+
lnPmt = torch.div(Pmt, Pmt_div)
|
22 |
+
|
23 |
+
# eq 5.2 : P(origin=noise) = k*Pns / (Pms + k*Pns)
|
24 |
+
Pon_div = x.narrow(1,1,K).add(K * Pns + eps)
|
25 |
+
Pon = Pon_div.clone().fill_(K * Pns)
|
26 |
+
lnPon = torch.div(Pon, Pon_div)
|
27 |
+
|
28 |
+
# equation 6 in ref. A
|
29 |
+
lnPmt.log_()
|
30 |
+
lnPon.log_()
|
31 |
+
|
32 |
+
lnPmtsum = lnPmt.sum(0)
|
33 |
+
lnPonsum = lnPon.view(-1, 1).sum(0)
|
34 |
+
|
35 |
+
loss = - (lnPmtsum + lnPonsum) / batchSize
|
36 |
+
|
37 |
+
return loss
|
38 |
+
|
PuzzleTuning/Counterpart PreTrain Methods/gcmae/lib/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
# nothing
|
PuzzleTuning/Counterpart PreTrain Methods/gcmae/lib/__pycache__/NCEAverage.cpython-38.pyc
ADDED
Binary file (2.96 kB). View file
|
|
PuzzleTuning/Counterpart PreTrain Methods/gcmae/lib/__pycache__/NCECriterion.cpython-38.pyc
ADDED
Binary file (1.13 kB). View file
|
|
PuzzleTuning/Counterpart PreTrain Methods/gcmae/lib/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (127 Bytes). View file
|
|
PuzzleTuning/Counterpart PreTrain Methods/gcmae/lib/__pycache__/alias_multinomial.cpython-38.pyc
ADDED
Binary file (1.68 kB). View file
|
|
PuzzleTuning/Counterpart PreTrain Methods/gcmae/lib/__pycache__/normalize.cpython-38.pyc
ADDED
Binary file (819 Bytes). View file
|
|
PuzzleTuning/Counterpart PreTrain Methods/gcmae/lib/__pycache__/utils.cpython-38.pyc
ADDED
Binary file (873 Bytes). View file
|
|
PuzzleTuning/Counterpart PreTrain Methods/gcmae/lib/alias_multinomial.py
ADDED
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
|
4 |
+
class AliasMethod(object):
|
5 |
+
'''
|
6 |
+
From: https://hips.seas.harvard.edu/blog/2013/03/03/the-alias-method-efficient-sampling-with-many-discrete-outcomes/
|
7 |
+
'''
|
8 |
+
def __init__(self, probs):
|
9 |
+
|
10 |
+
if probs.sum() > 1:
|
11 |
+
probs.div_(probs.sum())
|
12 |
+
K = len(probs)
|
13 |
+
self.prob = torch.zeros(K)
|
14 |
+
self.alias = torch.LongTensor([0]*K)
|
15 |
+
|
16 |
+
# Sort the data into the outcomes with probabilities
|
17 |
+
# that are larger and smaller than 1/K.
|
18 |
+
smaller = []
|
19 |
+
larger = []
|
20 |
+
for kk, prob in enumerate(probs):
|
21 |
+
self.prob[kk] = K*prob
|
22 |
+
if self.prob[kk] < 1.0:
|
23 |
+
smaller.append(kk)
|
24 |
+
else:
|
25 |
+
larger.append(kk)
|
26 |
+
|
27 |
+
# Loop though and create little binary mixtures that
|
28 |
+
# appropriately allocate the larger outcomes over the
|
29 |
+
# overall uniform mixture.
|
30 |
+
while len(smaller) > 0 and len(larger) > 0:
|
31 |
+
small = smaller.pop()
|
32 |
+
large = larger.pop()
|
33 |
+
|
34 |
+
self.alias[small] = large
|
35 |
+
self.prob[large] = (self.prob[large] - 1.0) + self.prob[small]
|
36 |
+
|
37 |
+
if self.prob[large] < 1.0:
|
38 |
+
smaller.append(large)
|
39 |
+
else:
|
40 |
+
larger.append(large)
|
41 |
+
|
42 |
+
for last_one in smaller+larger:
|
43 |
+
self.prob[last_one] = 1
|
44 |
+
|
45 |
+
def cuda(self):
|
46 |
+
self.prob = self.prob.cuda()
|
47 |
+
self.alias = self.alias.cuda()
|
48 |
+
|
49 |
+
def draw(self, N):
|
50 |
+
'''
|
51 |
+
Draw N samples from multinomial
|
52 |
+
'''
|
53 |
+
K = self.alias.size(0)
|
54 |
+
|
55 |
+
kk = torch.zeros(N, dtype=torch.long, device=self.prob.device).random_(0, K)
|
56 |
+
prob = self.prob.index_select(0, kk)
|
57 |
+
alias = self.alias.index_select(0, kk)
|
58 |
+
# b is whether a random number is greater than q
|
59 |
+
b = torch.bernoulli(prob)
|
60 |
+
oq = kk.mul(b.long())
|
61 |
+
oj = alias.mul((1-b).long())
|
62 |
+
|
63 |
+
return oq + oj
|
64 |
+
|
PuzzleTuning/Counterpart PreTrain Methods/gcmae/lib/normalize.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch.autograd import Variable
|
3 |
+
from torch import nn
|
4 |
+
|
5 |
+
class Normalize(nn.Module):
|
6 |
+
|
7 |
+
def __init__(self, power=2):
|
8 |
+
super(Normalize, self).__init__()
|
9 |
+
self.power = power
|
10 |
+
|
11 |
+
def forward(self, x):
|
12 |
+
norm = x.pow(self.power).sum(1, keepdim=True).pow(1./self.power)
|
13 |
+
out = x.div(norm)
|
14 |
+
return out
|