lealaxy commited on
Commit
240df91
·
1 Parent(s): 500a435

upload model

Browse files
README.md CHANGED
@@ -1,3 +1,52 @@
1
  ---
2
  license: mit
 
 
 
 
 
3
  ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
  license: mit
3
+ tags:
4
+ - vision
5
+ - image-segmentation
6
+ datasets:
7
+ - LEVIR-CD
8
  ---
9
+ # AdaptFormer model fine-tuned on LEVIR-CD
10
+
11
+ AdaptFormer model fine-tuned on LEVIR-CD at resolution 512x512. It was introduced in the paper [AdaptFormer: An Adaptive Hierarchical Semantic Approach for Change Detection on Remote Sensing Images](https://ieeexplore.ieee.org/document/10497147) by Pang et al. and first released in [this repository](https://github.com/aigzhusmart/AdaptFormer).
12
+
13
+ ## Model description
14
+
15
+ AdaptFormer, uniquely designed to adaptively interpret hierarchical semantics. Instead of a one-size-fits-all approach, it strategizes differently across three semantic depths: employing straightforward operations for shallow semantics, assimilating spatial data for medium semantics to emphasize detailed interregional changes, and integrating cascaded depthwise attention for in-depth semantics, focusing on high-level representations
16
+
17
+ Here is how to use this model to classify an image:
18
+
19
+ ```python
20
+ from transformers import AutoImageProcessor, AutoModel
21
+ from PIL import Image
22
+ import requests
23
+
24
+ image_processor = AutoImageProcessor.from_pretrained("deepang/adaptformer-LEVIR-CD")
25
+ model = AutoModel.from_pretrained("deepang/adaptformer-LEVIR-CD")
26
+
27
+ image_A = Image.open(requests.get('https://raw.githubusercontent.com/aigzhusmart/AdaptFormer/main/figures/test_2_1_A.png', stream=True).raw)
28
+ image_B = Image.open(requests.get('https://raw.githubusercontent.com/aigzhusmart/AdaptFormer/main/figures/test_2_1_B.png', stream=True).raw)
29
+ label = Image.open(requests.get('https://raw.githubusercontent.com/aigzhusmart/AdaptFormer/main/figures/test_2_1_label.png', stream=True).raw)
30
+
31
+
32
+ inputs = preprocessor(images=(image_A, image_B), return_tensors="pt")
33
+ outputs = adaptfromer_model(**inputs)
34
+ logits = outputs.logits # shape (batch_size, num_labels, height, width)
35
+ pred = logits.argmax(dim=1)[0]
36
+ ```
37
+
38
+ ### License
39
+
40
+ The license for this model can be found [here](https://github.com/aigzhusmart/AdaptFormer).
41
+
42
+ ### BibTeX entry and citation info
43
+
44
+ ```bibtex
45
+ @article{huang2024adaptformer,
46
+ title={AdaptFormer: An Adaptive Hierarchical Semantic Approach for Change Detection on Remote Sensing Images},
47
+ author={Huang, Teng and Hong, Yile and Pang, Yan and Liang, Jiaming and Hong, Jie and Huang, Lin and Zhang, Yuan and Jia, Yan and Savi, Patrizia},
48
+ journal={IEEE Transactions on Instrumentation and Measurement},
49
+ year={2024},
50
+ publisher={IEEE}
51
+ }
52
+ ```
config.json ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "AdaptFormerForChangeDetection"
4
+ ],
5
+ "auto_map": {
6
+ "AutoConfig": "configuration_adaptformer.AdaptFormerConfig",
7
+ "AutoModel": "modeling_adaptformer.AdaptFormerForChangeDetection",
8
+ "AutoImageProcessor": "preprocessing_adaptformer.AdaptFormerImageProcessor"
9
+ },
10
+ "depths": [
11
+ 3,
12
+ 3,
13
+ 3
14
+ ],
15
+ "embed_dims": [
16
+ 64,
17
+ 128,
18
+ 256
19
+ ],
20
+ "initializer_range": 0.02,
21
+ "mlp_ratios": [
22
+ 4,
23
+ 4,
24
+ 4
25
+ ],
26
+ "model_type": "adaptformer",
27
+ "num_channels": 3,
28
+ "num_classes": 2,
29
+ "num_heads": [
30
+ 1,
31
+ 2,
32
+ 4
33
+ ],
34
+ "semantic_loss_ignore_index": 255,
35
+ "semantic_loss_weight": [
36
+ 0,
37
+ 0,
38
+ 0.5,
39
+ 1
40
+ ],
41
+ "torch_dtype": "float32",
42
+ "transformers_version": "4.39.3"
43
+ }
configuration_adaptformer.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ AdaptFormer model configuration"""
2
+
3
+ from transformers import PretrainedConfig
4
+
5
+
6
+ class AdaptFormerConfig(PretrainedConfig):
7
+ r"""
8
+ This is the configuration class to store the configuration of a [`AdaptFormerForChangeDetection`].
9
+ It is used to instantiate an AdaptFormer model according to the specified arguments,
10
+ defining the model architecture. Instantiating a configuration with the defaults will yield a similar
11
+ configuration to that of the AdaptFormer
12
+ [deepang/adaptformer-LEVIR-CD](https://huggingface.co/deepang/adaptformer-LEVIR-CD)
13
+ architecture.
14
+
15
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
16
+ documentation from [`PretrainedConfig`] for more information.
17
+
18
+ Args:
19
+ num_channels (`int`, *optional*, defaults to 3):
20
+ The number of input channels.
21
+ num_classes (`int`, *optional*, defaults to 2):
22
+ The number of classes.
23
+ embed_dims (`List[int]`, *optional*, defaults to `[64, 128, 256]`):
24
+ Dimension of each of the encoder blocks.
25
+ num_heads (`List[int]`, *optional*, defaults to `[1, 2, 4]`):
26
+ Number of attention heads for each attention layer in each block of the encoder.
27
+ mlp_ratios (`List[int]`, *optional*, defaults to `[4, 4, 4]`):
28
+ Ratio of the size of the hidden layer compared to the size of the input layer of the Mix FFNs in the
29
+ encoder blocks.
30
+ depths (`List[int]`, *optional*, defaults to `[3, 3, 3]`):
31
+ The number of layers in each encoder block.
32
+ semantic_loss_ignore_index (`int`, *optional*, defaults to 255):
33
+ The index that is ignored by the loss function of the semantic segmentation model.
34
+ semantic_loss_weight (`List[float]`, *optional*, defaults to `[0, 0, 0.8, 1]`):
35
+ The weight of the semantic segmentation loss.
36
+ initializer_range (`float`, *optional*, defaults to 0.02):
37
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
38
+
39
+ Example:
40
+
41
+ ```python
42
+ >>> from transformers import AutoModel, AutoConfig
43
+
44
+ >>> # Initializing a AdaptFormer
45
+ >>> configuration = AutoConfig.from_pretrained("deepang/adaptformer-LEVIR-CD")
46
+
47
+ >>> # Initializing a model from the deepang/adaptformer-LEVIR-CD style configuration
48
+ >>> model = AutoModel(configuration)
49
+
50
+ >>> # Accessing the model configuration
51
+ >>> configuration = model.config
52
+ ```"""
53
+
54
+ model_type = "adaptformer"
55
+
56
+ def __init__(
57
+ self,
58
+ num_channels=3,
59
+ num_classes=2,
60
+ embed_dims=[64, 128, 256],
61
+ num_heads=[1, 2, 4],
62
+ mlp_ratios=[4, 4, 4],
63
+ depths=[3, 3, 3],
64
+ semantic_loss_ignore_index=255,
65
+ semantic_loss_weight=[0, 0, 0.5, 1],
66
+ initializer_range=0.02,
67
+ **kwargs,
68
+ ):
69
+ self.num_channels = num_channels
70
+ self.embed_dims = embed_dims
71
+ self.num_heads = num_heads
72
+ self.num_heads = num_heads
73
+ self.mlp_ratios = mlp_ratios
74
+ self.depths = depths
75
+ self.num_classes = num_classes
76
+ self.semantic_loss_ignore_index = semantic_loss_ignore_index
77
+ self.semantic_loss_weight = semantic_loss_weight
78
+ self.initializer_range = initializer_range
79
+
80
+ super().__init__(**kwargs)
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:32a543900a391fcb9b974956cfc96811261f7c6ad4b6393e6907910d99b42e04
3
+ size 50178960
modeling_adaptformer.py ADDED
@@ -0,0 +1,647 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ PyTorch AdaptFormer model."""
2
+
3
+ import itertools
4
+ from typing import Optional
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ from einops.layers.torch import Rearrange
10
+ from transformers import PreTrainedModel
11
+ from transformers.modeling_outputs import SemanticSegmenterOutput
12
+
13
+ from .configuration_adaptformer import AdaptFormerConfig
14
+
15
+
16
+ class SpatialExchange(nn.Module):
17
+
18
+ def __init__(self, p=1 / 2):
19
+ super().__init__()
20
+ assert p >= 0 and p <= 1
21
+ self.p = int(1 / p)
22
+
23
+ def forward(self, x1: torch.Tensor, x2: torch.Tensor):
24
+ _, _, _, w = x1.shape
25
+ exchange_mask = torch.arange(w) % self.p == 0
26
+
27
+ out_x1 = torch.zeros_like(x1, device=x1.device)
28
+ out_x2 = torch.zeros_like(x2, device=x1.device)
29
+ out_x1[..., ~exchange_mask] = x1[..., ~exchange_mask]
30
+ out_x2[..., ~exchange_mask] = x2[..., ~exchange_mask]
31
+ out_x1[..., exchange_mask] = x2[..., exchange_mask]
32
+ out_x2[..., exchange_mask] = x1[..., exchange_mask]
33
+
34
+ return out_x1, out_x2
35
+
36
+
37
+ class ChannelExchange(nn.Module):
38
+
39
+ def __init__(self, p=1 / 2):
40
+ super().__init__()
41
+ assert p >= 0 and p <= 1
42
+ self.p = int(1 / p)
43
+
44
+ def forward(self, x1: torch.Tensor, x2: torch.Tensor):
45
+ N, c, _, _ = x1.shape
46
+
47
+ exchange_map = torch.arange(c) % self.p == 0
48
+ exchange_mask = exchange_map.unsqueeze(0).expand((N, -1))
49
+
50
+ out_x1 = torch.zeros_like(x1, device=x1.device)
51
+ out_x2 = torch.zeros_like(x2, device=x1.device)
52
+ out_x1[~exchange_mask, ...] = x1[~exchange_mask, ...]
53
+ out_x2[~exchange_mask, ...] = x2[~exchange_mask, ...]
54
+ out_x1[exchange_mask, ...] = x2[exchange_mask, ...]
55
+ out_x2[exchange_mask, ...] = x1[exchange_mask, ...]
56
+
57
+ return out_x1, out_x2
58
+
59
+
60
+ class CascadedGroupAttention(nn.Module):
61
+ r"""Cascaded Group Attention.
62
+
63
+ Args:
64
+ dim (int): Number of input channels.
65
+ key_dim (int): The dimension for query and key.
66
+ num_heads (int): Number of attention heads.
67
+ attn_ratio (int): Multiplier for the query dim for value dimension.
68
+ resolution (int): Input resolution, correspond to the window size.
69
+ kernels (List[int]): The kernel size of the dw conv on query.
70
+ """
71
+
72
+ def __init__(
73
+ self,
74
+ dim,
75
+ key_dim,
76
+ num_heads=8,
77
+ attn_ratio=4,
78
+ resolution=14,
79
+ kernels=[5, 5, 5, 5],
80
+ ):
81
+ super().__init__()
82
+ self.num_heads = num_heads
83
+ self.scale = key_dim**-0.5
84
+ self.key_dim = key_dim
85
+ self.d = int(attn_ratio * key_dim)
86
+ self.attn_ratio = attn_ratio
87
+
88
+ qkvs = []
89
+ dws = []
90
+ for i in range(num_heads):
91
+ qkvs.append(
92
+ nn.Sequential(
93
+ nn.Conv2d(
94
+ dim // (num_heads),
95
+ self.key_dim * 2 + self.d,
96
+ 1,
97
+ 1,
98
+ 0,
99
+ bias=False,
100
+ ),
101
+ nn.BatchNorm2d(self.key_dim * 2 + self.d),
102
+ )
103
+ )
104
+ dws.append(
105
+ nn.Sequential(
106
+ nn.Conv2d(
107
+ self.key_dim,
108
+ self.key_dim,
109
+ kernels[i],
110
+ 1,
111
+ kernels[i] // 2,
112
+ groups=self.key_dim,
113
+ bias=False,
114
+ ),
115
+ nn.BatchNorm2d(self.key_dim),
116
+ )
117
+ )
118
+
119
+ self.qkvs = nn.ModuleList(qkvs)
120
+ self.dws = nn.ModuleList(dws)
121
+ self.proj = nn.Sequential(
122
+ nn.ReLU(),
123
+ nn.Conv2d(self.d * num_heads, dim, 1, 1, 0, bias=False),
124
+ nn.BatchNorm2d(dim),
125
+ )
126
+ self.act_gelu = nn.GELU()
127
+ points = list(itertools.product(range(resolution), range(resolution)))
128
+ N = len(points)
129
+ attention_offsets = {}
130
+ idxs = []
131
+ for p1 in points:
132
+ for p2 in points:
133
+ offset = (abs(p1[0] - p2[0]), abs(p1[1] - p2[1]))
134
+ if offset not in attention_offsets:
135
+ attention_offsets[offset] = len(attention_offsets)
136
+ idxs.append(attention_offsets[offset])
137
+ self.attention_biases = nn.Parameter(
138
+ torch.zeros(num_heads, len(attention_offsets))
139
+ )
140
+ self.register_buffer("attention_bias_idxs", torch.LongTensor(idxs).view(N, N))
141
+
142
+ @torch.no_grad()
143
+ def train(self, mode=True):
144
+ super().train(mode)
145
+ if mode and hasattr(self, "ab"):
146
+ del self.ab
147
+ else:
148
+ self.ab = self.attention_biases[:, self.attention_bias_idxs]
149
+
150
+ def forward(self, x):
151
+ B, _, H, W = x.shape
152
+ trainingab = self.attention_biases[:, self.attention_bias_idxs]
153
+ feats_in = x.chunk(len(self.qkvs), dim=1)
154
+ feats_out = []
155
+ feat = feats_in[0]
156
+ for i, qkv in enumerate(self.qkvs):
157
+ if i > 0:
158
+ feat = feat + feats_in[i]
159
+ feat = qkv(feat)
160
+ q, k, v = feat.view(B, -1, H, W).split(
161
+ [self.key_dim, self.key_dim, self.d], dim=1
162
+ )
163
+ q = self.act_gelu(self.dws[i](q)) + q
164
+ q, k, v = q.flatten(2), k.flatten(2), v.flatten(2)
165
+ attn = (q.transpose(-2, -1) @ k) * self.scale + (
166
+ trainingab[i] if self.training else self.ab[i].to(x.device)
167
+ )
168
+ attn = attn.softmax(dim=-1)
169
+ feat = (v @ attn.transpose(-2, -1)).view(B, self.d, H, W)
170
+ feats_out.append(feat)
171
+ x = self.proj(torch.cat(feats_out, 1))
172
+ return x
173
+
174
+
175
+ class LocalWindowAttention(nn.Module):
176
+ r"""Local Window Attention.
177
+
178
+ Args:
179
+ dim (int): Number of input channels.
180
+ key_dim (int): The dimension for query and key.
181
+ num_heads (int): Number of attention heads.
182
+ attn_ratio (int): Multiplier for the query dim for value dimension.
183
+ resolution (int): Input resolution.
184
+ window_resolution (int): Local window resolution.
185
+ kernels (List[int]): The kernel size of the dw conv on query.
186
+ """
187
+
188
+ def __init__(
189
+ self,
190
+ dim,
191
+ key_dim,
192
+ num_heads=8,
193
+ attn_ratio=4,
194
+ resolution=14,
195
+ window_resolution=7,
196
+ kernels=[5, 5, 5, 5],
197
+ ):
198
+ super().__init__()
199
+ self.dim = dim
200
+ self.num_heads = num_heads
201
+ self.resolution = resolution
202
+ assert window_resolution > 0, "window_size must be greater than 0"
203
+ self.window_resolution = window_resolution
204
+
205
+ window_resolution = min(window_resolution, resolution)
206
+ self.attn = CascadedGroupAttention(
207
+ dim,
208
+ key_dim,
209
+ num_heads,
210
+ attn_ratio=attn_ratio,
211
+ resolution=window_resolution,
212
+ kernels=kernels,
213
+ )
214
+
215
+ def forward(self, x):
216
+ H = W = self.resolution
217
+ B, C, H_, W_ = x.shape
218
+ # Only check this for classifcation models
219
+ assert (
220
+ H == H_ and W == W_
221
+ ), "input feature has wrong size, expect {}, got {}".format((H, W), (H_, W_))
222
+
223
+ if H <= self.window_resolution and W <= self.window_resolution:
224
+ x = self.attn(x)
225
+ else:
226
+ x = x.permute(0, 2, 3, 1)
227
+ pad_b = (
228
+ self.window_resolution - H % self.window_resolution
229
+ ) % self.window_resolution
230
+ pad_r = (
231
+ self.window_resolution - W % self.window_resolution
232
+ ) % self.window_resolution
233
+ padding = pad_b > 0 or pad_r > 0
234
+
235
+ if padding:
236
+ x = F.pad(x, (0, 0, 0, pad_r, 0, pad_b))
237
+
238
+ pH, pW = H + pad_b, W + pad_r
239
+ nH = pH // self.window_resolution
240
+ nW = pW // self.window_resolution
241
+ x = (
242
+ x.view(B, nH, self.window_resolution, nW, self.window_resolution, C)
243
+ .transpose(2, 3)
244
+ .reshape(B * nH * nW, self.window_resolution, self.window_resolution, C)
245
+ .permute(0, 3, 1, 2)
246
+ )
247
+ x = self.attn(x)
248
+ x = (
249
+ x.permute(0, 2, 3, 1)
250
+ .view(B, nH, nW, self.window_resolution, self.window_resolution, C)
251
+ .transpose(2, 3)
252
+ .reshape(B, pH, pW, C)
253
+ )
254
+ if padding:
255
+ x = x[:, :H, :W].contiguous()
256
+ x = x.permute(0, 3, 1, 2)
257
+ return x
258
+
259
+
260
+ class LocalAgg(nn.Module):
261
+
262
+ def __init__(self, channels):
263
+ super(LocalAgg, self).__init__()
264
+ self.bn = nn.BatchNorm2d(channels)
265
+ self.pointwise_conv_0 = nn.Conv2d(channels, channels, kernel_size=1, bias=False)
266
+ self.depthwise_conv = nn.Conv2d(
267
+ channels, channels, padding=1, kernel_size=3, groups=channels, bias=False
268
+ )
269
+ self.pointwise_prenorm_1 = nn.BatchNorm2d(channels)
270
+ self.pointwise_conv_1 = nn.Conv2d(channels, channels, kernel_size=1, bias=False)
271
+
272
+ def forward(self, x):
273
+ x = self.bn(x)
274
+ x = self.pointwise_conv_0(x)
275
+ x = self.depthwise_conv(x)
276
+ x = self.pointwise_prenorm_1(x)
277
+ x = self.pointwise_conv_1(x)
278
+ return x
279
+
280
+
281
+ class Mlp(nn.Module):
282
+
283
+ def __init__(self, channels, mlp_ratio):
284
+ super(Mlp, self).__init__()
285
+ self.up_proj = nn.Conv2d(
286
+ channels, channels * mlp_ratio, kernel_size=1, bias=False
287
+ )
288
+ self.down_proj = nn.Conv2d(
289
+ channels * mlp_ratio, channels, kernel_size=1, bias=False
290
+ )
291
+
292
+ def forward(self, x):
293
+ return self.down_proj(F.gelu(self.up_proj(x)))
294
+
295
+
296
+ class LocalMerge(nn.Module):
297
+ def __init__(self, channels, r, heads, resolution, partial=False):
298
+ super(LocalMerge, self).__init__()
299
+ self.partial = partial
300
+ self.cpe1 = nn.Conv2d(
301
+ channels, channels, kernel_size=3, padding=1, groups=channels, bias=False
302
+ )
303
+ self.local_agg = LocalAgg(channels)
304
+ self.mlp1 = Mlp(channels, r)
305
+ if partial:
306
+ self.cpe2 = nn.Conv2d(
307
+ channels,
308
+ channels,
309
+ kernel_size=3,
310
+ padding=1,
311
+ groups=channels,
312
+ bias=False,
313
+ )
314
+ self.attn = LocalWindowAttention(
315
+ channels,
316
+ 16,
317
+ heads,
318
+ attn_ratio=r,
319
+ resolution=resolution,
320
+ window_resolution=7,
321
+ kernels=[5, 5, 5, 5],
322
+ )
323
+ self.mlp2 = Mlp(channels, r)
324
+
325
+ def forward(self, x):
326
+ x = self.cpe1(x) + x
327
+ x = self.local_agg(x) + x
328
+ x = self.mlp1(x) + x
329
+ if self.partial:
330
+ x = self.cpe2(x) + x
331
+ x = self.attn(x) + x
332
+ x = self.mlp2(x) + x
333
+ return x
334
+
335
+
336
+ class AdaptFormerEncoderBlock(nn.Module):
337
+ def __init__(
338
+ self, in_chans, embed_dim, num_head, mlp_ratio, depth, resolution, partial
339
+ ):
340
+ super().__init__()
341
+
342
+ self.down = nn.Sequential(
343
+ nn.Conv2d(in_chans, embed_dim, kernel_size=2, stride=2),
344
+ nn.GroupNorm(num_groups=1, num_channels=embed_dim),
345
+ )
346
+
347
+ self.block = nn.Sequential(
348
+ *[
349
+ LocalMerge(
350
+ channels=embed_dim,
351
+ r=mlp_ratio,
352
+ heads=num_head,
353
+ resolution=resolution,
354
+ partial=partial,
355
+ )
356
+ for _ in range(depth)
357
+ ]
358
+ )
359
+
360
+ def forward(self, x: torch.Tensor):
361
+ return self.block(self.down(x))
362
+
363
+
364
+ class ChangeDetectionHaed(nn.Module):
365
+ def __init__(self, embedding_dim, in_channels, num_classes):
366
+ super(ChangeDetectionHaed, self).__init__()
367
+ self.in_proj = nn.Sequential(
368
+ nn.Conv2d(
369
+ in_channels=embedding_dim * len(in_channels),
370
+ out_channels=embedding_dim,
371
+ kernel_size=1,
372
+ ),
373
+ nn.BatchNorm2d(embedding_dim),
374
+ nn.ConvTranspose2d(embedding_dim, embedding_dim, 4, stride=2, padding=1),
375
+ )
376
+
377
+ self.conv1 = nn.Conv2d(embedding_dim, embedding_dim, 3, 1, 1)
378
+ self.conv2 = nn.Conv2d(embedding_dim, embedding_dim, 3, 1, 1)
379
+
380
+ self.out = nn.Conv2d(embedding_dim, num_classes, 3, 1, 1)
381
+
382
+ def forward(self, x: torch.Tensor):
383
+ x = self.in_proj(x)
384
+ x = self.conv2(F.relu(self.conv1(x))) * 0.1 + x
385
+ return self.out(x)
386
+
387
+
388
+ class AdaptFormerDecoder(nn.Module):
389
+
390
+ def __init__(
391
+ self,
392
+ config: AdaptFormerConfig,
393
+ ):
394
+ super(AdaptFormerDecoder, self).__init__()
395
+
396
+ self.in_channels = config.embed_dims
397
+ self.embedding_dim = config.embed_dims[-1]
398
+
399
+ self.linear_emb_layers = nn.ModuleList(
400
+ [
401
+ nn.Sequential(
402
+ Rearrange("n c ... -> n (...) c"),
403
+ nn.Linear(in_dim, self.embedding_dim),
404
+ )
405
+ for in_dim in self.in_channels
406
+ ]
407
+ )
408
+
409
+ self.diff_layers = nn.ModuleList(
410
+ [
411
+ nn.Sequential(
412
+ nn.Conv2d(2 * self.embedding_dim, self.embedding_dim, 3, 1, 1),
413
+ nn.ReLU(),
414
+ nn.BatchNorm2d(self.embedding_dim),
415
+ nn.Conv2d(self.embedding_dim, self.embedding_dim, 3, 1, 1),
416
+ nn.ReLU(),
417
+ )
418
+ for _ in range(3)
419
+ ]
420
+ )
421
+
422
+ self.prediction_layers = nn.ModuleList(
423
+ [
424
+ nn.Sequential(
425
+ nn.Conv2d(self.embedding_dim, config.num_classes, 3, 1, 1),
426
+ nn.ReLU(),
427
+ nn.BatchNorm2d(config.num_classes),
428
+ nn.Conv2d(config.num_classes, config.num_classes, 3, 1, 1),
429
+ )
430
+ for _ in range(3)
431
+ ]
432
+ )
433
+
434
+ self.head = ChangeDetectionHaed(
435
+ self.embedding_dim, self.in_channels, config.num_classes
436
+ )
437
+
438
+ def forward(self, pixel_valuesA, pixel_valuesB):
439
+ N, _, H, W = pixel_valuesA[0].shape
440
+
441
+ # c3
442
+ pixel_values_c3 = torch.cat([pixel_valuesA[2], pixel_valuesB[2]], dim=0)
443
+
444
+ _c3_1, _c3_2 = torch.chunk(
445
+ self.linear_emb_layers[2](pixel_values_c3).permute(0, 2, 1), 2
446
+ )
447
+ _c3_1 = _c3_1.reshape(N, -1, pixel_values_c3.shape[2], pixel_values_c3.shape[3])
448
+ _c3_2 = _c3_2.reshape(N, -1, pixel_values_c3.shape[2], pixel_values_c3.shape[3])
449
+
450
+ _c3 = self.diff_layers[2](torch.cat((_c3_1, _c3_2), dim=1))
451
+
452
+ p_c3 = self.prediction_layers[2](_c3)
453
+ _c3_up = F.interpolate(_c3, (H, W), mode="bilinear", align_corners=False)
454
+
455
+ # c2
456
+ pixel_values_c2 = torch.cat([pixel_valuesA[1], pixel_valuesB[1]], dim=0)
457
+ _c2_1, _c2_2 = torch.chunk(
458
+ self.linear_emb_layers[1](pixel_values_c2).permute(0, 2, 1), 2
459
+ )
460
+ _c2_1 = _c2_1.reshape(N, -1, pixel_values_c2.shape[2], pixel_values_c2.shape[3])
461
+ _c2_2 = _c2_2.reshape(N, -1, pixel_values_c2.shape[2], pixel_values_c2.shape[3])
462
+ _c2 = self.diff_layers[1](torch.cat((_c2_1, _c2_2), dim=1)) + F.interpolate(
463
+ _c3, scale_factor=2, mode="bilinear"
464
+ )
465
+ p_c2 = self.prediction_layers[1](_c2)
466
+ _c2_up = F.interpolate(_c2, (H, W), mode="bilinear", align_corners=False)
467
+
468
+ # c1
469
+ pixel_values_c1 = torch.cat([pixel_valuesA[0], pixel_valuesB[0]], dim=0)
470
+ _c1_1, _c1_2 = torch.chunk(
471
+ self.linear_emb_layers[0](pixel_values_c1).permute(0, 2, 1), 2
472
+ )
473
+ _c1_1 = _c1_1.reshape(N, -1, pixel_values_c1.shape[2], pixel_values_c1.shape[3])
474
+ _c1_2 = _c1_2.reshape(N, -1, pixel_values_c1.shape[2], pixel_values_c1.shape[3])
475
+ _c1 = self.diff_layers[0](torch.cat((_c1_1, _c1_2), dim=1)) + F.interpolate(
476
+ _c2, scale_factor=2, mode="bilinear"
477
+ )
478
+ p_c1 = self.prediction_layers[0](_c1)
479
+
480
+ cp = self.head(torch.cat((_c3_up, _c2_up, _c1), dim=1))
481
+
482
+ return [p_c3, p_c2, p_c1, cp]
483
+
484
+
485
+ class AdaptFormerPreTrainedModel(PreTrainedModel):
486
+ """
487
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
488
+ models.
489
+ """
490
+ config_class = AdaptFormerConfig
491
+ base_model_prefix = "adaptformer"
492
+
493
+ def _init_weights(self, m):
494
+ """Initialize the weights"""
495
+ if isinstance(m, nn.Linear):
496
+ nn.init.trunc_normal_(m.weight, std=0.02)
497
+ if isinstance(m, nn.Linear) and m.bias is not None:
498
+ nn.init.constant_(m.bias, 0)
499
+ elif isinstance(m, nn.LayerNorm):
500
+ nn.init.constant_(m.bias, 0)
501
+ nn.init.constant_(m.weight, 1.0)
502
+ elif isinstance(m, nn.Conv2d):
503
+ import math
504
+
505
+ fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
506
+ fan_out //= m.groups
507
+ m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
508
+ if m.bias is not None:
509
+ m.bias.data.zero_()
510
+
511
+
512
+ class AdaptFormerForChangeDetection(AdaptFormerPreTrainedModel):
513
+ """
514
+ This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use
515
+ it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
516
+ behavior.
517
+
518
+ Parameters:
519
+ config ([`AdaptFormerConfig`]): Model configuration class with all the parameters of the model.
520
+ Initializing with a config file does not load the weights associated with the model, only the
521
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
522
+ """
523
+
524
+ def __init__(
525
+ self,
526
+ config: AdaptFormerConfig,
527
+ ):
528
+ super().__init__(config)
529
+ self.config = config
530
+ self.block1 = AdaptFormerEncoderBlock(
531
+ in_chans=config.num_channels,
532
+ embed_dim=config.embed_dims[0],
533
+ num_head=config.num_heads[0],
534
+ mlp_ratio=config.mlp_ratios[0],
535
+ depth=config.depths[0],
536
+ resolution=config.embed_dims[2] // 2,
537
+ partial=False,
538
+ )
539
+ self.block2 = AdaptFormerEncoderBlock(
540
+ in_chans=config.embed_dims[0],
541
+ embed_dim=config.embed_dims[1],
542
+ num_head=config.num_heads[1],
543
+ mlp_ratio=config.mlp_ratios[1],
544
+ depth=config.depths[1],
545
+ resolution=config.embed_dims[1] // 2,
546
+ partial=False,
547
+ )
548
+ self.block3 = AdaptFormerEncoderBlock(
549
+ in_chans=config.embed_dims[1],
550
+ embed_dim=config.embed_dims[2],
551
+ num_head=config.num_heads[2],
552
+ mlp_ratio=config.mlp_ratios[2],
553
+ depth=config.depths[2],
554
+ resolution=config.embed_dims[0] // 2,
555
+ partial=True,
556
+ )
557
+ self.spatialex = SpatialExchange()
558
+ self.channelex = ChannelExchange()
559
+
560
+ self.decoder = AdaptFormerDecoder(config=config)
561
+
562
+ # Initialize weights and apply final processing
563
+ self.post_init()
564
+
565
+ def forward(
566
+ self,
567
+ pixel_valuesA: torch.Tensor,
568
+ pixel_valuesB: torch.Tensor,
569
+ labels: Optional[torch.Tensor] = None,
570
+ output_hidden_states: Optional[bool] = None,
571
+ return_dict: Optional[bool] = None,
572
+ ):
573
+ r"""
574
+ labels (`torch.LongTensor` of shape `(batch_size, height, width)`, *optional*):
575
+ Ground truth semantic segmentation maps for computing the loss. Indices should be in `[0, ...,
576
+ config.num_labels - 1]`. If `config.num_labels > 1`, a classification loss is computed (Cross-Entropy).
577
+
578
+ Returns:
579
+
580
+ Examples:
581
+
582
+ ```python
583
+ >>> from transformers import AutoImageProcessor, AutoModel
584
+ >>> from PIL import Image
585
+ >>> import requests
586
+
587
+ >>> image_processor = AutoImageProcessor.from_pretrained("deepang/adaptformer-LEVIR-CD")
588
+ >>> model = AutoModel.from_pretrained("deepang/adaptformer-LEVIR-CD")
589
+
590
+ >>> image_A = Image.open(requests.get('https://raw.githubusercontent.com/aigzhusmart/AdaptFormer/main/figures/test_2_1_A.png', stream=True).raw)
591
+ >>> image_B = Image.open(requests.get('https://raw.githubusercontent.com/aigzhusmart/AdaptFormer/main/figures/test_2_1_B.png', stream=True).raw)
592
+ >>> label = Image.open(requests.get('https://raw.githubusercontent.com/aigzhusmart/AdaptFormer/main/figures/test_2_1_label.png', stream=True).raw)
593
+
594
+ >>> with torch.no_grad():
595
+ >>> inputs = preprocessor(images=(image_A, image_B), return_tensors="pt")
596
+ >>> outputs = adaptfromer_model(**inputs)
597
+ >>> logits = outputs.logits.cpu()
598
+ >>> pred = logits.argmax(dim=1)[0]
599
+ ```"""
600
+ return_dict = (
601
+ return_dict if return_dict is not None else self.config.use_return_dict
602
+ )
603
+ x1_1, x2_1 = torch.chunk(
604
+ self.block1(torch.cat((pixel_valuesA, pixel_valuesB), dim=0)), 2
605
+ )
606
+
607
+ x1_2, x2_2 = torch.chunk(
608
+ self.block2(torch.cat(self.spatialex(x1_1, x2_1), dim=0)), 2
609
+ )
610
+
611
+ x1_3, x2_3 = torch.chunk(
612
+ self.block3(torch.cat(self.channelex(x1_2, x2_2), dim=0)), 2
613
+ )
614
+
615
+ hidden_states = self.decoder([x1_1, x1_2, x1_3], [x2_1, x2_2, x2_3])
616
+
617
+ loss = None
618
+ if labels is not None:
619
+ loss = 0
620
+ for i, hidden_state in enumerate(hidden_states):
621
+ upsampled_logits = F.interpolate(
622
+ hidden_state,
623
+ size=labels.shape[-2:],
624
+ mode="bilinear",
625
+ align_corners=False,
626
+ )
627
+ loss += (
628
+ F.cross_entropy(
629
+ upsampled_logits,
630
+ labels.long(),
631
+ ignore_index=self.config.semantic_loss_ignore_index,
632
+ )
633
+ * self.config.semantic_loss_weight[i]
634
+ )
635
+
636
+ if not return_dict:
637
+ if output_hidden_states:
638
+ output = (hidden_states[-1], hidden_states)
639
+ else:
640
+ output = (hidden_states[-1],)
641
+ return ((loss,) + output) if loss is not None else output
642
+
643
+ return SemanticSegmenterOutput(
644
+ loss=loss,
645
+ logits=hidden_states[-1],
646
+ hidden_states=hidden_states if output_hidden_states else None,
647
+ )
preprocessing_adaptformer.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Tuple
2
+
3
+ from transformers import ViTImageProcessor
4
+ from transformers.image_processing_utils import BatchFeature
5
+ from transformers.image_utils import ImageInput
6
+
7
+
8
+ class AdaptFormerImageProcessor(ViTImageProcessor):
9
+ r"""
10
+ Constructs a AdaptFormer image processor.
11
+
12
+ Args:
13
+ do_resize (`bool`, *optional*, defaults to `True`):
14
+ Whether to resize the image's (height, width) dimensions to the specified `(size["height"],
15
+ size["width"])`. Can be overridden by the `do_resize` parameter in the `preprocess` method.
16
+ size (`dict`, *optional*, defaults to `{"height": 224, "width": 224}`):
17
+ Size of the output image after resizing. Can be overridden by the `size` parameter in the `preprocess`
18
+ method.
19
+ resample (`PILImageResampling`, *optional*, defaults to `Resampling.BILINEAR`):
20
+ Resampling filter to use if resizing the image. Can be overridden by the `resample` parameter in the
21
+ `preprocess` method.
22
+ do_rescale (`bool`, *optional*, defaults to `True`):
23
+ Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by the `do_rescale`
24
+ parameter in the `preprocess` method.
25
+ rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
26
+ Scale factor to use if rescaling the image. Can be overridden by the `rescale_factor` parameter in the
27
+ `preprocess` method.
28
+ do_normalize (`bool`, *optional*, defaults to `True`):
29
+ Whether to normalize the image. Can be overridden by the `do_normalize` parameter in the `preprocess`
30
+ method.
31
+ image_mean (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_MEAN`):
32
+ Mean to use if normalizing the image. This is a float or list of floats the length of the number of
33
+ channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method.
34
+ image_std (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_STD`):
35
+ Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
36
+ number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
37
+ """
38
+
39
+ def __init__(self, *args, **kwargs):
40
+ super().__init__(*args, **kwargs)
41
+
42
+ def preprocess(
43
+ self,
44
+ images: Tuple[ImageInput, ImageInput],
45
+ **kwargs,
46
+ ) -> BatchFeature:
47
+ """
48
+ Preprocess an image or batch of images.
49
+
50
+ Args:
51
+ images (`Tuple[ImageInput, ImageInput]`):
52
+ Image Tuple to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
53
+ passing in images with pixel values between 0 and 1, set `do_rescale=False`.
54
+ do_resize (`bool`, *optional*, defaults to `self.do_resize`):
55
+ Whether to resize the image.
56
+ size (`Dict[str, int]`, *optional*, defaults to `self.size`):
57
+ Dictionary in the format `{"height": h, "width": w}` specifying the size of the output image after
58
+ resizing.
59
+ resample (`PILImageResampling` filter, *optional*, defaults to `self.resample`):
60
+ `PILImageResampling` filter to use if resizing the image e.g. `PILImageResampling.BILINEAR`. Only has
61
+ an effect if `do_resize` is set to `True`.
62
+ do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
63
+ Whether to rescale the image values between [0 - 1].
64
+ rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
65
+ Rescale factor to rescale the image by if `do_rescale` is set to `True`.
66
+ do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
67
+ Whether to normalize the image.
68
+ image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
69
+ Image mean to use if `do_normalize` is set to `True`.
70
+ image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
71
+ Image standard deviation to use if `do_normalize` is set to `True`.
72
+ return_tensors (`str` or `TensorType`, *optional*):
73
+ The type of tensors to return. Can be one of:
74
+ - Unset: Return a list of `np.ndarray`.
75
+ - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
76
+ - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
77
+ - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
78
+ - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
79
+ data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
80
+ The channel dimension format for the output image. Can be one of:
81
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
82
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
83
+ - Unset: Use the channel dimension format of the input image.
84
+ input_data_format (`ChannelDimension` or `str`, *optional*):
85
+ The channel dimension format for the input image. If unset, the channel dimension format is inferred
86
+ from the input image. Can be one of:
87
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
88
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
89
+ - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
90
+ """
91
+ imagesA, imagesB = images
92
+ feature_A = super().preprocess(imagesA, **kwargs)
93
+ feature_B = super().preprocess(imagesB, **kwargs)
94
+
95
+ data = {
96
+ "pixel_valuesA": feature_A["pixel_values"],
97
+ "pixel_valuesB": feature_B["pixel_values"],
98
+ }
99
+ return BatchFeature(data=data, tensor_type=kwargs.pop("return_tensors", None))
preprocessor_config.json ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "auto_map": {
3
+ "AutoImageProcessor": "preprocessing_adaptformer.AdaptFormerImageProcessor"
4
+ },
5
+ "size": 256,
6
+ "do_center_crop": false,
7
+ "do_convert_rgb": true,
8
+ "do_normalize": true,
9
+ "do_rescale": true,
10
+ "do_resize": true,
11
+ "image_mean": [
12
+ 0.485,
13
+ 0.456,
14
+ 0.406
15
+ ],
16
+ "image_processor_type": "AdaptFormerImageProcessor",
17
+ "image_std": [
18
+ 0.229,
19
+ 0.224,
20
+ 0.225
21
+ ],
22
+ "resample": 3
23
+ }