Spaces:
Sleeping
Sleeping
# Copyright (c) OpenMMLab. All rights reserved. | |
import itertools | |
from typing import Dict, List, Optional, Tuple, Union | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from mmcv.cnn import ConvModule | |
from mmocr.registry import MODELS | |
from .base import BasePreprocessor | |
class TPStransform(nn.Module): | |
"""Implement TPS transform. | |
This was partially adapted from https://github.com/ayumiymk/aster.pytorch | |
Args: | |
output_image_size (tuple[int, int]): The size of the output image. | |
Defaults to (32, 128). | |
num_control_points (int): The number of control points. Defaults to 20. | |
margins (tuple[float, float]): The margins for control points to the | |
top and down side of the image. Defaults to [0.05, 0.05]. | |
""" | |
def __init__(self, | |
output_image_size: Tuple[int, int] = (32, 100), | |
num_control_points: int = 20, | |
margins: Tuple[float, float] = [0.05, 0.05]) -> None: | |
super().__init__() | |
self.output_image_size = output_image_size | |
self.num_control_points = num_control_points | |
self.margins = margins | |
self.target_height, self.target_width = output_image_size | |
# build output control points | |
target_control_points = self._build_output_control_points( | |
num_control_points, margins) | |
N = num_control_points | |
# create padded kernel matrix | |
forward_kernel = torch.zeros(N + 3, N + 3) | |
target_control_partial_repr = self._compute_partial_repr( | |
target_control_points, target_control_points) | |
forward_kernel[:N, :N].copy_(target_control_partial_repr) | |
forward_kernel[:N, -3].fill_(1) | |
forward_kernel[-3, :N].fill_(1) | |
forward_kernel[:N, -2:].copy_(target_control_points) | |
forward_kernel[-2:, :N].copy_(target_control_points.transpose(0, 1)) | |
# compute inverse matrix | |
inverse_kernel = torch.inverse(forward_kernel).contiguous() | |
# create target coordinate matrix | |
HW = self.target_height * self.target_width | |
tgt_coord = list( | |
itertools.product( | |
range(self.target_height), range(self.target_width))) | |
tgt_coord = torch.Tensor(tgt_coord) | |
Y, X = tgt_coord.split(1, dim=1) | |
Y = Y / (self.target_height - 1) | |
X = X / (self.target_width - 1) | |
tgt_coord = torch.cat([X, Y], dim=1) | |
tgt_coord_partial_repr = self._compute_partial_repr( | |
tgt_coord, target_control_points) | |
tgt_coord_repr = torch.cat( | |
[tgt_coord_partial_repr, | |
torch.ones(HW, 1), tgt_coord], dim=1) | |
# register precomputed matrices | |
self.register_buffer('inverse_kernel', inverse_kernel) | |
self.register_buffer('padding_matrix', torch.zeros(3, 2)) | |
self.register_buffer('target_coordinate_repr', tgt_coord_repr) | |
self.register_buffer('target_control_points', target_control_points) | |
def forward(self, input: torch.Tensor, | |
source_control_points: torch.Tensor) -> torch.Tensor: | |
"""Forward function of the TPS block. | |
Args: | |
input (Tensor): The input image. | |
source_control_points (Tensor): The control points of the source | |
image of shape (N, self.num_control_points, 2). | |
Returns: | |
Tensor: The output image after TPS transform. | |
""" | |
assert source_control_points.ndimension() == 3 | |
assert source_control_points.size(1) == self.num_control_points | |
assert source_control_points.size(2) == 2 | |
batch_size = source_control_points.size(0) | |
Y = torch.cat([ | |
source_control_points, | |
self.padding_matrix.expand(batch_size, 3, 2) | |
], 1) | |
mapping_matrix = torch.matmul(self.inverse_kernel, Y) | |
source_coordinate = torch.matmul(self.target_coordinate_repr, | |
mapping_matrix) | |
grid = source_coordinate.view(-1, self.target_height, | |
self.target_width, 2) | |
grid = torch.clamp(grid, 0, 1) | |
grid = 2.0 * grid - 1.0 | |
output_maps = self._grid_sample(input, grid, canvas=None) | |
return output_maps | |
def _grid_sample(self, | |
input: torch.Tensor, | |
grid: torch.Tensor, | |
canvas: Optional[torch.Tensor] = None) -> torch.Tensor: | |
"""Sample the input image at the given grid. | |
Args: | |
input (Tensor): The input image. | |
grid (Tensor): The grid to sample the input image. | |
canvas (Optional[Tensor]): The canvas to store the output image. | |
Returns: | |
Tensor: The sampled image. | |
""" | |
output = F.grid_sample(input, grid, align_corners=True) | |
if canvas is None: | |
return output | |
else: | |
input_mask = input.data.new(input.size()).fill_(1) | |
output_mask = F.grid_sample(input_mask, grid, align_corners=True) | |
padded_output = output * output_mask + canvas * (1 - output_mask) | |
return padded_output | |
def _compute_partial_repr(self, input_points: torch.Tensor, | |
control_points: torch.Tensor) -> torch.Tensor: | |
"""Compute the partial representation matrix. | |
Args: | |
input_points (Tensor): The input points. | |
control_points (Tensor): The control points. | |
Returns: | |
Tensor: The partial representation matrix. | |
""" | |
N = input_points.size(0) | |
M = control_points.size(0) | |
pairwise_diff = input_points.view(N, 1, 2) - control_points.view( | |
1, M, 2) | |
pairwise_diff_square = pairwise_diff * pairwise_diff | |
pairwise_dist = pairwise_diff_square[:, :, | |
0] + pairwise_diff_square[:, :, 1] | |
repr_matrix = 0.5 * pairwise_dist * torch.log(pairwise_dist) | |
mask = repr_matrix != repr_matrix | |
repr_matrix.masked_fill_(mask, 0) | |
return repr_matrix | |
# output_ctrl_pts are specified, according to our task. | |
def _build_output_control_points(self, num_control_points: torch.Tensor, | |
margins: Tuple[float, | |
float]) -> torch.Tensor: | |
"""Build the output control points. | |
The output points will be fix at | |
top and down side of the image. | |
Args: | |
num_control_points (Tensor): The number of control points. | |
margins (Tuple[float, float]): The margins for control points to | |
the top and down side of the image. | |
Returns: | |
Tensor: The output control points. | |
""" | |
margin_x, margin_y = margins | |
num_ctrl_pts_per_side = num_control_points // 2 | |
ctrl_pts_x = np.linspace(margin_x, 1.0 - margin_x, | |
num_ctrl_pts_per_side) | |
ctrl_pts_y_top = np.ones(num_ctrl_pts_per_side) * margin_y | |
ctrl_pts_y_bottom = np.ones(num_ctrl_pts_per_side) * (1.0 - margin_y) | |
ctrl_pts_top = np.stack([ctrl_pts_x, ctrl_pts_y_top], axis=1) | |
ctrl_pts_bottom = np.stack([ctrl_pts_x, ctrl_pts_y_bottom], axis=1) | |
output_ctrl_pts_arr = np.concatenate([ctrl_pts_top, ctrl_pts_bottom], | |
axis=0) | |
output_ctrl_pts = torch.Tensor(output_ctrl_pts_arr) | |
return output_ctrl_pts | |
class STN(BasePreprocessor): | |
"""Implement STN module in ASTER: An Attentional Scene Text Recognizer with | |
Flexible Rectification | |
(https://ieeexplore.ieee.org/abstract/document/8395027/) | |
Args: | |
in_channels (int): The number of input channels. | |
resized_image_size (Tuple[int, int]): The resized image size. The input | |
image will be downsampled to have a better recitified result. | |
output_image_size: The size of the output image for TPS. Defaults to | |
(32, 100). | |
num_control_points: The number of control points. Defaults to 20. | |
margins: The margins for control points to the top and down side of the | |
image for TPS. Defaults to [0.05, 0.05]. | |
""" | |
def __init__(self, | |
in_channels: int, | |
resized_image_size: Tuple[int, int] = (32, 64), | |
output_image_size: Tuple[int, int] = (32, 100), | |
num_control_points: int = 20, | |
margins: Tuple[float, float] = [0.05, 0.05], | |
init_cfg: Optional[Union[Dict, List[Dict]]] = [ | |
dict(type='Xavier', layer='Conv2d'), | |
dict(type='Constant', val=1, layer='BatchNorm2d'), | |
]): | |
super().__init__(init_cfg=init_cfg) | |
self.resized_image_size = resized_image_size | |
self.num_control_points = num_control_points | |
self.tps = TPStransform(output_image_size, num_control_points, margins) | |
self.stn_convnet = nn.Sequential( | |
ConvModule(in_channels, 32, 3, 1, 1, norm_cfg=dict(type='BN')), | |
nn.MaxPool2d(kernel_size=2, stride=2), | |
ConvModule(32, 64, 3, 1, 1, norm_cfg=dict(type='BN')), | |
nn.MaxPool2d(kernel_size=2, stride=2), | |
ConvModule(64, 128, 3, 1, 1, norm_cfg=dict(type='BN')), | |
nn.MaxPool2d(kernel_size=2, stride=2), | |
ConvModule(128, 256, 3, 1, 1, norm_cfg=dict(type='BN')), | |
nn.MaxPool2d(kernel_size=2, stride=2), | |
ConvModule(256, 256, 3, 1, 1, norm_cfg=dict(type='BN')), | |
nn.MaxPool2d(kernel_size=2, stride=2), | |
ConvModule(256, 256, 3, 1, 1, norm_cfg=dict(type='BN')), | |
) | |
self.stn_fc1 = nn.Sequential( | |
nn.Linear(2 * 256, 512), nn.BatchNorm1d(512), | |
nn.ReLU(inplace=True)) | |
self.stn_fc2 = nn.Linear(512, num_control_points * 2) | |
self.init_stn(self.stn_fc2) | |
def init_stn(self, stn_fc2: nn.Linear) -> None: | |
"""Initialize the output linear layer of stn, so that the initial | |
source point will be at the top and down side of the image, which will | |
help to optimize. | |
Args: | |
stn_fc2 (nn.Linear): The output linear layer of stn. | |
""" | |
margin = 0.01 | |
sampling_num_per_side = int(self.num_control_points / 2) | |
ctrl_pts_x = np.linspace(margin, 1. - margin, sampling_num_per_side) | |
ctrl_pts_y_top = np.ones(sampling_num_per_side) * margin | |
ctrl_pts_y_bottom = np.ones(sampling_num_per_side) * (1 - margin) | |
ctrl_pts_top = np.stack([ctrl_pts_x, ctrl_pts_y_top], axis=1) | |
ctrl_pts_bottom = np.stack([ctrl_pts_x, ctrl_pts_y_bottom], axis=1) | |
ctrl_points = np.concatenate([ctrl_pts_top, ctrl_pts_bottom], | |
axis=0).astype(np.float32) | |
stn_fc2.weight.data.zero_() | |
stn_fc2.bias.data = torch.Tensor(ctrl_points).view(-1) | |
def forward(self, img: torch.Tensor) -> torch.Tensor: | |
"""Forward function of STN. | |
Args: | |
img (Tensor): The input image tensor. | |
Returns: | |
Tensor: The rectified image tensor. | |
""" | |
resize_img = F.interpolate( | |
img, self.resized_image_size, mode='bilinear', align_corners=True) | |
points = self.stn_convnet(resize_img) | |
batch_size, _, _, _ = points.size() | |
points = points.view(batch_size, -1) | |
img_feat = self.stn_fc1(points) | |
points = self.stn_fc2(0.1 * img_feat) | |
points = points.view(-1, self.num_control_points, 2) | |
transformd_image = self.tps(img, points) | |
return transformd_image | |