English
XavierJiezou commited on
Commit
0467378
·
verified ·
1 Parent(s): 2e198d7

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. app.py +238 -0
  2. checkpoints/cloud-adapter/gf1_full_weight.pth +3 -0
  3. checkpoints/cloud-adapter/gf2_full_weight.pth +3 -0
  4. checkpoints/cloud-adapter/hrc_whu_full_weight.pth +3 -0
  5. checkpoints/cloud-adapter/l1c_full_weight.pth +3 -0
  6. checkpoints/cloud-adapter/l2a_full_weight.pth +3 -0
  7. checkpoints/cloud-adapter/l8_full_weight.pth +3 -0
  8. cloud-adapter-configs/binary_classes_256x256.py +205 -0
  9. cloud-adapter-configs/multi_classes_512x512.py +205 -0
  10. cloud_adapter/__init__.py +0 -0
  11. cloud_adapter/__pycache__/__init__.cpython-38.pyc +0 -0
  12. cloud_adapter/__pycache__/cloud_adapter.cpython-38.pyc +0 -0
  13. cloud_adapter/__pycache__/cloud_adapter_dinov2.cpython-38.pyc +0 -0
  14. cloud_adapter/__pycache__/dino_v2.cpython-38.pyc +0 -0
  15. cloud_adapter/__pycache__/utils.cpython-38.pyc +0 -0
  16. cloud_adapter/cdnetv1.py +389 -0
  17. cloud_adapter/cdnetv2.py +693 -0
  18. cloud_adapter/cloud_adapter.py +590 -0
  19. cloud_adapter/cloud_adapter_dinov2.py +115 -0
  20. cloud_adapter/dbnet.py +680 -0
  21. cloud_adapter/dino_layers/__init__.py +11 -0
  22. cloud_adapter/dino_layers/__pycache__/__init__.cpython-38.pyc +0 -0
  23. cloud_adapter/dino_layers/__pycache__/attention.cpython-38.pyc +0 -0
  24. cloud_adapter/dino_layers/__pycache__/block.cpython-38.pyc +0 -0
  25. cloud_adapter/dino_layers/__pycache__/dino_head.cpython-38.pyc +0 -0
  26. cloud_adapter/dino_layers/__pycache__/drop_path.cpython-38.pyc +0 -0
  27. cloud_adapter/dino_layers/__pycache__/layer_scale.cpython-38.pyc +0 -0
  28. cloud_adapter/dino_layers/__pycache__/mlp.cpython-38.pyc +0 -0
  29. cloud_adapter/dino_layers/__pycache__/patch_embed.cpython-38.pyc +0 -0
  30. cloud_adapter/dino_layers/__pycache__/swiglu_ffn.cpython-38.pyc +0 -0
  31. cloud_adapter/dino_layers/attention.py +89 -0
  32. cloud_adapter/dino_layers/block.py +260 -0
  33. cloud_adapter/dino_layers/dino_head.py +58 -0
  34. cloud_adapter/dino_layers/drop_path.py +34 -0
  35. cloud_adapter/dino_layers/layer_scale.py +27 -0
  36. cloud_adapter/dino_layers/mlp.py +40 -0
  37. cloud_adapter/dino_layers/patch_embed.py +88 -0
  38. cloud_adapter/dino_layers/swiglu_ffn.py +72 -0
  39. cloud_adapter/dino_v2.py +353 -0
  40. cloud_adapter/hrcloudnet.py +751 -0
  41. cloud_adapter/kappamask.py +152 -0
  42. cloud_adapter/mcdnet.py +435 -0
  43. cloud_adapter/scnn.py +36 -0
  44. cloud_adapter/unetmobv2.py +31 -0
  45. cloud_adapter/utils.py +58 -0
  46. example_inputs/gf1/11.png +0 -0
  47. example_inputs/gf1/48.png +0 -0
  48. example_inputs/gf1/9.png +0 -0
  49. example_inputs/gf2/160.png +0 -0
  50. example_inputs/gf2/2.png +0 -0
app.py ADDED
@@ -0,0 +1,238 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from mmseg.apis import init_model
2
+ from typing import List
3
+ from glob import glob
4
+ from cloud_adapter.cloud_adapter_dinov2 import CloudAdapterDinoVisionTransformer
5
+ import numpy as np
6
+ from PIL import Image
7
+ from mmseg.models.segmentors.encoder_decoder import EncoderDecoder
8
+ import gradio as gr
9
+ import torch
10
+ import os
11
+
12
+
13
+ class CloudAdapterGradio:
14
+ def __init__(self, config_path=None, checkpoint_path=None, device="cpu", example_inputs=None, num_classes=2, palette=None):
15
+ self.config_path = config_path
16
+ self.checkpoint_path = checkpoint_path
17
+ self.device = device
18
+ self.model: EncoderDecoder = init_model(
19
+ self.config_path, self.checkpoint_path, device=self.device)
20
+ self.model.eval()
21
+ self.example_inputs = example_inputs
22
+ self.img_size = 256 if num_classes == 2 else 512
23
+ self.palette = palette
24
+ self.legend = self.html_legend(num_classes=num_classes)
25
+ self.name_mapping = {
26
+ "KappaMask": "kappamask",
27
+ "CDNetv1": "cdnetv1",
28
+ "CDNetv2": "cdnetv2",
29
+ "HRCloudNet": "hrcloudnet",
30
+ "MCDNet": "mcdnet",
31
+ "SCNN": "scnn",
32
+ "DBNet": "dbnet",
33
+ "UNetMobv2": "unetmobv2",
34
+ "Cloud-Adapter": "cloud-adapter",
35
+ }
36
+
37
+ self.create_ui()
38
+
39
+ def html_legend(self, num_classes=2):
40
+ if num_classes == 2:
41
+ return """
42
+ <div style="margin-top: 10px; text-align: left; display: flex; align-items: center; gap: 20px;justify-content: center;">
43
+ <div style="display: flex; align-items: center;">
44
+ <div style="width: 20px; height: 20px; background-color: rgb(79, 253, 199); margin-right: 10px; "></div>
45
+ <span>Clear</span>
46
+ </div>
47
+ <div style="display: flex; align-items: center;">
48
+ <div style="width: 20px; height: 20px; background-color: rgb(77, 2, 115); margin-right: 10px; "></div>
49
+ <span>Cloud</span>
50
+ </div>
51
+ </div>
52
+ """
53
+ return """
54
+ <div style="margin-top: 10px; text-align: left; display: flex; align-items: center; gap: 20px;justify-content: center;">
55
+ <div style="display: flex; align-items: center;">
56
+ <div style="width: 20px; height: 20px; background-color: rgb(79, 253, 199); margin-right: 10px; "></div>
57
+ <span>Clear Sky</span>
58
+ </div>
59
+ <div style="display: flex; align-items: center;">
60
+ <div style="width: 20px; height: 20px; background-color: rgb(77, 2, 115); margin-right: 10px; "></div>
61
+ <span>Thick Cloud</span>
62
+ </div>
63
+ <div style="display: flex; align-items: center;">
64
+ <div style="width: 20px; height: 20px; background-color: rgb(251, 255, 41); margin-right: 10px; "></div>
65
+ <span>Thin Cloud</span>
66
+ </div>
67
+ <div style="display: flex; align-items: center;">
68
+ <div style="width: 20px; height: 20px; background-color: rgb(221, 53, 223); margin-right: 10px; "></div>
69
+ <span>Cloud Shadow</span>
70
+ </div>
71
+ </div>
72
+ """
73
+
74
+ def create_ui(self):
75
+ with gr.Row():
76
+ # 左侧:输入图片和按钮
77
+ with gr.Column(scale=1): # 左侧列
78
+ in_image = gr.Image(
79
+ label='Input Image',
80
+ sources='upload',
81
+ elem_classes='input_image',
82
+ interactive=True,
83
+ type="pil",
84
+ )
85
+ with gr.Row():
86
+ run_button = gr.Button(
87
+ 'Run',
88
+ variant="primary",
89
+ )
90
+ # 示例输入列表
91
+ gr.Examples(
92
+ examples=self.example_inputs,
93
+ inputs=in_image,
94
+ label="Example Inputs"
95
+ )
96
+
97
+ # 右侧:输出图片
98
+ with gr.Column(scale=1): # 右侧列
99
+ with gr.Column():
100
+ # 输出图片
101
+ out_image = gr.Image(
102
+ label='Output Image',
103
+ elem_classes='output_image',
104
+ interactive=False
105
+ )
106
+ # 图例
107
+ legend = gr.HTML(
108
+ value=self.legend,
109
+ elem_classes="output_legend",
110
+ )
111
+
112
+ # 按钮点击逻辑:触发图像转换
113
+ run_button.click(
114
+ self.inference,
115
+ inputs=in_image,
116
+ outputs=out_image,
117
+ )
118
+
119
+ @torch.no_grad()
120
+ def inference(self, image: Image.Image) -> Image.Image:
121
+ return self.cloud_adapter_forward(image)
122
+
123
+ @torch.no_grad()
124
+ def cloud_adapter_forward(self, image: Image.Image) -> Image.Image:
125
+ """
126
+ Cloud Adapter Inference
127
+ """
128
+ ori_size = image.size
129
+ image = image.resize((self.img_size, self.img_size),
130
+ resample=Image.Resampling.BILINEAR)
131
+ image = np.array(image)
132
+ # print(image.shape)
133
+ image = (image - np.min(image)) / (np.max(image)-np.min(image))
134
+
135
+ image = torch.from_numpy(image).unsqueeze(0).to(self.device)
136
+ image = image.permute(0, 3, 1, 2).float()
137
+
138
+ outs = self.model.predict(image)
139
+ pred_mask = outs[0].pred_sem_seg.data.cpu().numpy().astype(np.uint8)
140
+
141
+ im = Image.fromarray(pred_mask[0]).convert("P")
142
+ im.putpalette(self.palette)
143
+
144
+ del image
145
+ del outs
146
+ if torch.cuda.is_available():
147
+ torch.cuda.empty_cache()
148
+ return im.resize(ori_size, resample=Image.Resampling.BILINEAR)
149
+
150
+
151
+ def get_palette(dataset_name: str) -> List[int]:
152
+ if dataset_name in ["cloudsen12_high_l1c", "cloudsen12_high_l2a"]:
153
+ return [79, 253, 199, 77, 2, 115, 251, 255, 41, 221, 53, 223]
154
+ if dataset_name == "l8_biome":
155
+ return [79, 253, 199, 221, 53, 223, 251, 255, 41, 77, 2, 115]
156
+ if dataset_name in ["gf12ms_whu_gf1", "gf12ms_whu_gf2", "hrc_whu"]:
157
+ return [79, 253, 199, 77, 2, 115]
158
+ raise Exception("dataset_name not supported")
159
+
160
+
161
+ if __name__ == '__main__':
162
+ title = 'Cloud Segmentation for Remote Sensing Images'
163
+ custom_css = """
164
+ h1 {
165
+ text-align: center;
166
+ font-size: 24px;
167
+ font-weight: bold;
168
+ margin-bottom: 20px;
169
+ }
170
+ """
171
+ hrc_whu_examples = glob("example_inputs/hrc_whu/*")
172
+ gf1_examples = glob("example_inputs/gf1/*")
173
+ gf2_examples = glob("example_inputs/gf2/*")
174
+ l1c_examples = glob("example_inputs/l1c/*")
175
+ l2a_examples = glob("example_inputs/l2a/*")
176
+ l8_examples = glob("example_inputs/l8/*")
177
+
178
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
179
+ with gr.Blocks(analytics_enabled=False, title=title,css=custom_css) as demo:
180
+ gr.Markdown(f'# {title}')
181
+ with gr.Tabs():
182
+ with gr.TabItem('Google Earth'):
183
+ CloudAdapterGradio(
184
+ config_path="cloud-adapter-configs/binary_classes_256x256.py",
185
+ checkpoint_path="checkpoints/cloud-adapter/hrc_whu_full_weight.pth",
186
+ device=device,
187
+ example_inputs=hrc_whu_examples,
188
+ num_classes=2,
189
+ palette=get_palette("hrc_whu"),
190
+ )
191
+ with gr.TabItem('Gaofen-1'):
192
+ CloudAdapterGradio(
193
+ config_path="cloud-adapter-configs/binary_classes_256x256.py",
194
+ checkpoint_path="checkpoints/cloud-adapter/gf1_full_weight.pth",
195
+ device=device,
196
+ example_inputs=gf1_examples,
197
+ num_classes=2,
198
+ palette=get_palette("gf12ms_whu_gf1"),
199
+ )
200
+ with gr.TabItem('Gaofen-2'):
201
+ CloudAdapterGradio(
202
+ config_path="cloud-adapter-configs/binary_classes_256x256.py",
203
+ checkpoint_path="checkpoints/cloud-adapter/gf2_full_weight.pth",
204
+ device=device,
205
+ example_inputs=gf2_examples,
206
+ num_classes=2,
207
+ palette=get_palette("gf12ms_whu_gf2"),
208
+ )
209
+
210
+ with gr.TabItem('Sentinel-2 (L1C)'):
211
+ CloudAdapterGradio(
212
+ config_path="cloud-adapter-configs/multi_classes_512x512.py",
213
+ checkpoint_path="checkpoints/cloud-adapter/l1c_full_weight.pth",
214
+ device=device,
215
+ example_inputs=l1c_examples,
216
+ num_classes=4,
217
+ palette=get_palette("cloudsen12_high_l1c"),
218
+ )
219
+ with gr.TabItem('Sentinel-2 (L2A)'):
220
+ CloudAdapterGradio(
221
+ config_path="cloud-adapter-configs/multi_classes_512x512.py",
222
+ checkpoint_path="checkpoints/cloud-adapter/l2a_full_weight.pth",
223
+ device=device,
224
+ example_inputs=l2a_examples,
225
+ num_classes=4,
226
+ palette=get_palette("cloudsen12_high_l2a"),
227
+ )
228
+ with gr.TabItem('Landsat-8'):
229
+ CloudAdapterGradio(
230
+ config_path="cloud-adapter-configs/multi_classes_512x512.py",
231
+ checkpoint_path="checkpoints/cloud-adapter/l8_full_weight.pth",
232
+ device=device,
233
+ example_inputs=l8_examples,
234
+ num_classes=4,
235
+ palette=get_palette("l8_biome"),
236
+ )
237
+
238
+ demo.launch(share=True, debug=True)
checkpoints/cloud-adapter/gf1_full_weight.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:89e224c3b518fc8f59874f85a8c01a470cdbe4d602e22caf7f1ad1ededa2899e
3
+ size 1326991459
checkpoints/cloud-adapter/gf2_full_weight.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4fd14e29f523e988743bc50e915816c80a69a526b032dab07e85433ace409d2f
3
+ size 1311500515
checkpoints/cloud-adapter/hrc_whu_full_weight.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:755b48d21763339284f4a9b6051c8dfb83a98babca75b39d6304b6f3e82f6c85
3
+ size 1316424759
checkpoints/cloud-adapter/l1c_full_weight.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d7d4bb1dd99e1995450894b985a8fb29b6b931419e3ea08674e1420a5a044804
3
+ size 1332592483
checkpoints/cloud-adapter/l2a_full_weight.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d88e426ff6222b6807315060db1ee1b65a2f98de85a62bf7d3814ba846427bd0
3
+ size 1327383395
checkpoints/cloud-adapter/l8_full_weight.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:268ff0297fa37cde78ec16884f33a126d4e7c37b6e37c69d6b18a0ba258a0cee
3
+ size 1461519710
cloud-adapter-configs/binary_classes_256x256.py ADDED
@@ -0,0 +1,205 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ crop_size = (
2
+ 256,
3
+ 256,
4
+ )
5
+ model = dict(
6
+ backbone=dict(
7
+ adapter_index=[
8
+ 0,
9
+ 1,
10
+ 2,
11
+ 3,
12
+ 4,
13
+ 5,
14
+ 6,
15
+ 7,
16
+ 8,
17
+ 9,
18
+ 10,
19
+ 11,
20
+ 12,
21
+ 13,
22
+ 14,
23
+ 15,
24
+ 16,
25
+ 17,
26
+ 18,
27
+ 19,
28
+ 20,
29
+ 21,
30
+ 22,
31
+ 23,
32
+ ],
33
+ block_chunks=0,
34
+ depth=24,
35
+ embed_dim=1024,
36
+ ffn_bias=True,
37
+ ffn_layer='mlp',
38
+ has_cat=False,
39
+ img_size=512,
40
+ init_values=1e-05,
41
+ mlp_ratio=4,
42
+ num_heads=16,
43
+ cloud_adapter_config=dict(
44
+ cnn_type='pmaa',
45
+ context_dim=64,
46
+ depth=4,
47
+ emd_dim=1024,
48
+ global_groups=1,
49
+ hidden_channels=64,
50
+ int_type='convnext',
51
+ local_groups=1,
52
+ num_layers=24,
53
+ rank_dim=16,
54
+ return_last_feature=False,
55
+ return_multi_feats=False,
56
+ type='CloudAdapter'),
57
+ patch_size=16,
58
+ proj_bias=True,
59
+ qkv_bias=True,
60
+ type='CloudAdapterDinoVisionTransformer'),
61
+ data_preprocessor=dict(
62
+ bgr_to_rgb=True,
63
+ mean=[
64
+ 123.675,
65
+ 116.28,
66
+ 103.53,
67
+ ],
68
+ pad_val=0,
69
+ seg_pad_val=255,
70
+ size=(
71
+ 512,
72
+ 512,
73
+ ),
74
+ std=[
75
+ 58.395,
76
+ 57.12,
77
+ 57.375,
78
+ ],
79
+ type='SegDataPreProcessor'),
80
+ decode_head=dict(
81
+ align_corners=False,
82
+ enforce_decoder_input_project=False,
83
+ feat_channels=256,
84
+ in_channels=[
85
+ 1024,
86
+ 1024,
87
+ 1024,
88
+ 1024,
89
+ ],
90
+ loss_cls=dict(
91
+ class_weight=[
92
+ 1.0,
93
+ 1.0,
94
+ 1.0,
95
+ 1.0,
96
+ 0.1,
97
+ ],
98
+ loss_weight=2.0,
99
+ reduction='mean',
100
+ type='mmdet.CrossEntropyLoss',
101
+ use_sigmoid=False),
102
+ loss_dice=dict(
103
+ activate=True,
104
+ eps=1.0,
105
+ loss_weight=5.0,
106
+ naive_dice=True,
107
+ reduction='mean',
108
+ type='mmdet.DiceLoss',
109
+ use_sigmoid=True),
110
+ loss_mask=dict(
111
+ loss_weight=5.0,
112
+ reduction='mean',
113
+ type='mmdet.CrossEntropyLoss',
114
+ use_sigmoid=True),
115
+ num_classes=2,
116
+ num_queries=100,
117
+ num_transformer_feat_level=3,
118
+ out_channels=256,
119
+ pixel_decoder=dict(
120
+ act_cfg=dict(type='ReLU'),
121
+ encoder=dict(
122
+ init_cfg=None,
123
+ layer_cfg=dict(
124
+ ffn_cfg=dict(
125
+ act_cfg=dict(inplace=True, type='ReLU'),
126
+ embed_dims=256,
127
+ feedforward_channels=1024,
128
+ ffn_drop=0.0,
129
+ num_fcs=2),
130
+ self_attn_cfg=dict(
131
+ batch_first=True,
132
+ dropout=0.0,
133
+ embed_dims=256,
134
+ im2col_step=64,
135
+ init_cfg=None,
136
+ norm_cfg=None,
137
+ num_heads=8,
138
+ num_levels=3,
139
+ num_points=4)),
140
+ num_layers=6),
141
+ init_cfg=None,
142
+ norm_cfg=dict(num_groups=32, type='GN'),
143
+ num_outs=3,
144
+ positional_encoding=dict(normalize=True, num_feats=128),
145
+ type='mmdet.MSDeformAttnPixelDecoder'),
146
+ positional_encoding=dict(normalize=True, num_feats=128),
147
+ strides=[
148
+ 4,
149
+ 8,
150
+ 16,
151
+ 32,
152
+ ],
153
+ train_cfg=dict(
154
+ assigner=dict(
155
+ match_costs=[
156
+ dict(type='mmdet.ClassificationCost', weight=2.0),
157
+ dict(
158
+ type='mmdet.CrossEntropyLossCost',
159
+ use_sigmoid=True,
160
+ weight=5.0),
161
+ dict(
162
+ eps=1.0,
163
+ pred_act=True,
164
+ type='mmdet.DiceCost',
165
+ weight=5.0),
166
+ ],
167
+ type='mmdet.HungarianAssigner'),
168
+ importance_sample_ratio=0.75,
169
+ num_points=12544,
170
+ oversample_ratio=3.0,
171
+ sampler=dict(type='mmdet.MaskPseudoSampler')),
172
+ transformer_decoder=dict(
173
+ init_cfg=None,
174
+ layer_cfg=dict(
175
+ cross_attn_cfg=dict(
176
+ attn_drop=0.0,
177
+ batch_first=True,
178
+ dropout_layer=None,
179
+ embed_dims=256,
180
+ num_heads=8,
181
+ proj_drop=0.0),
182
+ ffn_cfg=dict(
183
+ act_cfg=dict(inplace=True, type='ReLU'),
184
+ add_identity=True,
185
+ dropout_layer=None,
186
+ embed_dims=256,
187
+ feedforward_channels=2048,
188
+ ffn_drop=0.0,
189
+ num_fcs=2),
190
+ self_attn_cfg=dict(
191
+ attn_drop=0.0,
192
+ batch_first=True,
193
+ dropout_layer=None,
194
+ embed_dims=256,
195
+ num_heads=8,
196
+ proj_drop=0.0)),
197
+ num_layers=9,
198
+ return_intermediate=True),
199
+ type='Mask2FormerHead'),
200
+ test_cfg=dict(mode='whole'),
201
+ train_cfg=dict(),
202
+ type='EncoderDecoder')
203
+
204
+
205
+
cloud-adapter-configs/multi_classes_512x512.py ADDED
@@ -0,0 +1,205 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ crop_size = (
2
+ 512,
3
+ 512,
4
+ )
5
+ model = dict(
6
+ backbone=dict(
7
+ adapter_index=[
8
+ 0,
9
+ 1,
10
+ 2,
11
+ 3,
12
+ 4,
13
+ 5,
14
+ 6,
15
+ 7,
16
+ 8,
17
+ 9,
18
+ 10,
19
+ 11,
20
+ 12,
21
+ 13,
22
+ 14,
23
+ 15,
24
+ 16,
25
+ 17,
26
+ 18,
27
+ 19,
28
+ 20,
29
+ 21,
30
+ 22,
31
+ 23,
32
+ ],
33
+ block_chunks=0,
34
+ depth=24,
35
+ embed_dim=1024,
36
+ ffn_bias=True,
37
+ ffn_layer='mlp',
38
+ has_cat=False,
39
+ img_size=512,
40
+ init_values=1e-05,
41
+ mlp_ratio=4,
42
+ num_heads=16,
43
+ cloud_adapter_config=dict(
44
+ cnn_type='pmaa',
45
+ context_dim=64,
46
+ depth=4,
47
+ emd_dim=1024,
48
+ global_groups=1,
49
+ hidden_channels=64,
50
+ int_type='convnext',
51
+ local_groups=1,
52
+ num_layers=24,
53
+ rank_dim=16,
54
+ return_last_feature=False,
55
+ return_multi_feats=False,
56
+ type='CloudAdapter'),
57
+ patch_size=16,
58
+ proj_bias=True,
59
+ qkv_bias=True,
60
+ type='CloudAdapterDinoVisionTransformer'),
61
+ data_preprocessor=dict(
62
+ bgr_to_rgb=True,
63
+ mean=[
64
+ 123.675,
65
+ 116.28,
66
+ 103.53,
67
+ ],
68
+ pad_val=0,
69
+ seg_pad_val=255,
70
+ size=(
71
+ 512,
72
+ 512,
73
+ ),
74
+ std=[
75
+ 58.395,
76
+ 57.12,
77
+ 57.375,
78
+ ],
79
+ type='SegDataPreProcessor'),
80
+ decode_head=dict(
81
+ align_corners=False,
82
+ enforce_decoder_input_project=False,
83
+ feat_channels=256,
84
+ in_channels=[
85
+ 1024,
86
+ 1024,
87
+ 1024,
88
+ 1024,
89
+ ],
90
+ loss_cls=dict(
91
+ class_weight=[
92
+ 1.0,
93
+ 1.0,
94
+ 1.0,
95
+ 1.0,
96
+ 0.1,
97
+ ],
98
+ loss_weight=2.0,
99
+ reduction='mean',
100
+ type='mmdet.CrossEntropyLoss',
101
+ use_sigmoid=False),
102
+ loss_dice=dict(
103
+ activate=True,
104
+ eps=1.0,
105
+ loss_weight=5.0,
106
+ naive_dice=True,
107
+ reduction='mean',
108
+ type='mmdet.DiceLoss',
109
+ use_sigmoid=True),
110
+ loss_mask=dict(
111
+ loss_weight=5.0,
112
+ reduction='mean',
113
+ type='mmdet.CrossEntropyLoss',
114
+ use_sigmoid=True),
115
+ num_classes=4,
116
+ num_queries=100,
117
+ num_transformer_feat_level=3,
118
+ out_channels=256,
119
+ pixel_decoder=dict(
120
+ act_cfg=dict(type='ReLU'),
121
+ encoder=dict(
122
+ init_cfg=None,
123
+ layer_cfg=dict(
124
+ ffn_cfg=dict(
125
+ act_cfg=dict(inplace=True, type='ReLU'),
126
+ embed_dims=256,
127
+ feedforward_channels=1024,
128
+ ffn_drop=0.0,
129
+ num_fcs=2),
130
+ self_attn_cfg=dict(
131
+ batch_first=True,
132
+ dropout=0.0,
133
+ embed_dims=256,
134
+ im2col_step=64,
135
+ init_cfg=None,
136
+ norm_cfg=None,
137
+ num_heads=8,
138
+ num_levels=3,
139
+ num_points=4)),
140
+ num_layers=6),
141
+ init_cfg=None,
142
+ norm_cfg=dict(num_groups=32, type='GN'),
143
+ num_outs=3,
144
+ positional_encoding=dict(normalize=True, num_feats=128),
145
+ type='mmdet.MSDeformAttnPixelDecoder'),
146
+ positional_encoding=dict(normalize=True, num_feats=128),
147
+ strides=[
148
+ 4,
149
+ 8,
150
+ 16,
151
+ 32,
152
+ ],
153
+ train_cfg=dict(
154
+ assigner=dict(
155
+ match_costs=[
156
+ dict(type='mmdet.ClassificationCost', weight=2.0),
157
+ dict(
158
+ type='mmdet.CrossEntropyLossCost',
159
+ use_sigmoid=True,
160
+ weight=5.0),
161
+ dict(
162
+ eps=1.0,
163
+ pred_act=True,
164
+ type='mmdet.DiceCost',
165
+ weight=5.0),
166
+ ],
167
+ type='mmdet.HungarianAssigner'),
168
+ importance_sample_ratio=0.75,
169
+ num_points=12544,
170
+ oversample_ratio=3.0,
171
+ sampler=dict(type='mmdet.MaskPseudoSampler')),
172
+ transformer_decoder=dict(
173
+ init_cfg=None,
174
+ layer_cfg=dict(
175
+ cross_attn_cfg=dict(
176
+ attn_drop=0.0,
177
+ batch_first=True,
178
+ dropout_layer=None,
179
+ embed_dims=256,
180
+ num_heads=8,
181
+ proj_drop=0.0),
182
+ ffn_cfg=dict(
183
+ act_cfg=dict(inplace=True, type='ReLU'),
184
+ add_identity=True,
185
+ dropout_layer=None,
186
+ embed_dims=256,
187
+ feedforward_channels=2048,
188
+ ffn_drop=0.0,
189
+ num_fcs=2),
190
+ self_attn_cfg=dict(
191
+ attn_drop=0.0,
192
+ batch_first=True,
193
+ dropout_layer=None,
194
+ embed_dims=256,
195
+ num_heads=8,
196
+ proj_drop=0.0)),
197
+ num_layers=9,
198
+ return_intermediate=True),
199
+ type='Mask2FormerHead'),
200
+ test_cfg=dict(mode='whole'),
201
+ train_cfg=dict(),
202
+ type='EncoderDecoder')
203
+
204
+
205
+
cloud_adapter/__init__.py ADDED
File without changes
cloud_adapter/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (161 Bytes). View file
 
cloud_adapter/__pycache__/cloud_adapter.cpython-38.pyc ADDED
Binary file (18.3 kB). View file
 
cloud_adapter/__pycache__/cloud_adapter_dinov2.cpython-38.pyc ADDED
Binary file (4.08 kB). View file
 
cloud_adapter/__pycache__/dino_v2.cpython-38.pyc ADDED
Binary file (10.2 kB). View file
 
cloud_adapter/__pycache__/utils.cpython-38.pyc ADDED
Binary file (1.94 kB). View file
 
