Mountchicken's picture
Upload 704 files
9bf4bd7
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Tuple, Union
import torch
import torch.nn as nn
from mmocr.registry import MODELS
@MODELS.register_module()
class Maxpool2d(nn.Module):
"""A wrapper around nn.Maxpool2d().
Args:
kernel_size (int or tuple(int)): Kernel size for max pooling layer
stride (int or tuple(int)): Stride for max pooling layer
padding (int or tuple(int)): Padding for pooling layer
"""
def __init__(self,
kernel_size: Union[int, Tuple[int]],
stride: Union[int, Tuple[int]],
padding: Union[int, Tuple[int]] = 0,
**kwargs) -> None:
super().__init__()
self.model = nn.MaxPool2d(kernel_size, stride, padding)
def forward(self, x) -> torch.Tensor:
"""Forward function.
Args:
x (Tensor): Input feature map.
Returns:
Tensor: Output tensor after Maxpooling layer.
"""
return self.model(x)
@MODELS.register_module()
class GCAModule(nn.Module):
"""GCAModule in MASTER.
Args:
in_channels (int): Channels of input tensor.
ratio (float): Scale ratio of in_channels.
n_head (int): Numbers of attention head.
pooling_type (str): Spatial pooling type. Options are [``avg``,
``att``].
scale_attn (bool): Whether to scale the attention map. Defaults to
False.
fusion_type (str): Fusion type of input and context. Options are
[``channel_add``, ``channel_mul``, ``channel_concat``].
"""
def __init__(self,
in_channels: int,
ratio: float,
n_head: int,
pooling_type: str = 'att',
scale_attn: bool = False,
fusion_type: str = 'channel_add',
**kwargs) -> None:
super().__init__()
assert pooling_type in ['avg', 'att']
assert fusion_type in ['channel_add', 'channel_mul', 'channel_concat']
# in_channels must be divided by headers evenly
assert in_channels % n_head == 0 and in_channels >= 8
self.n_head = n_head
self.in_channels = in_channels
self.ratio = ratio
self.planes = int(in_channels * ratio)
self.pooling_type = pooling_type
self.fusion_type = fusion_type
self.scale_attn = scale_attn
self.single_header_inplanes = int(in_channels / n_head)
if pooling_type == 'att':
self.conv_mask = nn.Conv2d(
self.single_header_inplanes, 1, kernel_size=1)
self.softmax = nn.Softmax(dim=2)
else:
self.avg_pool = nn.AdaptiveAvgPool2d(1)
if fusion_type == 'channel_add':
self.channel_add_conv = nn.Sequential(
nn.Conv2d(self.in_channels, self.planes, kernel_size=1),
nn.LayerNorm([self.planes, 1, 1]), nn.ReLU(inplace=True),
nn.Conv2d(self.planes, self.in_channels, kernel_size=1))
elif fusion_type == 'channel_concat':
self.channel_concat_conv = nn.Sequential(
nn.Conv2d(self.in_channels, self.planes, kernel_size=1),
nn.LayerNorm([self.planes, 1, 1]), nn.ReLU(inplace=True),
nn.Conv2d(self.planes, self.in_channels, kernel_size=1))
# for concat
self.cat_conv = nn.Conv2d(
2 * self.in_channels, self.in_channels, kernel_size=1)
elif fusion_type == 'channel_mul':
self.channel_mul_conv = nn.Sequential(
nn.Conv2d(self.in_channels, self.planes, kernel_size=1),
nn.LayerNorm([self.planes, 1, 1]), nn.ReLU(inplace=True),
nn.Conv2d(self.planes, self.in_channels, kernel_size=1))
def spatial_pool(self, x: torch.Tensor) -> torch.Tensor:
"""Spatial pooling function.
Args:
x (Tensor): Input feature map.
Returns:
Tensor: Output tensor after spatial pooling.
"""
batch, channel, height, width = x.size()
if self.pooling_type == 'att':
# [N*headers, C', H , W] C = headers * C'
x = x.view(batch * self.n_head, self.single_header_inplanes,
height, width)
input_x = x
# [N*headers, C', H * W] C = headers * C'
input_x = input_x.view(batch * self.n_head,
self.single_header_inplanes, height * width)
# [N*headers, 1, C', H * W]
input_x = input_x.unsqueeze(1)
# [N*headers, 1, H, W]
context_mask = self.conv_mask(x)
# [N*headers, 1, H * W]
context_mask = context_mask.view(batch * self.n_head, 1,
height * width)
# scale variance
if self.scale_attn and self.n_head > 1:
context_mask = context_mask / \
torch.sqrt(self.single_header_inplanes)
# [N*headers, 1, H * W]
context_mask = self.softmax(context_mask)
# [N*headers, 1, H * W, 1]
context_mask = context_mask.unsqueeze(-1)
# [N*headers, 1, C', 1] =
# [N*headers, 1, C', H * W] * [N*headers, 1, H * W, 1]
context = torch.matmul(input_x, context_mask)
# [N, headers * C', 1, 1]
context = context.view(batch,
self.n_head * self.single_header_inplanes,
1, 1)
else:
# [N, C, 1, 1]
context = self.avg_pool(x)
return context
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Forward function.
Args:
x (Tensor): Input feature map.
Returns:
Tensor: Output tensor after GCAModule.
"""
# [N, C, 1, 1]
context = self.spatial_pool(x)
out = x
if self.fusion_type == 'channel_mul':
# [N, C, 1, 1]
channel_mul_term = torch.sigmoid(self.channel_mul_conv(context))
out = out * channel_mul_term
elif self.fusion_type == 'channel_add':
# [N, C, 1, 1]
channel_add_term = self.channel_add_conv(context)
out = out + channel_add_term
else:
# [N, C, 1, 1]
channel_concat_term = self.channel_concat_conv(context)
# use concat
_, C1, _, _ = channel_concat_term.shape
N, C2, H, W = out.shape
out = torch.cat([out,
channel_concat_term.expand(-1, -1, H, W)],
dim=1)
out = self.cat_conv(out)
out = nn.functional.layer_norm(out, [self.in_channels, H, W])
out = nn.functional.relu(out)
return out