Tianyinus's picture
init submit
edcf5ee verified
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from .conv_layers import DepthwiseSeparableConv, BasicBlock, Bottleneck, MBConv, FusedMBConv, ConvNormAct
from .trans_layers import TransformerBlock
from einops import rearrange
import pdb
class BidirectionAttention(nn.Module):
def __init__(self, feat_dim, map_dim, out_dim, heads=4, dim_head=64, attn_drop=0.,
proj_drop=0., map_size=16, proj_type='depthwise'):
super().__init__()
self.inner_dim = dim_head * heads
self.feat_dim = feat_dim
self.map_dim = map_dim
self.heads = heads
self.scale = dim_head ** (-0.5)
self.dim_head = dim_head
self.map_size = map_size
assert proj_type in ['linear', 'depthwise']
if proj_type == 'linear':
self.feat_qv = nn.Conv2d(feat_dim, self.inner_dim*2, kernel_size=1, bias=False)
self.feat_out = nn.Conv2d(self.inner_dim, out_dim, kernel_size=1, bias=False)
else:
self.feat_qv = DepthwiseSeparableConv(feat_dim, self.inner_dim * 2)
self.feat_out = DepthwiseSeparableConv(self.inner_dim, out_dim)
self.map_qv = nn.Conv2d(map_dim, self.inner_dim*2, kernel_size=1, bias=False)
self.map_out = nn.Conv2d(self.inner_dim, map_dim, kernel_size=1, bias=False)
self.attn_drop = nn.Dropout(attn_drop)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, feat, semantic_map):
B, C, H, W = feat.shape
feat_q, feat_v = self.feat_qv(feat).chunk(2, dim=1) # B, inner_dim, H, W
map_q, map_v = self.map_qv(semantic_map).chunk(2, dim=1) # B, inner_dim, rs, rs
feat_q, feat_v = map(lambda t: rearrange(t, 'b (dim_head heads) h w -> b heads (h w) dim_head', dim_head = self.dim_head, heads=self.heads, h=H, w=W), [feat_q, feat_v])
map_q, map_v = map(lambda t: rearrange(t, 'b (dim_head heads) h w -> b heads (h w) dim_head', dim_head=self.dim_head, heads=self.heads, h=self.map_size, w=self.map_size), [map_q, map_v])
attn = torch.einsum('bhid,bhjd->bhij', feat_q, map_q)
attn *= self.scale
feat_map_attn = F.softmax(attn, dim=-1) # semantic map is very concise that don't need dropout
# add dropout migth cause unstable during training
map_feat_attn = self.attn_drop(F.softmax(attn, dim=-2))
feat_out = torch.einsum('bhij,bhjd->bhid', feat_map_attn, map_v)
feat_out = rearrange(feat_out, 'b heads (h w) dim_head -> b (dim_head heads) h w', h=H, w=W, dim_head=self.dim_head, heads=self.heads)
map_out = torch.einsum('bhji,bhjd->bhid', map_feat_attn, feat_v)
map_out = rearrange(map_out, 'b heads (h w) dim_head -> b (dim_head heads) h w', b=B, dim_head=self.dim_head, heads=self.heads, h=self.map_size, w=self.map_size)
feat_out = self.proj_drop(self.feat_out(feat_out))
map_out = self.proj_drop(self.map_out(map_out))
return feat_out, map_out
class BidirectionAttentionBlock(nn.Module):
def __init__(self, feat_dim, map_dim, out_dim, heads, dim_head, norm=nn.BatchNorm2d,
act=nn.GELU, expansion=4, attn_drop=0., proj_drop=0., map_size=8,
proj_type='depthwise'):
super().__init__()
assert norm in [nn.BatchNorm2d, nn.InstanceNorm2d, True, False]
assert act in [nn.ReLU, nn.ReLU6, nn.GELU, nn.SiLU, True, False]
assert proj_type in ['linear', 'depthwise']
self.norm1 = norm(feat_dim) if norm else nn.Identity() # norm layer for feature map
self.norm2 = norm(map_dim) if norm else nn.Identity() # norm layer for semantic map
self.attn = BidirectionAttention(feat_dim, map_dim, out_dim, heads=heads, dim_head=dim_head, attn_drop=attn_drop, proj_drop=proj_drop, map_size=map_size, proj_type=proj_type)
self.shortcut = nn.Sequential()
if feat_dim != out_dim:
self.shortcut = ConvNormAct(feat_dim, out_dim, kernel_size=1, padding=0, norm=norm, act=act, preact=True)
if proj_type == 'linear':
self.feedforward = FusedMBConv(out_dim, out_dim, expansion=expansion, kernel_size=1, act=act, norm=norm) # 2 conv1x1
else:
self.feedforward = MBConv(out_dim, out_dim, expansion=expansion, kernel_size=3, act=act, norm=norm, p=proj_drop) # depthwise conv
def forward(self, x, semantic_map):
feat = self.norm1(x)
mapp = self.norm2(semantic_map)
out, mapp = self.attn(feat, mapp)
out += self.shortcut(x)
out = self.feedforward(out)
mapp += semantic_map
return out, mapp
class PatchMerging(nn.Module):
"""
Modified patch merging layer that works as down-sampling
"""
def __init__(self, dim, out_dim, norm=nn.BatchNorm2d, proj_type='depthwise', map_proj=True):
super().__init__()
self.dim = dim # 32
if proj_type == 'linear':
self.reduction = nn.Conv2d(4*dim, out_dim, kernel_size=1, bias=False)
else:
self.reduction = DepthwiseSeparableConv(4*dim, out_dim) # (32*4, 64)
self.norm = norm(4*dim)
if map_proj:
self.map_projection = nn.Conv2d(dim, out_dim, kernel_size=1, bias=False)
# (32, 64, kernel_size, bias)
def forward(self, x, semantic_map=None):
"""
x: B, C, H, W
"""
x0 = x[:, :, 0::2, 0::2]
x1 = x[:, :, 1::2, 0::2]
x2 = x[:, :, 0::2, 1::2]
x3 = x[:, :, 1::2, 1::2]
x = torch.cat([x0, x1, x2, x3], 1) # B, 4C, H, W
x = self.norm(x)
x = self.reduction(x) # depthwise + pointwise 4C -> outdim
if semantic_map is not None:
semantic_map = self.map_projection(semantic_map) # dim -> outdim
return x, semantic_map
class BasicLayer(nn.Module):
"""
A basic transformer layer for one stage
No downsample of upsample operation in this layer, they are wraped in the down_block or up_block of UTNet
"""
def __init__(self, feat_dim, map_dim, out_dim, num_blocks, heads=4, dim_head=64, expansion=1, attn_drop=0., proj_drop=0., map_size=8, proj_type='depthwise', norm=nn.BatchNorm2d, act=nn.GELU):
super().__init__()
dim1 = feat_dim
dim2 = out_dim
self.blocks = nn.ModuleList([])
for i in range(num_blocks):
self.blocks.append(BidirectionAttentionBlock(dim1, map_dim, dim2, heads, dim_head, expansion=expansion, attn_drop=attn_drop, proj_drop=proj_drop, map_size=map_size, proj_type=proj_type, norm=norm, act=act))
dim1 = out_dim
def forward(self, x, semantic_map):
for block in self.blocks:
x, semantic_map = block(x, semantic_map)
return x, semantic_map
class SemanticMapGeneration(nn.Module):
def __init__(self, feat_dim, map_dim, map_size): # (64, 64, 8)
super().__init__()
self.map_size = map_size # 8
self.map_dim = map_dim # 64
self.map_code_num = map_size * map_size # 8*8=64
self.base_proj = nn.Conv2d(feat_dim, map_dim, kernel_size=3, padding=1, bias=False)
# (64, 64, 3, 1, false)
self.semantic_proj = nn.Conv2d(feat_dim, self.map_code_num, kernel_size=3, padding=1, bias=False)
# (64, 64, 3, 1 false)
def forward(self, x):
B, C, H, W = x.shape # B, C, H, W
feat = self.base_proj(x) # B, map_dim, h, w
weight_map = self.semantic_proj(x) # B, map_code_num, h, w
weight_map = weight_map.view(B, self.map_code_num, -1)
weight_map = F.softmax(weight_map, dim=2) # B, map_code_num, hw
feat = feat.view(B, self.map_dim, -1) # B, map_dim, hw
semantic_map = torch.einsum('bij,bkj->bik', feat, weight_map)
return semantic_map.view(B, self.map_dim, self.map_size, self.map_size)
class SemanticMapFusion(nn.Module):
def __init__(self, in_dim_list, dim, heads, depth=1, norm=nn.BatchNorm2d):
super().__init__()
self.dim = dim
# project all maps to the same channel num
self.in_proj = nn.ModuleList([])
for i in range(len(in_dim_list)):
self.in_proj.append(nn.Conv2d(in_dim_list[i], dim, kernel_size=1, bias=False))
self.fusion = TransformerBlock(dim, depth, heads, dim//heads, dim, attn_drop=0., proj_drop=0.)
# project all maps back to their origin channel num
self.out_proj = nn.ModuleList([])
for i in range(len(in_dim_list)):
self.out_proj.append(nn.Conv2d(dim, in_dim_list[i], kernel_size=1, bias=False))
def forward(self, map_list):
B, _, H, W = map_list[0].shape
proj_maps = [self.in_proj[i](map_list[i]).view(B, self.dim, -1).permute(0, 2, 1) for i in range(len(map_list))]
# B, L, C where L=HW
proj_maps = torch.cat(proj_maps, dim=1)
attned_maps = self.fusion(proj_maps)
attned_maps = attned_maps.chunk(len(map_list), dim=1)
maps_out = [self.out_proj[i](attned_maps[i].permute(0, 2, 1).view(B, self.dim, H, W)) for i in range(len(map_list))]
return maps_out
#######################################################################
# UTNet block that for one stage, which contains conv block and trans block
class inconv(nn.Module):
def __init__(self, in_ch, out_ch, block=BasicBlock, norm=nn.BatchNorm2d, act=nn.GELU):
super().__init__()
self.conv1 = nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1, bias=False)
self.conv2 = block(out_ch, out_ch, norm=norm, act=act)
def forward(self, x):
if x.shape == 5:
x = x.squeeze(1)
out = self.conv1(x) # (3, 480, 480) -> (32, 480, 480]
out = self.conv2(out) # block (32, 32, norm, act) conv norm relu 残差
return out
class Down_block(nn.Module):
def __init__(self, in_ch, out_ch, conv_num, trans_num, conv_block=BasicBlock,
heads=4, dim_head=64, expansion=4, attn_drop=0., proj_drop=0., map_size=8,
proj_type='depthwise', norm=nn.BatchNorm2d, act=nn.GELU, map_generate=False,
map_proj=True, map_dim=None):
# (32, 64, 2, 0, basicblock, batchnorm, gelu, False, False)
super().__init__()
map_dim = out_ch if map_dim is None else map_dim # 64
self.map_generate = map_generate # False
if map_generate:
self.map_gen = SemanticMapGeneration(out_ch, map_dim, map_size)
# return semantic_map.view(B, self.map_dim, self.map_size, self.map_size)
self.patch_merging = PatchMerging(in_ch, out_ch, proj_type=proj_type, norm=norm, map_proj=map_proj)
# in_ch->out_ch
block_list = []
for i in range(conv_num): # 2
block_list.append(conv_block(out_ch, out_ch, norm=norm, act=act))
dim1 = out_ch
self.conv_blocks = nn.Sequential(*block_list)
self.trans_blocks = BasicLayer(out_ch, map_dim, out_ch, num_blocks=trans_num, \
heads=heads, dim_head=dim_head, norm=norm, act=act, expansion=expansion,\
attn_drop=attn_drop, proj_drop=proj_drop, map_size=map_size, proj_type=proj_type)
def forward(self, x, semantic_map=None):
x, semantic_map = self.patch_merging(x, semantic_map) # in_ch->out_chan
out = self.conv_blocks(x) # out->out
if self.map_generate:
semantic_map = self.map_gen(out) # (B, self.map_dim, self.map_size, self.map_size))
out, semantic_map = self.trans_blocks(out, semantic_map)
return out, semantic_map
class Up_block(nn.Module):
def __init__(self, in_ch, out_ch, conv_num, trans_num, conv_block=BasicBlock,
heads=4, dim_head=64, expansion=1, attn_drop=0., proj_drop=0., map_size=8,
proj_type='linear', norm=nn.BatchNorm2d, act=nn.GELU, map_dim=None,
map_shortcut=False):
super().__init__()
self.reduction = nn.Conv2d(in_ch+out_ch, out_ch, kernel_size=1, padding=0, bias=False)
self.norm = norm(in_ch+out_ch)
self.map_shortcut = map_shortcut
map_dim = out_ch if map_dim is None else map_dim
if map_shortcut:
self.map_reduction = nn.Conv2d(in_ch+out_ch, map_dim, kernel_size=1, bias=False)
else:
self.map_reduction = nn.Conv2d(in_ch, map_dim, kernel_size=1, bias=False)
self.trans_blocks = BasicLayer(out_ch, map_dim, out_ch, num_blocks=trans_num, \
heads=heads, dim_head=dim_head, norm=norm, act=act, expansion=expansion,\
attn_drop=attn_drop, proj_drop=proj_drop, map_size=map_size, proj_type=proj_type)
conv_list = []
for i in range(conv_num):
conv_list.append(conv_block(out_ch, out_ch, norm=norm, act=act))
self.conv_blocks = nn.Sequential(*conv_list)
def forward(self, x1, x2, map1, map2=None):
# x1: low-res feature, x2: high-res feature
# map1: semantic map from previous low-res layer
# map2: semantic map from encoder shortcut path, might be none if we don't have the map from encoder
x1 = F.interpolate(x1, size=x2.shape[-2:], mode='bilinear', align_corners=True)
feat = torch.cat([x1, x2], dim=1)
out = self.reduction(self.norm(feat))
if self.map_shortcut and map2 is not None:
semantic_map = torch.cat([map1, map2], dim=1)
else:
semantic_map = map1
semantic_map = self.map_reduction(semantic_map)
out, semantic_map = self.trans_blocks(out, semantic_map)
out = self.conv_blocks(out)
return out, semantic_map