cloud_adapter/cdnetv1.py ADDED
@@ -0,0 +1,389 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # @Time : 2024/7/24 上午11:36
3
+ # @Author : xiaoshun
4
+ # @Email : [email protected]
5
+ # @File : cdnetv1.py
6
+ # @Software: PyCharm
7
+
8
+ """Cloud detection Network"""
9
+
10
+ """Cloud detection Network"""
11
+
12
+ """
13
+ This is the implementation of CDnetV1 without multi-scale inputs. This implementation uses ResNet by default.
14
+ """
15
+
16
+ import torch
17
+ import torch.nn as nn
18
+ import torch.nn.functional as F
19
+
20
+ affine_par = True
21
+
22
+
23
+ def conv3x3(in_planes, out_planes, stride=1):
24
+ "3x3 convolution with padding"
25
+ return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
26
+ padding=1, bias=False)
27
+
28
+
29
+ class BasicBlock(nn.Module):
30
+ expansion = 1
31
+
32
+ def __init__(self, inplanes, planes, stride=1, downsample=None):
33
+ super(BasicBlock, self).__init__()
34
+ self.conv1 = conv3x3(inplanes, planes, stride)
35
+ self.bn1 = nn.BatchNorm2d(planes, affine=affine_par)
36
+ self.relu = nn.ReLU(inplace=True)
37
+ self.conv2 = conv3x3(planes, planes)
38
+ self.bn2 = nn.BatchNorm2d(planes, affine=affine_par)
39
+ self.downsample = downsample
40
+ self.stride = stride
41
+
42
+ def forward(self, x):
43
+ residual = x
44
+
45
+ out = self.conv1(x)
46
+ out = self.bn1(out)
47
+ out = self.relu(out)
48
+
49
+ out = self.conv2(out)
50
+ out = self.bn2(out)
51
+
52
+ if self.downsample is not None:
53
+ residual = self.downsample(x)
54
+
55
+ out += residual
56
+ out = self.relu(out)
57
+
58
+ return out
59
+
60
+
61
+ class Bottleneck(nn.Module):
62
+ expansion = 4
63
+
64
+ def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None):
65
+ super(Bottleneck, self).__init__()
66
+ self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, stride=stride, bias=False) # change
67
+ self.bn1 = nn.BatchNorm2d(planes, affine=affine_par)
68
+ for i in self.bn1.parameters():
69
+ i.requires_grad = False
70
+
71
+ padding = dilation
72
+ self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, # change
73
+ padding=padding, bias=False, dilation=dilation)
74
+ self.bn2 = nn.BatchNorm2d(planes, affine=affine_par)
75
+ for i in self.bn2.parameters():
76
+ i.requires_grad = False
77
+ self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
78
+ self.bn3 = nn.BatchNorm2d(planes * 4, affine=affine_par)
79
+ for i in self.bn3.parameters():
80
+ i.requires_grad = False
81
+ self.relu = nn.ReLU(inplace=True)
82
+ self.downsample = downsample
83
+ self.stride = stride
84
+
85
+ def forward(self, x):
86
+ residual = x
87
+
88
+ out = self.conv1(x)
89
+ out = self.bn1(out)
90
+ out = self.relu(out)
91
+
92
+ out = self.conv2(out)
93
+ out = self.bn2(out)
94
+ out = self.relu(out)
95
+
96
+ out = self.conv3(out)
97
+ out = self.bn3(out)
98
+
99
+ if self.downsample is not None:
100
+ residual = self.downsample(x)
101
+
102
+ out += residual
103
+ out = self.relu(out)
104
+
105
+ return out
106
+
107
+
108
+ class Classifier_Module(nn.Module):
109
+
110
+ def __init__(self, dilation_series, padding_series, num_classes):
111
+ super(Classifier_Module, self).__init__()
112
+ self.conv2d_list = nn.ModuleList()
113
+ for dilation, padding in zip(dilation_series, padding_series):
114
+ self.conv2d_list.append(
115
+ nn.Conv2d(2048, num_classes, kernel_size=3, stride=1, padding=padding, dilation=dilation, bias=True))
116
+
117
+ for m in self.conv2d_list:
118
+ m.weight.data.normal_(0, 0.01)
119
+
120
+ def forward(self, x):
121
+ out = self.conv2d_list[0](x)
122
+ for i in range(len(self.conv2d_list) - 1):
123
+ out += self.conv2d_list[i + 1](x)
124
+ return out
125
+
126
+
127
+ class _ConvBNReLU(nn.Module):
128
+ def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0,
129
+ dilation=1, groups=1, norm_layer=nn.BatchNorm2d):
130
+ super(_ConvBNReLU, self).__init__()
131
+ self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias=False)
132
+ self.bn = norm_layer(out_channels)
133
+ self.relu = nn.ReLU(True)
134
+
135
+ def forward(self, x):
136
+ x = self.conv(x)
137
+ x = self.bn(x)
138
+ x = self.relu(x)
139
+ return x
140
+
141
+
142
+ class _ASPPConv(nn.Module):
143
+ def __init__(self, in_channels, out_channels, atrous_rate, norm_layer):
144
+ super(_ASPPConv, self).__init__()
145
+ self.block = nn.Sequential(
146
+ nn.Conv2d(in_channels, out_channels, 3, padding=atrous_rate, dilation=atrous_rate, bias=False),
147
+ norm_layer(out_channels),
148
+ nn.ReLU(True)
149
+ )
150
+
151
+ def forward(self, x):
152
+ return self.block(x)
153
+
154
+
155
+ class _AsppPooling(nn.Module):
156
+ def __init__(self, in_channels, out_channels, norm_layer):
157
+ super(_AsppPooling, self).__init__()
158
+ self.gap = nn.Sequential(
159
+ nn.AdaptiveAvgPool2d(1),
160
+ nn.Conv2d(in_channels, out_channels, 1, bias=False),
161
+ norm_layer(out_channels),
162
+ nn.ReLU(True)
163
+ )
164
+
165
+ def forward(self, x):
166
+ size = x.size()[2:]
167
+ pool = self.gap(x)
168
+ out = F.interpolate(pool, size, mode='bilinear', align_corners=True)
169
+ return out
170
+
171
+
172
+ class _ASPP(nn.Module):
173
+ def __init__(self, in_channels, atrous_rates, norm_layer):
174
+ super(_ASPP, self).__init__()
175
+ out_channels = 512 # changed from 256
176
+ self.b0 = nn.Sequential(
177
+ nn.Conv2d(in_channels, out_channels, 1, bias=False),
178
+ norm_layer(out_channels),
179
+ nn.ReLU(True)
180
+ )
181
+
182
+ rate1, rate2, rate3 = tuple(atrous_rates)
183
+ self.b1 = _ASPPConv(in_channels, out_channels, rate1, norm_layer)
184
+ self.b2 = _ASPPConv(in_channels, out_channels, rate2, norm_layer)
185
+ self.b3 = _ASPPConv(in_channels, out_channels, rate3, norm_layer)
186
+ self.b4 = _AsppPooling(in_channels, out_channels, norm_layer=norm_layer)
187
+
188
+ # self.project = nn.Sequential(
189
+ # nn.Conv2d(5 * out_channels, out_channels, 1, bias=False),
190
+ # norm_layer(out_channels),
191
+ # nn.ReLU(True),
192
+ # nn.Dropout(0.5))
193
+ self.dropout2d = nn.Dropout2d(0.3)
194
+
195
+ def forward(self, x):
196
+ feat1 = self.dropout2d(self.b0(x))
197
+ feat2 = self.dropout2d(self.b1(x))
198
+ feat3 = self.dropout2d(self.b2(x))
199
+ feat4 = self.dropout2d(self.b3(x))
200
+ feat5 = self.dropout2d(self.b4(x))
201
+ x = torch.cat((feat1, feat2, feat3, feat4, feat5), dim=1)
202
+ # x = self.project(x)
203
+ return x
204
+
205
+
206
+ class _FPM(nn.Module):
207
+ def __init__(self, in_channels, num_classes, norm_layer=nn.BatchNorm2d):
208
+ super(_FPM, self).__init__()
209
+ self.aspp = _ASPP(in_channels, [6, 12, 18], norm_layer=norm_layer)
210
+ # self.dropout2d = nn.Dropout2d(0.5)
211
+
212
+ def forward(self, x):
213
+ x = torch.cat((x, self.aspp(x)), dim=1)
214
+ # x = self.dropout2d(x) # added
215
+ return x
216
+
217
+
218
+ class BR(nn.Module):
219
+ def __init__(self, num_classes, stride=1, downsample=None):
220
+ super(BR, self).__init__()
221
+ self.conv1 = conv3x3(num_classes, num_classes * 16, stride)
222
+ self.relu = nn.ReLU(inplace=True)
223
+ self.conv2 = conv3x3(num_classes * 16, num_classes)
224
+ self.stride = stride
225
+
226
+ def forward(self, x):
227
+ residual = x
228
+
229
+ out = self.conv1(x)
230
+ out = self.relu(out)
231
+
232
+ out = self.conv2(out)
233
+ out += residual
234
+
235
+ return out
236
+
237
+
238
+ class CDnetV1(nn.Module):
239
+ def __init__(self, in_channels=3,block=Bottleneck, layers=[3, 4, 6, 3], num_classes=21, aux=True):
240
+ self.inplanes = 64
241
+ self.aux = aux
242
+ super().__init__()
243
+ # self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
244
+ # self.bn1 = nn.BatchNorm2d(64, affine = affine_par)
245
+
246
+ self.conv1 = nn.Conv2d(in_channels, 64, kernel_size=3, stride=2, padding=1, bias=False)
247
+ self.bn1 = nn.BatchNorm2d(64, affine=affine_par)
248
+ self.conv2 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=False)
249
+ self.bn2 = nn.BatchNorm2d(64, affine=affine_par)
250
+ self.conv3 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=False)
251
+ self.bn3 = nn.BatchNorm2d(64, affine=affine_par)
252
+
253
+ for i in self.bn1.parameters():
254
+ i.requires_grad = False
255
+ self.relu = nn.ReLU(inplace=True)
256
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1, ceil_mode=True) # change
257
+ self.layer1 = self._make_layer(block, 64, layers[0])
258
+ self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
259
+ self.layer3 = self._make_layer(block, 256, layers[2], stride=1, dilation=2)
260
+ self.layer4 = self._make_layer(block, 512, layers[3], stride=1, dilation=4)
261
+ # self.layer5 = self._make_pred_layer(Classifier_Module, [6,12,18,24],[6,12,18,24],num_classes)
262
+
263
+ self.res5_con1x1 = nn.Sequential(
264
+ nn.Conv2d(1024 + 2048, 512, kernel_size=1, stride=1, padding=0),
265
+ nn.BatchNorm2d(512),
266
+ nn.ReLU(True)
267
+ )
268
+
269
+ self.fpm1 = _FPM(512, num_classes)
270
+ self.fpm2 = _FPM(512, num_classes)
271
+ self.fpm3 = _FPM(256, num_classes)
272
+
273
+ self.br1 = BR(num_classes)
274
+ self.br2 = BR(num_classes)
275
+ self.br3 = BR(num_classes)
276
+ self.br4 = BR(num_classes)
277
+ self.br5 = BR(num_classes)
278
+ self.br6 = BR(num_classes)
279
+ self.br7 = BR(num_classes)
280
+
281
+ self.predict1 = self._predict_layer(512 * 6, num_classes)
282
+ self.predict2 = self._predict_layer(512 * 6, num_classes)
283
+ self.predict3 = self._predict_layer(512 * 5 + 256, num_classes)
284
+
285
+ for m in self.modules():
286
+ if isinstance(m, nn.Conv2d):
287
+ n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
288
+ m.weight.data.normal_(0, 0.01)
289
+ elif isinstance(m, nn.BatchNorm2d):
290
+ m.weight.data.fill_(1)
291
+ m.bias.data.zero_()
292
+ # for i in m.parameters():
293
+ # i.requires_grad = False
294
+
295
+ def _predict_layer(self, in_channels, num_classes):
296
+ return nn.Sequential(nn.Conv2d(in_channels, 256, kernel_size=1, stride=1, padding=0),
297
+ nn.BatchNorm2d(256),
298
+ nn.ReLU(True),
299
+ nn.Dropout2d(0.1),
300
+ nn.Conv2d(256, num_classes, kernel_size=3, stride=1, padding=1, bias=True))
301
+
302
+ def _make_layer(self, block, planes, blocks, stride=1, dilation=1):
303
+ downsample = None
304
+ if stride != 1 or self.inplanes != planes * block.expansion or dilation == 2 or dilation == 4:
305
+ downsample = nn.Sequential(
306
+ nn.Conv2d(self.inplanes, planes * block.expansion,
307
+ kernel_size=1, stride=stride, bias=False),
308
+ nn.BatchNorm2d(planes * block.expansion, affine=affine_par))
309
+ for i in downsample._modules['1'].parameters():
310
+ i.requires_grad = False
311
+ layers = []
312
+ layers.append(block(self.inplanes, planes, stride, dilation=dilation, downsample=downsample))
313
+ self.inplanes = planes * block.expansion
314
+ for i in range(1, blocks):
315
+ layers.append(block(self.inplanes, planes, dilation=dilation))
316
+
317
+ return nn.Sequential(*layers)
318
+
319
+ # def _make_pred_layer(self,block, dilation_series, padding_series,num_classes):
320
+ # return block(dilation_series,padding_series,num_classes)
321
+
322
+ def base_forward(self, x):
323
+ x = self.relu(self.bn1(self.conv1(x)))
324
+ size_conv1 = x.size()[2:]
325
+ x = self.relu(self.bn2(self.conv2(x)))
326
+ x = self.relu(self.bn3(self.conv3(x)))
327
+ x = self.maxpool(x)
328
+ x = self.layer1(x)
329
+ res2 = x
330
+ x = self.layer2(x)
331
+ res3 = x
332
+ x = self.layer3(x)
333
+ res4 = x
334
+ x = self.layer4(x)
335
+ x = self.res5_con1x1(torch.cat([x, res4], dim=1))
336
+
337
+ return x, res3, res2, size_conv1
338
+
339
+ def forward(self, x):
340
+ size = x.size()[2:]
341
+ score1, score2, score3, size_conv1 = self.base_forward(x)
342
+ # outputs = list()
343
+ score1 = self.fpm1(score1)
344
+ score1 = self.predict1(score1) # 1/8
345
+ predict1 = score1
346
+ score1 = self.br1(score1)
347
+
348
+ score2 = self.fpm2(score2)
349
+ score2 = self.predict2(score2) # 1/8
350
+ predict2 = score2
351
+
352
+ # first fusion
353
+ score2 = self.br2(score2) + score1
354
+ score2 = self.br3(score2)
355
+
356
+ score3 = self.fpm3(score3)
357
+ score3 = self.predict3(score3) # 1/4
358
+ predict3 = score3
359
+ score3 = self.br4(score3)
360
+
361
+ # second fusion
362
+ size_score3 = score3.size()[2:]
363
+ score3 = score3 + F.interpolate(score2, size_score3, mode='bilinear', align_corners=True)
364
+ score3 = self.br5(score3)
365
+
366
+ # upsampling + BR
367
+ score3 = F.interpolate(score3, size_conv1, mode='bilinear', align_corners=True)
368
+ score3 = self.br6(score3)
369
+ score3 = F.interpolate(score3, size, mode='bilinear', align_corners=True)
370
+ score3 = self.br7(score3)
371
+
372
+ # if self.aux:
373
+ # auxout = self.dsn(mid)
374
+ # auxout = F.interpolate(auxout, size, mode='bilinear', align_corners=True)
375
+ # #outputs.append(auxout)
376
+ return score3
377
+ # return score3, predict1, predict2, predict3
378
+
379
+
380
+ if __name__ == '__main__':
381
+ model = CDnetV1(num_classes=21)
382
+ fake_image = torch.randn(2, 3, 224, 224)
383
+ outputs = model(fake_image)
384
+ for out in outputs:
385
+ print(out.shape)
386
+ # torch.Size([2, 21, 224, 224])
387
+ # torch.Size([2, 21, 29, 29])
388
+ # torch.Size([2, 21, 29, 29])
389
+ # torch.Size([2, 21, 57, 57])
cloud_adapter/cdnetv2.py ADDED
@@ -0,0 +1,693 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # @Time : 2024/7/24 下午3:41
3
+ # @Author : xiaoshun
4
+ # @Email : [email protected]
5
+ # @File : cdnetv2.py
6
+ # @Software: PyCharm
7
+
8
+ """Cloud detection Network"""
9
+
10
+ """
11
+ This is the implementation of CDnetV2 without multi-scale inputs. This implementation uses ResNet by default.
12
+ """
13
+ # nn.GroupNorm
14
+
15
+ import torch
16
+ # import torch.nn as nn
17
+ import torch.nn.functional as F
18
+ from torch import nn
19
+
20
+ affine_par = True
21
+
22
+
23
+ def conv3x3(in_planes, out_planes, stride=1):
24
+ "3x3 convolution with padding"
25
+ return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
26
+ padding=1, bias=False)
27
+
28
+
29
+ class BasicBlock(nn.Module):
30
+ expansion = 1
31
+
32
+ def __init__(self, inplanes, planes, stride=1, downsample=None):
33
+ super(BasicBlock, self).__init__()
34
+ self.conv1 = conv3x3(inplanes, planes, stride)
35
+ self.bn1 = nn.BatchNorm2d(planes, affine=affine_par)
36
+ self.relu = nn.ReLU(inplace=True)
37
+ self.conv2 = conv3x3(planes, planes)
38
+ self.bn2 = nn.BatchNorm2d(planes, affine=affine_par)
39
+ self.downsample = downsample
40
+ self.stride = stride
41
+
42
+ def forward(self, x):
43
+ residual = x
44
+
45
+ out = self.conv1(x)
46
+ out = self.bn1(out)
47
+ out = self.relu(out)
48
+
49
+ out = self.conv2(out)
50
+ out = self.bn2(out)
51
+
52
+ if self.downsample is not None:
53
+ residual = self.downsample(x)
54
+
55
+ out += residual
56
+ out = self.relu(out)
57
+
58
+ return out
59
+
60
+
61
+ class Bottleneck(nn.Module):
62
+ expansion = 4
63
+
64
+ def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None):
65
+ super(Bottleneck, self).__init__()
66
+ self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, stride=stride, bias=False) # change
67
+ self.bn1 = nn.BatchNorm2d(planes, affine=affine_par)
68
+ for i in self.bn1.parameters():
69
+ i.requires_grad = False
70
+
71
+ padding = dilation
72
+ self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, # change
73
+ padding=padding, bias=False, dilation=dilation)
74
+ self.bn2 = nn.BatchNorm2d(planes, affine=affine_par)
75
+ for i in self.bn2.parameters():
76
+ i.requires_grad = False
77
+ self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
78
+ self.bn3 = nn.BatchNorm2d(planes * 4, affine=affine_par)
79
+ for i in self.bn3.parameters():
80
+ i.requires_grad = False
81
+ self.relu = nn.ReLU(inplace=True)
82
+ self.downsample = downsample
83
+ self.stride = stride
84
+
85
+ def forward(self, x):
86
+ residual = x
87
+
88
+ out = self.conv1(x)
89
+ out = self.bn1(out)
90
+ out = self.relu(out)
91
+
92
+ out = self.conv2(out)
93
+ out = self.bn2(out)
94
+ out = self.relu(out)
95
+
96
+ out = self.conv3(out)
97
+ out = self.bn3(out)
98
+
99
+ if self.downsample is not None:
100
+ residual = self.downsample(x)
101
+
102
+ out += residual
103
+ out = self.relu(out)
104
+
105
+ return out
106
+
107
+ # self.layerx_1 = Bottleneck_nosample(64, 64, stride=1, dilation=1)
108
+ # self.layerx_2 = Bottleneck(256, 64, stride=1, dilation=1, downsample=None)
109
+ # self.layerx_3 = Bottleneck_downsample(256, 64, stride=2, dilation=1)
110
+
111
+
112
+ class Res_block_1(nn.Module):
113
+ expansion = 4
114
+
115
+ def __init__(self, inplanes=64, planes=64, stride=1, dilation=1):
116
+ super(Res_block_1, self).__init__()
117
+
118
+ self.conv1 = nn.Sequential(
119
+ nn.Conv2d(inplanes, planes, kernel_size=1, stride=1, bias=False),
120
+ nn.GroupNorm(8, planes),
121
+ nn.ReLU(inplace=True))
122
+
123
+ self.conv2 = nn.Sequential(
124
+ nn.Conv2d(planes, planes, kernel_size=3, stride=1,
125
+ padding=1, bias=False, dilation=1),
126
+ nn.GroupNorm(8, planes),
127
+ nn.ReLU(inplace=True))
128
+
129
+ self.conv3 = nn.Sequential(
130
+ nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False),
131
+ nn.GroupNorm(8, planes * 4))
132
+
133
+ self.relu = nn.ReLU(inplace=True)
134
+
135
+ self.down_sample = nn.Sequential(
136
+ nn.Conv2d(inplanes, planes * 4,
137
+ kernel_size=1, stride=1, bias=False),
138
+ nn.GroupNorm(8, planes * 4))
139
+
140
+ def forward(self, x):
141
+ # residual = x
142
+
143
+ out = self.conv1(x)
144
+ out = self.conv2(out)
145
+ out = self.conv3(out)
146
+ residual = self.down_sample(x)
147
+ out += residual
148
+ out = self.relu(out)
149
+
150
+ return out
151
+
152
+
153
+ class Res_block_2(nn.Module):
154
+ expansion = 4
155
+
156
+ def __init__(self, inplanes=256, planes=64, stride=1, dilation=1):
157
+ super(Res_block_2, self).__init__()
158
+
159
+ self.conv1 = nn.Sequential(
160
+ nn.Conv2d(inplanes, planes, kernel_size=1, stride=1, bias=False),
161
+ nn.GroupNorm(8, planes),
162
+ nn.ReLU(inplace=True))
163
+
164
+ self.conv2 = nn.Sequential(
165
+ nn.Conv2d(planes, planes, kernel_size=3, stride=1,
166
+ padding=1, bias=False, dilation=1),
167
+ nn.GroupNorm(8, planes),
168
+ nn.ReLU(inplace=True))
169
+
170
+ self.conv3 = nn.Sequential(
171
+ nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False),
172
+ nn.GroupNorm(8, planes * 4))
173
+
174
+ self.relu = nn.ReLU(inplace=True)
175
+
176
+ def forward(self, x):
177
+ residual = x
178
+
179
+ out = self.conv1(x)
180
+ out = self.conv2(out)
181
+ out = self.conv3(out)
182
+
183
+ out += residual
184
+ out = self.relu(out)
185
+
186
+ return out
187
+
188
+
189
+ class Res_block_3(nn.Module):
190
+ expansion = 4
191
+
192
+ def __init__(self, inplanes=256, planes=64, stride=1, dilation=1):
193
+ super(Res_block_3, self).__init__()
194
+
195
+ self.conv1 = nn.Sequential(
196
+ nn.Conv2d(inplanes, planes, kernel_size=1, stride=stride, bias=False),
197
+ nn.GroupNorm(8, planes),
198
+ nn.ReLU(inplace=True))
199
+
200
+ self.conv2 = nn.Sequential(
201
+ nn.Conv2d(planes, planes, kernel_size=3, stride=1,
202
+ padding=1, bias=False, dilation=1),
203
+ nn.GroupNorm(8, planes),
204
+ nn.ReLU(inplace=True))
205
+
206
+ self.conv3 = nn.Sequential(
207
+ nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False),
208
+ nn.GroupNorm(8, planes * 4))
209
+
210
+ self.relu = nn.ReLU(inplace=True)
211
+
212
+ self.downsample = nn.Sequential(
213
+ nn.Conv2d(inplanes, planes * 4,
214
+ kernel_size=1, stride=stride, bias=False),
215
+ nn.GroupNorm(8, planes * 4))
216
+
217
+ def forward(self, x):
218
+ # residual = x
219
+
220
+ out = self.conv1(x)
221
+ out = self.conv2(out)
222
+ out = self.conv3(out)
223
+ # residual = self.downsample(x)
224
+ out += self.downsample(x)
225
+ out = self.relu(out)
226
+
227
+ return out
228
+
229
+
230
+ class Classifier_Module(nn.Module):
231
+
232
+ def __init__(self, dilation_series, padding_series, num_classes):
233
+ super(Classifier_Module, self).__init__()
234
+ self.conv2d_list = nn.ModuleList()
235
+ for dilation, padding in zip(dilation_series, padding_series):
236
+ self.conv2d_list.append(
237
+ nn.Conv2d(2048, num_classes, kernel_size=3, stride=1, padding=padding, dilation=dilation, bias=True))
238
+
239
+ for m in self.conv2d_list:
240
+ m.weight.data.normal_(0, 0.01)
241
+
242
+ def forward(self, x):
243
+ out = self.conv2d_list[0](x)
244
+ for i in range(len(self.conv2d_list) - 1):
245
+ out += self.conv2d_list[i + 1](x)
246
+ return out
247
+
248
+
249
+ class _ConvBNReLU(nn.Module):
250
+ def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0,
251
+ dilation=1, groups=1, relu6=False, norm_layer=nn.BatchNorm2d):
252
+ super(_ConvBNReLU, self).__init__()
253
+ self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias=False)
254
+ self.bn = norm_layer(out_channels)
255
+ self.relu = nn.ReLU6(True) if relu6 else nn.ReLU(True)
256
+
257
+ def forward(self, x):
258
+ x = self.conv(x)
259
+ x = self.bn(x)
260
+ x = self.relu(x)
261
+ return x
262
+
263
+
264
+ class _ASPPConv(nn.Module):
265
+ def __init__(self, in_channels, out_channels, atrous_rate, norm_layer):
266
+ super(_ASPPConv, self).__init__()
267
+ self.block = nn.Sequential(
268
+ nn.Conv2d(in_channels, out_channels, 3, padding=atrous_rate, dilation=atrous_rate, bias=False),
269
+ norm_layer(out_channels),
270
+ nn.ReLU(True)
271
+ )
272
+
273
+ def forward(self, x):
274
+ return self.block(x)
275
+
276
+
277
+ class _AsppPooling(nn.Module):
278
+ def __init__(self, in_channels, out_channels, norm_layer):
279
+ super(_AsppPooling, self).__init__()
280
+ self.gap = nn.Sequential(
281
+ nn.AdaptiveAvgPool2d(1),
282
+ nn.Conv2d(in_channels, out_channels, 1, bias=False),
283
+ norm_layer(out_channels),
284
+ nn.ReLU(True)
285
+ )
286
+
287
+ def forward(self, x):
288
+ size = x.size()[2:]
289
+ pool = self.gap(x)
290
+ out = F.interpolate(pool, size, mode='bilinear', align_corners=True)
291
+ return out
292
+
293
+
294
+ class _ASPP(nn.Module):
295
+ def __init__(self, in_channels, atrous_rates, norm_layer):
296
+ super(_ASPP, self).__init__()
297
+ out_channels = 256
298
+ self.b0 = nn.Sequential(
299
+ nn.Conv2d(in_channels, out_channels, 1, bias=False),
300
+ norm_layer(out_channels),
301
+ nn.ReLU(True)
302
+ )
303
+
304
+ rate1, rate2, rate3 = tuple(atrous_rates)
305
+ self.b1 = _ASPPConv(in_channels, out_channels, rate1, norm_layer)
306
+ self.b2 = _ASPPConv(in_channels, out_channels, rate2, norm_layer)
307
+ self.b3 = _ASPPConv(in_channels, out_channels, rate3, norm_layer)
308
+ self.b4 = _AsppPooling(in_channels, out_channels, norm_layer=norm_layer)
309
+
310
+ self.project = nn.Sequential(
311
+ nn.Conv2d(5 * out_channels, out_channels, 1, bias=False),
312
+ norm_layer(out_channels),
313
+ nn.ReLU(True),
314
+ nn.Dropout(0.5)
315
+ )
316
+
317
+ def forward(self, x):
318
+ feat1 = self.b0(x)
319
+ feat2 = self.b1(x)
320
+ feat3 = self.b2(x)
321
+ feat4 = self.b3(x)
322
+ feat5 = self.b4(x)
323
+ x = torch.cat((feat1, feat2, feat3, feat4, feat5), dim=1)
324
+ x = self.project(x)
325
+ return x
326
+
327
+
328
+ class _DeepLabHead(nn.Module):
329
+ def __init__(self, num_classes, c1_channels=256, norm_layer=nn.BatchNorm2d):
330
+ super(_DeepLabHead, self).__init__()
331
+ self.aspp = _ASPP(2048, [12, 24, 36], norm_layer=norm_layer)
332
+ self.c1_block = _ConvBNReLU(c1_channels, 48, 3, padding=1, norm_layer=norm_layer)
333
+ self.block = nn.Sequential(
334
+ _ConvBNReLU(304, 256, 3, padding=1, norm_layer=norm_layer),
335
+ nn.Dropout(0.5),
336
+ _ConvBNReLU(256, 256, 3, padding=1, norm_layer=norm_layer),
337
+ nn.Dropout(0.1),
338
+ nn.Conv2d(256, num_classes, 1))
339
+
340
+ def forward(self, x, c1):
341
+ size = c1.size()[2:]
342
+ c1 = self.c1_block(c1)
343
+ x = self.aspp(x)
344
+ x = F.interpolate(x, size, mode='bilinear', align_corners=True)
345
+ return self.block(torch.cat([x, c1], dim=1))
346
+
347
+
348
+ class _CARM(nn.Module):
349
+ def __init__(self, in_planes, ratio=8):
350
+ super(_CARM, self).__init__()
351
+ self.avg_pool = nn.AdaptiveAvgPool2d(1)
352
+ self.max_pool = nn.AdaptiveMaxPool2d(1)
353
+
354
+ self.fc1_1 = nn.Linear(in_planes, in_planes // ratio)
355
+ self.fc1_2 = nn.Linear(in_planes // ratio, in_planes)
356
+
357
+ self.fc2_1 = nn.Linear(in_planes, in_planes // ratio)
358
+ self.fc2_2 = nn.Linear(in_planes // ratio, in_planes)
359
+ self.relu = nn.ReLU(True)
360
+
361
+ self.sigmoid = nn.Sigmoid()
362
+
363
+ def forward(self, x):
364
+ avg_out = self.avg_pool(x)
365
+ avg_out = avg_out.view(avg_out.size(0), -1)
366
+ avg_out = self.fc1_2(self.relu(self.fc1_1(avg_out)))
367
+
368
+ max_out = self.max_pool(x)
369
+ max_out = max_out.view(max_out.size(0), -1)
370
+ max_out = self.fc2_2(self.relu(self.fc2_1(max_out)))
371
+
372
+ max_out_size = max_out.size()[1]
373
+ avg_out = torch.reshape(avg_out, (-1, max_out_size, 1, 1))
374
+ max_out = torch.reshape(max_out, (-1, max_out_size, 1, 1))
375
+
376
+ out = self.sigmoid(avg_out + max_out)
377
+
378
+ x = out * x
379
+ return x
380
+
381
+
382
+ class FSFB_CH(nn.Module):
383
+ def __init__(self, in_planes, num, ratio=8):
384
+ super(FSFB_CH, self).__init__()
385
+ self.avg_pool = nn.AdaptiveAvgPool2d(1)
386
+ self.max_pool = nn.AdaptiveMaxPool2d(1)
387
+
388
+ self.fc1_1 = nn.Linear(in_planes, in_planes // ratio)
389
+ self.fc1_2 = nn.Linear(in_planes // ratio, num * in_planes)
390
+
391
+ self.fc2_1 = nn.Linear(in_planes, in_planes // ratio)
392
+ self.fc2_2 = nn.Linear(in_planes // ratio, num * in_planes)
393
+ self.relu = nn.ReLU(True)
394
+
395
+ self.fc3 = nn.Linear(num * in_planes, 2 * num * in_planes)
396
+ self.fc4 = nn.Linear(2 * num * in_planes, 2 * num * in_planes)
397
+ self.fc5 = nn.Linear(2 * num * in_planes, num * in_planes)
398
+
399
+ self.softmax = nn.Softmax(dim=3)
400
+
401
+ def forward(self, x, num):
402
+ avg_out = self.avg_pool(x)
403
+ avg_out = avg_out.view(avg_out.size(0), -1)
404
+ avg_out = self.fc1_2(self.relu(self.fc1_1(avg_out)))
405
+
406
+ max_out = self.max_pool(x)
407
+ max_out = max_out.view(max_out.size(0), -1)
408
+ max_out = self.fc2_2(self.relu(self.fc2_1(max_out)))
409
+
410
+ out = avg_out + max_out
411
+ out = self.relu(self.fc3(out))
412
+ out = self.relu(self.fc4(out))
413
+ out = self.relu(self.fc5(out)) # (N, num*in_planes)
414
+
415
+ out_size = out.size()[1]
416
+ out = torch.reshape(out, (-1, out_size // num, 1, num)) # (N, in_planes, 1, num )
417
+ out = self.softmax(out)
418
+
419
+ channel_scale = torch.chunk(out, num, dim=3) # (N, in_planes, 1, 1 )
420
+
421
+ return channel_scale
422
+
423
+
424
+ class FSFB_SP(nn.Module):
425
+ def __init__(self, num, norm_layer=nn.BatchNorm2d):
426
+ super(FSFB_SP, self).__init__()
427
+ self.conv = nn.Sequential(
428
+ nn.Conv2d(2, 2 * num, kernel_size=3, padding=1, bias=False),
429
+ norm_layer(2 * num),
430
+ nn.ReLU(True),
431
+ nn.Conv2d(2 * num, 4 * num, kernel_size=3, padding=1, bias=False),
432
+ norm_layer(4 * num),
433
+ nn.ReLU(True),
434
+ nn.Conv2d(4 * num, 4 * num, kernel_size=3, padding=1, bias=False),
435
+ norm_layer(4 * num),
436
+ nn.ReLU(True),
437
+ nn.Conv2d(4 * num, 2 * num, kernel_size=3, padding=1, bias=False),
438
+ norm_layer(2 * num),
439
+ nn.ReLU(True),
440
+ nn.Conv2d(2 * num, num, kernel_size=3, padding=1, bias=False)
441
+ )
442
+ self.softmax = nn.Softmax(dim=1)
443
+
444
+ def forward(self, x, num):
445
+ avg_out = torch.mean(x, dim=1, keepdim=True)
446
+ max_out, _ = torch.max(x, dim=1, keepdim=True)
447
+ x = torch.cat([avg_out, max_out], dim=1)
448
+ x = self.conv(x)
449
+ x = self.softmax(x)
450
+ spatial_scale = torch.chunk(x, num, dim=1)
451
+ return spatial_scale
452
+
453
+
454
+ ##################################################################################################################
455
+
456
+
457
+ class _HFFM(nn.Module):
458
+ def __init__(self, in_channels, atrous_rates, norm_layer=nn.BatchNorm2d):
459
+ super(_HFFM, self).__init__()
460
+ out_channels = 256
461
+ self.b0 = nn.Sequential(
462
+ nn.Conv2d(in_channels, out_channels, 1, bias=False),
463
+ norm_layer(out_channels),
464
+ nn.ReLU(True)
465
+ )
466
+
467
+ rate1, rate2, rate3 = tuple(atrous_rates)
468
+ self.b1 = _ASPPConv(in_channels, out_channels, rate1, norm_layer)
469
+ self.b2 = _ASPPConv(in_channels, out_channels, rate2, norm_layer)
470
+ self.b3 = _ASPPConv(in_channels, out_channels, rate3, norm_layer)
471
+ self.b4 = _AsppPooling(in_channels, out_channels, norm_layer=norm_layer)
472
+ self.carm = _CARM(in_channels)
473
+ self.sa = FSFB_SP(4, norm_layer)
474
+ self.ca = FSFB_CH(out_channels, 4, 8)
475
+
476
+ def forward(self, x, num):
477
+ x = self.carm(x)
478
+ # feat1 = self.b0(x)
479
+ feat1 = self.b1(x)
480
+ feat2 = self.b2(x)
481
+ feat3 = self.b3(x)
482
+ feat4 = self.b4(x)
483
+ feat = feat1 + feat2 + feat3 + feat4
484
+ spatial_atten = self.sa(feat, num)
485
+ channel_atten = self.ca(feat, num)
486
+
487
+ feat_ca = channel_atten[0] * feat1 + channel_atten[1] * feat2 + channel_atten[2] * feat3 + channel_atten[
488
+ 3] * feat4
489
+ feat_sa = spatial_atten[0] * feat1 + spatial_atten[1] * feat2 + spatial_atten[2] * feat3 + spatial_atten[
490
+ 3] * feat4
491
+ feat_sa = feat_sa + feat_ca
492
+
493
+ return feat_sa
494
+
495
+
496
+ class _AFFM(nn.Module):
497
+ def __init__(self, in_channels=256, norm_layer=nn.BatchNorm2d):
498
+ super(_AFFM, self).__init__()
499
+
500
+ self.sa = FSFB_SP(2, norm_layer)
501
+ self.ca = FSFB_CH(in_channels, 2, 8)
502
+ self.carm = _CARM(in_channels)
503
+
504
+ def forward(self, feat1, feat2, hffm, num):
505
+ feat = feat1 + feat2
506
+ spatial_atten = self.sa(feat, num)
507
+ channel_atten = self.ca(feat, num)
508
+
509
+ feat_ca = channel_atten[0] * feat1 + channel_atten[1] * feat2
510
+ feat_sa = spatial_atten[0] * feat1 + spatial_atten[1] * feat2
511
+ output = self.carm(feat_sa + feat_ca + hffm)
512
+ # output = self.carm (feat_sa + hffm)
513
+
514
+ return output, channel_atten, spatial_atten
515
+
516
+
517
+ class block_Conv3x3(nn.Module):
518
+ def __init__(self, in_channels):
519
+ super(block_Conv3x3, self).__init__()
520
+ self.block = nn.Sequential(
521
+ nn.Conv2d(in_channels, 256, kernel_size=3, stride=1, padding=1, bias=False),
522
+ nn.BatchNorm2d(256),
523
+ nn.ReLU(True)
524
+ )
525
+
526
+ def forward(self, x):
527
+ return self.block(x)
528
+
529
+
530
+ class CDnetV2(nn.Module):
531
+ def __init__(self, in_channels=3,block=Bottleneck, layers=[3, 4, 6, 3], num_classes=21, aux=True):
532
+ self.inplanes = 256 # change
533
+ self.aux = aux
534
+ super().__init__()
535
+ # self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
536
+ # self.bn1 = nn.BatchNorm2d(64, affine = affine_par)
537
+
538
+ self.conv1 = nn.Conv2d(in_channels, 64, kernel_size=3, stride=2, padding=1, bias=False)
539
+ self.bn1 = nn.BatchNorm2d(64, affine=affine_par)
540
+
541
+ self.conv2 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=False)
542
+ self.bn2 = nn.BatchNorm2d(64, affine=affine_par)
543
+
544
+ self.conv3 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=False)
545
+ self.bn3 = nn.BatchNorm2d(64, affine=affine_par)
546
+
547
+ self.relu = nn.ReLU(inplace=True)
548
+
549
+ self.dropout = nn.Dropout(0.3)
550
+ for i in self.bn1.parameters():
551
+ i.requires_grad = False
552
+
553
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1, ceil_mode=True) # change
554
+
555
+ # self.layer1 = self._make_layer(block, 64, layers[0])
556
+
557
+ self.layerx_1 = Res_block_1(64, 64, stride=1, dilation=1)
558
+ self.layerx_2 = Res_block_2(256, 64, stride=1, dilation=1)
559
+ self.layerx_3 = Res_block_3(256, 64, stride=2, dilation=1)
560
+
561
+ self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
562
+ self.layer3 = self._make_layer(block, 256, layers[2], stride=1, dilation=2)
563
+ self.layer4 = self._make_layer(block, 512, layers[3], stride=1, dilation=4)
564
+ # self.layer5 = self._make_pred_layer(Classifier_Module, [6,12,18,24],[6,12,18,24],num_classes)
565
+
566
+ self.hffm = _HFFM(2048, [6, 12, 18])
567
+ self.affm_1 = _AFFM()
568
+ self.affm_2 = _AFFM()
569
+ self.affm_3 = _AFFM()
570
+ self.affm_4 = _AFFM()
571
+ self.carm = _CARM(256)
572
+
573
+ self.con_layer1_1 = block_Conv3x3(256)
574
+ self.con_res2 = block_Conv3x3(256)
575
+ self.con_res3 = block_Conv3x3(512)
576
+ self.con_res4 = block_Conv3x3(1024)
577
+ self.con_res5 = block_Conv3x3(2048)
578
+
579
+ self.dsn1 = nn.Sequential(
580
+ nn.Conv2d(256, num_classes, kernel_size=1, stride=1, padding=0)
581
+ )
582
+
583
+ self.dsn2 = nn.Sequential(
584
+ nn.Conv2d(256, num_classes, kernel_size=1, stride=1, padding=0)
585
+ )
586
+
587
+ for m in self.modules():
588
+ if isinstance(m, nn.Conv2d):
589
+ n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
590
+ m.weight.data.normal_(0, 0.01)
591
+ elif isinstance(m, nn.BatchNorm2d):
592
+ m.weight.data.fill_(1)
593
+ m.bias.data.zero_()
594
+ # for i in m.parameters():
595
+ # i.requires_grad = False
596
+
597
+ # self.inplanes = 256 # change
598
+
599
+ def _make_layer(self, block, planes, blocks, stride=1, dilation=1):
600
+ downsample = None
601
+ if stride != 1 or self.inplanes != planes * block.expansion or dilation == 2 or dilation == 4:
602
+ downsample = nn.Sequential(
603
+ nn.Conv2d(self.inplanes, planes * block.expansion,
604
+ kernel_size=1, stride=stride, bias=False),
605
+ nn.BatchNorm2d(planes * block.expansion, affine=affine_par))
606
+ for i in downsample._modules['1'].parameters():
607
+ i.requires_grad = False
608
+ layers = []
609
+ layers.append(block(self.inplanes, planes, stride, dilation=dilation, downsample=downsample))
610
+ self.inplanes = planes * block.expansion
611
+ for i in range(1, blocks):
612
+ layers.append(block(self.inplanes, planes, dilation=dilation))
613
+
614
+ return nn.Sequential(*layers)
615
+
616
+ # def _make_pred_layer(self,block, dilation_series, padding_series,num_classes):
617
+ # return block(dilation_series,padding_series,num_classes)
618
+
619
+ def base_forward(self, x):
620
+ x = self.relu(self.bn1(self.conv1(x))) # 1/2
621
+ x = self.relu(self.bn2(self.conv2(x)))
622
+ x = self.relu(self.bn3(self.conv3(x)))
623
+ x = self.maxpool(x) # 1/4
624
+
625
+ # x = self.layer1(x) # 1/8
626
+
627
+ # layer1
628
+ x = self.layerx_1(x) # 1/4
629
+ layer1_0 = x
630
+
631
+ x = self.layerx_2(x) # 1/4
632
+ layer1_0 = self.con_layer1_1(x + layer1_0) # 256
633
+ size_layer1_0 = layer1_0.size()[2:]
634
+
635
+ x = self.layerx_3(x) # 1/8
636
+ res2 = self.con_res2(x) # 256
637
+ size_res2 = res2.size()[2:]
638
+
639
+ # layer2-4
640
+ x = self.layer2(x) # 1/16
641
+ res3 = self.con_res3(x) # 256
642
+ x = self.layer3(x) # 1/16
643
+
644
+ res4 = self.con_res4(x) # 256
645
+ x = self.layer4(x) # 1/16
646
+ res5 = self.con_res5(x) # 256
647
+
648
+ # x = self.res5_con1x1(torch.cat([x, res4], dim=1))
649
+ return layer1_0, res2, res3, res4, res5, x, size_layer1_0, size_res2
650
+
651
+ # return res2, res3, res4, res5, x, layer_1024, size_res2
652
+
653
+ def forward(self, x):
654
+ # size = x.size()[2:]
655
+ layer1_0, res2, res3, res4, res5, layer4, size_layer1_0, size_res2 = self.base_forward(x)
656
+
657
+ hffm = self.hffm(layer4, 4) # 256 HFFM
658
+ res5 = res5 + hffm
659
+ aux_feature = res5 # loss_aux
660
+ # res5 = self.carm(res5)
661
+ res5, _, _ = self.affm_1(res4, res5, hffm, 2) # 1/16
662
+ # aux_feature = res5
663
+ res5, _, _ = self.affm_2(res3, res5, hffm, 2) # 1/16
664
+
665
+ res5 = F.interpolate(res5, size_res2, mode='bilinear', align_corners=True)
666
+ res5, _, _ = self.affm_3(res2, res5, F.interpolate(hffm, size_res2, mode='bilinear', align_corners=True), 2)
667
+
668
+ res5 = F.interpolate(res5, size_layer1_0, mode='bilinear', align_corners=True)
669
+ res5, _, _ = self.affm_4(layer1_0, res5,
670
+ F.interpolate(hffm, size_layer1_0, mode='bilinear', align_corners=True), 2)
671
+
672
+ output = self.dsn1(res5)
673
+
674
+ if self.aux:
675
+ auxout = self.dsn2(aux_feature)
676
+ # auxout = F.interpolate(auxout, size, mode='bilinear', align_corners=True)
677
+ # outputs.append(auxout)
678
+ size = x.size()[2:]
679
+ pred, pred_aux = output, auxout
680
+ pred = F.interpolate(pred, size, mode='bilinear', align_corners=True)
681
+ pred_aux = F.interpolate(pred_aux, size, mode='bilinear', align_corners=True)
682
+ return pred
683
+ return pred, pred_aux
684
+
685
+
686
+ if __name__ == '__main__':
687
+ model = CDnetV2(num_classes=3)
688
+ fake_image = torch.rand(2, 3, 256, 256)
689
+ output = model(fake_image)
690
+ for out in output:
691
+ print(out.shape)
692
+ # torch.Size([2, 3, 256, 256])
693
+ # torch.Size([2, 3, 256, 256])
cloud_adapter/cloud_adapter.py ADDED
@@ -0,0 +1,590 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from einops import rearrange
4
+ from torch import nn, einsum
5
+ from einops import rearrange
6
+ from mmseg.models.builder import MODELS
7
+ import math
8
+ import torch
9
+ from torch import nn as nn
10
+ from mmseg.models.builder import MODELS
11
+ from timm.layers import DropPath, trunc_normal_
12
+ from typing import List
13
+ from timm.layers import create_act_layer
14
+ from functools import partial
15
+ import torch.nn.functional as F
16
+
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+ import torch.nn.functional as F
21
+ import math
22
+ from timm.layers import CondConv2d, get_condconv_initializer, create_conv2d, DropPath, get_norm_act_layer
23
+
24
+
25
+ class LoRaMLP(nn.Module):
26
+ def __init__(self, in_dim, out_dim, rank_dim=8):
27
+ super().__init__()
28
+ self.loramlp = nn.Sequential(
29
+ nn.Linear(in_dim, rank_dim, bias=False),
30
+ nn.Linear(rank_dim, out_dim, bias=False),
31
+ )
32
+
33
+ def forward(self, x):
34
+ return self.loramlp(x)
35
+
36
+
37
+ class CrossAttention(nn.Module):
38
+ def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, rank_dim=None):
39
+ super().__init__()
40
+ inner_dim = dim_head * heads # 512
41
+ context_dim = query_dim if context_dim is None else context_dim
42
+
43
+ self.scale = dim_head ** -0.5
44
+ self.heads = heads
45
+
46
+ if not rank_dim:
47
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
48
+ self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
49
+ self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
50
+
51
+ self.to_out = nn.Linear(inner_dim, query_dim, bias=False)
52
+ else:
53
+ self.to_q = LoRaMLP(query_dim, inner_dim, rank_dim=rank_dim)
54
+ self.to_k = LoRaMLP(context_dim, inner_dim, rank_dim=rank_dim)
55
+ self.to_v = LoRaMLP(context_dim, inner_dim, rank_dim=rank_dim)
56
+
57
+ self.to_out = LoRaMLP(inner_dim, query_dim, rank_dim=rank_dim)
58
+
59
+ def forward(self, x, context):
60
+ h = self.heads
61
+
62
+ q = self.to_q(x)
63
+ k = self.to_k(context)
64
+ v = self.to_v(context)
65
+
66
+ q, k, v = map(lambda t: rearrange(
67
+ t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
68
+
69
+ sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
70
+
71
+ attn = sim.softmax(dim=-1)
72
+
73
+ out = einsum('b i j, b j d -> b i d', attn, v)
74
+ out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
75
+
76
+ return self.to_out(out)
77
+
78
+
79
+ def num_groups(group_size, channels):
80
+ if not group_size:
81
+ return 1
82
+ else:
83
+ assert channels % group_size == 0
84
+ return channels // group_size
85
+
86
+
87
+ def _init_weight_goog(m, n='', fix_group_fanout=True):
88
+ if isinstance(m, CondConv2d):
89
+ fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
90
+ if fix_group_fanout:
91
+ fan_out //= m.groups
92
+ init_weight_fn = get_condconv_initializer(
93
+ lambda w: nn.init.normal_(w, 0, math.sqrt(2.0 / fan_out)), m.num_experts, m.weight_shape)
94
+ init_weight_fn(m.weight)
95
+ if m.bias is not None:
96
+ nn.init.zeros_(m.bias)
97
+ elif isinstance(m, nn.Conv2d):
98
+ fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
99
+ if fix_group_fanout:
100
+ fan_out //= m.groups
101
+ nn.init.normal_(m.weight, 0, math.sqrt(2.0 / fan_out))
102
+ if m.bias is not None:
103
+ nn.init.zeros_(m.bias)
104
+ elif isinstance(m, nn.BatchNorm2d):
105
+ nn.init.ones_(m.weight)
106
+ nn.init.zeros_(m.bias)
107
+ elif isinstance(m, nn.Linear):
108
+ fan_out = m.weight.size(0)
109
+ fan_in = 0
110
+ if 'routing_fn' in n:
111
+ fan_in = m.weight.size(1)
112
+ init_range = 1.0 / math.sqrt(fan_in + fan_out)
113
+ nn.init.uniform_(m.weight, -init_range, init_range)
114
+ if m.bias is not None:
115
+ nn.init.zeros_(m.bias)
116
+
117
+
118
+ class DepthwiseSeparableConv(nn.Module):
119
+ def __init__(
120
+ self, in_chs, out_chs, dw_kernel_size=3, stride=1, dilation=1, group_size=1, pad_type='',
121
+ noskip=False, pw_kernel_size=1, pw_act=False, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d,
122
+ se_layer=None, drop_path_rate=0.):
123
+ super(DepthwiseSeparableConv, self).__init__()
124
+ norm_act_layer = get_norm_act_layer(norm_layer)
125
+ groups = num_groups(group_size, in_chs)
126
+ self.has_skip = (stride == 1 and in_chs == out_chs) and not noskip
127
+ self.has_pw_act = pw_act
128
+
129
+ self.conv_dw = create_conv2d(
130
+ in_chs, in_chs, dw_kernel_size, stride=stride, dilation=dilation, padding=pad_type, groups=groups)
131
+ self.bn1 = norm_act_layer(in_chs, inplace=True)
132
+
133
+ self.se = se_layer(
134
+ in_chs, act_layer=act_layer) if se_layer else nn.Identity()
135
+
136
+ self.conv_pw = create_conv2d(
137
+ in_chs, out_chs, pw_kernel_size, padding=pad_type)
138
+ self.bn2 = norm_act_layer(
139
+ out_chs, inplace=True, apply_act=self.has_pw_act)
140
+ self.drop_path = DropPath(
141
+ drop_path_rate) if drop_path_rate else nn.Identity()
142
+
143
+ def feature_info(self, location):
144
+ if location == 'expansion':
145
+ return dict(module='conv_pw', hook_type='forward_pre', num_chs=self.conv_pw.in_channels)
146
+ else:
147
+ return dict(module='', hook_type='', num_chs=self.conv_pw.out_channels)
148
+
149
+ def forward(self, x):
150
+ shortcut = x
151
+ x = self.conv_dw(x)
152
+ x = self.bn1(x)
153
+ x = self.se(x)
154
+ x = self.conv_pw(x)
155
+ x = self.bn2(x)
156
+ if self.has_skip:
157
+ x = self.drop_path(x) + shortcut
158
+ return x
159
+
160
+
161
+ class PMAAConvBlock(nn.Module):
162
+ def __init__(self, in_channels=3, hidden_channels=256, depth=4, norm=nn.BatchNorm2d, act=nn.ReLU, return_multi_feats=False, return_last_feature=True, has_stem=True, has_block=True):
163
+ super().__init__()
164
+ self.return_last_feature = return_last_feature
165
+ self.depth = depth
166
+ self.has_stem = has_stem
167
+ self.return_multi_feats = return_multi_feats
168
+
169
+ self.proj_1x1 = DepthwiseSeparableConv(
170
+ in_channels, hidden_channels, dw_kernel_size=1, norm_layer=norm, act_layer=act)
171
+
172
+ self.spp_dw = nn.ModuleList()
173
+
174
+ if has_stem:
175
+ self.spp_dw.append(
176
+ DepthwiseSeparableConv(hidden_channels, hidden_channels, dw_kernel_size=3,
177
+ stride=1, group_size=hidden_channels, pad_type="same")
178
+ )
179
+ else:
180
+ self.spp_dw.append(nn.Identity())
181
+
182
+ if has_block:
183
+ for _ in range(self.depth):
184
+ self.spp_dw.append(
185
+ DepthwiseSeparableConv(
186
+ hidden_channels, hidden_channels, dw_kernel_size=3, stride=2, group_size=hidden_channels
187
+ )
188
+ )
189
+ else:
190
+ for _ in range(self.depth):
191
+ self.spp_dw.append(
192
+ nn.MaxPool2d(kernel_size=2, stride=2)
193
+ )
194
+ self._init_weights()
195
+
196
+ def forward(self, x):
197
+ B, C, H, W = x.shape
198
+ output1 = self.proj_1x1(x)
199
+ output = [self.spp_dw[0](output1)]
200
+
201
+ for k in range(1, self.depth+1):
202
+ out_k = self.spp_dw[k](output[-1])
203
+ output.append(out_k)
204
+
205
+ if self.return_multi_feats:
206
+ return output[1:]
207
+ else:
208
+ if self.return_last_feature:
209
+ return output[-1]
210
+ global_f = torch.zeros(
211
+ output[-1].shape, requires_grad=True, device=output1.device)
212
+ for fea in output:
213
+ global_f = global_f + F.adaptive_avg_pool2d(
214
+ fea, output_size=output[-1].shape[-2:]
215
+ )
216
+ return global_f
217
+
218
+ def _init_weights(self):
219
+ init_fn = _init_weight_goog
220
+ for n, m in self.named_modules():
221
+ init_fn(m, n)
222
+
223
+
224
+ class ConvnextInteractiveModule(nn.Module):
225
+ def __init__(self, emd_dim=1024, context_dim=256, rank_dim=None):
226
+ super().__init__()
227
+ self.attn = CrossAttention(emd_dim, context_dim, rank_dim=rank_dim)
228
+
229
+ def forward(self, x, cache, index):
230
+ # x: 1024 2 1024
231
+ if isinstance(cache, list) or isinstance(cache, tuple):
232
+ # len(cache) 4 cache[4]-23
233
+ # 0-5->0 6-11 -> 1 12-17->2 18-23->3
234
+ cache = cache[index]
235
+ cache = F.interpolate(
236
+ cache, (int(math.sqrt(x.shape[0])), int(math.sqrt(x.shape[0]))), mode="bilinear", align_corners=False
237
+ )
238
+ cache = cache.flatten(2) # B C N
239
+ cache = cache.permute(2, 0, 1) # N B C
240
+
241
+ # Reshape: batch first
242
+ x = x.permute(1, 0, 2) # B N C
243
+ cache = cache.permute(1, 0, 2) # B N C
244
+ return (x + self.attn(x, cache)).permute(1, 0, 2)
245
+
246
+
247
+ class PMAAInteractiveModule(nn.Module):
248
+ def __init__(self,
249
+ emd_dim=1024,
250
+ context_dim=64,
251
+ kernel: int = 1,
252
+ norm=nn.BatchNorm2d,
253
+ local_groups=32,
254
+ global_groups=2,
255
+ return_multi_feats=False,
256
+ ):
257
+ super().__init__()
258
+ self.return_multi_feats = return_multi_feats
259
+ self.local_embedding = nn.Sequential(
260
+ nn.Conv2d(emd_dim, emd_dim, kernel, groups=local_groups,
261
+ padding=int((kernel - 1) / 2), bias=False),
262
+ norm(emd_dim)
263
+ )
264
+ self.global_embedding = nn.Sequential(
265
+ nn.Conv2d(context_dim, emd_dim, kernel, groups=global_groups,
266
+ padding=int((kernel - 1) / 2), bias=False),
267
+ norm(emd_dim)
268
+ )
269
+ self.global_act = nn.Sequential(
270
+ nn.Conv2d(context_dim, emd_dim, kernel, groups=global_groups,
271
+ padding=int((kernel - 1) / 2), bias=False),
272
+ norm(emd_dim)
273
+ )
274
+ self.act = nn.Sigmoid()
275
+ self._init_weights()
276
+
277
+ def _init_weights(self):
278
+ init_fn = _init_weight_goog
279
+ for n, m in self.named_modules():
280
+ init_fn(m, n)
281
+
282
+ def forward(self, x, cache, index):
283
+ if isinstance(cache, list) or isinstance(cache, tuple):
284
+ cache = cache[index]
285
+ N, B, C = x.shape
286
+ H = W = int(math.sqrt(N))
287
+ # reshape x -> B, C, H, W
288
+ x = x.permute(1, 2, 0).reshape(B, C, H, W)
289
+ local_feat = self.local_embedding(x) # 32
290
+ global_act = self.global_act(cache)
291
+ sig_act = F.interpolate(self.act(global_act), size=(H, W)) # 32
292
+
293
+ global_feat = self.global_embedding(cache)
294
+ global_feat = F.interpolate(global_feat, size=(H, W)) # 32
295
+
296
+ out = local_feat * sig_act + global_feat
297
+
298
+ return out.permute(2, 3, 0, 1).reshape(N, B, C)
299
+
300
+
301
+ class LayerNorm(nn.Module):
302
+ r""" LayerNorm that supports two data formats: channels_last (default) or channels_first.
303
+ The ordering of the dimensions in the inputs. channels_last corresponds to inputs with
304
+ shape (batch_size, height, width, channels) while channels_first corresponds to inputs
305
+ with shape (batch_size, channels, height, width).
306
+ """
307
+
308
+ def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
309
+ super().__init__()
310
+ self.weight = nn.Parameter(torch.ones(normalized_shape))
311
+ self.bias = nn.Parameter(torch.zeros(normalized_shape))
312
+ self.eps = eps
313
+ self.data_format = data_format
314
+ if self.data_format not in ["channels_last", "channels_first"]:
315
+ raise NotImplementedError
316
+ self.normalized_shape = (normalized_shape, )
317
+
318
+ def forward(self, x):
319
+ if self.data_format == "channels_last":
320
+ return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
321
+ elif self.data_format == "channels_first":
322
+ u = x.mean(1, keepdim=True)
323
+ s = (x - u).pow(2).mean(1, keepdim=True)
324
+ x = (x - u) / torch.sqrt(s + self.eps)
325
+ x = self.weight[:, None, None] * x + self.bias[:, None, None]
326
+ return x
327
+
328
+
329
+ class Block(nn.Module):
330
+ r""" ConvNeXt Block. There are two equivalent implementations:
331
+ (1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W)
332
+ (2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back
333
+ We use (2) as we find it slightly faster in PyTorch
334
+
335
+ Args:
336
+ dim (int): Number of input channels.
337
+ drop_path (float): Stochastic depth rate. Default: 0.0
338
+ layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
339
+ """
340
+
341
+ def __init__(self, dim, drop_path=0., layer_scale_init_value=1e-6):
342
+ super().__init__()
343
+ self.dwconv = nn.Conv2d(dim, dim, kernel_size=7,
344
+ padding=3, groups=dim) # depthwise conv
345
+ self.norm = LayerNorm(dim, eps=1e-6)
346
+ # pointwise/1x1 convs, implemented with linear layers
347
+ self.pwconv1 = nn.Linear(dim, 4 * dim)
348
+ self.act = nn.GELU()
349
+ self.pwconv2 = nn.Linear(4 * dim, dim)
350
+ self.gamma = nn.Parameter(layer_scale_init_value * torch.ones((dim)),
351
+ requires_grad=True) if layer_scale_init_value > 0 else None
352
+ self.drop_path = DropPath(
353
+ drop_path) if drop_path > 0. else nn.Identity()
354
+
355
+ def forward(self, x):
356
+ input = x
357
+ x = self.dwconv(x)
358
+ x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)
359
+ x = self.norm(x)
360
+ x = self.pwconv1(x)
361
+ x = self.act(x)
362
+ x = self.pwconv2(x)
363
+ if self.gamma is not None:
364
+ x = self.gamma * x
365
+ x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W)
366
+
367
+ x = input + self.drop_path(x)
368
+ return x
369
+
370
+
371
+ class ConvNeXt(nn.Module):
372
+ r""" ConvNeXt
373
+ A PyTorch impl of : `A ConvNet for the 2020s` -
374
+ https://arxiv.org/pdf/2201.03545.pdf
375
+
376
+ Args:
377
+ in_chans (int): Number of input image channels. Default: 3
378
+ num_classes (int): Number of classes for classification head. Default: 1000
379
+ depths (tuple(int)): Number of blocks at each stage. Default: [3, 3, 9, 3]
380
+ dims (int): Feature dimension at each stage. Default: [96, 192, 384, 768]
381
+ drop_path_rate (float): Stochastic depth rate. Default: 0.
382
+ layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
383
+ head_init_scale (float): Init scaling value for classifier weights and biases. Default: 1.
384
+ """
385
+
386
+ def __init__(self, in_chans=3, depths=[3, 3, 9, 3], dims=[96, 192, 384, 768],
387
+ drop_path_rate=0., layer_scale_init_value=1e-6, out_indices=[0, 1, 2, 3],
388
+ return_multi_feats=False,
389
+ return_last_feature=True
390
+ ):
391
+ super().__init__()
392
+ self.return_last_feature = return_last_feature
393
+ self.return_multi_feats = return_multi_feats
394
+
395
+ # stem and 3 intermediate downsampling conv layers
396
+ self.downsample_layers = nn.ModuleList()
397
+ stem = nn.Sequential(
398
+ nn.Conv2d(in_chans, dims[0], kernel_size=2, stride=2),
399
+ LayerNorm(dims[0], eps=1e-6, data_format="channels_first")
400
+ )
401
+ self.downsample_layers.append(stem)
402
+ for i in range(3):
403
+ downsample_layer = nn.Sequential(
404
+ LayerNorm(dims[i], eps=1e-6, data_format="channels_first"),
405
+ nn.Conv2d(dims[i], dims[i+1], kernel_size=2, stride=2),
406
+ )
407
+ self.downsample_layers.append(downsample_layer)
408
+
409
+ # 4 feature resolution stages, each consisting of multiple residual blocks
410
+ self.stages = nn.ModuleList()
411
+ dp_rates = [x.item()
412
+ for x in torch.linspace(0, drop_path_rate, sum(depths))]
413
+ cur = 0
414
+ for i in range(4):
415
+ stage = nn.Sequential(
416
+ *[Block(dim=dims[i], drop_path=dp_rates[cur + j],
417
+ layer_scale_init_value=layer_scale_init_value) for j in range(depths[i])]
418
+ )
419
+ self.stages.append(stage)
420
+ cur += depths[i]
421
+
422
+ self.out_indices = out_indices
423
+
424
+ norm_layer = partial(LayerNorm, eps=1e-6, data_format="channels_first")
425
+ for i_layer in range(4):
426
+ layer = norm_layer(dims[i_layer])
427
+ layer_name = f'norm{i_layer}'
428
+ self.add_module(layer_name, layer)
429
+
430
+ self.apply(self._init_weights)
431
+
432
+ def _init_weights(self, m):
433
+ if isinstance(m, (nn.Conv2d, nn.Linear)):
434
+ trunc_normal_(m.weight, std=.02)
435
+ nn.init.constant_(m.bias, 0)
436
+
437
+ def init_weights(self, pretrained=None):
438
+ """Initialize the weights in backbone.
439
+ Args:
440
+ pretrained (str, optional): Path to pre-trained weights.
441
+ Defaults to None.
442
+ """
443
+
444
+ def _init_weights(m):
445
+ if isinstance(m, nn.Linear):
446
+ trunc_normal_(m.weight, std=.02)
447
+ if isinstance(m, nn.Linear) and m.bias is not None:
448
+ nn.init.constant_(m.bias, 0)
449
+ elif isinstance(m, nn.LayerNorm):
450
+ nn.init.constant_(m.bias, 0)
451
+ nn.init.constant_(m.weight, 1.0)
452
+
453
+ if isinstance(pretrained, str):
454
+ self.apply(_init_weights)
455
+ # logger = get_root_logger()
456
+ # load_checkpoint(self, pretrained, strict=False, logger=logger)
457
+ elif pretrained is None:
458
+ self.apply(_init_weights)
459
+ else:
460
+ raise TypeError('pretrained must be a str or None')
461
+
462
+ def forward_features(self, x):
463
+ outs = []
464
+ for i in range(4):
465
+ x = self.downsample_layers[i](x)
466
+ x = self.stages[i](x)
467
+ if i in self.out_indices:
468
+ norm_layer = getattr(self, f'norm{i}')
469
+ x_out = norm_layer(x)
470
+ outs.append(x_out)
471
+ if self.return_multi_feats:
472
+ return tuple(outs)
473
+ if self.return_last_feature:
474
+ return outs[-1]
475
+ global_f = torch.zeros(
476
+ outs[-1].shape, requires_grad=True, device=outs[-1].device)
477
+ for fea in outs:
478
+ global_f = global_f + F.adaptive_avg_pool2d(
479
+ fea, output_size=outs[-1].shape[-2:]
480
+ )
481
+ return global_f
482
+
483
+ def forward(self, x):
484
+ x = self.forward_features(x)
485
+ return x
486
+
487
+
488
+ class NoAdaptingModule(nn.Identity):
489
+ def __init__(self):
490
+ super().__init__()
491
+
492
+ def forward(self, x, cache, layer):
493
+ return x
494
+
495
+
496
+ @MODELS.register_module()
497
+ class CloudAdapter(nn.Module):
498
+ def __init__(self,
499
+ cnn_type="convnext", # convnext or mobilenet
500
+ int_type="convnext", # cross_attention or
501
+ # 共同的参数 start
502
+ emd_dim=1024,
503
+ num_layers=24,
504
+
505
+ # 先判断是否返回多特征,之后再判断是否进行特征融合
506
+ return_multi_feats=True,
507
+ return_last_feature=False,
508
+
509
+ # 共同的参数 end
510
+
511
+ # pmaa 提取单个特征 or 多尺寸特征 start
512
+ hidden_channels=256,
513
+ depth=4,
514
+ norm=nn.BatchNorm2d,
515
+ act=nn.ReLU,
516
+ # pmaa 提取单个特征 or 多尺寸特征 end
517
+
518
+ # pmaa net start
519
+ local_groups=1,
520
+ global_groups=1,
521
+ # pmaa net end
522
+
523
+ # convnext 提取单个特征 or 多尺寸特征 start
524
+ context_dim=256,
525
+ rank_dim=None,
526
+ # convnext 提取单个特征 or 多尺寸特征 end,
527
+ has_stem=True,
528
+ has_block=True,
529
+ ):
530
+ super().__init__()
531
+ self.cnn = nn.Identity()
532
+ self.net = nn.Identity()
533
+ if cnn_type == "pmaa":
534
+ self.cnn = PMAAConvBlock(
535
+ hidden_channels=hidden_channels,
536
+ depth=depth,
537
+ norm=norm,
538
+ act=act,
539
+ return_multi_feats=return_multi_feats,
540
+ return_last_feature=return_last_feature,
541
+ has_stem=has_stem,
542
+ has_block=has_block
543
+ )
544
+ elif cnn_type == "convnext":
545
+ self.cnn = ConvNeXt(depths=[1]*4,
546
+ dims=[context_dim]*4,
547
+ return_multi_feats=return_multi_feats,
548
+ return_last_feature=return_last_feature
549
+ )
550
+
551
+ else:
552
+ raise ValueError(
553
+ f"cnn_type must in ['convnext','pmaa'],but got {cnn_type}")
554
+
555
+ if int_type == "convnext":
556
+ self.net = nn.ModuleList(
557
+ ConvnextInteractiveModule(emd_dim, context_dim, rank_dim)
558
+ for _ in range(num_layers)
559
+ )
560
+ elif int_type == "pmaa":
561
+ self.net = nn.ModuleList(
562
+ PMAAInteractiveModule(
563
+ emd_dim, context_dim, local_groups=local_groups, global_groups=global_groups)
564
+ for _ in range(num_layers)
565
+ )
566
+
567
+ elif int_type == "no_adapting":
568
+ self.net = nn.ModuleList(
569
+ NoAdaptingModule() for _ in range(num_layers)
570
+ )
571
+ else:
572
+ raise ValueError(
573
+ f"int_type must in ['convnext','pmaa'],but got {int_type}")
574
+
575
+ def forward(self, feats, layer, batch_first=True, has_cls_token=True, cache=None):
576
+ if batch_first:
577
+ feats = feats.permute(1, 0, 2) # 1025 2 1024
578
+ if has_cls_token:
579
+ cls_token, feats = torch.tensor_split(feats, [1], dim=0)
580
+ # 24 // 1
581
+ # feat: 1024 2 1024
582
+ feats = self.net[layer].forward(
583
+ feats, cache, layer//(len(self.net) // 4))
584
+
585
+ if has_cls_token:
586
+ feats = torch.cat([cls_token, feats], dim=0)
587
+ if batch_first:
588
+ feats = feats.permute(1, 0, 2)
589
+ return feats
590
+
cloud_adapter/cloud_adapter_dinov2.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from mmseg.models.builder import BACKBONES, MODELS
2
+ from torch import nn as nn
3
+ from .cloud_adapter import CloudAdapter
4
+ from .dino_v2 import DinoVisionTransformer
5
+ from .utils import set_requires_grad, set_train
6
+ import torch
7
+ import torch.nn.functional as F
8
+
9
+
10
+ @BACKBONES.register_module()
11
+ class CloudAdapterDinoVisionTransformer(DinoVisionTransformer):
12
+ def __init__(
13
+ self,
14
+ cloud_adapter_config=None,
15
+ has_cat=False,
16
+ # [0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, ],
17
+ adapter_index=[0, 6, 12, 18], # Transformer Block 的索引
18
+ **kwargs,
19
+ ):
20
+ super().__init__(**kwargs)
21
+ self.cloud_adapter: CloudAdapter = MODELS.build(cloud_adapter_config)
22
+ self.has_cat = has_cat
23
+ self.adapter_index = adapter_index
24
+
25
+ def forward_features(self, x, masks=None):
26
+ B, _, h, w = x.shape
27
+ cache = self.cloud_adapter.cnn(x) # 得到多尺度特征或者单个特征
28
+ H, W = h // self.patch_size, w // self.patch_size
29
+ x = self.prepare_tokens_with_masks(x, masks)
30
+ outs = []
31
+ cur_idx = 0 # 交互模块的索引
32
+ for idx, blk in enumerate(self.blocks):
33
+ x = blk(x)
34
+ if idx in self.adapter_index:
35
+ x = self.cloud_adapter.forward(
36
+ x,
37
+ cur_idx,
38
+ batch_first=True,
39
+ has_cls_token=True,
40
+ cache=cache,
41
+ )
42
+ cur_idx += 1
43
+ if idx in self.out_indices:
44
+ outs.append(
45
+ x[:, 1:, :].permute(0, 2, 1).reshape(
46
+ B, -1, H, W).contiguous()
47
+ )
48
+ return outs, cache
49
+
50
+ def process_cache(self,ret,cache):
51
+ cache = F.interpolate(
52
+ cache,size=(ret.shape[-2],ret.shape[-1]),mode="bilinear",align_corners=False)
53
+ return cache
54
+
55
+ def forward(self, *args, **kwargs):
56
+ ret, cache = self.forward_features(*args, **kwargs)
57
+ if isinstance(ret[0], torch.Tensor):
58
+ ret[0] = F.interpolate(
59
+ ret[0], scale_factor=4, mode="bilinear", align_corners=False
60
+ )
61
+ ret[1] = F.interpolate(
62
+ ret[1], scale_factor=2, mode="bilinear", align_corners=False
63
+ )
64
+ ret[3] = F.interpolate(
65
+ ret[3], scale_factor=0.5, mode="bilinear", align_corners=False
66
+ )
67
+ if self.has_cat:
68
+ if isinstance(cache,tuple) or isinstance(cache,list):
69
+ ret[0] = torch.cat((ret[0], cache[0]), dim=1)
70
+ ret[1] = torch.cat((ret[1], cache[1]), dim=1)
71
+ ret[2] = torch.cat((ret[2], cache[2]), dim=1)
72
+ ret[3] = torch.cat((ret[3], cache[3]), dim=1)
73
+ else:
74
+ ret[0] = torch.cat((ret[0], self.process_cache(ret[0],cache)), dim=1)
75
+ ret[1] = torch.cat((ret[1], self.process_cache(ret[1],cache)), dim=1)
76
+ ret[2] = torch.cat((ret[2], self.process_cache(ret[2],cache)), dim=1)
77
+ ret[3] = torch.cat((ret[3], self.process_cache(ret[3],cache)), dim=1)
78
+ # ret[0] = torch.cat(ret[0], cache[0], dim=1) # bs 1024 128 128, bs 256 128 128
79
+ else:
80
+ ret[0][0] = F.interpolate(
81
+ ret[0][0], scale_factor=4, mode="bilinear", align_corners=False
82
+ )
83
+ ret[0][1] = F.interpolate(
84
+ ret[0][1], scale_factor=2, mode="bilinear", align_corners=False
85
+ )
86
+ ret[0][3] = F.interpolate(
87
+ ret[0][3], scale_factor=0.5, mode="bilinear", align_corners=False
88
+ )
89
+ if self.has_cat:
90
+ if isinstance(cache,tuple) or isinstance(cache,list):
91
+ ret[0][0] = torch.cat((ret[0][0], cache[0]), dim=1)
92
+ ret[0][1] = torch.cat((ret[0][1], cache[1]), dim=1)
93
+ ret[0][2] = torch.cat((ret[0][2], cache[2]), dim=1)
94
+ ret[0][3] = torch.cat((ret[0][3], cache[3]), dim=1)
95
+ else:
96
+ ret[0][0] = torch.cat((ret[0][0], self.process_cache(ret[0][0],cache)), dim=1)
97
+ ret[0][1] = torch.cat((ret[0][1], self.process_cache(ret[0][1],cache)), dim=1)
98
+ ret[0][2] = torch.cat((ret[0][2], self.process_cache(ret[0][2],cache)), dim=1)
99
+ ret[0][3] = torch.cat((ret[0][3], self.process_cache(ret[0][3],cache)), dim=1)
100
+ return ret
101
+
102
+ def train(self, mode: bool = True):
103
+ if not mode:
104
+ return super().train(mode)
105
+ set_requires_grad(self, ["cloud_adapter"])
106
+ set_train(self, ["cloud_adapter"])
107
+
108
+ def state_dict(self, destination, prefix, keep_vars):
109
+ state = super().state_dict(destination, prefix, keep_vars)
110
+ keys = [k for k in state.keys() if "cloud_adapter" not in k]
111
+ for key in keys:
112
+ state.pop(key)
113
+ if key in destination:
114
+ destination.pop(key)
115
+ return state
cloud_adapter/dbnet.py ADDED
@@ -0,0 +1,680 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # @Time : 2024/7/26 上午11:19
3
+ # @Author : xiaoshun
4
+ # @Email : [email protected]
5
+ # @File : dbnet.py
6
+ # @Software: PyCharm
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+ from einops import rearrange
12
+
13
+
14
+ # from models.Transformer.ViT import truncated_normal_
15
+
16
+ # Decoder细化卷积模块
17
+ class SBR(nn.Module):
18
+ def __init__(self, in_ch):
19
+ super(SBR, self).__init__()
20
+ self.conv1x3 = nn.Sequential(
21
+ nn.Conv2d(in_ch, in_ch, kernel_size=(1, 3), stride=1, padding=(0, 1)),
22
+ nn.BatchNorm2d(in_ch),
23
+ nn.ReLU(True)
24
+ )
25
+ self.conv3x1 = nn.Sequential(
26
+ nn.Conv2d(in_ch, in_ch, kernel_size=(3, 1), stride=1, padding=(1, 0)),
27
+ nn.BatchNorm2d(in_ch),
28
+ nn.ReLU(True)
29
+ )
30
+
31
+ def forward(self, x):
32
+ out = self.conv3x1(self.conv1x3(x)) # 先进行1x3的卷积,得到结果并将结果再进行3x1的卷积
33
+ return out + x
34
+
35
+
36
+ # 下采样卷积模块 stage 1,2,3
37
+ class c_stage123(nn.Module):
38
+ def __init__(self, in_chans, out_chans):
39
+ super().__init__()
40
+ self.stage123 = nn.Sequential(
41
+ nn.Conv2d(in_channels=in_chans, out_channels=out_chans, kernel_size=3, stride=2, padding=1),
42
+ nn.BatchNorm2d(out_chans),
43
+ nn.ReLU(),
44
+ nn.Conv2d(in_channels=out_chans, out_channels=out_chans, kernel_size=3, stride=1, padding=1),
45
+ nn.BatchNorm2d(out_chans),
46
+ nn.ReLU(),
47
+ )
48
+ self.conv1x1_123 = nn.Conv2d(in_channels=in_chans, out_channels=out_chans, kernel_size=1)
49
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
50
+
51
+ def forward(self, x):
52
+ stage123 = self.stage123(x) # 3*3卷积,两倍下采样 3*224*224-->64*112*112
53
+ max = self.maxpool(x) # 最大值池化,两倍下采样 3*224*224-->3*112*112
54
+ max = self.conv1x1_123(max) # 1*1卷积 3*112*112-->64*112*112
55
+ stage123 = stage123 + max # 残差结构,广播机制
56
+ return stage123
57
+
58
+
59
+ # 下采样卷积模块 stage4,5
60
+ class c_stage45(nn.Module):
61
+ def __init__(self, in_chans, out_chans):
62
+ super().__init__()
63
+ self.stage45 = nn.Sequential(
64
+ nn.Conv2d(in_channels=in_chans, out_channels=out_chans, kernel_size=3, stride=2, padding=1),
65
+ nn.BatchNorm2d(out_chans),
66
+ nn.ReLU(),
67
+ nn.Conv2d(in_channels=out_chans, out_channels=out_chans, kernel_size=3, stride=1, padding=1),
68
+ nn.BatchNorm2d(out_chans),
69
+ nn.ReLU(),
70
+ nn.Conv2d(in_channels=out_chans, out_channels=out_chans, kernel_size=3, stride=1, padding=1),
71
+ nn.BatchNorm2d(out_chans),
72
+ nn.ReLU(),
73
+ )
74
+ self.conv1x1_45 = nn.Conv2d(in_channels=in_chans, out_channels=out_chans, kernel_size=1)
75
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
76
+
77
+ def forward(self, x):
78
+ stage45 = self.stage45(x) # 3*3卷积模块 2倍下采样
79
+ max = self.maxpool(x) # 最大值池化,两倍下采样
80
+ max = self.conv1x1_45(max) # 1*1卷积模块 调整通道数
81
+ stage45 = stage45 + max # 残差结构
82
+ return stage45
83
+
84
+
85
+ class Identity(nn.Module): # 恒等映射
86
+ def __init__(self):
87
+ super().__init__()
88
+
89
+ def forward(self, x):
90
+ return x
91
+
92
+
93
+ # 轻量卷积模块
94
+ class DepthwiseConv2d(nn.Module): # 用于自注意力机制
95
+ def __init__(self, in_chans, out_chans, kernel_size=1, stride=1, padding=0, dilation=1):
96
+ super().__init__()
97
+ # depthwise conv
98
+ self.depthwise = nn.Conv2d(
99
+ in_channels=in_chans,
100
+ out_channels=in_chans,
101
+ kernel_size=kernel_size,
102
+ stride=stride,
103
+ padding=padding,
104
+ dilation=dilation, # 深层卷积的膨胀率
105
+ groups=in_chans # 指定分组卷积的组数
106
+ )
107
+ # batch norm
108
+ self.bn = nn.BatchNorm2d(num_features=in_chans)
109
+
110
+ # pointwise conv 逐点卷积
111
+ self.pointwise = nn.Conv2d(
112
+ in_channels=in_chans,
113
+ out_channels=out_chans,
114
+ kernel_size=1
115
+ )
116
+
117
+ def forward(self, x):
118
+ x = self.depthwise(x)
119
+ x = self.bn(x)
120
+ x = self.pointwise(x)
121
+ return x
122
+
123
+
124
+ # residual skip connection 残差跳跃连接
125
+ class Residual(nn.Module):
126
+ def __init__(self, fn):
127
+ super().__init__()
128
+ self.fn = fn
129
+
130
+ def forward(self, input, **kwargs):
131
+ x = self.fn(input, **kwargs)
132
+ return (x + input)
133
+
134
+
135
+ # layer norm plus 层归一化
136
+ class PreNorm(nn.Module): # 代表神经网络层
137
+ def __init__(self, dim, fn):
138
+ super().__init__()
139
+ self.norm = nn.LayerNorm(dim)
140
+ self.fn = fn
141
+
142
+ def forward(self, input, **kwargs):
143
+ return self.fn(self.norm(input), **kwargs)
144
+
145
+
146
+ # FeedForward层使得representation的表达能力更强
147
+ class FeedForward(nn.Module):
148
+ def __init__(self, dim, hidden_dim, dropout=0.):
149
+ super().__init__()
150
+ self.net = nn.Sequential(
151
+ nn.Linear(in_features=dim, out_features=hidden_dim),
152
+ nn.GELU(),
153
+ nn.Dropout(dropout),
154
+ nn.Linear(in_features=hidden_dim, out_features=dim),
155
+ nn.Dropout(dropout)
156
+ )
157
+
158
+ def forward(self, input):
159
+ return self.net(input)
160
+
161
+
162
+ class ConvAttnetion(nn.Module):
163
+ '''
164
+ using the Depth_Separable_Wise Conv2d to produce the q, k, v instead of using Linear Project in ViT
165
+ '''
166
+
167
+ def __init__(self, dim, img_size, heads=8, dim_head=64, kernel_size=3, q_stride=1, k_stride=1, v_stride=1,
168
+ dropout=0., last_stage=False):
169
+ super().__init__()
170
+ self.last_stage = last_stage
171
+ self.img_size = img_size
172
+ inner_dim = dim_head * heads # 512
173
+ project_out = not (heads == 1 and dim_head == dim)
174
+
175
+ self.heads = heads
176
+ self.scale = dim_head ** (-0.5)
177
+
178
+ pad = (kernel_size - q_stride) // 2
179
+
180
+ self.to_q = DepthwiseConv2d(in_chans=dim, out_chans=inner_dim, kernel_size=kernel_size, stride=q_stride,
181
+ padding=pad) # 自注意力机制
182
+ self.to_k = DepthwiseConv2d(in_chans=dim, out_chans=inner_dim, kernel_size=kernel_size, stride=k_stride,
183
+ padding=pad)
184
+ self.to_v = DepthwiseConv2d(in_chans=dim, out_chans=inner_dim, kernel_size=kernel_size, stride=v_stride,
185
+ padding=pad)
186
+
187
+ self.to_out = nn.Sequential(
188
+ nn.Linear(
189
+ in_features=inner_dim,
190
+ out_features=dim
191
+ ),
192
+ nn.Dropout(dropout)
193
+ ) if project_out else Identity()
194
+
195
+ def forward(self, x):
196
+ b, n, c, h = *x.shape, self.heads # * 星号的作用大概是去掉 tuple 属性吧
197
+
198
+ # print(x.shape)
199
+ # print('+++++++++++++++++++++++++++++++++')
200
+
201
+ # if语句内容没有使用
202
+ if self.last_stage:
203
+ cls_token = x[:, 0]
204
+ # print(cls_token.shape)
205
+ # print('+++++++++++++++++++++++++++++++++')
206
+ x = x[:, 1:] # 去掉每个数组的第一个元素
207
+
208
+ cls_token = rearrange(torch.unsqueeze(cls_token, dim=1), 'b n (h d) -> b h n d', h=h)
209
+
210
+ # rearrange:用于对张量的维度进行重新变换排序,可用于替换pytorch中的reshape,view,transpose和permute等操作
211
+ x = rearrange(x, 'b (l w) n -> b n l w', l=self.img_size, w=self.img_size) # [1, 3136, 64]-->1*64*56*56
212
+ # batch_size,N(通道数),h,w
213
+
214
+ q = self.to_q(x) # 1*64*56*56-->1*64*56*56
215
+ # print(q.shape)
216
+ # print('++++++++++++++')
217
+ q = rearrange(q, 'b (h d) l w -> b h (l w) d', h=h) # 1*64*56*56-->1*1*3136*64
218
+ # print(q.shape)
219
+ # print('=====================')
220
+ # batch_size,head,h*w,dim_head
221
+
222
+ k = self.to_k(x) # 操作和q一样
223
+ k = rearrange(k, 'b (h d) l w -> b h (l w) d', h=h)
224
+ # batch_size,head,h*w,dim_head
225
+
226
+ v = self.to_v(x) ##操作和q一样
227
+ # print(v.shape)
228
+ # print('[[[[[[[[[[[[[[[[[[[[[[[[[[[[')
229
+ v = rearrange(v, 'b (h d) l w -> b h (l w) d', h=h)
230
+ # print(v.shape)
231
+ # print(']]]]]]]]]]]]]]]]]]]]]]]]]]]')
232
+ # batch_size,head,h*w,dim_head
233
+
234
+ if self.last_stage:
235
+ # print(q.shape)
236
+ # print('================')
237
+ q = torch.cat([cls_token, q], dim=2)
238
+ # print(q.shape)
239
+ # print('++++++++++++++++++')
240
+ v = torch.cat([cls_token, v], dim=2)
241
+ k = torch.cat([cls_token, k], dim=2)
242
+
243
+ # calculate attention by matmul + scale
244
+ # permute:(batch_size,head,dim_head,h*w
245
+ # print(k.shape)
246
+ # print('++++++++++++++++++++')
247
+ k = k.permute(0, 1, 3, 2) # 1*1*3136*64-->1*1*64*3136
248
+ # print(k.shape)
249
+ # print('====================')
250
+ attention = (q.matmul(k)) # 1*1*3136*3136
251
+ # print(attention.shape)
252
+ # print('--------------------')
253
+ attention = attention * self.scale # 可以得到一个logit的向量,避免出现梯度下降和梯度爆炸
254
+ # print(attention.shape)
255
+ # print('####################')
256
+ # pass a softmax
257
+ attention = F.softmax(attention, dim=-1)
258
+ # print(attention.shape)
259
+ # print('********************')
260
+
261
+ # matmul v
262
+ # attention.matmul(v):(batch_size,head,h*w,dim_head)
263
+ # permute:(batch_size,h*w,head,dim_head)
264
+ out = (attention.matmul(v)).permute(0, 2, 1, 3).reshape(b, n,
265
+ c) # 1*3136*64 这些操作的目的是将注意力权重和值向量相乘后得到的结果进行重塑,得到一个形状为 (batch size, 序列长度, 值向量或矩阵的维度) 的张量
266
+
267
+ # linear project
268
+ out = self.to_out(out)
269
+ return out
270
+
271
+
272
+ # Reshape Layers
273
+ class Rearrange(nn.Module):
274
+ def __init__(self, string, h, w):
275
+ super().__init__()
276
+ self.string = string
277
+ self.h = h
278
+ self.w = w
279
+
280
+ def forward(self, input):
281
+
282
+ if self.string == 'b c h w -> b (h w) c':
283
+ N, C, H, W = input.shape
284
+ # print(input.shape)
285
+ x = torch.reshape(input, shape=(N, -1, self.h * self.w)).permute(0, 2, 1)
286
+ # print(x.shape)
287
+ # print('+++++++++++++++++++')
288
+ if self.string == 'b (h w) c -> b c h w':
289
+ N, _, C = input.shape
290
+ # print(input.shape)
291
+ x = torch.reshape(input, shape=(N, self.h, self.w, -1)).permute(0, 3, 1, 2)
292
+ # print(x.shape)
293
+ # print('=====================')
294
+ return x
295
+
296
+
297
+ # Transformer layers
298
+ class Transformer(nn.Module):
299
+ def __init__(self, dim, img_size, depth, heads, dim_head, mlp_dim, dropout=0., last_stage=False):
300
+ super().__init__()
301
+ self.layers = nn.ModuleList([ # 管理子模块,参数注册
302
+ nn.ModuleList([
303
+ PreNorm(dim=dim, fn=ConvAttnetion(dim, img_size, heads=heads, dim_head=dim_head, dropout=dropout,
304
+ last_stage=last_stage)), # 归一化,重参数化
305
+ PreNorm(dim=dim, fn=FeedForward(dim=dim, hidden_dim=mlp_dim, dropout=dropout))
306
+ ]) for _ in range(depth)
307
+ ])
308
+
309
+ def forward(self, x):
310
+ for attn, ff in self.layers:
311
+ x = x + attn(x)
312
+ x = x + ff(x)
313
+ return x
314
+
315
+
316
+ class DBNet(nn.Module): # 最主要的大函数
317
+ def __init__(self, img_size, in_channels, num_classes, dim=64, kernels=[7, 3, 3, 3], strides=[4, 2, 2, 2],
318
+ heads=[1, 3, 6, 6],
319
+ depth=[1, 2, 10, 10], pool='cls', dropout=0., emb_dropout=0., scale_dim=4, ):
320
+ super().__init__()
321
+
322
+ assert pool in ['cls', 'mean'], f'pool type must be either cls or mean pooling'
323
+ self.pool = pool
324
+ self.dim = dim
325
+
326
+ # stage1
327
+ # k:7 s:4 in: 1, 64, 56, 56 out: 1, 3136, 64
328
+ self.stage1_conv_embed = nn.Sequential(
329
+ nn.Conv2d( # 1*3*224*224-->[1, 64, 56, 56]
330
+ in_channels=in_channels,
331
+ out_channels=dim,
332
+ kernel_size=kernels[0],
333
+ stride=strides[0],
334
+ padding=2
335
+ ),
336
+ Rearrange('b c h w -> b (h w) c', h=img_size // 4, w=img_size // 4), # [1, 64, 56, 56]-->[1, 3136, 64]
337
+ nn.LayerNorm(dim) # 对每个batch归一化
338
+ )
339
+
340
+ self.stage1_transformer = nn.Sequential(
341
+ Transformer( #
342
+ dim=dim,
343
+ img_size=img_size // 4,
344
+ depth=depth[0], # Transformer层中的编码器和解码器层数。
345
+ heads=heads[0],
346
+ dim_head=self.dim, # 它是每个注意力头的维度大小,通常是嵌入维度除以头数。
347
+ mlp_dim=dim * scale_dim, # mlp_dim:它是Transformer中前馈神经网络的隐藏层维度大小,通常是嵌入维度乘以一个缩放因子。
348
+ dropout=dropout,
349
+ # last_stage=last_stage #它是一个标志位,用于表示该Transformer层是否是最后一层。
350
+ ),
351
+ Rearrange('b (h w) c -> b c h w', h=img_size // 4, w=img_size // 4)
352
+ )
353
+
354
+ # stage2
355
+ # k:3 s:2 in: 1, 192, 28, 28 out: 1, 784, 192
356
+ in_channels = dim
357
+ scale = heads[1] // heads[0]
358
+ dim = scale * dim
359
+
360
+ self.stage2_conv_embed = nn.Sequential(
361
+ nn.Conv2d(
362
+ in_channels=in_channels,
363
+ out_channels=dim,
364
+ kernel_size=kernels[1],
365
+ stride=strides[1],
366
+ padding=1
367
+ ),
368
+ Rearrange('b c h w -> b (h w) c', h=img_size // 8, w=img_size // 8),
369
+ nn.LayerNorm(dim)
370
+ )
371
+
372
+ self.stage2_transformer = nn.Sequential(
373
+ Transformer(
374
+ dim=dim,
375
+ img_size=img_size // 8,
376
+ depth=depth[1],
377
+ heads=heads[1],
378
+ dim_head=self.dim,
379
+ mlp_dim=dim * scale_dim,
380
+ dropout=dropout
381
+ ),
382
+ Rearrange('b (h w) c -> b c h w', h=img_size // 8, w=img_size // 8)
383
+ )
384
+
385
+ # stage3
386
+ in_channels = dim
387
+ scale = heads[2] // heads[1]
388
+ dim = scale * dim
389
+
390
+ self.stage3_conv_embed = nn.Sequential(
391
+ nn.Conv2d(
392
+ in_channels=in_channels,
393
+ out_channels=dim,
394
+ kernel_size=kernels[2],
395
+ stride=strides[2],
396
+ padding=1
397
+ ),
398
+ Rearrange('b c h w -> b (h w) c', h=img_size // 16, w=img_size // 16),
399
+ nn.LayerNorm(dim)
400
+ )
401
+
402
+ self.stage3_transformer = nn.Sequential(
403
+ Transformer(
404
+ dim=dim,
405
+ img_size=img_size // 16,
406
+ depth=depth[2],
407
+ heads=heads[2],
408
+ dim_head=self.dim,
409
+ mlp_dim=dim * scale_dim,
410
+ dropout=dropout
411
+ ),
412
+ Rearrange('b (h w) c -> b c h w', h=img_size // 16, w=img_size // 16)
413
+ )
414
+
415
+ # stage4
416
+ in_channels = dim
417
+ scale = heads[3] // heads[2]
418
+ dim = scale * dim
419
+
420
+ self.stage4_conv_embed = nn.Sequential(
421
+ nn.Conv2d(
422
+ in_channels=in_channels,
423
+ out_channels=dim,
424
+ kernel_size=kernels[3],
425
+ stride=strides[3],
426
+ padding=1
427
+ ),
428
+ Rearrange('b c h w -> b (h w) c', h=img_size // 32, w=img_size // 32),
429
+ nn.LayerNorm(dim)
430
+ )
431
+
432
+ self.stage4_transformer = nn.Sequential(
433
+ Transformer(
434
+ dim=dim, img_size=img_size // 32,
435
+ depth=depth[3],
436
+ heads=heads[3],
437
+ dim_head=self.dim,
438
+ mlp_dim=dim * scale_dim,
439
+ dropout=dropout,
440
+ ),
441
+ Rearrange('b (h w) c -> b c h w', h=img_size // 32, w=img_size // 32)
442
+ )
443
+
444
+ ### CNN Branch ###
445
+ self.c_stage1 = c_stage123(in_chans=3, out_chans=64)
446
+ self.c_stage2 = c_stage123(in_chans=64, out_chans=128)
447
+ self.c_stage3 = c_stage123(in_chans=128, out_chans=384)
448
+ self.c_stage4 = c_stage45(in_chans=384, out_chans=512)
449
+ self.c_stage5 = c_stage45(in_chans=512, out_chans=1024)
450
+ self.c_max = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
451
+ self.up_conv1 = nn.Conv2d(in_channels=192, out_channels=128, kernel_size=1)
452
+ self.up_conv2 = nn.Conv2d(in_channels=384, out_channels=512, kernel_size=1)
453
+
454
+ ### CTmerge ###
455
+ self.CTmerge1 = nn.Sequential(
456
+ nn.Conv2d(in_channels=128, out_channels=64, kernel_size=3, stride=1, padding=1),
457
+ nn.BatchNorm2d(64),
458
+ nn.ReLU(),
459
+ nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1),
460
+ nn.BatchNorm2d(64),
461
+ nn.ReLU(),
462
+ )
463
+ self.CTmerge2 = nn.Sequential(
464
+ nn.Conv2d(in_channels=320, out_channels=128, kernel_size=3, stride=1, padding=1),
465
+ nn.BatchNorm2d(128),
466
+ nn.ReLU(),
467
+ nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1),
468
+ nn.BatchNorm2d(128),
469
+ nn.ReLU(),
470
+ )
471
+ self.CTmerge3 = nn.Sequential(
472
+ nn.Conv2d(in_channels=768, out_channels=512, kernel_size=3, stride=1, padding=1),
473
+ nn.BatchNorm2d(512),
474
+ nn.ReLU(),
475
+ nn.Conv2d(in_channels=512, out_channels=384, kernel_size=3, stride=1, padding=1),
476
+ nn.BatchNorm2d(384),
477
+ nn.ReLU(),
478
+ nn.Conv2d(in_channels=384, out_channels=384, kernel_size=3, stride=1, padding=1),
479
+ nn.BatchNorm2d(384),
480
+ nn.ReLU(),
481
+ )
482
+
483
+ self.CTmerge4 = nn.Sequential(
484
+ nn.Conv2d(in_channels=896, out_channels=640, kernel_size=3, stride=1, padding=1),
485
+ nn.BatchNorm2d(640),
486
+ nn.ReLU(),
487
+ nn.Conv2d(in_channels=640, out_channels=512, kernel_size=3, stride=1, padding=1),
488
+ nn.BatchNorm2d(512),
489
+ nn.ReLU(),
490
+ nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),
491
+ nn.BatchNorm2d(512),
492
+ nn.ReLU(),
493
+ )
494
+
495
+ # decoder
496
+ self.decoder4 = nn.Sequential(
497
+ DepthwiseConv2d(
498
+ in_chans=1408,
499
+ out_chans=1024,
500
+ kernel_size=3,
501
+ stride=1,
502
+ padding=1
503
+ ),
504
+ DepthwiseConv2d(
505
+ in_chans=1024,
506
+ out_chans=512,
507
+ kernel_size=3,
508
+ stride=1,
509
+ padding=1
510
+ ),
511
+ nn.GELU()
512
+ )
513
+ self.decoder3 = nn.Sequential(
514
+ DepthwiseConv2d(
515
+ in_chans=896,
516
+ out_chans=512,
517
+ kernel_size=3,
518
+ stride=1,
519
+ padding=1
520
+ ),
521
+ DepthwiseConv2d(
522
+ in_chans=512,
523
+ out_chans=384,
524
+ kernel_size=3,
525
+ stride=1,
526
+ padding=1
527
+ ),
528
+ nn.GELU()
529
+ )
530
+
531
+ self.decoder2 = nn.Sequential(
532
+ DepthwiseConv2d(
533
+ in_chans=576,
534
+ out_chans=256,
535
+ kernel_size=3,
536
+ stride=1,
537
+ padding=1
538
+ ),
539
+ DepthwiseConv2d(
540
+ in_chans=256,
541
+ out_chans=192,
542
+ kernel_size=3,
543
+ stride=1,
544
+ padding=1
545
+ ),
546
+ nn.GELU()
547
+ )
548
+
549
+ self.decoder1 = nn.Sequential(
550
+ DepthwiseConv2d(
551
+ in_chans=256,
552
+ out_chans=64,
553
+ kernel_size=3,
554
+ stride=1,
555
+ padding=1
556
+ ),
557
+ DepthwiseConv2d(
558
+ in_chans=64,
559
+ out_chans=16,
560
+ kernel_size=3,
561
+ stride=1,
562
+ padding=1
563
+ ),
564
+ nn.GELU()
565
+ )
566
+ self.sbr4 = SBR(512)
567
+ self.sbr3 = SBR(384)
568
+ self.sbr2 = SBR(192)
569
+ self.sbr1 = SBR(16)
570
+
571
+ self.head = nn.Conv2d(in_channels=16, out_channels=num_classes, kernel_size=1)
572
+
573
+ def forward(self, input):
574
+ ### encoder ###
575
+ # stage1 = ts1 cat cs1
576
+ # t_s1 = self.t_stage1(input)
577
+ # print(input.shape)
578
+ # print('++++++++++++++++++++++')
579
+
580
+ t_s1 = self.stage1_conv_embed(input) # 1*3*224*224-->1*3136*64
581
+
582
+ # print(t_s1.shape)
583
+ # print('======================')
584
+
585
+ t_s1 = self.stage1_transformer(t_s1) # 1*3136*64-->1*64*56*56
586
+
587
+ # print(t_s1.shape)
588
+ # print('----------------------')
589
+
590
+ c_s1 = self.c_stage1(input) # 1*3*224*224-->1*64*112*112
591
+
592
+ # print(c_s1.shape)
593
+ # print('!!!!!!!!!!!!!!!!!!!!!!!')
594
+
595
+ stage1 = self.CTmerge1(torch.cat([t_s1, self.c_max(c_s1)], dim=1)) # 1*64*56*56 # 拼接两条分支
596
+
597
+ # print(stage1.shape)
598
+ # print('[[[[[[[[[[[[[[[[[[[[[[[')
599
+
600
+ # stage2 = ts2 up cs2
601
+ # t_s2 = self.t_stage2(stage1)
602
+ t_s2 = self.stage2_conv_embed(stage1) # 1*64*56*56-->1*784*192 # stage2_conv_embed是转化为序列操作
603
+
604
+ # print(t_s2.shape)
605
+ # print('[[[[[[[[[[[[[[[[[[[[[[[')
606
+ t_s2 = self.stage2_transformer(t_s2) # 1*784*192-->1*192*28*28
607
+ # print(t_s2.shape)
608
+ # print('+++++++++++++++++++++++++')
609
+
610
+ c_s2 = self.c_stage2(c_s1) # 1*64*112*112-->1*128*56*56
611
+ stage2 = self.CTmerge2(
612
+ torch.cat([c_s2, F.interpolate(t_s2, size=c_s2.size()[2:], mode='bilinear', align_corners=True)],
613
+ dim=1)) # mode='bilinear'表示使用双线性插值 1*128*56*56
614
+
615
+ # stage3 = ts3 cat cs3
616
+ # t_s3 = self.t_stage3(t_s2)
617
+ t_s3 = self.stage3_conv_embed(t_s2) # 1*192*28*28-->1*196*384
618
+ # print(t_s3.shape)
619
+ # print('///////////////////////')
620
+ t_s3 = self.stage3_transformer(t_s3) # 1*196*384-->1*384*14*14
621
+ # print(t_s3.shape)
622
+ # print('....................')
623
+ c_s3 = self.c_stage3(stage2) # 1*128*56*56-->1*384*28*28
624
+ stage3 = self.CTmerge3(torch.cat([t_s3, self.c_max(c_s3)], dim=1)) # 1*384*14*14
625
+
626
+ # stage4 = ts4 up cs4
627
+ # t_s4 = self.t_stage4(stage3)
628
+ t_s4 = self.stage4_conv_embed(stage3) # 1*384*14*14-->1*49*384
629
+ # print(t_s4.shape)
630
+ # print(';;;;;;;;;;;;;;;;;;;;;;;')
631
+ t_s4 = self.stage4_transformer(t_s4) # 1*49*384-->1*384*7*7
632
+ # print(t_s4.shape)
633
+ # print('::::::::::::::::::::')
634
+
635
+ c_s4 = self.c_stage4(c_s3) # 1*384*28*28-->1*512*14*14
636
+ stage4 = self.CTmerge4(
637
+ torch.cat([c_s4, F.interpolate(t_s4, size=c_s4.size()[2:], mode='bilinear', align_corners=True)],
638
+ dim=1)) # 1*512*14*14
639
+
640
+ # cs5
641
+ c_s5 = self.c_stage5(stage4) # 1*512*14*14-->1*1024*7*7
642
+
643
+ ### decoder ###
644
+ decoder4 = torch.cat([c_s5, t_s4], dim=1) # 1*1408*7*7
645
+ decoder4 = self.decoder4(decoder4) # 1*1408*7*7-->1*512*7*7
646
+ decoder4 = F.interpolate(decoder4, size=c_s3.size()[2:], mode='bilinear',
647
+ align_corners=True) # 1*512*7*7-->1*512*28*28
648
+ decoder4 = self.sbr4(decoder4) # 1*512*28*28
649
+ # print(decoder4.shape)
650
+
651
+ decoder3 = torch.cat([decoder4, c_s3], dim=1) # 1*896*28*28
652
+ decoder3 = self.decoder3(decoder3) # 1*384*28*28
653
+ decoder3 = F.interpolate(decoder3, size=t_s2.size()[2:], mode='bilinear', align_corners=True) # 1*384*28*28
654
+ decoder3 = self.sbr3(decoder3) # 1*384*28*28
655
+ # print(decoder3.shape)
656
+
657
+ decoder2 = torch.cat([decoder3, t_s2], dim=1) # 1*576*28*28
658
+ decoder2 = self.decoder2(decoder2) # 1*192*28*28
659
+ decoder2 = F.interpolate(decoder2, size=c_s1.size()[2:], mode='bilinear', align_corners=True) # 1*192*112*112
660
+ decoder2 = self.sbr2(decoder2) # 1*192*112*112
661
+ # print(decoder2.shape)
662
+
663
+ decoder1 = torch.cat([decoder2, c_s1], dim=1) # 1*256*112*112
664
+ decoder1 = self.decoder1(decoder1) # 1*16*112*112
665
+ # print(decoder1.shape)
666
+ final = F.interpolate(decoder1, size=input.size()[2:], mode='bilinear', align_corners=True) # 1*16*224*224
667
+ # print(final.shape)
668
+ # final = self.sbr1(decoder1)
669
+ # print(final.shape)
670
+ final = self.head(final) # 1*3*224*224
671
+
672
+ return final
673
+
674
+
675
+ if __name__ == '__main__':
676
+ x = torch.rand(1, 3, 224, 224).cuda()
677
+ model = DBNet(img_size=224, in_channels=3, num_classes=7).cuda()
678
+ y = model(x)
679
+ print(y.shape)
680
+ # torch.Size([1, 7, 224, 224])
cloud_adapter/dino_layers/__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ from .dino_head import DINOHead
7
+ from .mlp import Mlp
8
+ from .patch_embed import PatchEmbed
9
+ from .swiglu_ffn import SwiGLUFFN, SwiGLUFFNFused
10
+ from .block import NestedTensorBlock,drop_add_residual_stochastic_depth
11
+ from .attention import MemEffAttention
cloud_adapter/dino_layers/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (502 Bytes). View file
 
cloud_adapter/dino_layers/__pycache__/attention.cpython-38.pyc ADDED
Binary file (2.57 kB). View file
 
cloud_adapter/dino_layers/__pycache__/block.cpython-38.pyc ADDED
Binary file (8.24 kB). View file
 
cloud_adapter/dino_layers/__pycache__/dino_head.cpython-38.pyc ADDED
Binary file (1.97 kB). View file
 
cloud_adapter/dino_layers/__pycache__/drop_path.cpython-38.pyc ADDED
Binary file (1.21 kB). View file
 
cloud_adapter/dino_layers/__pycache__/layer_scale.cpython-38.pyc ADDED
Binary file (1.01 kB). View file
 
cloud_adapter/dino_layers/__pycache__/mlp.cpython-38.pyc ADDED
Binary file (1.19 kB). View file
 
cloud_adapter/dino_layers/__pycache__/patch_embed.cpython-38.pyc ADDED
Binary file (2.65 kB). View file
 
cloud_adapter/dino_layers/__pycache__/swiglu_ffn.cpython-38.pyc ADDED
Binary file (2.18 kB). View file
 
cloud_adapter/dino_layers/attention.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ # References:
7
+ # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
8
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
9
+
10
+ import logging
11
+ import os
12
+ import warnings
13
+
14
+ from torch import Tensor
15
+ from torch import nn
16
+
17
+
18
+ logger = logging.getLogger("dinov2")
19
+
20
+
21
+ XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None
22
+ try:
23
+ if XFORMERS_ENABLED:
24
+ from xformers.ops import memory_efficient_attention, unbind
25
+
26
+ XFORMERS_AVAILABLE = True
27
+ warnings.warn("xFormers is available (Attention)")
28
+ else:
29
+ warnings.warn("xFormers is disabled (Attention)")
30
+ raise ImportError
31
+ except ImportError:
32
+ XFORMERS_AVAILABLE = False
33
+ warnings.warn("xFormers is not available (Attention)")
34
+
35
+
36
+ class Attention(nn.Module):
37
+ def __init__(
38
+ self,
39
+ dim: int,
40
+ num_heads: int = 8,
41
+ qkv_bias: bool = False,
42
+ proj_bias: bool = True,
43
+ attn_drop: float = 0.0,
44
+ proj_drop: float = 0.0,
45
+ ) -> None:
46
+ super().__init__()
47
+ self.num_heads = num_heads
48
+ head_dim = dim // num_heads
49
+ self.scale = head_dim**-0.5
50
+
51
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
52
+ self.attn_drop = nn.Dropout(attn_drop)
53
+ self.proj = nn.Linear(dim, dim, bias=proj_bias)
54
+ self.proj_drop = nn.Dropout(proj_drop)
55
+
56
+ def forward(self, x: Tensor) -> Tensor:
57
+ B, N, C = x.shape
58
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
59
+
60
+ q, k, v = qkv[0] * self.scale, qkv[1], qkv[2]
61
+ attn = q @ k.transpose(-2, -1)
62
+
63
+ attn = attn.softmax(dim=-1)
64
+ attn = self.attn_drop(attn)
65
+
66
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
67
+ x = self.proj(x)
68
+ x = self.proj_drop(x)
69
+ return x
70
+
71
+
72
+ class MemEffAttention(Attention):
73
+ def forward(self, x: Tensor, attn_bias=None) -> Tensor:
74
+ if not XFORMERS_AVAILABLE:
75
+ if attn_bias is not None:
76
+ raise AssertionError("xFormers is required for using nested tensors")
77
+ return super().forward(x)
78
+
79
+ B, N, C = x.shape
80
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
81
+
82
+ q, k, v = unbind(qkv, 2)
83
+
84
+ x = memory_efficient_attention(q, k, v, attn_bias=attn_bias)
85
+ x = x.reshape([B, N, C])
86
+
87
+ x = self.proj(x)
88
+ x = self.proj_drop(x)
89
+ return x
cloud_adapter/dino_layers/block.py ADDED
@@ -0,0 +1,260 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ # References:
7
+ # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
8
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
9
+
10
+ import logging
11
+ import os
12
+ from typing import Callable, List, Any, Tuple, Dict
13
+ import warnings
14
+
15
+ import torch
16
+ from torch import nn, Tensor
17
+
18
+ from .attention import Attention, MemEffAttention
19
+ from .drop_path import DropPath
20
+ from .layer_scale import LayerScale
21
+ from .mlp import Mlp
22
+
23
+
24
+ logger = logging.getLogger("dinov2")
25
+
26
+
27
+ XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None
28
+ try:
29
+ if XFORMERS_ENABLED:
30
+ from xformers.ops import fmha, scaled_index_add, index_select_cat
31
+
32
+ XFORMERS_AVAILABLE = True
33
+ warnings.warn("xFormers is available (Block)")
34
+ else:
35
+ warnings.warn("xFormers is disabled (Block)")
36
+ raise ImportError
37
+ except ImportError:
38
+ XFORMERS_AVAILABLE = False
39
+
40
+ warnings.warn("xFormers is not available (Block)")
41
+
42
+
43
+ class Block(nn.Module):
44
+ def __init__(
45
+ self,
46
+ dim: int,
47
+ num_heads: int,
48
+ mlp_ratio: float = 4.0,
49
+ qkv_bias: bool = False,
50
+ proj_bias: bool = True,
51
+ ffn_bias: bool = True,
52
+ drop: float = 0.0,
53
+ attn_drop: float = 0.0,
54
+ init_values=None,
55
+ drop_path: float = 0.0,
56
+ act_layer: Callable[..., nn.Module] = nn.GELU,
57
+ norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
58
+ attn_class: Callable[..., nn.Module] = Attention,
59
+ ffn_layer: Callable[..., nn.Module] = Mlp,
60
+ ) -> None:
61
+ super().__init__()
62
+ # print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}")
63
+ self.norm1 = norm_layer(dim)
64
+ self.attn = attn_class(
65
+ dim,
66
+ num_heads=num_heads,
67
+ qkv_bias=qkv_bias,
68
+ proj_bias=proj_bias,
69
+ attn_drop=attn_drop,
70
+ proj_drop=drop,
71
+ )
72
+ self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
73
+ self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
74
+
75
+ self.norm2 = norm_layer(dim)
76
+ mlp_hidden_dim = int(dim * mlp_ratio)
77
+ self.mlp = ffn_layer(
78
+ in_features=dim,
79
+ hidden_features=mlp_hidden_dim,
80
+ act_layer=act_layer,
81
+ drop=drop,
82
+ bias=ffn_bias,
83
+ )
84
+ self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
85
+ self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
86
+
87
+ self.sample_drop_ratio = drop_path
88
+
89
+ def forward(self, x: Tensor) -> Tensor:
90
+ def attn_residual_func(x: Tensor) -> Tensor:
91
+ return self.ls1(self.attn(self.norm1(x)))
92
+
93
+ def ffn_residual_func(x: Tensor) -> Tensor:
94
+ return self.ls2(self.mlp(self.norm2(x)))
95
+
96
+ if self.training and self.sample_drop_ratio > 0.1:
97
+ # the overhead is compensated only for a drop path rate larger than 0.1
98
+ x = drop_add_residual_stochastic_depth(
99
+ x,
100
+ residual_func=attn_residual_func,
101
+ sample_drop_ratio=self.sample_drop_ratio,
102
+ )
103
+ x = drop_add_residual_stochastic_depth(
104
+ x,
105
+ residual_func=ffn_residual_func,
106
+ sample_drop_ratio=self.sample_drop_ratio,
107
+ )
108
+ elif self.training and self.sample_drop_ratio > 0.0:
109
+ x = x + self.drop_path1(attn_residual_func(x))
110
+ x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2
111
+ else:
112
+ x = x + attn_residual_func(x)
113
+ x = x + ffn_residual_func(x)
114
+ return x
115
+
116
+
117
+ def drop_add_residual_stochastic_depth(
118
+ x: Tensor,
119
+ residual_func: Callable[[Tensor], Tensor],
120
+ sample_drop_ratio: float = 0.0,
121
+ ) -> Tensor:
122
+ # 1) extract subset using permutation
123
+ b, n, d = x.shape
124
+ sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
125
+ brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
126
+ x_subset = x[brange]
127
+
128
+ # 2) apply residual_func to get residual
129
+ residual = residual_func(x_subset)
130
+
131
+ x_flat = x.flatten(1)
132
+ residual = residual.flatten(1)
133
+
134
+ residual_scale_factor = b / sample_subset_size
135
+
136
+ # 3) add the residual
137
+ x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
138
+ return x_plus_residual.view_as(x)
139
+
140
+
141
+ def get_branges_scales(x, sample_drop_ratio=0.0):
142
+ b, n, d = x.shape
143
+ sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
144
+ brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
145
+ residual_scale_factor = b / sample_subset_size
146
+ return brange, residual_scale_factor
147
+
148
+
149
+ def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None):
150
+ if scaling_vector is None:
151
+ x_flat = x.flatten(1)
152
+ residual = residual.flatten(1)
153
+ x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
154
+ else:
155
+ x_plus_residual = scaled_index_add(
156
+ x, brange, residual.to(dtype=x.dtype), scaling=scaling_vector, alpha=residual_scale_factor
157
+ )
158
+ return x_plus_residual
159
+
160
+
161
+ attn_bias_cache: Dict[Tuple, Any] = {}
162
+
163
+
164
+ def get_attn_bias_and_cat(x_list, branges=None):
165
+ """
166
+ this will perform the index select, cat the tensors, and provide the attn_bias from cache
167
+ """
168
+ batch_sizes = [b.shape[0] for b in branges] if branges is not None else [x.shape[0] for x in x_list]
169
+ all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list))
170
+ if all_shapes not in attn_bias_cache.keys():
171
+ seqlens = []
172
+ for b, x in zip(batch_sizes, x_list):
173
+ for _ in range(b):
174
+ seqlens.append(x.shape[1])
175
+ attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens)
176
+ attn_bias._batch_sizes = batch_sizes
177
+ attn_bias_cache[all_shapes] = attn_bias
178
+
179
+ if branges is not None:
180
+ cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(1, -1, x_list[0].shape[-1])
181
+ else:
182
+ tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list)
183
+ cat_tensors = torch.cat(tensors_bs1, dim=1)
184
+
185
+ return attn_bias_cache[all_shapes], cat_tensors
186
+
187
+
188
+ def drop_add_residual_stochastic_depth_list(
189
+ x_list: List[Tensor],
190
+ residual_func: Callable[[Tensor, Any], Tensor],
191
+ sample_drop_ratio: float = 0.0,
192
+ scaling_vector=None,
193
+ ) -> Tensor:
194
+ # 1) generate random set of indices for dropping samples in the batch
195
+ branges_scales = [get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list]
196
+ branges = [s[0] for s in branges_scales]
197
+ residual_scale_factors = [s[1] for s in branges_scales]
198
+
199
+ # 2) get attention bias and index+concat the tensors
200
+ attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges)
201
+
202
+ # 3) apply residual_func to get residual, and split the result
203
+ residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore
204
+
205
+ outputs = []
206
+ for x, brange, residual, residual_scale_factor in zip(x_list, branges, residual_list, residual_scale_factors):
207
+ outputs.append(add_residual(x, brange, residual, residual_scale_factor, scaling_vector).view_as(x))
208
+ return outputs
209
+
210
+
211
+ class NestedTensorBlock(Block):
212
+ def forward_nested(self, x_list: List[Tensor]) -> List[Tensor]:
213
+ """
214
+ x_list contains a list of tensors to nest together and run
215
+ """
216
+ assert isinstance(self.attn, MemEffAttention)
217
+
218
+ if self.training and self.sample_drop_ratio > 0.0:
219
+
220
+ def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
221
+ return self.attn(self.norm1(x), attn_bias=attn_bias)
222
+
223
+ def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
224
+ return self.mlp(self.norm2(x))
225
+
226
+ x_list = drop_add_residual_stochastic_depth_list(
227
+ x_list,
228
+ residual_func=attn_residual_func,
229
+ sample_drop_ratio=self.sample_drop_ratio,
230
+ scaling_vector=self.ls1.gamma if isinstance(self.ls1, LayerScale) else None,
231
+ )
232
+ x_list = drop_add_residual_stochastic_depth_list(
233
+ x_list,
234
+ residual_func=ffn_residual_func,
235
+ sample_drop_ratio=self.sample_drop_ratio,
236
+ scaling_vector=self.ls2.gamma if isinstance(self.ls1, LayerScale) else None,
237
+ )
238
+ return x_list
239
+ else:
240
+
241
+ def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
242
+ return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias))
243
+
244
+ def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
245
+ return self.ls2(self.mlp(self.norm2(x)))
246
+
247
+ attn_bias, x = get_attn_bias_and_cat(x_list)
248
+ x = x + attn_residual_func(x, attn_bias=attn_bias)
249
+ x = x + ffn_residual_func(x)
250
+ return attn_bias.split(x)
251
+
252
+ def forward(self, x_or_x_list):
253
+ if isinstance(x_or_x_list, Tensor):
254
+ return super().forward(x_or_x_list)
255
+ elif isinstance(x_or_x_list, list):
256
+ if not XFORMERS_AVAILABLE:
257
+ raise AssertionError("xFormers is required for using nested tensors")
258
+ return self.forward_nested(x_or_x_list)
259
+ else:
260
+ raise AssertionError
cloud_adapter/dino_layers/dino_head.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ from torch.nn.init import trunc_normal_
9
+ from torch.nn.utils import weight_norm
10
+
11
+
12
+ class DINOHead(nn.Module):
13
+ def __init__(
14
+ self,
15
+ in_dim,
16
+ out_dim,
17
+ use_bn=False,
18
+ nlayers=3,
19
+ hidden_dim=2048,
20
+ bottleneck_dim=256,
21
+ mlp_bias=True,
22
+ ):
23
+ super().__init__()
24
+ nlayers = max(nlayers, 1)
25
+ self.mlp = _build_mlp(nlayers, in_dim, bottleneck_dim, hidden_dim=hidden_dim, use_bn=use_bn, bias=mlp_bias)
26
+ self.apply(self._init_weights)
27
+ self.last_layer = weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False))
28
+ self.last_layer.weight_g.data.fill_(1)
29
+
30
+ def _init_weights(self, m):
31
+ if isinstance(m, nn.Linear):
32
+ trunc_normal_(m.weight, std=0.02)
33
+ if isinstance(m, nn.Linear) and m.bias is not None:
34
+ nn.init.constant_(m.bias, 0)
35
+
36
+ def forward(self, x):
37
+ x = self.mlp(x)
38
+ eps = 1e-6 if x.dtype == torch.float16 else 1e-12
39
+ x = nn.functional.normalize(x, dim=-1, p=2, eps=eps)
40
+ x = self.last_layer(x)
41
+ return x
42
+
43
+
44
+ def _build_mlp(nlayers, in_dim, bottleneck_dim, hidden_dim=None, use_bn=False, bias=True):
45
+ if nlayers == 1:
46
+ return nn.Linear(in_dim, bottleneck_dim, bias=bias)
47
+ else:
48
+ layers = [nn.Linear(in_dim, hidden_dim, bias=bias)]
49
+ if use_bn:
50
+ layers.append(nn.BatchNorm1d(hidden_dim))
51
+ layers.append(nn.GELU())
52
+ for _ in range(nlayers - 2):
53
+ layers.append(nn.Linear(hidden_dim, hidden_dim, bias=bias))
54
+ if use_bn:
55
+ layers.append(nn.BatchNorm1d(hidden_dim))
56
+ layers.append(nn.GELU())
57
+ layers.append(nn.Linear(hidden_dim, bottleneck_dim, bias=bias))
58
+ return nn.Sequential(*layers)
cloud_adapter/dino_layers/drop_path.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ # References:
7
+ # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
8
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py
9
+
10
+
11
+ from torch import nn
12
+
13
+
14
+ def drop_path(x, drop_prob: float = 0.0, training: bool = False):
15
+ if drop_prob == 0.0 or not training:
16
+ return x
17
+ keep_prob = 1 - drop_prob
18
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
19
+ random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
20
+ if keep_prob > 0.0:
21
+ random_tensor.div_(keep_prob)
22
+ output = x * random_tensor
23
+ return output
24
+
25
+
26
+ class DropPath(nn.Module):
27
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
28
+
29
+ def __init__(self, drop_prob=None):
30
+ super(DropPath, self).__init__()
31
+ self.drop_prob = drop_prob
32
+
33
+ def forward(self, x):
34
+ return drop_path(x, self.drop_prob, self.training)
cloud_adapter/dino_layers/layer_scale.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ # Modified from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L103-L110
7
+
8
+ from typing import Union
9
+
10
+ import torch
11
+ from torch import Tensor
12
+ from torch import nn
13
+
14
+
15
+ class LayerScale(nn.Module):
16
+ def __init__(
17
+ self,
18
+ dim: int,
19
+ init_values: Union[float, Tensor] = 1e-5,
20
+ inplace: bool = False,
21
+ ) -> None:
22
+ super().__init__()
23
+ self.inplace = inplace
24
+ self.gamma = nn.Parameter(init_values * torch.ones(dim))
25
+
26
+ def forward(self, x: Tensor) -> Tensor:
27
+ return x.mul_(self.gamma) if self.inplace else x * self.gamma
cloud_adapter/dino_layers/mlp.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ # References:
7
+ # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
8
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/mlp.py
9
+
10
+
11
+ from typing import Callable, Optional
12
+
13
+ from torch import Tensor, nn
14
+
15
+
16
+ class Mlp(nn.Module):
17
+ def __init__(
18
+ self,
19
+ in_features: int,
20
+ hidden_features: Optional[int] = None,
21
+ out_features: Optional[int] = None,
22
+ act_layer: Callable[..., nn.Module] = nn.GELU,
23
+ drop: float = 0.0,
24
+ bias: bool = True,
25
+ ) -> None:
26
+ super().__init__()
27
+ out_features = out_features or in_features
28
+ hidden_features = hidden_features or in_features
29
+ self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
30
+ self.act = act_layer()
31
+ self.fc2 = nn.Linear(hidden_features, out_features, bias=bias)
32
+ self.drop = nn.Dropout(drop)
33
+
34
+ def forward(self, x: Tensor) -> Tensor:
35
+ x = self.fc1(x)
36
+ x = self.act(x)
37
+ x = self.drop(x)
38
+ x = self.fc2(x)
39
+ x = self.drop(x)
40
+ return x
cloud_adapter/dino_layers/patch_embed.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ # References:
7
+ # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
8
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
9
+
10
+ from typing import Callable, Optional, Tuple, Union
11
+
12
+ from torch import Tensor
13
+ import torch.nn as nn
14
+
15
+
16
+ def make_2tuple(x):
17
+ if isinstance(x, tuple):
18
+ assert len(x) == 2
19
+ return x
20
+
21
+ assert isinstance(x, int)
22
+ return (x, x)
23
+
24
+
25
+ class PatchEmbed(nn.Module):
26
+ """
27
+ 2D image to patch embedding: (B,C,H,W) -> (B,N,D)
28
+
29
+ Args:
30
+ img_size: Image size.
31
+ patch_size: Patch token size.
32
+ in_chans: Number of input image channels.
33
+ embed_dim: Number of linear projection output channels.
34
+ norm_layer: Normalization layer.
35
+ """
36
+
37
+ def __init__(
38
+ self,
39
+ img_size: Union[int, Tuple[int, int]] = 224,
40
+ patch_size: Union[int, Tuple[int, int]] = 16,
41
+ in_chans: int = 3,
42
+ embed_dim: int = 768,
43
+ norm_layer: Optional[Callable] = None,
44
+ flatten_embedding: bool = True,
45
+ ) -> None:
46
+ super().__init__()
47
+
48
+ image_HW = make_2tuple(img_size)
49
+ patch_HW = make_2tuple(patch_size)
50
+ patch_grid_size = (
51
+ image_HW[0] // patch_HW[0],
52
+ image_HW[1] // patch_HW[1],
53
+ )
54
+
55
+ self.img_size = image_HW
56
+ self.patch_size = patch_HW
57
+ self.patches_resolution = patch_grid_size
58
+ self.num_patches = patch_grid_size[0] * patch_grid_size[1]
59
+
60
+ self.in_chans = in_chans
61
+ self.embed_dim = embed_dim
62
+
63
+ self.flatten_embedding = flatten_embedding
64
+
65
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW)
66
+ self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
67
+
68
+ def forward(self, x: Tensor) -> Tensor:
69
+ _, _, H, W = x.shape
70
+ patch_H, patch_W = self.patch_size
71
+
72
+ assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}"
73
+ assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}"
74
+
75
+ x = self.proj(x) # B C H W
76
+ H, W = x.size(2), x.size(3)
77
+ x = x.flatten(2).transpose(1, 2) # B HW C
78
+ x = self.norm(x)
79
+ if not self.flatten_embedding:
80
+ x = x.reshape(-1, H, W, self.embed_dim) # B H W C
81
+ return x
82
+
83
+ def flops(self) -> float:
84
+ Ho, Wo = self.patches_resolution
85
+ flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
86
+ if self.norm is not None:
87
+ flops += Ho * Wo * self.embed_dim
88
+ return flops
cloud_adapter/dino_layers/swiglu_ffn.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ import os
7
+ from typing import Callable, Optional
8
+ import warnings
9
+
10
+ from torch import Tensor, nn
11
+ import torch.nn.functional as F
12
+
13
+
14
+ class SwiGLUFFN(nn.Module):
15
+ def __init__(
16
+ self,
17
+ in_features: int,
18
+ hidden_features: Optional[int] = None,
19
+ out_features: Optional[int] = None,
20
+ act_layer: Callable[..., nn.Module] = None,
21
+ drop: float = 0.0,
22
+ bias: bool = True,
23
+ ) -> None:
24
+ super().__init__()
25
+ out_features = out_features or in_features
26
+ hidden_features = hidden_features or in_features
27
+ self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias)
28
+ self.w3 = nn.Linear(hidden_features, out_features, bias=bias)
29
+
30
+ def forward(self, x: Tensor) -> Tensor:
31
+ x12 = self.w12(x)
32
+ x1, x2 = x12.chunk(2, dim=-1)
33
+ hidden = F.silu(x1) * x2
34
+ return self.w3(hidden)
35
+
36
+
37
+ XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None
38
+ try:
39
+ if XFORMERS_ENABLED:
40
+ from xformers.ops import SwiGLU
41
+
42
+ XFORMERS_AVAILABLE = True
43
+ warnings.warn("xFormers is available (SwiGLU)")
44
+ else:
45
+ warnings.warn("xFormers is disabled (SwiGLU)")
46
+ raise ImportError
47
+ except ImportError:
48
+ SwiGLU = SwiGLUFFN
49
+ XFORMERS_AVAILABLE = False
50
+
51
+ warnings.warn("xFormers is not available (SwiGLU)")
52
+
53
+
54
+ class SwiGLUFFNFused(SwiGLU):
55
+ def __init__(
56
+ self,
57
+ in_features: int,
58
+ hidden_features: Optional[int] = None,
59
+ out_features: Optional[int] = None,
60
+ act_layer: Callable[..., nn.Module] = None,
61
+ drop: float = 0.0,
62
+ bias: bool = True,
63
+ ) -> None:
64
+ out_features = out_features or in_features
65
+ hidden_features = hidden_features or in_features
66
+ hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8
67
+ super().__init__(
68
+ in_features=in_features,
69
+ hidden_features=hidden_features,
70
+ out_features=out_features,
71
+ bias=bias,
72
+ )
cloud_adapter/dino_v2.py ADDED
@@ -0,0 +1,353 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ # References:
7
+ # https://github.com/facebookresearch/dino/blob/main/vision_transformer.py
8
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
9
+
10
+ from functools import partial
11
+ import math
12
+ from typing import Sequence, Tuple, Union, Callable
13
+
14
+ import torch
15
+ import torch.nn as nn
16
+ import torch.utils.checkpoint
17
+ from mmseg.models.builder import BACKBONES
18
+ from mmengine.model import BaseModule
19
+ import torch.nn.functional as F
20
+ from .dino_layers import (
21
+ Mlp,
22
+ PatchEmbed,
23
+ SwiGLUFFNFused,
24
+ MemEffAttention,
25
+ NestedTensorBlock as Block,
26
+ )
27
+
28
+
29
+ def named_apply(
30
+ fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False
31
+ ) -> nn.Module:
32
+ if not depth_first and include_root:
33
+ fn(module=module, name=name)
34
+ for child_name, child_module in module.named_children():
35
+ child_name = ".".join((name, child_name)) if name else child_name
36
+ named_apply(
37
+ fn=fn,
38
+ module=child_module,
39
+ name=child_name,
40
+ depth_first=depth_first,
41
+ include_root=True,
42
+ )
43
+ if depth_first and include_root:
44
+ fn(module=module, name=name)
45
+ return module
46
+
47
+
48
+ class BlockChunk(nn.ModuleList):
49
+ def forward(self, x):
50
+ for b in self:
51
+ x = b(x)
52
+ return x
53
+
54
+
55
+ @BACKBONES.register_module()
56
+ class DinoVisionTransformer(BaseModule):
57
+ def __init__(
58
+ self,
59
+ img_size=224,
60
+ patch_size=16,
61
+ in_chans=3,
62
+ embed_dim=768,
63
+ depth=12,
64
+ num_heads=12,
65
+ mlp_ratio=4.0,
66
+ qkv_bias=True,
67
+ ffn_bias=True,
68
+ proj_bias=True,
69
+ drop_path_rate=0.0,
70
+ drop_path_uniform=False,
71
+ init_values=None, # for layerscale: None or 0 => no layerscale
72
+ embed_layer=PatchEmbed,
73
+ act_layer=nn.GELU,
74
+ block_fn=partial(Block, attn_class=MemEffAttention),
75
+ ffn_layer="mlp",
76
+ block_chunks=1,
77
+ out_indices=[7, 11, 15, 23],
78
+ init_cfg=None,
79
+ ):
80
+ """
81
+ Args:
82
+ img_size (int, tuple): input image size
83
+ patch_size (int, tuple): patch size
84
+ in_chans (int): number of input channels
85
+ embed_dim (int): embedding dimension
86
+ depth (int): depth of transformer
87
+ num_heads (int): number of attention heads
88
+ mlp_ratio (int): ratio of mlp hidden dim to embedding dim
89
+ qkv_bias (bool): enable bias for qkv if True
90
+ proj_bias (bool): enable bias for proj in attn if True
91
+ ffn_bias (bool): enable bias for ffn if True
92
+ drop_path_rate (float): stochastic depth rate
93
+ drop_path_uniform (bool): apply uniform drop rate across blocks
94
+ weight_init (str): weight init scheme
95
+ init_values (float): layer-scale init values
96
+ embed_layer (nn.Module): patch embedding layer
97
+ act_layer (nn.Module): MLP activation layer
98
+ block_fn (nn.Module): transformer block class
99
+ ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity"
100
+ block_chunks: (int) split block sequence into block_chunks units for FSDP wrap
101
+ """
102
+ super().__init__(init_cfg)
103
+ norm_layer = partial(nn.LayerNorm, eps=1e-6)
104
+ self.out_indices = out_indices
105
+ self.drop_path_rate = drop_path_rate
106
+ self.num_features = (
107
+ self.embed_dim
108
+ ) = embed_dim # num_features for consistency with other models
109
+ self.num_tokens = 1
110
+ self.n_blocks = depth
111
+ self.num_heads = num_heads
112
+ self.norm_layer = norm_layer
113
+ self.patch_size = patch_size
114
+
115
+ self.patch_embed = embed_layer(
116
+ img_size=img_size,
117
+ patch_size=patch_size,
118
+ in_chans=in_chans,
119
+ embed_dim=embed_dim,
120
+ )
121
+ num_patches = self.patch_embed.num_patches
122
+
123
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
124
+ self.pos_embed = nn.Parameter(
125
+ torch.zeros(1, num_patches + self.num_tokens, embed_dim)
126
+ )
127
+
128
+ if drop_path_uniform is True:
129
+ dpr = [drop_path_rate] * depth
130
+ else:
131
+ dpr = [
132
+ x.item() for x in torch.linspace(0, drop_path_rate, depth)
133
+ ] # stochastic depth decay rule
134
+
135
+ if ffn_layer == "mlp":
136
+ ffn_layer = Mlp
137
+ elif ffn_layer == "swiglufused" or ffn_layer == "swiglu":
138
+ ffn_layer = SwiGLUFFNFused
139
+ elif ffn_layer == "identity":
140
+
141
+ def f(*args, **kwargs):
142
+ return nn.Identity()
143
+
144
+ ffn_layer = f
145
+ else:
146
+ raise NotImplementedError
147
+
148
+ blocks_list = [
149
+ block_fn(
150
+ dim=embed_dim,
151
+ num_heads=num_heads,
152
+ mlp_ratio=mlp_ratio,
153
+ qkv_bias=qkv_bias,
154
+ proj_bias=proj_bias,
155
+ ffn_bias=ffn_bias,
156
+ drop_path=dpr[i],
157
+ norm_layer=norm_layer,
158
+ act_layer=act_layer,
159
+ ffn_layer=ffn_layer,
160
+ init_values=init_values,
161
+ )
162
+ for i in range(depth)
163
+ ]
164
+ if block_chunks > 0:
165
+ self.chunked_blocks = True
166
+ chunked_blocks = []
167
+ chunksize = depth // block_chunks
168
+ for i in range(0, depth, chunksize):
169
+ # this is to keep the block index consistent if we chunk the block list
170
+ chunked_blocks.append(
171
+ [nn.Identity()] * i + blocks_list[i : i + chunksize]
172
+ )
173
+ self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks])
174
+ else:
175
+ self.chunked_blocks = False
176
+ self.blocks = nn.ModuleList(blocks_list)
177
+
178
+ self.norm = norm_layer(embed_dim)
179
+ self.head = nn.Identity()
180
+
181
+ self.mask_token = nn.Parameter(torch.zeros(1, embed_dim))
182
+
183
+ def interpolate_pos_encoding(self, x, w, h):
184
+ previous_dtype = x.dtype
185
+ npatch = x.shape[1] - 1
186
+ N = self.pos_embed.shape[1] - 1
187
+ if npatch == N and w == h:
188
+ return self.pos_embed
189
+ pos_embed = self.pos_embed.float()
190
+ class_pos_embed = pos_embed[:, 0]
191
+ patch_pos_embed = pos_embed[:, 1:]
192
+ dim = x.shape[-1]
193
+ w0 = w // self.patch_size
194
+ h0 = h // self.patch_size
195
+ # we add a small number to avoid floating point error in the interpolation
196
+ # see discussion at https://github.com/facebookresearch/dino/issues/8
197
+ w0, h0 = w0 + 0.1, h0 + 0.1
198
+
199
+ patch_pos_embed = nn.functional.interpolate(
200
+ patch_pos_embed.reshape(
201
+ 1, int(math.sqrt(N)), int(math.sqrt(N)), dim
202
+ ).permute(0, 3, 1, 2),
203
+ scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)),
204
+ mode="bicubic",
205
+ )
206
+
207
+ assert (
208
+ int(w0) == patch_pos_embed.shape[-2]
209
+ and int(h0) == patch_pos_embed.shape[-1]
210
+ )
211
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
212
+ return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(
213
+ previous_dtype
214
+ )
215
+
216
+ def prepare_tokens_with_masks(self, x, masks=None):
217
+ B, nc, w, h = x.shape
218
+ x = self.patch_embed(x)
219
+ if masks is not None:
220
+ x = torch.where(
221
+ masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x
222
+ )
223
+
224
+ x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
225
+ x = x + self.interpolate_pos_encoding(x, w, h)
226
+
227
+ return x
228
+
229
+ def forward_features_list(self, x_list, masks_list):
230
+ x = [
231
+ self.prepare_tokens_with_masks(x, masks)
232
+ for x, masks in zip(x_list, masks_list)
233
+ ]
234
+ for blk in self.blocks:
235
+ x = blk(x)
236
+
237
+ all_x = x
238
+ output = []
239
+ for x, masks in zip(all_x, masks_list):
240
+ x_norm = self.norm(x)
241
+ output.append(
242
+ {
243
+ "x_norm_clstoken": x_norm[:, 0],
244
+ "x_norm_patchtokens": x_norm[:, 1:],
245
+ "x_prenorm": x,
246
+ "masks": masks,
247
+ }
248
+ )
249
+ return output
250
+
251
+ def forward_features(self, x, masks=None):
252
+ B, _, h, w = x.shape
253
+ if isinstance(x, list):
254
+ return self.forward_features_list(x, masks)
255
+
256
+ x = self.prepare_tokens_with_masks(x, masks)
257
+ outs = []
258
+ for idx, blk in enumerate(self.blocks):
259
+ x = blk(x)
260
+ if idx in self.out_indices:
261
+ outs.append(
262
+ x[:, 1:, :]
263
+ .permute(0, 2, 1)
264
+ .reshape(B, -1, h // self.patch_size, w // self.patch_size)
265
+ .contiguous()
266
+ )
267
+ return outs
268
+
269
+ def _get_intermediate_layers_not_chunked(self, x, n=1):
270
+ x = self.prepare_tokens_with_masks(x)
271
+ # If n is an int, take the n last blocks. If it's a list, take them
272
+ output, total_block_len = [], len(self.blocks)
273
+ blocks_to_take = (
274
+ range(total_block_len - n, total_block_len) if isinstance(n, int) else n
275
+ )
276
+ for i, blk in enumerate(self.blocks):
277
+ x = blk(x)
278
+ if i in blocks_to_take:
279
+ output.append(x)
280
+ assert len(output) == len(
281
+ blocks_to_take
282
+ ), f"only {len(output)} / {len(blocks_to_take)} blocks found"
283
+ return output
284
+
285
+ def _get_intermediate_layers_chunked(self, x, n=1):
286
+ x = self.prepare_tokens_with_masks(x)
287
+ output, i, total_block_len = [], 0, len(self.blocks[-1])
288
+ # If n is an int, take the n last blocks. If it's a list, take them
289
+ blocks_to_take = (
290
+ range(total_block_len - n, total_block_len) if isinstance(n, int) else n
291
+ )
292
+ for block_chunk in self.blocks:
293
+ for blk in block_chunk[i:]: # Passing the nn.Identity()
294
+ x = blk(x)
295
+ if i in blocks_to_take:
296
+ output.append(x)
297
+ i += 1
298
+ assert len(output) == len(
299
+ blocks_to_take
300
+ ), f"only {len(output)} / {len(blocks_to_take)} blocks found"
301
+ return output
302
+
303
+ def get_intermediate_layers(
304
+ self,
305
+ x: torch.Tensor,
306
+ n: Union[int, Sequence] = 1, # Layers or n last layers to take
307
+ reshape: bool = False,
308
+ return_class_token: bool = False,
309
+ norm=True,
310
+ ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]:
311
+ if self.chunked_blocks:
312
+ outputs = self._get_intermediate_layers_chunked(x, n)
313
+ else:
314
+ outputs = self._get_intermediate_layers_not_chunked(x, n)
315
+ if norm:
316
+ outputs = [self.norm(out) for out in outputs]
317
+ class_tokens = [out[:, 0] for out in outputs]
318
+ outputs = [out[:, 1:] for out in outputs]
319
+ if reshape:
320
+ B, _, w, h = x.shape
321
+ outputs = [
322
+ out.reshape(B, w // self.patch_size, h // self.patch_size, -1)
323
+ .permute(0, 3, 1, 2)
324
+ .contiguous()
325
+ for out in outputs
326
+ ]
327
+ if return_class_token:
328
+ return tuple(zip(outputs, class_tokens))
329
+ return tuple(outputs)
330
+
331
+ def forward(self, *args, **kwargs):
332
+ ret = self.forward_features(*args, **kwargs)
333
+ if isinstance(ret[0], torch.Tensor):
334
+ ret[0] = F.interpolate(
335
+ ret[0], scale_factor=4, mode="bilinear", align_corners=False
336
+ )
337
+ ret[1] = F.interpolate(
338
+ ret[1], scale_factor=2, mode="bilinear", align_corners=False
339
+ )
340
+ ret[3] = F.interpolate(
341
+ ret[3], scale_factor=0.5, mode="bilinear", align_corners=False
342
+ )
343
+ else:
344
+ ret[0][0] = F.interpolate(
345
+ ret[0][0], scale_factor=4, mode="bilinear", align_corners=False
346
+ )
347
+ ret[0][1] = F.interpolate(
348
+ ret[0][1], scale_factor=2, mode="bilinear", align_corners=False
349
+ )
350
+ ret[0][3] = F.interpolate(
351
+ ret[0][3], scale_factor=0.5, mode="bilinear", align_corners=False
352
+ )
353
+ return ret
cloud_adapter/hrcloudnet.py ADDED
@@ -0,0 +1,751 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 论文地址:https://arxiv.org/abs/2407.07365
2
+ #
3
+ from __future__ import absolute_import
4
+ from __future__ import division
5
+ from __future__ import print_function
6
+
7
+ import logging
8
+ import os
9
+
10
+ import numpy as np
11
+ import torch
12
+ import torch._utils
13
+ import torch.nn as nn
14
+ import torch.nn.functional as F
15
+
16
+ BatchNorm2d = nn.BatchNorm2d
17
+ # BN_MOMENTUM = 0.01
18
+ relu_inplace = True
19
+ BN_MOMENTUM = 0.1
20
+ ALIGN_CORNERS = True
21
+
22
+ logger = logging.getLogger(__name__)
23
+
24
+
25
+ def conv3x3(in_planes, out_planes, stride=1):
26
+ """3x3 convolution with padding"""
27
+ return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
28
+ padding=1, bias=False)
29
+
30
+
31
+ from yacs.config import CfgNode as CN
32
+ import math
33
+ from einops import rearrange
34
+
35
+ # configs for HRNet48
36
+ HRNET_48 = CN()
37
+ HRNET_48.FINAL_CONV_KERNEL = 1
38
+
39
+ HRNET_48.STAGE1 = CN()
40
+ HRNET_48.STAGE1.NUM_MODULES = 1
41
+ HRNET_48.STAGE1.NUM_BRANCHES = 1
42
+ HRNET_48.STAGE1.NUM_BLOCKS = [4]
43
+ HRNET_48.STAGE1.NUM_CHANNELS = [64]
44
+ HRNET_48.STAGE1.BLOCK = 'BOTTLENECK'
45
+ HRNET_48.STAGE1.FUSE_METHOD = 'SUM'
46
+
47
+ HRNET_48.STAGE2 = CN()
48
+ HRNET_48.STAGE2.NUM_MODULES = 1
49
+ HRNET_48.STAGE2.NUM_BRANCHES = 2
50
+ HRNET_48.STAGE2.NUM_BLOCKS = [4, 4]
51
+ HRNET_48.STAGE2.NUM_CHANNELS = [48, 96]
52
+ HRNET_48.STAGE2.BLOCK = 'BASIC'
53
+ HRNET_48.STAGE2.FUSE_METHOD = 'SUM'
54
+
55
+ HRNET_48.STAGE3 = CN()
56
+ HRNET_48.STAGE3.NUM_MODULES = 4
57
+ HRNET_48.STAGE3.NUM_BRANCHES = 3
58
+ HRNET_48.STAGE3.NUM_BLOCKS = [4, 4, 4]
59
+ HRNET_48.STAGE3.NUM_CHANNELS = [48, 96, 192]
60
+ HRNET_48.STAGE3.BLOCK = 'BASIC'
61
+ HRNET_48.STAGE3.FUSE_METHOD = 'SUM'
62
+
63
+ HRNET_48.STAGE4 = CN()
64
+ HRNET_48.STAGE4.NUM_MODULES = 3
65
+ HRNET_48.STAGE4.NUM_BRANCHES = 4
66
+ HRNET_48.STAGE4.NUM_BLOCKS = [4, 4, 4, 4]
67
+ HRNET_48.STAGE4.NUM_CHANNELS = [48, 96, 192, 384]
68
+ HRNET_48.STAGE4.BLOCK = 'BASIC'
69
+ HRNET_48.STAGE4.FUSE_METHOD = 'SUM'
70
+
71
+ HRNET_32 = CN()
72
+ HRNET_32.FINAL_CONV_KERNEL = 1
73
+
74
+ HRNET_32.STAGE1 = CN()
75
+ HRNET_32.STAGE1.NUM_MODULES = 1
76
+ HRNET_32.STAGE1.NUM_BRANCHES = 1
77
+ HRNET_32.STAGE1.NUM_BLOCKS = [4]
78
+ HRNET_32.STAGE1.NUM_CHANNELS = [64]
79
+ HRNET_32.STAGE1.BLOCK = 'BOTTLENECK'
80
+ HRNET_32.STAGE1.FUSE_METHOD = 'SUM'
81
+
82
+ HRNET_32.STAGE2 = CN()
83
+ HRNET_32.STAGE2.NUM_MODULES = 1
84
+ HRNET_32.STAGE2.NUM_BRANCHES = 2
85
+ HRNET_32.STAGE2.NUM_BLOCKS = [4, 4]
86
+ HRNET_32.STAGE2.NUM_CHANNELS = [32, 64]
87
+ HRNET_32.STAGE2.BLOCK = 'BASIC'
88
+ HRNET_32.STAGE2.FUSE_METHOD = 'SUM'
89
+
90
+ HRNET_32.STAGE3 = CN()
91
+ HRNET_32.STAGE3.NUM_MODULES = 4
92
+ HRNET_32.STAGE3.NUM_BRANCHES = 3
93
+ HRNET_32.STAGE3.NUM_BLOCKS = [4, 4, 4]
94
+ HRNET_32.STAGE3.NUM_CHANNELS = [32, 64, 128]
95
+ HRNET_32.STAGE3.BLOCK = 'BASIC'
96
+ HRNET_32.STAGE3.FUSE_METHOD = 'SUM'
97
+
98
+ HRNET_32.STAGE4 = CN()
99
+ HRNET_32.STAGE4.NUM_MODULES = 3
100
+ HRNET_32.STAGE4.NUM_BRANCHES = 4
101
+ HRNET_32.STAGE4.NUM_BLOCKS = [4, 4, 4, 4]
102
+ HRNET_32.STAGE4.NUM_CHANNELS = [32, 64, 128, 256]
103
+ HRNET_32.STAGE4.BLOCK = 'BASIC'
104
+ HRNET_32.STAGE4.FUSE_METHOD = 'SUM'
105
+
106
+ HRNET_18 = CN()
107
+ HRNET_18.FINAL_CONV_KERNEL = 1
108
+
109
+ HRNET_18.STAGE1 = CN()
110
+ HRNET_18.STAGE1.NUM_MODULES = 1
111
+ HRNET_18.STAGE1.NUM_BRANCHES = 1
112
+ HRNET_18.STAGE1.NUM_BLOCKS = [4]
113
+ HRNET_18.STAGE1.NUM_CHANNELS = [64]
114
+ HRNET_18.STAGE1.BLOCK = 'BOTTLENECK'
115
+ HRNET_18.STAGE1.FUSE_METHOD = 'SUM'
116
+
117
+ HRNET_18.STAGE2 = CN()
118
+ HRNET_18.STAGE2.NUM_MODULES = 1
119
+ HRNET_18.STAGE2.NUM_BRANCHES = 2
120
+ HRNET_18.STAGE2.NUM_BLOCKS = [4, 4]
121
+ HRNET_18.STAGE2.NUM_CHANNELS = [18, 36]
122
+ HRNET_18.STAGE2.BLOCK = 'BASIC'
123
+ HRNET_18.STAGE2.FUSE_METHOD = 'SUM'
124
+
125
+ HRNET_18.STAGE3 = CN()
126
+ HRNET_18.STAGE3.NUM_MODULES = 4
127
+ HRNET_18.STAGE3.NUM_BRANCHES = 3
128
+ HRNET_18.STAGE3.NUM_BLOCKS = [4, 4, 4]
129
+ HRNET_18.STAGE3.NUM_CHANNELS = [18, 36, 72]
130
+ HRNET_18.STAGE3.BLOCK = 'BASIC'
131
+ HRNET_18.STAGE3.FUSE_METHOD = 'SUM'
132
+
133
+ HRNET_18.STAGE4 = CN()
134
+ HRNET_18.STAGE4.NUM_MODULES = 3
135
+ HRNET_18.STAGE4.NUM_BRANCHES = 4
136
+ HRNET_18.STAGE4.NUM_BLOCKS = [4, 4, 4, 4]
137
+ HRNET_18.STAGE4.NUM_CHANNELS = [18, 36, 72, 144]
138
+ HRNET_18.STAGE4.BLOCK = 'BASIC'
139
+ HRNET_18.STAGE4.FUSE_METHOD = 'SUM'
140
+
141
+
142
+ class PPM(nn.Module):
143
+ def __init__(self, in_dim, reduction_dim, bins):
144
+ super(PPM, self).__init__()
145
+ self.features = []
146
+ for bin in bins:
147
+ self.features.append(nn.Sequential(
148
+ nn.AdaptiveAvgPool2d(bin),
149
+ nn.Conv2d(in_dim, reduction_dim, kernel_size=1, bias=False),
150
+ nn.BatchNorm2d(reduction_dim),
151
+ nn.ReLU(inplace=True)
152
+ ))
153
+ self.features = nn.ModuleList(self.features)
154
+
155
+ def forward(self, x):
156
+ x_size = x.size()
157
+ out = [x]
158
+ for f in self.features:
159
+ out.append(F.interpolate(f(x), x_size[2:], mode='bilinear', align_corners=True))
160
+ return torch.cat(out, 1)
161
+
162
+
163
+ class BasicBlock(nn.Module):
164
+ expansion = 1
165
+
166
+ def __init__(self, inplanes, planes, stride=1, downsample=None):
167
+ super(BasicBlock, self).__init__()
168
+ self.conv1 = conv3x3(inplanes, planes, stride)
169
+ self.bn1 = BatchNorm2d(planes, momentum=BN_MOMENTUM)
170
+ self.relu = nn.ReLU(inplace=relu_inplace)
171
+ self.conv2 = conv3x3(planes, planes)
172
+ self.bn2 = BatchNorm2d(planes, momentum=BN_MOMENTUM)
173
+ self.downsample = downsample
174
+ self.stride = stride
175
+
176
+ def forward(self, x):
177
+ residual = x
178
+
179
+ out = self.conv1(x)
180
+ out = self.bn1(out)
181
+ out = self.relu(out)
182
+
183
+ out = self.conv2(out)
184
+ out = self.bn2(out)
185
+
186
+ if self.downsample is not None:
187
+ residual = self.downsample(x)
188
+ out = out + residual
189
+ out = self.relu(out)
190
+
191
+ return out
192
+
193
+
194
+ class Bottleneck(nn.Module):
195
+ expansion = 4
196
+
197
+ def __init__(self, inplanes, planes, stride=1, downsample=None):
198
+ super(Bottleneck, self).__init__()
199
+ self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
200
+ self.bn1 = BatchNorm2d(planes, momentum=BN_MOMENTUM)
201
+ self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
202
+ padding=1, bias=False)
203
+ self.bn2 = BatchNorm2d(planes, momentum=BN_MOMENTUM)
204
+ self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1,
205
+ bias=False)
206
+ self.bn3 = BatchNorm2d(planes * self.expansion,
207
+ momentum=BN_MOMENTUM)
208
+ self.relu = nn.ReLU(inplace=relu_inplace)
209
+ self.downsample = downsample
210
+ self.stride = stride
211
+
212
+ def forward(self, x):
213
+ residual = x
214
+
215
+ out = self.conv1(x)
216
+ out = self.bn1(out)
217
+ out = self.relu(out)
218
+
219
+ out = self.conv2(out)
220
+ out = self.bn2(out)
221
+ out = self.relu(out)
222
+
223
+ out = self.conv3(out)
224
+ out = self.bn3(out)
225
+
226
+ if self.downsample is not None:
227
+ residual = self.downsample(x)
228
+ # att = self.downsample(att)
229
+ out = out + residual
230
+ out = self.relu(out)
231
+
232
+ return out
233
+
234
+
235
+ class HighResolutionModule(nn.Module):
236
+ def __init__(self, num_branches, blocks, num_blocks, num_inchannels,
237
+ num_channels, fuse_method, multi_scale_output=True):
238
+ super(HighResolutionModule, self).__init__()
239
+ self._check_branches(
240
+ num_branches, blocks, num_blocks, num_inchannels, num_channels)
241
+
242
+ self.num_inchannels = num_inchannels
243
+ self.fuse_method = fuse_method
244
+ self.num_branches = num_branches
245
+
246
+ self.multi_scale_output = multi_scale_output
247
+
248
+ self.branches = self._make_branches(
249
+ num_branches, blocks, num_blocks, num_channels)
250
+ self.fuse_layers = self._make_fuse_layers()
251
+ self.relu = nn.ReLU(inplace=relu_inplace)
252
+
253
+ def _check_branches(self, num_branches, blocks, num_blocks,
254
+ num_inchannels, num_channels):
255
+ if num_branches != len(num_blocks):
256
+ error_msg = 'NUM_BRANCHES({}) <> NUM_BLOCKS({})'.format(
257
+ num_branches, len(num_blocks))
258
+ logger.error(error_msg)
259
+ raise ValueError(error_msg)
260
+
261
+ if num_branches != len(num_channels):
262
+ error_msg = 'NUM_BRANCHES({}) <> NUM_CHANNELS({})'.format(
263
+ num_branches, len(num_channels))
264
+ logger.error(error_msg)
265
+ raise ValueError(error_msg)
266
+
267
+ if num_branches != len(num_inchannels):
268
+ error_msg = 'NUM_BRANCHES({}) <> NUM_INCHANNELS({})'.format(
269
+ num_branches, len(num_inchannels))
270
+ logger.error(error_msg)
271
+ raise ValueError(error_msg)
272
+
273
+ def _make_one_branch(self, branch_index, block, num_blocks, num_channels,
274
+ stride=1):
275
+ downsample = None
276
+ if stride != 1 or \
277
+ self.num_inchannels[branch_index] != num_channels[branch_index] * block.expansion:
278
+ downsample = nn.Sequential(
279
+ nn.Conv2d(self.num_inchannels[branch_index],
280
+ num_channels[branch_index] * block.expansion,
281
+ kernel_size=1, stride=stride, bias=False),
282
+ BatchNorm2d(num_channels[branch_index] * block.expansion,
283
+ momentum=BN_MOMENTUM),
284
+ )
285
+
286
+ layers = []
287
+ layers.append(block(self.num_inchannels[branch_index],
288
+ num_channels[branch_index], stride, downsample))
289
+ self.num_inchannels[branch_index] = \
290
+ num_channels[branch_index] * block.expansion
291
+ for i in range(1, num_blocks[branch_index]):
292
+ layers.append(block(self.num_inchannels[branch_index],
293
+ num_channels[branch_index]))
294
+
295
+ return nn.Sequential(*layers)
296
+
297
+ # 创建平行层
298
+ def _make_branches(self, num_branches, block, num_blocks, num_channels):
299
+ branches = []
300
+
301
+ for i in range(num_branches):
302
+ branches.append(
303
+ self._make_one_branch(i, block, num_blocks, num_channels))
304
+
305
+ return nn.ModuleList(branches)
306
+
307
+ def _make_fuse_layers(self):
308
+ if self.num_branches == 1:
309
+ return None
310
+ num_branches = self.num_branches # 3
311
+ num_inchannels = self.num_inchannels # [48, 96, 192]
312
+ fuse_layers = []
313
+ for i in range(num_branches if self.multi_scale_output else 1):
314
+ fuse_layer = []
315
+ for j in range(num_branches):
316
+ if j > i:
317
+ fuse_layer.append(nn.Sequential(
318
+ nn.Conv2d(num_inchannels[j],
319
+ num_inchannels[i],
320
+ 1,
321
+ 1,
322
+ 0,
323
+ bias=False),
324
+ BatchNorm2d(num_inchannels[i], momentum=BN_MOMENTUM)))
325
+ elif j == i:
326
+ fuse_layer.append(None)
327
+ else:
328
+ conv3x3s = []
329
+ for k in range(i - j):
330
+ if k == i - j - 1:
331
+ num_outchannels_conv3x3 = num_inchannels[i]
332
+ conv3x3s.append(nn.Sequential(
333
+ nn.Conv2d(num_inchannels[j],
334
+ num_outchannels_conv3x3,
335
+ 3, 2, 1, bias=False),
336
+ BatchNorm2d(num_outchannels_conv3x3,
337
+ momentum=BN_MOMENTUM)))
338
+ else:
339
+ num_outchannels_conv3x3 = num_inchannels[j]
340
+ conv3x3s.append(nn.Sequential(
341
+ nn.Conv2d(num_inchannels[j],
342
+ num_outchannels_conv3x3,
343
+ 3, 2, 1, bias=False),
344
+ BatchNorm2d(num_outchannels_conv3x3,
345
+ momentum=BN_MOMENTUM),
346
+ nn.ReLU(inplace=relu_inplace)))
347
+ fuse_layer.append(nn.Sequential(*conv3x3s))
348
+ fuse_layers.append(nn.ModuleList(fuse_layer))
349
+
350
+ return nn.ModuleList(fuse_layers)
351
+
352
+ def get_num_inchannels(self):
353
+ return self.num_inchannels
354
+
355
+ def forward(self, x):
356
+ if self.num_branches == 1:
357
+ return [self.branches[0](x[0])]
358
+
359
+ for i in range(self.num_branches):
360
+ x[i] = self.branches[i](x[i])
361
+
362
+ x_fuse = []
363
+ for i in range(len(self.fuse_layers)):
364
+ y = x[0] if i == 0 else self.fuse_layers[i][0](x[0])
365
+ for j in range(1, self.num_branches):
366
+ if i == j:
367
+ y = y + x[j]
368
+ elif j > i:
369
+ width_output = x[i].shape[-1]
370
+ height_output = x[i].shape[-2]
371
+ y = y + F.interpolate(
372
+ self.fuse_layers[i][j](x[j]),
373
+ size=[height_output, width_output],
374
+ mode='bilinear', align_corners=ALIGN_CORNERS)
375
+ else:
376
+ y = y + self.fuse_layers[i][j](x[j])
377
+ x_fuse.append(self.relu(y))
378
+
379
+ return x_fuse
380
+
381
+
382
+ blocks_dict = {
383
+ 'BASIC': BasicBlock,
384
+ 'BOTTLENECK': Bottleneck
385
+ }
386
+
387
+
388
+ class HRCloudNet(nn.Module):
389
+
390
+ def __init__(self, in_channels=3,num_classes=2, base_c=48, **kwargs):
391
+ global ALIGN_CORNERS
392
+ extra = HRNET_48
393
+ super(HRCloudNet, self).__init__()
394
+ ALIGN_CORNERS = True
395
+ # ALIGN_CORNERS = config.MODEL.ALIGN_CORNERS
396
+ self.num_classes = num_classes
397
+ # stem net
398
+ self.conv1 = nn.Conv2d(in_channels, 64, kernel_size=3, stride=2, padding=1,
399
+ bias=False)
400
+ self.bn1 = BatchNorm2d(64, momentum=BN_MOMENTUM)
401
+ self.conv2 = nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1,
402
+ bias=False)
403
+ self.bn2 = BatchNorm2d(64, momentum=BN_MOMENTUM)
404
+ self.relu = nn.ReLU(inplace=relu_inplace)
405
+
406
+ self.stage1_cfg = extra['STAGE1']
407
+ num_channels = self.stage1_cfg['NUM_CHANNELS'][0]
408
+ block = blocks_dict[self.stage1_cfg['BLOCK']]
409
+ num_blocks = self.stage1_cfg['NUM_BLOCKS'][0]
410
+ self.layer1 = self._make_layer(block, 64, num_channels, num_blocks)
411
+ stage1_out_channel = block.expansion * num_channels
412
+
413
+ self.stage2_cfg = extra['STAGE2']
414
+ num_channels = self.stage2_cfg['NUM_CHANNELS']
415
+ block = blocks_dict[self.stage2_cfg['BLOCK']]
416
+ num_channels = [
417
+ num_channels[i] * block.expansion for i in range(len(num_channels))]
418
+ self.transition1 = self._make_transition_layer(
419
+ [stage1_out_channel], num_channels)
420
+ self.stage2, pre_stage_channels = self._make_stage(
421
+ self.stage2_cfg, num_channels)
422
+
423
+ self.stage3_cfg = extra['STAGE3']
424
+ num_channels = self.stage3_cfg['NUM_CHANNELS']
425
+ block = blocks_dict[self.stage3_cfg['BLOCK']]
426
+ num_channels = [
427
+ num_channels[i] * block.expansion for i in range(len(num_channels))]
428
+ self.transition2 = self._make_transition_layer(
429
+ pre_stage_channels, num_channels) # 只在pre[-1]与cur[-1]之间下采样?
430
+ self.stage3, pre_stage_channels = self._make_stage(
431
+ self.stage3_cfg, num_channels)
432
+
433
+ self.stage4_cfg = extra['STAGE4']
434
+ num_channels = self.stage4_cfg['NUM_CHANNELS']
435
+ block = blocks_dict[self.stage4_cfg['BLOCK']]
436
+ num_channels = [
437
+ num_channels[i] * block.expansion for i in range(len(num_channels))]
438
+ self.transition3 = self._make_transition_layer(
439
+ pre_stage_channels, num_channels)
440
+ self.stage4, pre_stage_channels = self._make_stage(
441
+ self.stage4_cfg, num_channels, multi_scale_output=True)
442
+ self.out_conv = OutConv(base_c, num_classes)
443
+ last_inp_channels = int(np.sum(pre_stage_channels))
444
+
445
+ self.corr = Corr(nclass=2)
446
+ self.proj = nn.Sequential(
447
+ # 512 32
448
+ nn.Conv2d(720, 48, kernel_size=3, stride=1, padding=1, bias=True),
449
+ nn.BatchNorm2d(48),
450
+ nn.ReLU(inplace=True),
451
+ nn.Dropout2d(0.1),
452
+ )
453
+ # self.up1 = Up(base_c * 16, base_c * 8 // factor, bilinear)
454
+ self.up2 = Up(base_c * 8, base_c * 4, True)
455
+ self.up3 = Up(base_c * 4, base_c * 2, True)
456
+ self.up4 = Up(base_c * 2, base_c, True)
457
+ fea_dim = 720
458
+ bins = (1, 2, 3, 6)
459
+ self.ppm = PPM(fea_dim, int(fea_dim / len(bins)), bins)
460
+ fea_dim *= 2
461
+ self.cls = nn.Sequential(
462
+ nn.Conv2d(fea_dim, 512, kernel_size=3, padding=1, bias=False),
463
+ nn.BatchNorm2d(512),
464
+ nn.ReLU(inplace=True),
465
+ nn.Dropout2d(p=0.1),
466
+ nn.Conv2d(512, num_classes, kernel_size=1)
467
+ )
468
+
469
+ '''
470
+ 转换层的作用有两种情况:
471
+
472
+ 当前分支数小于之前分支数时,仅对前几个分支进行通道数调整。
473
+ 当前分支数大于之前分支数时,新建一些转换层,对多余的分支进行下采样,改变通道数以适应后续的连接。
474
+ 最终,这些转换层会被组合成一个 nn.ModuleList 对象,并在网络的构建过程中使用。
475
+ 这有助于确保每个分支的通道数在不同阶段之间能够正确匹配,以便进行特征的融合和连接
476
+ '''
477
+
478
+ def _make_transition_layer(
479
+ self, num_channels_pre_layer, num_channels_cur_layer):
480
+ # 现在的分支数
481
+ num_branches_cur = len(num_channels_cur_layer) # 3
482
+ # 处理前的分支数
483
+ num_branches_pre = len(num_channels_pre_layer) # 2
484
+
485
+ transition_layers = []
486
+ for i in range(num_branches_cur):
487
+ # 如果当前分支数小于之前分支数,仅针对第一到第二阶段
488
+ if i < num_branches_pre:
489
+ # 如果对应层的通道数不一致,则进行转化(
490
+ if num_channels_cur_layer[i] != num_channels_pre_layer[i]:
491
+ transition_layers.append(nn.Sequential(
492
+
493
+ nn.Conv2d(num_channels_pre_layer[i],
494
+ num_channels_cur_layer[i],
495
+ 3,
496
+ 1,
497
+ 1,
498
+ bias=False),
499
+ BatchNorm2d(
500
+ num_channels_cur_layer[i], momentum=BN_MOMENTUM),
501
+ nn.ReLU(inplace=relu_inplace)))
502
+ else:
503
+ transition_layers.append(None)
504
+ else: # 在新建层下采样改变通道数
505
+ conv3x3s = []
506
+ for j in range(i + 1 - num_branches_pre): # 3
507
+ inchannels = num_channels_pre_layer[-1]
508
+ outchannels = num_channels_cur_layer[i] \
509
+ if j == i - num_branches_pre else inchannels
510
+ conv3x3s.append(nn.Sequential(
511
+ nn.Conv2d(
512
+ inchannels, outchannels, 3, 2, 1, bias=False),
513
+ BatchNorm2d(outchannels, momentum=BN_MOMENTUM),
514
+ nn.ReLU(inplace=relu_inplace)))
515
+ transition_layers.append(nn.Sequential(*conv3x3s))
516
+
517
+ return nn.ModuleList(transition_layers)
518
+
519
+ '''
520
+ _make_layer 函数的主要作用是创建一个由多个相同类型的残差块(Residual Block)组成的层。
521
+ '''
522
+
523
+ def _make_layer(self, block, inplanes, planes, blocks, stride=1):
524
+ downsample = None
525
+ if stride != 1 or inplanes != planes * block.expansion:
526
+ downsample = nn.Sequential(
527
+ nn.Conv2d(inplanes, planes * block.expansion,
528
+ kernel_size=1, stride=stride, bias=False),
529
+ BatchNorm2d(planes * block.expansion, momentum=BN_MOMENTUM),
530
+ )
531
+
532
+ layers = []
533
+ layers.append(block(inplanes, planes, stride, downsample))
534
+ inplanes = planes * block.expansion
535
+ for i in range(1, blocks):
536
+ layers.append(block(inplanes, planes))
537
+
538
+ return nn.Sequential(*layers)
539
+
540
+ # 多尺度融合
541
+ def _make_stage(self, layer_config, num_inchannels,
542
+ multi_scale_output=True):
543
+ num_modules = layer_config['NUM_MODULES']
544
+ num_branches = layer_config['NUM_BRANCHES']
545
+ num_blocks = layer_config['NUM_BLOCKS']
546
+ num_channels = layer_config['NUM_CHANNELS']
547
+ block = blocks_dict[layer_config['BLOCK']]
548
+ fuse_method = layer_config['FUSE_METHOD']
549
+
550
+ modules = []
551
+ for i in range(num_modules): # 重复4次
552
+ # multi_scale_output is only used last module
553
+ if not multi_scale_output and i == num_modules - 1:
554
+ reset_multi_scale_output = False
555
+ else:
556
+ reset_multi_scale_output = True
557
+ modules.append(
558
+ HighResolutionModule(num_branches,
559
+ block,
560
+ num_blocks,
561
+ num_inchannels,
562
+ num_channels,
563
+ fuse_method,
564
+ reset_multi_scale_output)
565
+ )
566
+ num_inchannels = modules[-1].get_num_inchannels()
567
+
568
+ return nn.Sequential(*modules), num_inchannels
569
+
570
+ def forward(self, input, need_fp=True, use_corr=True):
571
+ # from ipdb import set_trace
572
+ # set_trace()
573
+ x = self.conv1(input)
574
+ x = self.bn1(x)
575
+ x = self.relu(x)
576
+ # x_176 = x
577
+ x = self.conv2(x)
578
+ x = self.bn2(x)
579
+ x = self.relu(x)
580
+ x = self.layer1(x)
581
+
582
+ x_list = []
583
+ for i in range(self.stage2_cfg['NUM_BRANCHES']): # 2
584
+ if self.transition1[i] is not None:
585
+ x_list.append(self.transition1[i](x))
586
+ else:
587
+ x_list.append(x)
588
+ y_list = self.stage2(x_list)
589
+ # Y1
590
+ x_list = []
591
+ for i in range(self.stage3_cfg['NUM_BRANCHES']):
592
+ if self.transition2[i] is not None:
593
+ if i < self.stage2_cfg['NUM_BRANCHES']:
594
+ x_list.append(self.transition2[i](y_list[i]))
595
+ else:
596
+ x_list.append(self.transition2[i](y_list[-1]))
597
+ else:
598
+ x_list.append(y_list[i])
599
+ y_list = self.stage3(x_list)
600
+
601
+ x_list = []
602
+ for i in range(self.stage4_cfg['NUM_BRANCHES']):
603
+ if self.transition3[i] is not None:
604
+ if i < self.stage3_cfg['NUM_BRANCHES']:
605
+ x_list.append(self.transition3[i](y_list[i]))
606
+ else:
607
+ x_list.append(self.transition3[i](y_list[-1]))
608
+ else:
609
+ x_list.append(y_list[i])
610
+ x = self.stage4(x_list)
611
+ dict_return = {}
612
+ # Upsampling
613
+ x0_h, x0_w = x[0].size(2), x[0].size(3)
614
+
615
+ x3 = F.interpolate(x[3], size=(x0_h, x0_w), mode='bilinear', align_corners=ALIGN_CORNERS)
616
+ # x = self.stage3_(x)
617
+ x[2] = self.up2(x[3], x[2])
618
+ x2 = F.interpolate(x[2], size=(x0_h, x0_w), mode='bilinear', align_corners=ALIGN_CORNERS)
619
+ # x = self.stage2_(x)
620
+ x[1] = self.up3(x[2], x[1])
621
+ x1 = F.interpolate(x[1], size=(x0_h, x0_w), mode='bilinear', align_corners=ALIGN_CORNERS)
622
+ x[0] = self.up4(x[1], x[0])
623
+ xk = torch.cat([x[0], x1, x2, x3], 1)
624
+ # PPM
625
+ feat = self.ppm(xk)
626
+ x = self.cls(feat)
627
+ # fp分支
628
+ if need_fp:
629
+ logits = F.interpolate(x, size=input.size()[2:], mode='bilinear', align_corners=True)
630
+ # logits = self.out_conv(torch.cat((x, nn.Dropout2d(0.5)(x))))
631
+ out = logits
632
+ out_fp = logits
633
+ if use_corr:
634
+ proj_feats = self.proj(xk)
635
+ corr_out = self.corr(proj_feats, out)
636
+ corr_out = F.interpolate(corr_out, size=(352, 352), mode="bilinear", align_corners=True)
637
+ dict_return['corr_out'] = corr_out
638
+ dict_return['out'] = out
639
+ dict_return['out_fp'] = out_fp
640
+
641
+ return dict_return['out']
642
+
643
+ out = F.interpolate(x, size=input.size()[2:], mode='bilinear', align_corners=True)
644
+ if use_corr: # True
645
+ proj_feats = self.proj(xk)
646
+ # 计算
647
+ corr_out = self.corr(proj_feats, out)
648
+ corr_out = F.interpolate(corr_out, size=(352, 352), mode="bilinear", align_corners=True)
649
+ dict_return['corr_out'] = corr_out
650
+ dict_return['out'] = out
651
+ return dict_return['out']
652
+ # return x
653
+
654
+ def init_weights(self, pretrained='', ):
655
+ logger.info('=> init weights from normal distribution')
656
+ for m in self.modules():
657
+ if isinstance(m, nn.Conv2d):
658
+ nn.init.normal_(m.weight, std=0.001)
659
+ elif isinstance(m, nn.BatchNorm2d):
660
+ nn.init.constant_(m.weight, 1)
661
+ nn.init.constant_(m.bias, 0)
662
+ if os.path.isfile(pretrained):
663
+ pretrained_dict = torch.load(pretrained)
664
+ logger.info('=> loading pretrained model {}'.format(pretrained))
665
+ model_dict = self.state_dict()
666
+ pretrained_dict = {k: v for k, v in pretrained_dict.items()
667
+ if k in model_dict.keys()}
668
+ for k, _ in pretrained_dict.items():
669
+ logger.info(
670
+ '=> loading {} pretrained model {}'.format(k, pretrained))
671
+ model_dict.update(pretrained_dict)
672
+ self.load_state_dict(model_dict)
673
+
674
+
675
+ class OutConv(nn.Sequential):
676
+ def __init__(self, in_channels, num_classes):
677
+ super(OutConv, self).__init__(
678
+ nn.Conv2d(720, num_classes, kernel_size=1)
679
+ )
680
+
681
+
682
+ class DoubleConv(nn.Sequential):
683
+ def __init__(self, in_channels, out_channels, mid_channels=None):
684
+ if mid_channels is None:
685
+ mid_channels = out_channels
686
+ super(DoubleConv, self).__init__(
687
+ nn.Conv2d(in_channels + out_channels, mid_channels, kernel_size=3, padding=1, bias=False),
688
+ nn.BatchNorm2d(mid_channels),
689
+ nn.ReLU(inplace=True),
690
+ nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),
691
+ nn.BatchNorm2d(out_channels),
692
+ nn.ReLU(inplace=True)
693
+ )
694
+
695
+
696
+ class Up(nn.Module):
697
+ def __init__(self, in_channels, out_channels, bilinear=True):
698
+ super(Up, self).__init__()
699
+ if bilinear:
700
+ self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
701
+ self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)
702
+ else:
703
+ self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
704
+ self.conv = DoubleConv(in_channels, out_channels)
705
+
706
+ def forward(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor:
707
+ x1 = self.up(x1)
708
+ # [N, C, H, W]
709
+ diff_y = x2.size()[2] - x1.size()[2]
710
+ diff_x = x2.size()[3] - x1.size()[3]
711
+
712
+ # padding_left, padding_right, padding_top, padding_bottom
713
+ x1 = F.pad(x1, [diff_x // 2, diff_x - diff_x // 2,
714
+ diff_y // 2, diff_y - diff_y // 2])
715
+
716
+ x = torch.cat([x2, x1], dim=1)
717
+ x = self.conv(x)
718
+ return x
719
+
720
+
721
+ class Corr(nn.Module):
722
+ def __init__(self, nclass=2):
723
+ super(Corr, self).__init__()
724
+ self.nclass = nclass
725
+ self.conv1 = nn.Conv2d(48, self.nclass, kernel_size=1, stride=1, padding=0, bias=True)
726
+ self.conv2 = nn.Conv2d(48, self.nclass, kernel_size=1, stride=1, padding=0, bias=True)
727
+
728
+ def forward(self, feature_in, out):
729
+ # in torch.Size([4, 32, 22, 22])
730
+ # out = [4 2 352 352]
731
+ h_in, w_in = math.ceil(feature_in.shape[2] / (1)), math.ceil(feature_in.shape[3] / (1))
732
+ out = F.interpolate(out.detach(), (h_in, w_in), mode='bilinear', align_corners=True)
733
+ feature = F.interpolate(feature_in, (h_in, w_in), mode='bilinear', align_corners=True)
734
+ f1 = rearrange(self.conv1(feature), 'n c h w -> n c (h w)')
735
+ f2 = rearrange(self.conv2(feature), 'n c h w -> n c (h w)')
736
+ out_temp = rearrange(out, 'n c h w -> n c (h w)')
737
+ corr_map = torch.matmul(f1.transpose(1, 2), f2) / torch.sqrt(torch.tensor(f1.shape[1]).float())
738
+ corr_map = F.softmax(corr_map, dim=-1)
739
+ # out_temp 2 2 484
740
+ # corr_map 4 484 484
741
+ out = rearrange(torch.matmul(out_temp, corr_map), 'n c (h w) -> n c h w', h=h_in, w=w_in)
742
+ # out torch.Size([4, 2, 22, 22])
743
+ return out
744
+
745
+
746
+ if __name__ == '__main__':
747
+ input = torch.randn(4, 3, 352, 352)
748
+ cloud = HRCloudNet(num_classes=2)
749
+ output = cloud(input)
750
+ print(output.shape)
751
+ # torch.Size([4, 2, 352, 352]) torch.Size([4, 2, 352, 352]) torch.Size([4, 2, 352, 352])
cloud_adapter/kappamask.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # @Time : 2024/8/7 下午3:51
3
+ # @Author : xiaoshun
4
+ # @Email : [email protected]
5
+ # @File : kappamask.py.py
6
+ # @Software: PyCharm
7
+
8
+ import torch
9
+ from torch import nn as nn
10
+ from torch.nn import functional as F
11
+
12
+
13
+ class KappaMask(nn.Module):
14
+ def __init__(self, num_classes=2, in_channels=3):
15
+ super().__init__()
16
+ self.conv1 = nn.Sequential(
17
+ nn.Conv2d(in_channels, 64, 3, 1, 1),
18
+ nn.ReLU(inplace=True),
19
+ nn.Conv2d(64, 64, 3, 1, 1),
20
+ nn.ReLU(inplace=True),
21
+ )
22
+ self.conv2 = nn.Sequential(
23
+ nn.Conv2d(64, 128, 3, 1, 1),
24
+ nn.ReLU(inplace=True),
25
+ nn.Conv2d(128, 128, 3, 1, 1),
26
+ nn.ReLU(inplace=True),
27
+ )
28
+ self.conv3 = nn.Sequential(
29
+ nn.Conv2d(128, 256, 3, 1, 1),
30
+ nn.ReLU(inplace=True),
31
+ nn.Conv2d(256, 256, 3, 1, 1),
32
+ nn.ReLU(inplace=True),
33
+ )
34
+
35
+ self.conv4 = nn.Sequential(
36
+ nn.Conv2d(256, 512, 3, 1, 1),
37
+ nn.ReLU(inplace=True),
38
+ nn.Conv2d(512, 512, 3, 1, 1),
39
+ nn.ReLU(inplace=True),
40
+ )
41
+ self.drop4 = nn.Dropout(0.5)
42
+
43
+ self.conv5 = nn.Sequential(
44
+ nn.Conv2d(512, 1024, 3, 1, 1),
45
+ nn.ReLU(inplace=True),
46
+ nn.Conv2d(1024, 1024, 3, 1, 1),
47
+ nn.ReLU(inplace=True),
48
+ )
49
+ self.drop5 = nn.Dropout(0.5)
50
+
51
+ self.up6 = nn.Sequential(
52
+ nn.Upsample(scale_factor=2),
53
+ nn.ZeroPad2d((0, 1, 0, 1)),
54
+ nn.Conv2d(1024, 512, 2),
55
+ nn.ReLU(inplace=True)
56
+ )
57
+ self.conv6 = nn.Sequential(
58
+ nn.Conv2d(1024, 512, 3, 1, 1),
59
+ nn.ReLU(inplace=True),
60
+ nn.Conv2d(512, 512, 3, 1, 1),
61
+ nn.ReLU(inplace=True),
62
+ )
63
+ self.up7 = nn.Sequential(
64
+ nn.Upsample(scale_factor=2),
65
+ nn.ZeroPad2d((0, 1, 0, 1)),
66
+ nn.Conv2d(512, 256, 2),
67
+ nn.ReLU(inplace=True)
68
+ )
69
+ self.conv7 = nn.Sequential(
70
+ nn.Conv2d(512, 256, 3, 1, 1),
71
+ nn.ReLU(inplace=True),
72
+ nn.Conv2d(256, 256, 3, 1, 1),
73
+ nn.ReLU(inplace=True),
74
+ )
75
+
76
+ self.up8 = nn.Sequential(
77
+ nn.Upsample(scale_factor=2),
78
+ nn.ZeroPad2d((0, 1, 0, 1)),
79
+ nn.Conv2d(256, 128, 2),
80
+ nn.ReLU(inplace=True)
81
+ )
82
+ self.conv8 = nn.Sequential(
83
+ nn.Conv2d(256, 128, 3, 1, 1),
84
+ nn.ReLU(inplace=True),
85
+ nn.Conv2d(128, 128, 3, 1, 1),
86
+ nn.ReLU(inplace=True),
87
+ )
88
+
89
+ self.up9 = nn.Sequential(
90
+ nn.Upsample(scale_factor=2),
91
+ nn.ZeroPad2d((0, 1, 0, 1)),
92
+ nn.Conv2d(128, 64, 2),
93
+ nn.ReLU(inplace=True)
94
+ )
95
+ self.conv9 = nn.Sequential(
96
+ nn.Conv2d(128, 64, 3, 1, 1),
97
+ nn.ReLU(inplace=True),
98
+ nn.Conv2d(64, 64, 3, 1, 1),
99
+ nn.ReLU(inplace=True),
100
+ nn.Conv2d(64, 2, 3, 1, 1),
101
+ nn.ReLU(inplace=True),
102
+ )
103
+ self.conv10 = nn.Conv2d(2, num_classes, 1)
104
+ self.__init_weights()
105
+
106
+ def __init_weights(self):
107
+ for m in self.modules():
108
+ if isinstance(m, nn.Conv2d):
109
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
110
+
111
+ def forward(self, x):
112
+ conv1 = self.conv1(x)
113
+ pool1 = F.max_pool2d(conv1, 2, 2)
114
+
115
+ conv2 = self.conv2(pool1)
116
+ pool2 = F.max_pool2d(conv2, 2, 2)
117
+
118
+ conv3 = self.conv3(pool2)
119
+ pool3 = F.max_pool2d(conv3, 2, 2)
120
+
121
+ conv4 = self.conv4(pool3)
122
+ drop4 = self.drop4(conv4)
123
+ pool4 = F.max_pool2d(drop4, 2, 2)
124
+
125
+ conv5 = self.conv5(pool4)
126
+ drop5 = self.drop5(conv5)
127
+
128
+ up6 = self.up6(drop5)
129
+ merge6 = torch.cat((drop4, up6), dim=1)
130
+ conv6 = self.conv6(merge6)
131
+
132
+ up7 = self.up7(conv6)
133
+ merge7 = torch.cat((conv3, up7), dim=1)
134
+ conv7 = self.conv7(merge7)
135
+
136
+ up8 = self.up8(conv7)
137
+ merge8 = torch.cat((conv2, up8), dim=1)
138
+ conv8 = self.conv8(merge8)
139
+
140
+ up9 = self.up9(conv8)
141
+ merge9 = torch.cat((conv1, up9), dim=1)
142
+ conv9 = self.conv9(merge9)
143
+
144
+ output = self.conv10(conv9)
145
+ return output
146
+
147
+
148
+ if __name__ == '__main__':
149
+ model = KappaMask(num_classes=2, in_channels=3)
150
+ fake_data = torch.rand(2, 3, 256, 256)
151
+ output = model(fake_data)
152
+ print(output.shape)
cloud_adapter/mcdnet.py ADDED
@@ -0,0 +1,435 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # @Time : 2024/7/21 下午3:51
3
+ # @Author : xiaoshun
4
+ # @Email : [email protected]
5
+ # @File : mcdnet.py
6
+ # @Software: PyCharm
7
+ import image_dehazer
8
+ import numpy as np
9
+ # 论文地址:https://www.sciencedirect.com/science/article/pii/S1569843224001742?via%3Dihub
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+
14
+
15
+ class _DPFF(nn.Module):
16
+ def __init__(self, in_channels) -> None:
17
+ super(_DPFF, self).__init__()
18
+ self.cbr1 = nn.Conv2d(in_channels * 2, in_channels, 1, 1, bias=False)
19
+ self.cbr2 = nn.Conv2d(in_channels * 2, in_channels, 1, 1, bias=False)
20
+ # self.sigmoid = nn.Sigmoid()
21
+ self.cbr3 = nn.Conv2d(in_channels, in_channels, 1, 1, bias=False)
22
+ self.cbr4 = nn.Conv2d(in_channels * 2, in_channels, 1, 1, bias=False)
23
+
24
+ def forward(self, feature1, feature2):
25
+ d1 = torch.abs(feature1 - feature2)
26
+ d2 = self.cbr1(torch.cat([feature1, feature2], dim=1))
27
+ d = torch.cat([d1, d2], dim=1)
28
+ d = self.cbr2(d)
29
+ # d = self.sigmoid(d)
30
+
31
+ v1, v2 = self.cbr3(feature1), self.cbr3(feature2)
32
+ v1, v2 = v1 * d, v2 * d
33
+ features = torch.cat([v1, v2], dim=1)
34
+ features = self.cbr4(features)
35
+
36
+ return features
37
+
38
+
39
+ class DPFF(nn.Module):
40
+ def __init__(self, layer_channels) -> None:
41
+ super(DPFF, self).__init__()
42
+ self.cfes = nn.ModuleList()
43
+ for layer_channel in layer_channels:
44
+ self.cfes.append(_DPFF(layer_channel))
45
+
46
+ def forward(self, features1, features2):
47
+ outputs = []
48
+ for feature1, feature2, cfe in zip(features1, features2, self.cfes):
49
+ outputs.append(cfe(feature1, feature2))
50
+ return outputs
51
+
52
+
53
+ class DirectDPFF(nn.Module):
54
+ def __init__(self, layer_channels) -> None:
55
+ super(DirectDPFF, self).__init__()
56
+ self.fusions = nn.ModuleList(
57
+ [nn.Conv2d(layer_channel * 2, layer_channel, 1, 1) for layer_channel in layer_channels]
58
+ )
59
+
60
+ def forward(self, features1, features2):
61
+ outputs = []
62
+ for feature1, feature2, fusion in zip(features1, features2, self.fusions):
63
+ feature = torch.cat([feature1, feature2], dim=1)
64
+ outputs.append(fusion(feature))
65
+ return outputs
66
+
67
+
68
+ class ConvBlock(nn.Module):
69
+ def __init__(self, input_size, output_size, kernel_size=4, stride=2, padding=1, bias=True,
70
+ bn=False, activation=True, maxpool=True):
71
+ super(ConvBlock, self).__init__()
72
+ self.module = []
73
+ if maxpool:
74
+ down = nn.Sequential(
75
+ *[
76
+ nn.MaxPool2d(2),
77
+ nn.Conv2d(input_size, output_size, 1, 1, 0, bias=bias)
78
+ ]
79
+ )
80
+ else:
81
+ down = nn.Conv2d(input_size, output_size, kernel_size, stride, padding, bias=bias)
82
+ self.module.append(down)
83
+ if bn:
84
+ self.module.append(nn.BatchNorm2d(output_size))
85
+ if activation:
86
+ self.module.append(nn.PReLU())
87
+ self.module = nn.Sequential(*self.module)
88
+
89
+ def forward(self, x):
90
+ out = self.module(x)
91
+
92
+ return out
93
+
94
+
95
+ class DeconvBlock(nn.Module):
96
+ def __init__(self, input_size, output_size, kernel_size=4, stride=2, padding=1, bias=True,
97
+ bn=False, activation=True, bilinear=True):
98
+ super(DeconvBlock, self).__init__()
99
+ self.module = []
100
+ if bilinear:
101
+ deconv = nn.Sequential(
102
+ *[
103
+ nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
104
+ nn.Conv2d(input_size, output_size, 1, 1, 0, bias=bias)
105
+ ]
106
+ )
107
+ else:
108
+ deconv = nn.ConvTranspose2d(input_size, output_size, kernel_size, stride, padding, bias=bias)
109
+ self.module.append(deconv)
110
+ if bn:
111
+ self.module.append(nn.BatchNorm2d(output_size))
112
+ if activation:
113
+ self.module.append(nn.PReLU())
114
+ self.module = nn.Sequential(*self.module)
115
+
116
+ def forward(self, x):
117
+ out = self.module(x)
118
+
119
+ return out
120
+
121
+
122
+ class FusionBlock(torch.nn.Module):
123
+ def __init__(self, num_filter, num_ft, kernel_size=4, stride=2, padding=1, bias=True, maxpool=False,
124
+ bilinear=False):
125
+ super(FusionBlock, self).__init__()
126
+ self.num_ft = num_ft
127
+ self.up_convs = nn.ModuleList()
128
+ self.down_convs = nn.ModuleList()
129
+ for i in range(self.num_ft):
130
+ self.up_convs.append(
131
+ DeconvBlock(num_filter // (2 ** i), num_filter // (2 ** (i + 1)), kernel_size, stride, padding,
132
+ bias=bias, bilinear=bilinear)
133
+ )
134
+ self.down_convs.append(
135
+ ConvBlock(num_filter // (2 ** (i + 1)), num_filter // (2 ** i), kernel_size, stride, padding, bias=bias,
136
+ maxpool=maxpool)
137
+ )
138
+
139
+ def forward(self, ft_l, ft_h_list):
140
+ ft_fusion = ft_l
141
+ for i in range(len(ft_h_list)):
142
+ ft = ft_fusion
143
+ for j in range(self.num_ft - i):
144
+ ft = self.up_convs[j](ft)
145
+ ft = ft - ft_h_list[i]
146
+ for j in range(self.num_ft - i):
147
+ ft = self.down_convs[self.num_ft - i - j - 1](ft)
148
+ ft_fusion = ft_fusion + ft
149
+
150
+ return ft_fusion
151
+
152
+
153
+ class ConvLayer(nn.Module):
154
+ def __init__(self, in_channels, out_channels, kernel_size, stride, bias=True):
155
+ super(ConvLayer, self).__init__()
156
+ reflection_padding = kernel_size // 2
157
+ self.reflection_pad = nn.ReflectionPad2d(reflection_padding)
158
+ self.conv2d = nn.Conv2d(in_channels, out_channels, kernel_size, stride, bias=bias)
159
+
160
+ def forward(self, x):
161
+ out = self.reflection_pad(x)
162
+ out = self.conv2d(out)
163
+ return out
164
+
165
+
166
+ class UpsampleConvLayer(torch.nn.Module):
167
+ def __init__(self, in_channels, out_channels, kernel_size, stride):
168
+ super(UpsampleConvLayer, self).__init__()
169
+ self.conv2d = nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride=stride)
170
+
171
+ def forward(self, x):
172
+ out = self.conv2d(x)
173
+ return out
174
+
175
+
176
+ class AddRelu(nn.Module):
177
+ """It is for adding two feed forwards to the output of the two following conv layers in expanding path
178
+ """
179
+
180
+ def __init__(self) -> None:
181
+ super(AddRelu, self).__init__()
182
+ self.relu = nn.PReLU()
183
+
184
+ def forward(self, input_tensor1, input_tensor2, input_tensor3):
185
+ x = input_tensor1 + input_tensor2 + input_tensor3
186
+ return self.relu(x)
187
+
188
+
189
+ class BasicBlock(nn.Module):
190
+ def __init__(self, in_channels, out_channels, mid_channels=None):
191
+ super(BasicBlock, self).__init__()
192
+ if not mid_channels:
193
+ mid_channels = out_channels
194
+ self.conv1 = ConvLayer(in_channels, mid_channels, kernel_size=3, stride=1)
195
+ self.bn1 = nn.BatchNorm2d(mid_channels, momentum=0.1)
196
+ self.relu = nn.PReLU()
197
+
198
+ self.conv2 = ConvLayer(mid_channels, out_channels, kernel_size=3, stride=1)
199
+ self.bn2 = nn.BatchNorm2d(out_channels, momentum=0.1)
200
+
201
+ self.conv3 = ConvLayer(in_channels, out_channels, kernel_size=1, stride=1)
202
+
203
+ def forward(self, x):
204
+ out = self.conv1(x)
205
+ out = self.bn1(out)
206
+ out = self.relu(out)
207
+
208
+ out = self.conv2(out)
209
+ out = self.bn2(out)
210
+
211
+ residual = self.conv3(x)
212
+
213
+ out = out + residual
214
+ out = self.relu(out)
215
+
216
+ return out
217
+
218
+
219
+ class Bottleneck(nn.Module):
220
+ def __init__(self, in_channels, out_channels):
221
+ super(Bottleneck, self).__init__()
222
+ self.conv1 = ConvLayer(in_channels, out_channels, kernel_size=3, stride=1)
223
+ self.bn1 = nn.BatchNorm2d(out_channels, momentum=0.1)
224
+
225
+ self.conv2 = ConvLayer(out_channels, out_channels, kernel_size=3, stride=1)
226
+ self.bn2 = nn.BatchNorm2d(out_channels, momentum=0.1)
227
+
228
+ self.conv3 = ConvLayer(out_channels, out_channels, kernel_size=3, stride=1)
229
+ self.bn3 = nn.BatchNorm2d(out_channels, momentum=0.1)
230
+
231
+ self.conv4 = ConvLayer(in_channels, out_channels, kernel_size=1, stride=1)
232
+
233
+ self.relu = nn.PReLU()
234
+
235
+ def forward(self, x):
236
+ out = self.conv1(x)
237
+ out = self.bn1(out)
238
+ out = self.relu(out)
239
+
240
+ out = self.conv2(out)
241
+ out = self.bn2(out)
242
+ out = self.relu(out)
243
+
244
+ out = self.conv3(out)
245
+ out = self.bn3(out)
246
+
247
+ residual = self.conv4(x)
248
+
249
+ out = out + residual
250
+ out = self.relu(out)
251
+
252
+ return out
253
+
254
+
255
+ class PPM(nn.Module):
256
+ def __init__(self, in_channels, out_channels):
257
+ super(PPM, self).__init__()
258
+
259
+ self.pool_sizes = [1, 2, 3, 6] # subregion size in each level
260
+ self.num_levels = len(self.pool_sizes) # number of pyramid levels
261
+
262
+ self.conv_layers = nn.ModuleList()
263
+ for i in range(self.num_levels):
264
+ self.conv_layers.append(nn.Sequential(
265
+ nn.AdaptiveAvgPool2d(output_size=self.pool_sizes[i]),
266
+ nn.Conv2d(in_channels, in_channels // self.num_levels, kernel_size=1),
267
+ nn.BatchNorm2d(in_channels // self.num_levels),
268
+ nn.ReLU(inplace=True)
269
+ ))
270
+ self.out_conv = nn.Conv2d(in_channels * 2, out_channels, kernel_size=1, stride=1)
271
+
272
+ def forward(self, x):
273
+ input_size = x.size()[2:] # get input size
274
+ output = [x]
275
+
276
+ # pyramid pooling
277
+ for i in range(self.num_levels):
278
+ out = self.conv_layers[i](x)
279
+ out = F.interpolate(out, size=input_size, mode='bilinear', align_corners=True)
280
+ output.append(out)
281
+
282
+ # concatenate features from different levels
283
+ output = torch.cat(output, dim=1)
284
+ output = self.out_conv(output)
285
+
286
+ return output
287
+
288
+
289
+ class MCDNet(nn.Module):
290
+ def __init__(self, in_channels=4, num_classes=4, maxpool=False, bilinear=False) -> None:
291
+ super().__init__()
292
+ level = 1
293
+ # encoder
294
+ self.conv_input = ConvLayer(in_channels, 32 * level, kernel_size=3, stride=2)
295
+
296
+ self.dense0 = BasicBlock(32 * level, 32 * level)
297
+ self.conv2x = ConvLayer(32 * level, 64 * level, kernel_size=3, stride=2)
298
+
299
+ self.dense1 = BasicBlock(64 * level, 64 * level)
300
+ self.conv4x = ConvLayer(64 * level, 128 * level, kernel_size=3, stride=2)
301
+
302
+ self.dense2 = BasicBlock(128 * level, 128 * level)
303
+ self.conv8x = ConvLayer(128 * level, 256 * level, kernel_size=3, stride=2)
304
+
305
+ self.dense3 = BasicBlock(256 * level, 256 * level)
306
+ self.conv16x = ConvLayer(256 * level, 512 * level, kernel_size=3, stride=2)
307
+
308
+ self.dense4 = PPM(512 * level, 512 * level)
309
+
310
+ # dpff
311
+ self.dpffm = DPFF([32, 64, 128, 256, 512])
312
+
313
+ # decoder
314
+ self.convd16x = UpsampleConvLayer(512 * level, 256 * level, kernel_size=3, stride=2)
315
+ self.fusion4 = FusionBlock(256 * level, 3, maxpool=maxpool, bilinear=bilinear)
316
+ self.dense_4 = Bottleneck(512 * level, 256 * level)
317
+ self.add_block4 = AddRelu()
318
+
319
+ self.convd8x = UpsampleConvLayer(256 * level, 128 * level, kernel_size=3, stride=2)
320
+ self.fusion3 = FusionBlock(128 * level, 2, maxpool=maxpool, bilinear=bilinear)
321
+ self.dense_3 = Bottleneck(256 * level, 128 * level)
322
+ self.add_block3 = AddRelu()
323
+
324
+ self.convd4x = UpsampleConvLayer(128 * level, 64 * level, kernel_size=3, stride=2)
325
+ self.fusion2 = FusionBlock(64 * level, 1, maxpool=maxpool, bilinear=bilinear)
326
+ self.dense_2 = Bottleneck(128 * level, 64 * level)
327
+ self.add_block2 = AddRelu()
328
+
329
+ self.convd2x = UpsampleConvLayer(64 * level, 32 * level, kernel_size=3, stride=2)
330
+ self.dense_1 = Bottleneck(64 * level, 32 * level)
331
+ self.add_block1 = AddRelu()
332
+
333
+ self.head = UpsampleConvLayer(32 * level, num_classes, kernel_size=3, stride=2)
334
+ self.apply(self._weights_init)
335
+
336
+ @torch.no_grad()
337
+ def get_lr_data(self, x: torch.Tensor) -> torch.Tensor:
338
+ images = x.cpu().permute(0, 2, 3, 1).numpy() # b, h, w, c
339
+ batch_size = images.shape[0]
340
+ lr = []
341
+ for i in range(batch_size):
342
+ lr_image = image_dehazer.remove_haze((images[i]*255).astype(np.uint8), showHazeTransmissionMap=False)[0] # h, w, c, numpy.array
343
+ lr_tensor = torch.from_numpy(lr_image).permute(2, 0, 1)/255. # c, h, w
344
+ lr.append(lr_tensor)
345
+ return torch.stack(lr, dim=0).to(x.device) # b, c, h, w
346
+
347
+ def _weights_init(self, m):
348
+ if isinstance(m, nn.Linear):
349
+ nn.init.xavier_normal_(m.weight)
350
+ nn.init.constant_(m.bias, 0)
351
+ elif isinstance(m, nn.Conv2d):
352
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
353
+ elif isinstance(m, nn.BatchNorm2d):
354
+ nn.init.constant_(m.weight, 1)
355
+ nn.init.constant_(m.bias, 0)
356
+
357
+ def forward(self, x1):
358
+ x2 = self.get_lr_data(x1)
359
+ # encoder1
360
+ res1x_1 = self.conv_input(x1)
361
+ res1x_1 = self.dense0(res1x_1)
362
+
363
+ res2x_1 = self.conv2x(res1x_1)
364
+ res2x_1 = self.dense1(res2x_1)
365
+
366
+ res4x_1 = self.conv4x(res2x_1)
367
+ res4x_1 = self.dense2(res4x_1)
368
+
369
+ res8x_1 = self.conv8x(res4x_1)
370
+ res8x_1 = self.dense3(res8x_1)
371
+
372
+ res16x_1 = self.conv16x(res8x_1)
373
+ res16x_1 = self.dense4(res16x_1)
374
+
375
+ # encoder2
376
+ res1x_2 = self.conv_input(x2)
377
+ res1x_2 = self.dense0(res1x_2)
378
+
379
+ res2x_2 = self.conv2x(res1x_2)
380
+ res2x_2 = self.dense1(res2x_2)
381
+
382
+ res4x_2 = self.conv4x(res2x_2)
383
+ res4x_2 = self.dense2(res4x_2)
384
+
385
+ res8x_2 = self.conv8x(res4x_2)
386
+ res8x_2 = self.dense3(res8x_2)
387
+
388
+ res16x_2 = self.conv16x(res8x_2)
389
+ res16x_2 = self.dense4(res16x_2)
390
+
391
+ # dual-perspective feature fusion
392
+ res1x, res2x, res4x, res8x, res16x = self.dpffm(
393
+ [res1x_1, res2x_1, res4x_1, res8x_1, res16x_1],
394
+ [res1x_2, res2x_2, res4x_2, res8x_2, res16x_2]
395
+ )
396
+
397
+ # decoder
398
+ res8x1 = self.convd16x(res16x)
399
+ res8x1 = F.interpolate(res8x1, res8x.size()[2:], mode='bilinear')
400
+ res8x2 = self.fusion4(res8x, [res1x, res2x, res4x])
401
+ res8x2 = torch.cat([res8x1, res8x2], dim=1)
402
+ res8x2 = self.dense_4(res8x2)
403
+ res8x2 = self.add_block4(res8x1, res8x, res8x2)
404
+
405
+ res4x1 = self.convd8x(res8x2)
406
+ res4x1 = F.interpolate(res4x1, res4x.size()[2:], mode='bilinear')
407
+ res4x2 = self.fusion3(res4x, [res1x, res2x])
408
+ res4x2 = torch.cat([res4x1, res4x2], dim=1)
409
+ res4x2 = self.dense_3(res4x2)
410
+ res4x2 = self.add_block3(res4x1, res4x, res4x2)
411
+
412
+ res2x1 = self.convd4x(res4x2)
413
+ res2x1 = F.interpolate(res2x1, res2x.size()[2:], mode='bilinear')
414
+ res2x2 = self.fusion2(res2x, [res1x])
415
+ res2x2 = torch.cat([res2x1, res2x2], dim=1)
416
+ res2x2 = self.dense_2(res2x2)
417
+ res2x2 = self.add_block2(res2x1, res2x, res2x2)
418
+
419
+ res1x1 = self.convd2x(res2x2)
420
+ res1x1 = F.interpolate(res1x1, res1x.size()[2:], mode='bilinear')
421
+ res1x2 = torch.cat([res1x1, res1x], dim=1)
422
+ res1x2 = self.dense_1(res1x2)
423
+ res1x2 = self.add_block1(res1x1, res1x, res1x2)
424
+
425
+ out = self.head(res1x2)
426
+ out = F.interpolate(out, x1.size()[2:], mode='bilinear')
427
+
428
+ return out
429
+
430
+
431
+ if __name__ == "__main__":
432
+ num_classes = 2
433
+ model = MCDNet()
434
+ # inp = torch.randn(size=(2, 3, 256, 256))
435
+ # assert model(input).shape == (2, 2, 256, 256)
cloud_adapter/scnn.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # @Time : 2024/7/21 下午5:11
3
+ # @Author : xiaoshun
4
+ # @Email : [email protected]
5
+ # @File : scnn.py
6
+ # @Software: PyCharm
7
+
8
+ # 论文地址:https://www.sciencedirect.com/science/article/abs/pii/S0924271624000352?via%3Dihub#fn1
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+
14
+
15
+ class SCNN(nn.Module):
16
+ def __init__(self, in_channels=3, num_classes=2, dropout_p=0.5):
17
+ super().__init__()
18
+ self.conv1 = nn.Conv2d(in_channels, 64, kernel_size=1)
19
+ self.conv2 = nn.Conv2d(64, num_classes, kernel_size=1)
20
+ self.conv3 = nn.Conv2d(num_classes, num_classes, kernel_size=3, padding=1)
21
+ self.dropout = nn.Dropout2d(p=dropout_p)
22
+
23
+ def forward(self, x):
24
+ x = F.relu(self.conv1(x))
25
+ x = self.dropout(x)
26
+ x = self.conv2(x)
27
+ x = self.conv3(x)
28
+ return x
29
+
30
+
31
+ if __name__ == '__main__':
32
+ model = SCNN(num_classes=7)
33
+ fake_img = torch.randn((2, 3, 224, 224))
34
+ out = model(fake_img)
35
+ print(out.shape)
36
+ # torch.Size([2, 7, 224, 224])
cloud_adapter/unetmobv2.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # @Time : 2024/8/6 下午3:44
3
+ # @Author : xiaoshun
4
+ # @Email : [email protected]
5
+ # @File : unetmobv2.py
6
+ # @Software: PyCharm
7
+ import segmentation_models_pytorch as smp
8
+ import torch
9
+ from torch import nn as nn
10
+
11
+
12
+ class UNetMobV2(nn.Module):
13
+ def __init__(self,num_classes,in_channels=3):
14
+ super().__init__()
15
+ self.backbone = smp.Unet(
16
+ encoder_name='mobilenet_v2',
17
+ encoder_weights=None,
18
+ in_channels=in_channels,
19
+ classes=num_classes,
20
+ )
21
+
22
+ def forward(self, x):
23
+ x = self.backbone(x)
24
+ return x
25
+
26
+
27
+ if __name__ == '__main__':
28
+ fake_image = torch.rand(1, 3, 224, 224)
29
+ model = UNetMobV2(num_classes=2)
30
+ output = model(fake_image)
31
+ print(output.size())
cloud_adapter/utils.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ from typing import List
3
+ from mmengine.logging import MMLogger
4
+
5
+ first_set_requires_grad = True
6
+ first_set_train = True
7
+
8
+
9
+ def set_requires_grad(model: nn.Module, keywords: List[str]):
10
+ """
11
+ notice:key in name!
12
+ """
13
+ requires_grad_names = []
14
+ num_params = 0
15
+ num_trainable = 0
16
+ for name, param in model.named_parameters():
17
+ num_params += param.numel()
18
+ if any(key in name for key in keywords):
19
+ param.requires_grad = True
20
+ requires_grad_names.append(name)
21
+ num_trainable += param.numel()
22
+ else:
23
+ param.requires_grad = False
24
+ global first_set_requires_grad
25
+ if first_set_requires_grad:
26
+ logger = MMLogger.get_current_instance()
27
+ for name in requires_grad_names:
28
+ logger.info(f"set_requires_grad----{name}")
29
+ logger.info(
30
+ f"Total trainable params--{num_trainable}, All params--{num_params}, Ratio--{num_trainable*100/num_params:.1f}%"
31
+ )
32
+ first_set_requires_grad = False
33
+
34
+
35
+ def _set_train(model: nn.Module, keywords: List[str], prefix: str = ""):
36
+ train_names = []
37
+ for name, child in model.named_children():
38
+ fullname = ".".join([prefix, name])
39
+ if any(name.startswith(key) for key in keywords):
40
+ train_names.append(fullname)
41
+ child.train()
42
+ else:
43
+ train_names += _set_train(child, keywords, prefix=fullname)
44
+ return train_names
45
+
46
+
47
+ def set_train(model: nn.Module, keywords: List[str]):
48
+ """
49
+ notice:sub name startwith key!
50
+ """
51
+ model.train(False)
52
+ train_names = _set_train(model, keywords)
53
+ global first_set_train
54
+ if first_set_train:
55
+ logger = MMLogger.get_current_instance()
56
+ for train_name in train_names:
57
+ logger.info(f"set_train----{train_name}")
58
+ first_set_train = False
example_inputs/gf1/11.png ADDED
example_inputs/gf1/48.png ADDED
example_inputs/gf1/9.png ADDED
example_inputs/gf2/160.png ADDED
example_inputs/gf2/2.png ADDED