Mountchicken's picture
Upload 704 files
9bf4bd7
raw
history blame
11.7 kB
# Copyright (c) OpenMMLab. All rights reserved.
import itertools
from typing import Dict, List, Optional, Tuple, Union
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import ConvModule
from mmocr.registry import MODELS
from .base import BasePreprocessor
class TPStransform(nn.Module):
"""Implement TPS transform.
This was partially adapted from https://github.com/ayumiymk/aster.pytorch
Args:
output_image_size (tuple[int, int]): The size of the output image.
Defaults to (32, 128).
num_control_points (int): The number of control points. Defaults to 20.
margins (tuple[float, float]): The margins for control points to the
top and down side of the image. Defaults to [0.05, 0.05].
"""
def __init__(self,
output_image_size: Tuple[int, int] = (32, 100),
num_control_points: int = 20,
margins: Tuple[float, float] = [0.05, 0.05]) -> None:
super().__init__()
self.output_image_size = output_image_size
self.num_control_points = num_control_points
self.margins = margins
self.target_height, self.target_width = output_image_size
# build output control points
target_control_points = self._build_output_control_points(
num_control_points, margins)
N = num_control_points
# create padded kernel matrix
forward_kernel = torch.zeros(N + 3, N + 3)
target_control_partial_repr = self._compute_partial_repr(
target_control_points, target_control_points)
forward_kernel[:N, :N].copy_(target_control_partial_repr)
forward_kernel[:N, -3].fill_(1)
forward_kernel[-3, :N].fill_(1)
forward_kernel[:N, -2:].copy_(target_control_points)
forward_kernel[-2:, :N].copy_(target_control_points.transpose(0, 1))
# compute inverse matrix
inverse_kernel = torch.inverse(forward_kernel).contiguous()
# create target coordinate matrix
HW = self.target_height * self.target_width
tgt_coord = list(
itertools.product(
range(self.target_height), range(self.target_width)))
tgt_coord = torch.Tensor(tgt_coord)
Y, X = tgt_coord.split(1, dim=1)
Y = Y / (self.target_height - 1)
X = X / (self.target_width - 1)
tgt_coord = torch.cat([X, Y], dim=1)
tgt_coord_partial_repr = self._compute_partial_repr(
tgt_coord, target_control_points)
tgt_coord_repr = torch.cat(
[tgt_coord_partial_repr,
torch.ones(HW, 1), tgt_coord], dim=1)
# register precomputed matrices
self.register_buffer('inverse_kernel', inverse_kernel)
self.register_buffer('padding_matrix', torch.zeros(3, 2))
self.register_buffer('target_coordinate_repr', tgt_coord_repr)
self.register_buffer('target_control_points', target_control_points)
def forward(self, input: torch.Tensor,
source_control_points: torch.Tensor) -> torch.Tensor:
"""Forward function of the TPS block.
Args:
input (Tensor): The input image.
source_control_points (Tensor): The control points of the source
image of shape (N, self.num_control_points, 2).
Returns:
Tensor: The output image after TPS transform.
"""
assert source_control_points.ndimension() == 3
assert source_control_points.size(1) == self.num_control_points
assert source_control_points.size(2) == 2
batch_size = source_control_points.size(0)
Y = torch.cat([
source_control_points,
self.padding_matrix.expand(batch_size, 3, 2)
], 1)
mapping_matrix = torch.matmul(self.inverse_kernel, Y)
source_coordinate = torch.matmul(self.target_coordinate_repr,
mapping_matrix)
grid = source_coordinate.view(-1, self.target_height,
self.target_width, 2)
grid = torch.clamp(grid, 0, 1)
grid = 2.0 * grid - 1.0
output_maps = self._grid_sample(input, grid, canvas=None)
return output_maps
def _grid_sample(self,
input: torch.Tensor,
grid: torch.Tensor,
canvas: Optional[torch.Tensor] = None) -> torch.Tensor:
"""Sample the input image at the given grid.
Args:
input (Tensor): The input image.
grid (Tensor): The grid to sample the input image.
canvas (Optional[Tensor]): The canvas to store the output image.
Returns:
Tensor: The sampled image.
"""
output = F.grid_sample(input, grid, align_corners=True)
if canvas is None:
return output
else:
input_mask = input.data.new(input.size()).fill_(1)
output_mask = F.grid_sample(input_mask, grid, align_corners=True)
padded_output = output * output_mask + canvas * (1 - output_mask)
return padded_output
def _compute_partial_repr(self, input_points: torch.Tensor,
control_points: torch.Tensor) -> torch.Tensor:
"""Compute the partial representation matrix.
Args:
input_points (Tensor): The input points.
control_points (Tensor): The control points.
Returns:
Tensor: The partial representation matrix.
"""
N = input_points.size(0)
M = control_points.size(0)
pairwise_diff = input_points.view(N, 1, 2) - control_points.view(
1, M, 2)
pairwise_diff_square = pairwise_diff * pairwise_diff
pairwise_dist = pairwise_diff_square[:, :,
0] + pairwise_diff_square[:, :, 1]
repr_matrix = 0.5 * pairwise_dist * torch.log(pairwise_dist)
mask = repr_matrix != repr_matrix
repr_matrix.masked_fill_(mask, 0)
return repr_matrix
# output_ctrl_pts are specified, according to our task.
def _build_output_control_points(self, num_control_points: torch.Tensor,
margins: Tuple[float,
float]) -> torch.Tensor:
"""Build the output control points.
The output points will be fix at
top and down side of the image.
Args:
num_control_points (Tensor): The number of control points.
margins (Tuple[float, float]): The margins for control points to
the top and down side of the image.
Returns:
Tensor: The output control points.
"""
margin_x, margin_y = margins
num_ctrl_pts_per_side = num_control_points // 2
ctrl_pts_x = np.linspace(margin_x, 1.0 - margin_x,
num_ctrl_pts_per_side)
ctrl_pts_y_top = np.ones(num_ctrl_pts_per_side) * margin_y
ctrl_pts_y_bottom = np.ones(num_ctrl_pts_per_side) * (1.0 - margin_y)
ctrl_pts_top = np.stack([ctrl_pts_x, ctrl_pts_y_top], axis=1)
ctrl_pts_bottom = np.stack([ctrl_pts_x, ctrl_pts_y_bottom], axis=1)
output_ctrl_pts_arr = np.concatenate([ctrl_pts_top, ctrl_pts_bottom],
axis=0)
output_ctrl_pts = torch.Tensor(output_ctrl_pts_arr)
return output_ctrl_pts
@MODELS.register_module()
class STN(BasePreprocessor):
"""Implement STN module in ASTER: An Attentional Scene Text Recognizer with
Flexible Rectification
(https://ieeexplore.ieee.org/abstract/document/8395027/)
Args:
in_channels (int): The number of input channels.
resized_image_size (Tuple[int, int]): The resized image size. The input
image will be downsampled to have a better recitified result.
output_image_size: The size of the output image for TPS. Defaults to
(32, 100).
num_control_points: The number of control points. Defaults to 20.
margins: The margins for control points to the top and down side of the
image for TPS. Defaults to [0.05, 0.05].
"""
def __init__(self,
in_channels: int,
resized_image_size: Tuple[int, int] = (32, 64),
output_image_size: Tuple[int, int] = (32, 100),
num_control_points: int = 20,
margins: Tuple[float, float] = [0.05, 0.05],
init_cfg: Optional[Union[Dict, List[Dict]]] = [
dict(type='Xavier', layer='Conv2d'),
dict(type='Constant', val=1, layer='BatchNorm2d'),
]):
super().__init__(init_cfg=init_cfg)
self.resized_image_size = resized_image_size
self.num_control_points = num_control_points
self.tps = TPStransform(output_image_size, num_control_points, margins)
self.stn_convnet = nn.Sequential(
ConvModule(in_channels, 32, 3, 1, 1, norm_cfg=dict(type='BN')),
nn.MaxPool2d(kernel_size=2, stride=2),
ConvModule(32, 64, 3, 1, 1, norm_cfg=dict(type='BN')),
nn.MaxPool2d(kernel_size=2, stride=2),
ConvModule(64, 128, 3, 1, 1, norm_cfg=dict(type='BN')),
nn.MaxPool2d(kernel_size=2, stride=2),
ConvModule(128, 256, 3, 1, 1, norm_cfg=dict(type='BN')),
nn.MaxPool2d(kernel_size=2, stride=2),
ConvModule(256, 256, 3, 1, 1, norm_cfg=dict(type='BN')),
nn.MaxPool2d(kernel_size=2, stride=2),
ConvModule(256, 256, 3, 1, 1, norm_cfg=dict(type='BN')),
)
self.stn_fc1 = nn.Sequential(
nn.Linear(2 * 256, 512), nn.BatchNorm1d(512),
nn.ReLU(inplace=True))
self.stn_fc2 = nn.Linear(512, num_control_points * 2)
self.init_stn(self.stn_fc2)
def init_stn(self, stn_fc2: nn.Linear) -> None:
"""Initialize the output linear layer of stn, so that the initial
source point will be at the top and down side of the image, which will
help to optimize.
Args:
stn_fc2 (nn.Linear): The output linear layer of stn.
"""
margin = 0.01
sampling_num_per_side = int(self.num_control_points / 2)
ctrl_pts_x = np.linspace(margin, 1. - margin, sampling_num_per_side)
ctrl_pts_y_top = np.ones(sampling_num_per_side) * margin
ctrl_pts_y_bottom = np.ones(sampling_num_per_side) * (1 - margin)
ctrl_pts_top = np.stack([ctrl_pts_x, ctrl_pts_y_top], axis=1)
ctrl_pts_bottom = np.stack([ctrl_pts_x, ctrl_pts_y_bottom], axis=1)
ctrl_points = np.concatenate([ctrl_pts_top, ctrl_pts_bottom],
axis=0).astype(np.float32)
stn_fc2.weight.data.zero_()
stn_fc2.bias.data = torch.Tensor(ctrl_points).view(-1)
def forward(self, img: torch.Tensor) -> torch.Tensor:
"""Forward function of STN.
Args:
img (Tensor): The input image tensor.
Returns:
Tensor: The rectified image tensor.
"""
resize_img = F.interpolate(
img, self.resized_image_size, mode='bilinear', align_corners=True)
points = self.stn_convnet(resize_img)
batch_size, _, _, _ = points.size()
points = points.view(batch_size, -1)
img_feat = self.stn_fc1(points)
points = self.stn_fc2(0.1 * img_feat)
points = points.view(-1, self.num_control_points, 2)
transformd_image = self.tps(img, points)
return transformd_image