Tianyinus's picture
init submit
edcf5ee verified
"""
From Conformer with alter: conv and trans cls head was changed to volting together
ver: DEC 1st 16:00 official release
ref: https://github.com/pengzhiliang/Conformer/blob/main/conformer.py
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from functools import partial
from timm.models.layers import DropPath, trunc_normal_
class Mlp(nn.Module): # FFN
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
class Attention(nn.Module): # MHSA
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
# NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
self.scale = qk_scale or head_dim ** -0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x):
B, N, C = x.shape # N is patch number, C is patch dimension
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) # re arrange
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
class Block(nn.Module): # Encoder from ViT
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
drop_path=0., act_layer=nn.GELU, norm_layer=partial(nn.LayerNorm, eps=1e-6)):
super().__init__()
# pre norm 1
self.norm1 = norm_layer(dim)
# MHSA
self.attn = Attention(
dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
# pre norm 2
self.norm2 = norm_layer(dim)
# FFN(MLP)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
def forward(self, x):
x = x + self.drop_path(self.attn(self.norm1(x)))
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x
class ConvBlock(nn.Module): # ResNet bottleneck Convblock actually
def __init__(self, inplanes, outplanes, stride=1, res_conv=False, act_layer=nn.ReLU, groups=1,
norm_layer=partial(nn.BatchNorm2d, eps=1e-6), drop_block=None, drop_path=None):
super(ConvBlock, self).__init__()
expansion = 4
med_planes = outplanes // expansion
self.conv1 = nn.Conv2d(inplanes, med_planes, kernel_size=1, stride=1, padding=0, bias=False)
self.bn1 = norm_layer(med_planes)
self.act1 = act_layer(inplace=True)
self.conv2 = nn.Conv2d(med_planes, med_planes, kernel_size=3, stride=stride, groups=groups, padding=1,
bias=False)
self.bn2 = norm_layer(med_planes)
self.act2 = act_layer(inplace=True)
self.conv3 = nn.Conv2d(med_planes, outplanes, kernel_size=1, stride=1, padding=0, bias=False)
self.bn3 = norm_layer(outplanes)
self.act3 = act_layer(inplace=True)
if res_conv:
self.residual_conv = nn.Conv2d(inplanes, outplanes, kernel_size=1, stride=stride, padding=0, bias=False)
self.residual_bn = norm_layer(outplanes)
self.res_conv = res_conv
self.drop_block = drop_block
self.drop_path = drop_path
def zero_init_last_bn(self):
nn.init.zeros_(self.bn3.weight)
def forward(self, x, x_t=None, return_x_2=True):
residual = x
x = self.conv1(x)
x = self.bn1(x)
if self.drop_block is not None:
x = self.drop_block(x)
x = self.act1(x)
x = self.conv2(x) if x_t is None else self.conv2(x + x_t)
x = self.bn2(x)
if self.drop_block is not None:
x = self.drop_block(x)
x2 = self.act2(x)
x = self.conv3(x2)
x = self.bn3(x)
if self.drop_block is not None:
x = self.drop_block(x)
if self.drop_path is not None:
x = self.drop_path(x)
if self.res_conv:
residual = self.residual_conv(residual)
residual = self.residual_bn(residual)
x += residual
x = self.act3(x)
if return_x_2:
return x, x2
else:
return x
class FCUDown(nn.Module):
""" CNN feature maps -> Transformer patch embeddings
"""
def __init__(self, inplanes, outplanes, dw_stride, act_layer=nn.GELU,
norm_layer=partial(nn.LayerNorm, eps=1e-6)):
super(FCUDown, self).__init__()
self.dw_stride = dw_stride
self.conv_project = nn.Conv2d(inplanes, outplanes, kernel_size=1, stride=1, padding=0) # fix dimension
self.sample_pooling = nn.AvgPool2d(kernel_size=dw_stride, stride=dw_stride) # fix feature map size
self.ln = norm_layer(outplanes)
self.act = act_layer()
def forward(self, x, x_t):
x = self.conv_project(x) # [N, C, H, W]
x = self.sample_pooling(x).flatten(2).transpose(1, 2)
x = self.ln(x)
x = self.act(x)
x = torch.cat([x_t[:, 0][:, None, :], x], dim=1) # concatenate class token from x_t
return x
class FCUUp(nn.Module):
""" Transformer patch embeddings -> CNN feature maps
by interpolate operation
"""
def __init__(self, inplanes, outplanes, up_stride, act_layer=nn.ReLU,
norm_layer=partial(nn.BatchNorm2d, eps=1e-6), ):
super(FCUUp, self).__init__()
self.up_stride = up_stride
self.conv_project = nn.Conv2d(inplanes, outplanes, kernel_size=1, stride=1, padding=0)
self.bn = norm_layer(outplanes)
self.act = act_layer()
def forward(self, x, H, W): # interpolate to
B, _, C = x.shape
# [N, 197, 384] -> [N, 196, 384] -> [N, 384, 196] -> [N, 384, 14, 14]
x_r = x[:, 1:].transpose(1, 2).reshape(B, C, H, W) # drop cls token of x_t
x_r = self.act(self.bn(self.conv_project(x_r)))
return F.interpolate(x_r, size=(H * self.up_stride, W * self.up_stride)) # interpolate operation
class Med_ConvBlock(nn.Module): # ResNet bottleneck indentity actually
""" special case for Convblock without down sampling,
"""
def __init__(self, inplanes, act_layer=nn.ReLU, groups=1, norm_layer=partial(nn.BatchNorm2d, eps=1e-6),
drop_block=None, drop_path=None):
super(Med_ConvBlock, self).__init__()
expansion = 4
med_planes = inplanes // expansion
self.conv1 = nn.Conv2d(inplanes, med_planes, kernel_size=1, stride=1, padding=0, bias=False)
self.bn1 = norm_layer(med_planes)
self.act1 = act_layer(inplace=True)
self.conv2 = nn.Conv2d(med_planes, med_planes, kernel_size=3, stride=1, groups=groups, padding=1, bias=False)
self.bn2 = norm_layer(med_planes)
self.act2 = act_layer(inplace=True)
self.conv3 = nn.Conv2d(med_planes, inplanes, kernel_size=1, stride=1, padding=0, bias=False)
self.bn3 = norm_layer(inplanes)
self.act3 = act_layer(inplace=True)
self.drop_block = drop_block
self.drop_path = drop_path
def zero_init_last_bn(self):
nn.init.zeros_(self.bn3.weight)
def forward(self, x):
residual = x
x = self.conv1(x)
x = self.bn1(x)
if self.drop_block is not None:
x = self.drop_block(x)
x = self.act1(x)
x = self.conv2(x)
x = self.bn2(x)
if self.drop_block is not None:
x = self.drop_block(x)
x = self.act2(x)
x = self.conv3(x)
x = self.bn3(x)
if self.drop_block is not None:
x = self.drop_block(x)
if self.drop_path is not None:
x = self.drop_path(x)
x += residual
x = self.act3(x)
return x
class ConvTransBlock(nn.Module):
"""
Basic module for ConvTransformer, keep feature maps for CNN block and patch embeddings for transformer encoder block
"""
def __init__(self, inplanes, outplanes, res_conv, stride, dw_stride, embed_dim, num_heads=12, mlp_ratio=4.,
qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0., drop_path_rate=0.,
last_fusion=False, num_med_block=0, groups=1):
super(ConvTransBlock, self).__init__()
expansion = 4
# ConvBlock
self.cnn_block = ConvBlock(inplanes=inplanes, outplanes=outplanes, res_conv=res_conv, stride=stride,
groups=groups)
if last_fusion:
self.fusion_block = ConvBlock(inplanes=outplanes, outplanes=outplanes, stride=2, res_conv=True,
groups=groups)
else:
self.fusion_block = ConvBlock(inplanes=outplanes, outplanes=outplanes, groups=groups)
# identity block
if num_med_block > 0:
self.med_block = []
for i in range(num_med_block):
self.med_block.append(Med_ConvBlock(inplanes=outplanes, groups=groups))
self.med_block = nn.ModuleList(self.med_block) # nn.ModuleList
# FCU
self.squeeze_block = FCUDown(inplanes=outplanes // expansion, outplanes=embed_dim, dw_stride=dw_stride)
self.expand_block = FCUUp(inplanes=embed_dim, outplanes=outplanes // expansion, up_stride=dw_stride)
# Transformer Encoder block
self.trans_block = Block(
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=drop_path_rate)
self.dw_stride = dw_stride
self.embed_dim = embed_dim
self.num_med_block = num_med_block
self.last_fusion = last_fusion
def forward(self, x, x_t):
x, x2 = self.cnn_block(x)
_, _, H, W = x2.shape
x_st = self.squeeze_block(x2, x_t)
x_t = self.trans_block(x_st + x_t)
if self.num_med_block > 0:
for m in self.med_block:
x = m(x)
x_t_r = self.expand_block(x_t, H // self.dw_stride, W // self.dw_stride)
x = self.fusion_block(x, x_t_r, return_x_2=False)
return x, x_t
class Conformer(nn.Module):
def __init__(self, patch_size=16, in_chans=3, num_classes=1000, base_channel=64, channel_ratio=4, num_med_block=0,
embed_dim=768, depth=12, num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None,
drop_rate=0., attn_drop_rate=0., drop_path_rate=0.):
# Transformer
super().__init__()
self.num_classes = num_classes
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
assert depth % 3 == 0
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
self.trans_dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
# Classifier head
self.trans_norm = nn.LayerNorm(embed_dim)
self.trans_cls_head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
self.pooling = nn.AdaptiveAvgPool2d(1)
self.conv_cls_head = nn.Linear(int(256 * channel_ratio), num_classes)
self.cls_head = nn.Linear(int(2 * num_classes), num_classes)
# Stem stage: get the feature maps by conv block (copied form resnet.py)
self.conv1 = nn.Conv2d(in_chans, 64, kernel_size=7, stride=2, padding=3, bias=False) # 1 / 2 [112, 112]
self.bn1 = nn.BatchNorm2d(64)
self.act1 = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) # 1 / 4 [56, 56]
# 1 stage
stage_1_channel = int(base_channel * channel_ratio)
trans_dw_stride = patch_size // 4
self.conv_1 = ConvBlock(inplanes=64, outplanes=stage_1_channel, res_conv=True, stride=1)
# embedding
self.trans_patch_conv = nn.Conv2d(64, embed_dim, kernel_size=trans_dw_stride, stride=trans_dw_stride, padding=0)
self.trans_1 = Block(dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias,
qk_scale=qk_scale, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=self.trans_dpr[0],
)
# 2~4 stage
init_stage = 2
fin_stage = depth // 3 + 1
for i in range(init_stage, fin_stage):
self.add_module('conv_trans_' + str(i),
ConvTransBlock(
stage_1_channel, stage_1_channel, False, 1, dw_stride=trans_dw_stride,
embed_dim=embed_dim,
num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
drop_rate=drop_rate, attn_drop_rate=attn_drop_rate,
drop_path_rate=self.trans_dpr[i - 1],
num_med_block=num_med_block
)
)
stage_2_channel = int(base_channel * channel_ratio * 2)
# 5~8 stage
init_stage = fin_stage # 5
fin_stage = fin_stage + depth // 3 # 9
for i in range(init_stage, fin_stage):
s = 2 if i == init_stage else 1
in_channel = stage_1_channel if i == init_stage else stage_2_channel
res_conv = True if i == init_stage else False
self.add_module('conv_trans_' + str(i),
ConvTransBlock(
in_channel, stage_2_channel, res_conv, s, dw_stride=trans_dw_stride // 2,
embed_dim=embed_dim,
num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
drop_rate=drop_rate, attn_drop_rate=attn_drop_rate,
drop_path_rate=self.trans_dpr[i - 1],
num_med_block=num_med_block
)
)
stage_3_channel = int(base_channel * channel_ratio * 2 * 2)
# 9~12 stage
init_stage = fin_stage # 9
fin_stage = fin_stage + depth // 3 # 13
for i in range(init_stage, fin_stage):
s = 2 if i == init_stage else 1
in_channel = stage_2_channel if i == init_stage else stage_3_channel
res_conv = True if i == init_stage else False
last_fusion = True if i == depth else False
self.add_module('conv_trans_' + str(i),
ConvTransBlock(
in_channel, stage_3_channel, res_conv, s, dw_stride=trans_dw_stride // 4,
embed_dim=embed_dim,
num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
drop_rate=drop_rate, attn_drop_rate=attn_drop_rate,
drop_path_rate=self.trans_dpr[i - 1],
num_med_block=num_med_block, last_fusion=last_fusion
)
)
self.fin_stage = fin_stage
trunc_normal_(self.cls_token, std=.02)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
elif isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1.)
nn.init.constant_(m.bias, 0.)
elif isinstance(m, nn.GroupNorm):
nn.init.constant_(m.weight, 1.)
nn.init.constant_(m.bias, 0.)
@torch.jit.ignore
def no_weight_decay(self):
return {'cls_token'}
def forward(self, x):
B = x.shape[0]
cls_tokens = self.cls_token.expand(B, -1, -1)
# pdb.set_trace()
# stem stage [N, 3, 224, 224] -> [N, 64, 56, 56]
x_base = self.maxpool(self.act1(self.bn1(self.conv1(x))))
# 1 stage
x = self.conv_1(x_base, return_x_2=False)
# embedding: [N, 64, 56, 56] -> [N, d, p, p] -> [N, d, p^2] -> [N, p^2, d] -> [N, p^2 + 1, d]
x_t = self.trans_patch_conv(x_base).flatten(2).transpose(1, 2)
x_t = torch.cat([cls_tokens, x_t], dim=1)
x_t = self.trans_1(x_t)
# 2 ~ final
for i in range(2, self.fin_stage):
x, x_t = eval('self.conv_trans_' + str(i))(x, x_t)
# conv classification
x_p = self.pooling(x).flatten(1)
conv_cls = self.conv_cls_head(x_p)
# trans classification
x_t = self.trans_norm(x_t)
tran_cls = self.trans_cls_head(x_t[:, 0])
# 加一个类别投票
cls = torch.cat([conv_cls, tran_cls], dim=1)
cls = self.cls_head(cls)
return cls
# return [conv_cls, tran_cls]