import torch from torch import nn from einops import rearrange from torch import nn, einsum from einops import rearrange from mmseg.models.builder import MODELS import math import torch from torch import nn as nn from mmseg.models.builder import MODELS from timm.layers import DropPath, trunc_normal_ from typing import List from timm.layers import create_act_layer from functools import partial import torch.nn.functional as F import torch import torch.nn as nn import torch.nn.functional as F import math from timm.layers import CondConv2d, get_condconv_initializer, create_conv2d, DropPath, get_norm_act_layer class LoRaMLP(nn.Module): def __init__(self, in_dim, out_dim, rank_dim=8): super().__init__() self.loramlp = nn.Sequential( nn.Linear(in_dim, rank_dim, bias=False), nn.Linear(rank_dim, out_dim, bias=False), ) def forward(self, x): return self.loramlp(x) class CrossAttention(nn.Module): def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, rank_dim=None): super().__init__() inner_dim = dim_head * heads # 512 context_dim = query_dim if context_dim is None else context_dim self.scale = dim_head ** -0.5 self.heads = heads if not rank_dim: self.to_q = nn.Linear(query_dim, inner_dim, bias=False) self.to_k = nn.Linear(context_dim, inner_dim, bias=False) self.to_v = nn.Linear(context_dim, inner_dim, bias=False) self.to_out = nn.Linear(inner_dim, query_dim, bias=False) else: self.to_q = LoRaMLP(query_dim, inner_dim, rank_dim=rank_dim) self.to_k = LoRaMLP(context_dim, inner_dim, rank_dim=rank_dim) self.to_v = LoRaMLP(context_dim, inner_dim, rank_dim=rank_dim) self.to_out = LoRaMLP(inner_dim, query_dim, rank_dim=rank_dim) def forward(self, x, context): h = self.heads q = self.to_q(x) k = self.to_k(context) v = self.to_v(context) q, k, v = map(lambda t: rearrange( t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) sim = einsum('b i d, b j d -> b i j', q, k) * self.scale attn = sim.softmax(dim=-1) out = einsum('b i j, b j d -> b i d', attn, v) out = rearrange(out, '(b h) n d -> b n (h d)', h=h) return self.to_out(out) def num_groups(group_size, channels): if not group_size: return 1 else: assert channels % group_size == 0 return channels // group_size def _init_weight_goog(m, n='', fix_group_fanout=True): if isinstance(m, CondConv2d): fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels if fix_group_fanout: fan_out //= m.groups init_weight_fn = get_condconv_initializer( lambda w: nn.init.normal_(w, 0, math.sqrt(2.0 / fan_out)), m.num_experts, m.weight_shape) init_weight_fn(m.weight) if m.bias is not None: nn.init.zeros_(m.bias) elif isinstance(m, nn.Conv2d): fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels if fix_group_fanout: fan_out //= m.groups nn.init.normal_(m.weight, 0, math.sqrt(2.0 / fan_out)) if m.bias is not None: nn.init.zeros_(m.bias) elif isinstance(m, nn.BatchNorm2d): nn.init.ones_(m.weight) nn.init.zeros_(m.bias) elif isinstance(m, nn.Linear): fan_out = m.weight.size(0) fan_in = 0 if 'routing_fn' in n: fan_in = m.weight.size(1) init_range = 1.0 / math.sqrt(fan_in + fan_out) nn.init.uniform_(m.weight, -init_range, init_range) if m.bias is not None: nn.init.zeros_(m.bias) class DepthwiseSeparableConv(nn.Module): def __init__( self, in_chs, out_chs, dw_kernel_size=3, stride=1, dilation=1, group_size=1, pad_type='', noskip=False, pw_kernel_size=1, pw_act=False, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, se_layer=None, drop_path_rate=0.): super(DepthwiseSeparableConv, self).__init__() norm_act_layer = get_norm_act_layer(norm_layer) groups = num_groups(group_size, in_chs) self.has_skip = (stride == 1 and in_chs == out_chs) and not noskip self.has_pw_act = pw_act self.conv_dw = create_conv2d( in_chs, in_chs, dw_kernel_size, stride=stride, dilation=dilation, padding=pad_type, groups=groups) self.bn1 = norm_act_layer(in_chs, inplace=True) self.se = se_layer( in_chs, act_layer=act_layer) if se_layer else nn.Identity() self.conv_pw = create_conv2d( in_chs, out_chs, pw_kernel_size, padding=pad_type) self.bn2 = norm_act_layer( out_chs, inplace=True, apply_act=self.has_pw_act) self.drop_path = DropPath( drop_path_rate) if drop_path_rate else nn.Identity() def feature_info(self, location): if location == 'expansion': return dict(module='conv_pw', hook_type='forward_pre', num_chs=self.conv_pw.in_channels) else: return dict(module='', hook_type='', num_chs=self.conv_pw.out_channels) def forward(self, x): shortcut = x x = self.conv_dw(x) x = self.bn1(x) x = self.se(x) x = self.conv_pw(x) x = self.bn2(x) if self.has_skip: x = self.drop_path(x) + shortcut return x class PMAAConvBlock(nn.Module): def __init__(self, in_channels=3, hidden_channels=256, depth=4, norm=nn.BatchNorm2d, act=nn.ReLU, return_multi_feats=False, return_last_feature=True, has_stem=True, has_block=True): super().__init__() self.return_last_feature = return_last_feature self.depth = depth self.has_stem = has_stem self.return_multi_feats = return_multi_feats self.proj_1x1 = DepthwiseSeparableConv( in_channels, hidden_channels, dw_kernel_size=1, norm_layer=norm, act_layer=act) self.spp_dw = nn.ModuleList() if has_stem: self.spp_dw.append( DepthwiseSeparableConv(hidden_channels, hidden_channels, dw_kernel_size=3, stride=1, group_size=hidden_channels, pad_type="same") ) else: self.spp_dw.append(nn.Identity()) if has_block: for _ in range(self.depth): self.spp_dw.append( DepthwiseSeparableConv( hidden_channels, hidden_channels, dw_kernel_size=3, stride=2, group_size=hidden_channels ) ) else: for _ in range(self.depth): self.spp_dw.append( nn.MaxPool2d(kernel_size=2, stride=2) ) self._init_weights() def forward(self, x): B, C, H, W = x.shape output1 = self.proj_1x1(x) output = [self.spp_dw[0](output1)] for k in range(1, self.depth+1): out_k = self.spp_dw[k](output[-1]) output.append(out_k) if self.return_multi_feats: return output[1:] else: if self.return_last_feature: return output[-1] global_f = torch.zeros( output[-1].shape, requires_grad=True, device=output1.device) for fea in output: global_f = global_f + F.adaptive_avg_pool2d( fea, output_size=output[-1].shape[-2:] ) return global_f def _init_weights(self): init_fn = _init_weight_goog for n, m in self.named_modules(): init_fn(m, n) class ConvnextInteractiveModule(nn.Module): def __init__(self, emd_dim=1024, context_dim=256, rank_dim=None): super().__init__() self.attn = CrossAttention(emd_dim, context_dim, rank_dim=rank_dim) def forward(self, x, cache, index): # x: 1024 2 1024 if isinstance(cache, list) or isinstance(cache, tuple): # len(cache) 4 cache[4]-23 # 0-5->0 6-11 -> 1 12-17->2 18-23->3 cache = cache[index] cache = F.interpolate( cache, (int(math.sqrt(x.shape[0])), int(math.sqrt(x.shape[0]))), mode="bilinear", align_corners=False ) cache = cache.flatten(2) # B C N cache = cache.permute(2, 0, 1) # N B C # Reshape: batch first x = x.permute(1, 0, 2) # B N C cache = cache.permute(1, 0, 2) # B N C return (x + self.attn(x, cache)).permute(1, 0, 2) class PMAAInteractiveModule(nn.Module): def __init__(self, emd_dim=1024, context_dim=64, kernel: int = 1, norm=nn.BatchNorm2d, local_groups=32, global_groups=2, return_multi_feats=False, ): super().__init__() self.return_multi_feats = return_multi_feats self.local_embedding = nn.Sequential( nn.Conv2d(emd_dim, emd_dim, kernel, groups=local_groups, padding=int((kernel - 1) / 2), bias=False), norm(emd_dim) ) self.global_embedding = nn.Sequential( nn.Conv2d(context_dim, emd_dim, kernel, groups=global_groups, padding=int((kernel - 1) / 2), bias=False), norm(emd_dim) ) self.global_act = nn.Sequential( nn.Conv2d(context_dim, emd_dim, kernel, groups=global_groups, padding=int((kernel - 1) / 2), bias=False), norm(emd_dim) ) self.act = nn.Sigmoid() self._init_weights() def _init_weights(self): init_fn = _init_weight_goog for n, m in self.named_modules(): init_fn(m, n) def forward(self, x, cache, index): if isinstance(cache, list) or isinstance(cache, tuple): cache = cache[index] N, B, C = x.shape H = W = int(math.sqrt(N)) # reshape x -> B, C, H, W x = x.permute(1, 2, 0).reshape(B, C, H, W) local_feat = self.local_embedding(x) # 32 global_act = self.global_act(cache) sig_act = F.interpolate(self.act(global_act), size=(H, W)) # 32 global_feat = self.global_embedding(cache) global_feat = F.interpolate(global_feat, size=(H, W)) # 32 out = local_feat * sig_act + global_feat return out.permute(2, 3, 0, 1).reshape(N, B, C) class LayerNorm(nn.Module): r""" LayerNorm that supports two data formats: channels_last (default) or channels_first. The ordering of the dimensions in the inputs. channels_last corresponds to inputs with shape (batch_size, height, width, channels) while channels_first corresponds to inputs with shape (batch_size, channels, height, width). """ def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"): super().__init__() self.weight = nn.Parameter(torch.ones(normalized_shape)) self.bias = nn.Parameter(torch.zeros(normalized_shape)) self.eps = eps self.data_format = data_format if self.data_format not in ["channels_last", "channels_first"]: raise NotImplementedError self.normalized_shape = (normalized_shape, ) def forward(self, x): if self.data_format == "channels_last": return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) elif self.data_format == "channels_first": u = x.mean(1, keepdim=True) s = (x - u).pow(2).mean(1, keepdim=True) x = (x - u) / torch.sqrt(s + self.eps) x = self.weight[:, None, None] * x + self.bias[:, None, None] return x class Block(nn.Module): r""" ConvNeXt Block. There are two equivalent implementations: (1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W) (2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back We use (2) as we find it slightly faster in PyTorch Args: dim (int): Number of input channels. drop_path (float): Stochastic depth rate. Default: 0.0 layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6. """ def __init__(self, dim, drop_path=0., layer_scale_init_value=1e-6): super().__init__() self.dwconv = nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim) # depthwise conv self.norm = LayerNorm(dim, eps=1e-6) # pointwise/1x1 convs, implemented with linear layers self.pwconv1 = nn.Linear(dim, 4 * dim) self.act = nn.GELU() self.pwconv2 = nn.Linear(4 * dim, dim) self.gamma = nn.Parameter(layer_scale_init_value * torch.ones((dim)), requires_grad=True) if layer_scale_init_value > 0 else None self.drop_path = DropPath( drop_path) if drop_path > 0. else nn.Identity() def forward(self, x): input = x x = self.dwconv(x) x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C) x = self.norm(x) x = self.pwconv1(x) x = self.act(x) x = self.pwconv2(x) if self.gamma is not None: x = self.gamma * x x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W) x = input + self.drop_path(x) return x class ConvNeXt(nn.Module): r""" ConvNeXt A PyTorch impl of : `A ConvNet for the 2020s` - https://arxiv.org/pdf/2201.03545.pdf Args: in_chans (int): Number of input image channels. Default: 3 num_classes (int): Number of classes for classification head. Default: 1000 depths (tuple(int)): Number of blocks at each stage. Default: [3, 3, 9, 3] dims (int): Feature dimension at each stage. Default: [96, 192, 384, 768] drop_path_rate (float): Stochastic depth rate. Default: 0. layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6. head_init_scale (float): Init scaling value for classifier weights and biases. Default: 1. """ def __init__(self, in_chans=3, depths=[3, 3, 9, 3], dims=[96, 192, 384, 768], drop_path_rate=0., layer_scale_init_value=1e-6, out_indices=[0, 1, 2, 3], return_multi_feats=False, return_last_feature=True ): super().__init__() self.return_last_feature = return_last_feature self.return_multi_feats = return_multi_feats # stem and 3 intermediate downsampling conv layers self.downsample_layers = nn.ModuleList() stem = nn.Sequential( nn.Conv2d(in_chans, dims[0], kernel_size=2, stride=2), LayerNorm(dims[0], eps=1e-6, data_format="channels_first") ) self.downsample_layers.append(stem) for i in range(3): downsample_layer = nn.Sequential( LayerNorm(dims[i], eps=1e-6, data_format="channels_first"), nn.Conv2d(dims[i], dims[i+1], kernel_size=2, stride=2), ) self.downsample_layers.append(downsample_layer) # 4 feature resolution stages, each consisting of multiple residual blocks self.stages = nn.ModuleList() dp_rates = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] cur = 0 for i in range(4): stage = nn.Sequential( *[Block(dim=dims[i], drop_path=dp_rates[cur + j], layer_scale_init_value=layer_scale_init_value) for j in range(depths[i])] ) self.stages.append(stage) cur += depths[i] self.out_indices = out_indices norm_layer = partial(LayerNorm, eps=1e-6, data_format="channels_first") for i_layer in range(4): layer = norm_layer(dims[i_layer]) layer_name = f'norm{i_layer}' self.add_module(layer_name, layer) self.apply(self._init_weights) def _init_weights(self, m): if isinstance(m, (nn.Conv2d, nn.Linear)): trunc_normal_(m.weight, std=.02) nn.init.constant_(m.bias, 0) def init_weights(self, pretrained=None): """Initialize the weights in backbone. Args: pretrained (str, optional): Path to pre-trained weights. Defaults to None. """ def _init_weights(m): if isinstance(m, nn.Linear): trunc_normal_(m.weight, std=.02) if isinstance(m, nn.Linear) and m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.LayerNorm): nn.init.constant_(m.bias, 0) nn.init.constant_(m.weight, 1.0) if isinstance(pretrained, str): self.apply(_init_weights) # logger = get_root_logger() # load_checkpoint(self, pretrained, strict=False, logger=logger) elif pretrained is None: self.apply(_init_weights) else: raise TypeError('pretrained must be a str or None') def forward_features(self, x): outs = [] for i in range(4): x = self.downsample_layers[i](x) x = self.stages[i](x) if i in self.out_indices: norm_layer = getattr(self, f'norm{i}') x_out = norm_layer(x) outs.append(x_out) if self.return_multi_feats: return tuple(outs) if self.return_last_feature: return outs[-1] global_f = torch.zeros( outs[-1].shape, requires_grad=True, device=outs[-1].device) for fea in outs: global_f = global_f + F.adaptive_avg_pool2d( fea, output_size=outs[-1].shape[-2:] ) return global_f def forward(self, x): x = self.forward_features(x) return x class NoAdaptingModule(nn.Identity): def __init__(self): super().__init__() def forward(self, x, cache, layer): return x @MODELS.register_module() class CloudAdapter(nn.Module): def __init__(self, cnn_type="convnext", # convnext or mobilenet int_type="convnext", # cross_attention or # 共同的参数 start emd_dim=1024, num_layers=24, # 先判断是否返回多特征,之后再判断是否进行特征融合 return_multi_feats=True, return_last_feature=False, # 共同的参数 end # pmaa 提取单个特征 or 多尺寸特征 start hidden_channels=256, depth=4, norm=nn.BatchNorm2d, act=nn.ReLU, # pmaa 提取单个特征 or 多尺寸特征 end # pmaa net start local_groups=1, global_groups=1, # pmaa net end # convnext 提取单个特征 or 多尺寸特征 start context_dim=256, rank_dim=None, # convnext 提取单个特征 or 多尺寸特征 end, has_stem=True, has_block=True, ): super().__init__() self.cnn = nn.Identity() self.net = nn.Identity() if cnn_type == "pmaa": self.cnn = PMAAConvBlock( hidden_channels=hidden_channels, depth=depth, norm=norm, act=act, return_multi_feats=return_multi_feats, return_last_feature=return_last_feature, has_stem=has_stem, has_block=has_block ) elif cnn_type == "convnext": self.cnn = ConvNeXt(depths=[1]*4, dims=[context_dim]*4, return_multi_feats=return_multi_feats, return_last_feature=return_last_feature ) else: raise ValueError( f"cnn_type must in ['convnext','pmaa'],but got {cnn_type}") if int_type == "convnext": self.net = nn.ModuleList( ConvnextInteractiveModule(emd_dim, context_dim, rank_dim) for _ in range(num_layers) ) elif int_type == "pmaa": self.net = nn.ModuleList( PMAAInteractiveModule( emd_dim, context_dim, local_groups=local_groups, global_groups=global_groups) for _ in range(num_layers) ) elif int_type == "no_adapting": self.net = nn.ModuleList( NoAdaptingModule() for _ in range(num_layers) ) else: raise ValueError( f"int_type must in ['convnext','pmaa'],but got {int_type}") def forward(self, feats, layer, batch_first=True, has_cls_token=True, cache=None): if batch_first: feats = feats.permute(1, 0, 2) # 1025 2 1024 if has_cls_token: cls_token, feats = torch.tensor_split(feats, [1], dim=0) # 24 // 1 # feat: 1024 2 1024 feats = self.net[layer].forward( feats, cache, layer//(len(self.net) // 4)) if has_cls_token: feats = torch.cat([cls_token, feats], dim=0) if batch_first: feats = feats.permute(1, 0, 2) return feats