adaptformer-LEVIR-CD / modeling_adaptformer.py
lealaxy's picture
upload model
240df91
""" PyTorch AdaptFormer model."""
import itertools
from typing import Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops.layers.torch import Rearrange
from transformers import PreTrainedModel
from transformers.modeling_outputs import SemanticSegmenterOutput
from .configuration_adaptformer import AdaptFormerConfig
class SpatialExchange(nn.Module):
def __init__(self, p=1 / 2):
super().__init__()
assert p >= 0 and p <= 1
self.p = int(1 / p)
def forward(self, x1: torch.Tensor, x2: torch.Tensor):
_, _, _, w = x1.shape
exchange_mask = torch.arange(w) % self.p == 0
out_x1 = torch.zeros_like(x1, device=x1.device)
out_x2 = torch.zeros_like(x2, device=x1.device)
out_x1[..., ~exchange_mask] = x1[..., ~exchange_mask]
out_x2[..., ~exchange_mask] = x2[..., ~exchange_mask]
out_x1[..., exchange_mask] = x2[..., exchange_mask]
out_x2[..., exchange_mask] = x1[..., exchange_mask]
return out_x1, out_x2
class ChannelExchange(nn.Module):
def __init__(self, p=1 / 2):
super().__init__()
assert p >= 0 and p <= 1
self.p = int(1 / p)
def forward(self, x1: torch.Tensor, x2: torch.Tensor):
N, c, _, _ = x1.shape
exchange_map = torch.arange(c) % self.p == 0
exchange_mask = exchange_map.unsqueeze(0).expand((N, -1))
out_x1 = torch.zeros_like(x1, device=x1.device)
out_x2 = torch.zeros_like(x2, device=x1.device)
out_x1[~exchange_mask, ...] = x1[~exchange_mask, ...]
out_x2[~exchange_mask, ...] = x2[~exchange_mask, ...]
out_x1[exchange_mask, ...] = x2[exchange_mask, ...]
out_x2[exchange_mask, ...] = x1[exchange_mask, ...]
return out_x1, out_x2
class CascadedGroupAttention(nn.Module):
r"""Cascaded Group Attention.
Args:
dim (int): Number of input channels.
key_dim (int): The dimension for query and key.
num_heads (int): Number of attention heads.
attn_ratio (int): Multiplier for the query dim for value dimension.
resolution (int): Input resolution, correspond to the window size.
kernels (List[int]): The kernel size of the dw conv on query.
"""
def __init__(
self,
dim,
key_dim,
num_heads=8,
attn_ratio=4,
resolution=14,
kernels=[5, 5, 5, 5],
):
super().__init__()
self.num_heads = num_heads
self.scale = key_dim**-0.5
self.key_dim = key_dim
self.d = int(attn_ratio * key_dim)
self.attn_ratio = attn_ratio
qkvs = []
dws = []
for i in range(num_heads):
qkvs.append(
nn.Sequential(
nn.Conv2d(
dim // (num_heads),
self.key_dim * 2 + self.d,
1,
1,
0,
bias=False,
),
nn.BatchNorm2d(self.key_dim * 2 + self.d),
)
)
dws.append(
nn.Sequential(
nn.Conv2d(
self.key_dim,
self.key_dim,
kernels[i],
1,
kernels[i] // 2,
groups=self.key_dim,
bias=False,
),
nn.BatchNorm2d(self.key_dim),
)
)
self.qkvs = nn.ModuleList(qkvs)
self.dws = nn.ModuleList(dws)
self.proj = nn.Sequential(
nn.ReLU(),
nn.Conv2d(self.d * num_heads, dim, 1, 1, 0, bias=False),
nn.BatchNorm2d(dim),
)
self.act_gelu = nn.GELU()
points = list(itertools.product(range(resolution), range(resolution)))
N = len(points)
attention_offsets = {}
idxs = []
for p1 in points:
for p2 in points:
offset = (abs(p1[0] - p2[0]), abs(p1[1] - p2[1]))
if offset not in attention_offsets:
attention_offsets[offset] = len(attention_offsets)
idxs.append(attention_offsets[offset])
self.attention_biases = nn.Parameter(
torch.zeros(num_heads, len(attention_offsets))
)
self.register_buffer("attention_bias_idxs", torch.LongTensor(idxs).view(N, N))
@torch.no_grad()
def train(self, mode=True):
super().train(mode)
if mode and hasattr(self, "ab"):
del self.ab
else:
self.ab = self.attention_biases[:, self.attention_bias_idxs]
def forward(self, x):
B, _, H, W = x.shape
trainingab = self.attention_biases[:, self.attention_bias_idxs]
feats_in = x.chunk(len(self.qkvs), dim=1)
feats_out = []
feat = feats_in[0]
for i, qkv in enumerate(self.qkvs):
if i > 0:
feat = feat + feats_in[i]
feat = qkv(feat)
q, k, v = feat.view(B, -1, H, W).split(
[self.key_dim, self.key_dim, self.d], dim=1
)
q = self.act_gelu(self.dws[i](q)) + q
q, k, v = q.flatten(2), k.flatten(2), v.flatten(2)
attn = (q.transpose(-2, -1) @ k) * self.scale + (
trainingab[i] if self.training else self.ab[i].to(x.device)
)
attn = attn.softmax(dim=-1)
feat = (v @ attn.transpose(-2, -1)).view(B, self.d, H, W)
feats_out.append(feat)
x = self.proj(torch.cat(feats_out, 1))
return x
class LocalWindowAttention(nn.Module):
r"""Local Window Attention.
Args:
dim (int): Number of input channels.
key_dim (int): The dimension for query and key.
num_heads (int): Number of attention heads.
attn_ratio (int): Multiplier for the query dim for value dimension.
resolution (int): Input resolution.
window_resolution (int): Local window resolution.
kernels (List[int]): The kernel size of the dw conv on query.
"""
def __init__(
self,
dim,
key_dim,
num_heads=8,
attn_ratio=4,
resolution=14,
window_resolution=7,
kernels=[5, 5, 5, 5],
):
super().__init__()
self.dim = dim
self.num_heads = num_heads
self.resolution = resolution
assert window_resolution > 0, "window_size must be greater than 0"
self.window_resolution = window_resolution
window_resolution = min(window_resolution, resolution)
self.attn = CascadedGroupAttention(
dim,
key_dim,
num_heads,
attn_ratio=attn_ratio,
resolution=window_resolution,
kernels=kernels,
)
def forward(self, x):
H = W = self.resolution
B, C, H_, W_ = x.shape
# Only check this for classifcation models
assert (
H == H_ and W == W_
), "input feature has wrong size, expect {}, got {}".format((H, W), (H_, W_))
if H <= self.window_resolution and W <= self.window_resolution:
x = self.attn(x)
else:
x = x.permute(0, 2, 3, 1)
pad_b = (
self.window_resolution - H % self.window_resolution
) % self.window_resolution
pad_r = (
self.window_resolution - W % self.window_resolution
) % self.window_resolution
padding = pad_b > 0 or pad_r > 0
if padding:
x = F.pad(x, (0, 0, 0, pad_r, 0, pad_b))
pH, pW = H + pad_b, W + pad_r
nH = pH // self.window_resolution
nW = pW // self.window_resolution
x = (
x.view(B, nH, self.window_resolution, nW, self.window_resolution, C)
.transpose(2, 3)
.reshape(B * nH * nW, self.window_resolution, self.window_resolution, C)
.permute(0, 3, 1, 2)
)
x = self.attn(x)
x = (
x.permute(0, 2, 3, 1)
.view(B, nH, nW, self.window_resolution, self.window_resolution, C)
.transpose(2, 3)
.reshape(B, pH, pW, C)
)
if padding:
x = x[:, :H, :W].contiguous()
x = x.permute(0, 3, 1, 2)
return x
class LocalAgg(nn.Module):
def __init__(self, channels):
super(LocalAgg, self).__init__()
self.bn = nn.BatchNorm2d(channels)
self.pointwise_conv_0 = nn.Conv2d(channels, channels, kernel_size=1, bias=False)
self.depthwise_conv = nn.Conv2d(
channels, channels, padding=1, kernel_size=3, groups=channels, bias=False
)
self.pointwise_prenorm_1 = nn.BatchNorm2d(channels)
self.pointwise_conv_1 = nn.Conv2d(channels, channels, kernel_size=1, bias=False)
def forward(self, x):
x = self.bn(x)
x = self.pointwise_conv_0(x)
x = self.depthwise_conv(x)
x = self.pointwise_prenorm_1(x)
x = self.pointwise_conv_1(x)
return x
class Mlp(nn.Module):
def __init__(self, channels, mlp_ratio):
super(Mlp, self).__init__()
self.up_proj = nn.Conv2d(
channels, channels * mlp_ratio, kernel_size=1, bias=False
)
self.down_proj = nn.Conv2d(
channels * mlp_ratio, channels, kernel_size=1, bias=False
)
def forward(self, x):
return self.down_proj(F.gelu(self.up_proj(x)))
class LocalMerge(nn.Module):
def __init__(self, channels, r, heads, resolution, partial=False):
super(LocalMerge, self).__init__()
self.partial = partial
self.cpe1 = nn.Conv2d(
channels, channels, kernel_size=3, padding=1, groups=channels, bias=False
)
self.local_agg = LocalAgg(channels)
self.mlp1 = Mlp(channels, r)
if partial:
self.cpe2 = nn.Conv2d(
channels,
channels,
kernel_size=3,
padding=1,
groups=channels,
bias=False,
)
self.attn = LocalWindowAttention(
channels,
16,
heads,
attn_ratio=r,
resolution=resolution,
window_resolution=7,
kernels=[5, 5, 5, 5],
)
self.mlp2 = Mlp(channels, r)
def forward(self, x):
x = self.cpe1(x) + x
x = self.local_agg(x) + x
x = self.mlp1(x) + x
if self.partial:
x = self.cpe2(x) + x
x = self.attn(x) + x
x = self.mlp2(x) + x
return x
class AdaptFormerEncoderBlock(nn.Module):
def __init__(
self, in_chans, embed_dim, num_head, mlp_ratio, depth, resolution, partial
):
super().__init__()
self.down = nn.Sequential(
nn.Conv2d(in_chans, embed_dim, kernel_size=2, stride=2),
nn.GroupNorm(num_groups=1, num_channels=embed_dim),
)
self.block = nn.Sequential(
*[
LocalMerge(
channels=embed_dim,
r=mlp_ratio,
heads=num_head,
resolution=resolution,
partial=partial,
)
for _ in range(depth)
]
)
def forward(self, x: torch.Tensor):
return self.block(self.down(x))
class ChangeDetectionHaed(nn.Module):
def __init__(self, embedding_dim, in_channels, num_classes):
super(ChangeDetectionHaed, self).__init__()
self.in_proj = nn.Sequential(
nn.Conv2d(
in_channels=embedding_dim * len(in_channels),
out_channels=embedding_dim,
kernel_size=1,
),
nn.BatchNorm2d(embedding_dim),
nn.ConvTranspose2d(embedding_dim, embedding_dim, 4, stride=2, padding=1),
)
self.conv1 = nn.Conv2d(embedding_dim, embedding_dim, 3, 1, 1)
self.conv2 = nn.Conv2d(embedding_dim, embedding_dim, 3, 1, 1)
self.out = nn.Conv2d(embedding_dim, num_classes, 3, 1, 1)
def forward(self, x: torch.Tensor):
x = self.in_proj(x)
x = self.conv2(F.relu(self.conv1(x))) * 0.1 + x
return self.out(x)
class AdaptFormerDecoder(nn.Module):
def __init__(
self,
config: AdaptFormerConfig,
):
super(AdaptFormerDecoder, self).__init__()
self.in_channels = config.embed_dims
self.embedding_dim = config.embed_dims[-1]
self.linear_emb_layers = nn.ModuleList(
[
nn.Sequential(
Rearrange("n c ... -> n (...) c"),
nn.Linear(in_dim, self.embedding_dim),
)
for in_dim in self.in_channels
]
)
self.diff_layers = nn.ModuleList(
[
nn.Sequential(
nn.Conv2d(2 * self.embedding_dim, self.embedding_dim, 3, 1, 1),
nn.ReLU(),
nn.BatchNorm2d(self.embedding_dim),
nn.Conv2d(self.embedding_dim, self.embedding_dim, 3, 1, 1),
nn.ReLU(),
)
for _ in range(3)
]
)
self.prediction_layers = nn.ModuleList(
[
nn.Sequential(
nn.Conv2d(self.embedding_dim, config.num_classes, 3, 1, 1),
nn.ReLU(),
nn.BatchNorm2d(config.num_classes),
nn.Conv2d(config.num_classes, config.num_classes, 3, 1, 1),
)
for _ in range(3)
]
)
self.head = ChangeDetectionHaed(
self.embedding_dim, self.in_channels, config.num_classes
)
def forward(self, pixel_valuesA, pixel_valuesB):
N, _, H, W = pixel_valuesA[0].shape
# c3
pixel_values_c3 = torch.cat([pixel_valuesA[2], pixel_valuesB[2]], dim=0)
_c3_1, _c3_2 = torch.chunk(
self.linear_emb_layers[2](pixel_values_c3).permute(0, 2, 1), 2
)
_c3_1 = _c3_1.reshape(N, -1, pixel_values_c3.shape[2], pixel_values_c3.shape[3])
_c3_2 = _c3_2.reshape(N, -1, pixel_values_c3.shape[2], pixel_values_c3.shape[3])
_c3 = self.diff_layers[2](torch.cat((_c3_1, _c3_2), dim=1))
p_c3 = self.prediction_layers[2](_c3)
_c3_up = F.interpolate(_c3, (H, W), mode="bilinear", align_corners=False)
# c2
pixel_values_c2 = torch.cat([pixel_valuesA[1], pixel_valuesB[1]], dim=0)
_c2_1, _c2_2 = torch.chunk(
self.linear_emb_layers[1](pixel_values_c2).permute(0, 2, 1), 2
)
_c2_1 = _c2_1.reshape(N, -1, pixel_values_c2.shape[2], pixel_values_c2.shape[3])
_c2_2 = _c2_2.reshape(N, -1, pixel_values_c2.shape[2], pixel_values_c2.shape[3])
_c2 = self.diff_layers[1](torch.cat((_c2_1, _c2_2), dim=1)) + F.interpolate(
_c3, scale_factor=2, mode="bilinear"
)
p_c2 = self.prediction_layers[1](_c2)
_c2_up = F.interpolate(_c2, (H, W), mode="bilinear", align_corners=False)
# c1
pixel_values_c1 = torch.cat([pixel_valuesA[0], pixel_valuesB[0]], dim=0)
_c1_1, _c1_2 = torch.chunk(
self.linear_emb_layers[0](pixel_values_c1).permute(0, 2, 1), 2
)
_c1_1 = _c1_1.reshape(N, -1, pixel_values_c1.shape[2], pixel_values_c1.shape[3])
_c1_2 = _c1_2.reshape(N, -1, pixel_values_c1.shape[2], pixel_values_c1.shape[3])
_c1 = self.diff_layers[0](torch.cat((_c1_1, _c1_2), dim=1)) + F.interpolate(
_c2, scale_factor=2, mode="bilinear"
)
p_c1 = self.prediction_layers[0](_c1)
cp = self.head(torch.cat((_c3_up, _c2_up, _c1), dim=1))
return [p_c3, p_c2, p_c1, cp]
class AdaptFormerPreTrainedModel(PreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models.
"""
config_class = AdaptFormerConfig
base_model_prefix = "adaptformer"
def _init_weights(self, m):
"""Initialize the weights"""
if isinstance(m, nn.Linear):
nn.init.trunc_normal_(m.weight, std=0.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
elif isinstance(m, nn.Conv2d):
import math
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
fan_out //= m.groups
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
if m.bias is not None:
m.bias.data.zero_()
class AdaptFormerForChangeDetection(AdaptFormerPreTrainedModel):
"""
This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use
it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
behavior.
Parameters:
config ([`AdaptFormerConfig`]): Model configuration class with all the parameters of the model.
Initializing with a config file does not load the weights associated with the model, only the
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""
def __init__(
self,
config: AdaptFormerConfig,
):
super().__init__(config)
self.config = config
self.block1 = AdaptFormerEncoderBlock(
in_chans=config.num_channels,
embed_dim=config.embed_dims[0],
num_head=config.num_heads[0],
mlp_ratio=config.mlp_ratios[0],
depth=config.depths[0],
resolution=config.embed_dims[2] // 2,
partial=False,
)
self.block2 = AdaptFormerEncoderBlock(
in_chans=config.embed_dims[0],
embed_dim=config.embed_dims[1],
num_head=config.num_heads[1],
mlp_ratio=config.mlp_ratios[1],
depth=config.depths[1],
resolution=config.embed_dims[1] // 2,
partial=False,
)
self.block3 = AdaptFormerEncoderBlock(
in_chans=config.embed_dims[1],
embed_dim=config.embed_dims[2],
num_head=config.num_heads[2],
mlp_ratio=config.mlp_ratios[2],
depth=config.depths[2],
resolution=config.embed_dims[0] // 2,
partial=True,
)
self.spatialex = SpatialExchange()
self.channelex = ChannelExchange()
self.decoder = AdaptFormerDecoder(config=config)
# Initialize weights and apply final processing
self.post_init()
def forward(
self,
pixel_valuesA: torch.Tensor,
pixel_valuesB: torch.Tensor,
labels: Optional[torch.Tensor] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
):
r"""
labels (`torch.LongTensor` of shape `(batch_size, height, width)`, *optional*):
Ground truth semantic segmentation maps for computing the loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels > 1`, a classification loss is computed (Cross-Entropy).
Returns:
Examples:
```python
>>> from transformers import AutoImageProcessor, AutoModel
>>> from PIL import Image
>>> import requests
>>> image_processor = AutoImageProcessor.from_pretrained("deepang/adaptformer-LEVIR-CD")
>>> model = AutoModel.from_pretrained("deepang/adaptformer-LEVIR-CD")
>>> image_A = Image.open(requests.get('https://raw.githubusercontent.com/aigzhusmart/AdaptFormer/main/figures/test_2_1_A.png', stream=True).raw)
>>> image_B = Image.open(requests.get('https://raw.githubusercontent.com/aigzhusmart/AdaptFormer/main/figures/test_2_1_B.png', stream=True).raw)
>>> label = Image.open(requests.get('https://raw.githubusercontent.com/aigzhusmart/AdaptFormer/main/figures/test_2_1_label.png', stream=True).raw)
>>> with torch.no_grad():
>>> inputs = preprocessor(images=(image_A, image_B), return_tensors="pt")
>>> outputs = adaptfromer_model(**inputs)
>>> logits = outputs.logits.cpu()
>>> pred = logits.argmax(dim=1)[0]
```"""
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
x1_1, x2_1 = torch.chunk(
self.block1(torch.cat((pixel_valuesA, pixel_valuesB), dim=0)), 2
)
x1_2, x2_2 = torch.chunk(
self.block2(torch.cat(self.spatialex(x1_1, x2_1), dim=0)), 2
)
x1_3, x2_3 = torch.chunk(
self.block3(torch.cat(self.channelex(x1_2, x2_2), dim=0)), 2
)
hidden_states = self.decoder([x1_1, x1_2, x1_3], [x2_1, x2_2, x2_3])
loss = None
if labels is not None:
loss = 0
for i, hidden_state in enumerate(hidden_states):
upsampled_logits = F.interpolate(
hidden_state,
size=labels.shape[-2:],
mode="bilinear",
align_corners=False,
)
loss += (
F.cross_entropy(
upsampled_logits,
labels.long(),
ignore_index=self.config.semantic_loss_ignore_index,
)
* self.config.semantic_loss_weight[i]
)
if not return_dict:
if output_hidden_states:
output = (hidden_states[-1], hidden_states)
else:
output = (hidden_states[-1],)
return ((loss,) + output) if loss is not None else output
return SemanticSegmenterOutput(
loss=loss,
logits=hidden_states[-1],
hidden_states=hidden_states if output_hidden_states else None,
)