Spaces:
Sleeping
Sleeping
File size: 11,713 Bytes
9bf4bd7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 |
# Copyright (c) OpenMMLab. All rights reserved.
import itertools
from typing import Dict, List, Optional, Tuple, Union
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import ConvModule
from mmocr.registry import MODELS
from .base import BasePreprocessor
class TPStransform(nn.Module):
"""Implement TPS transform.
This was partially adapted from https://github.com/ayumiymk/aster.pytorch
Args:
output_image_size (tuple[int, int]): The size of the output image.
Defaults to (32, 128).
num_control_points (int): The number of control points. Defaults to 20.
margins (tuple[float, float]): The margins for control points to the
top and down side of the image. Defaults to [0.05, 0.05].
"""
def __init__(self,
output_image_size: Tuple[int, int] = (32, 100),
num_control_points: int = 20,
margins: Tuple[float, float] = [0.05, 0.05]) -> None:
super().__init__()
self.output_image_size = output_image_size
self.num_control_points = num_control_points
self.margins = margins
self.target_height, self.target_width = output_image_size
# build output control points
target_control_points = self._build_output_control_points(
num_control_points, margins)
N = num_control_points
# create padded kernel matrix
forward_kernel = torch.zeros(N + 3, N + 3)
target_control_partial_repr = self._compute_partial_repr(
target_control_points, target_control_points)
forward_kernel[:N, :N].copy_(target_control_partial_repr)
forward_kernel[:N, -3].fill_(1)
forward_kernel[-3, :N].fill_(1)
forward_kernel[:N, -2:].copy_(target_control_points)
forward_kernel[-2:, :N].copy_(target_control_points.transpose(0, 1))
# compute inverse matrix
inverse_kernel = torch.inverse(forward_kernel).contiguous()
# create target coordinate matrix
HW = self.target_height * self.target_width
tgt_coord = list(
itertools.product(
range(self.target_height), range(self.target_width)))
tgt_coord = torch.Tensor(tgt_coord)
Y, X = tgt_coord.split(1, dim=1)
Y = Y / (self.target_height - 1)
X = X / (self.target_width - 1)
tgt_coord = torch.cat([X, Y], dim=1)
tgt_coord_partial_repr = self._compute_partial_repr(
tgt_coord, target_control_points)
tgt_coord_repr = torch.cat(
[tgt_coord_partial_repr,
torch.ones(HW, 1), tgt_coord], dim=1)
# register precomputed matrices
self.register_buffer('inverse_kernel', inverse_kernel)
self.register_buffer('padding_matrix', torch.zeros(3, 2))
self.register_buffer('target_coordinate_repr', tgt_coord_repr)
self.register_buffer('target_control_points', target_control_points)
def forward(self, input: torch.Tensor,
source_control_points: torch.Tensor) -> torch.Tensor:
"""Forward function of the TPS block.
Args:
input (Tensor): The input image.
source_control_points (Tensor): The control points of the source
image of shape (N, self.num_control_points, 2).
Returns:
Tensor: The output image after TPS transform.
"""
assert source_control_points.ndimension() == 3
assert source_control_points.size(1) == self.num_control_points
assert source_control_points.size(2) == 2
batch_size = source_control_points.size(0)
Y = torch.cat([
source_control_points,
self.padding_matrix.expand(batch_size, 3, 2)
], 1)
mapping_matrix = torch.matmul(self.inverse_kernel, Y)
source_coordinate = torch.matmul(self.target_coordinate_repr,
mapping_matrix)
grid = source_coordinate.view(-1, self.target_height,
self.target_width, 2)
grid = torch.clamp(grid, 0, 1)
grid = 2.0 * grid - 1.0
output_maps = self._grid_sample(input, grid, canvas=None)
return output_maps
def _grid_sample(self,
input: torch.Tensor,
grid: torch.Tensor,
canvas: Optional[torch.Tensor] = None) -> torch.Tensor:
"""Sample the input image at the given grid.
Args:
input (Tensor): The input image.
grid (Tensor): The grid to sample the input image.
canvas (Optional[Tensor]): The canvas to store the output image.
Returns:
Tensor: The sampled image.
"""
output = F.grid_sample(input, grid, align_corners=True)
if canvas is None:
return output
else:
input_mask = input.data.new(input.size()).fill_(1)
output_mask = F.grid_sample(input_mask, grid, align_corners=True)
padded_output = output * output_mask + canvas * (1 - output_mask)
return padded_output
def _compute_partial_repr(self, input_points: torch.Tensor,
control_points: torch.Tensor) -> torch.Tensor:
"""Compute the partial representation matrix.
Args:
input_points (Tensor): The input points.
control_points (Tensor): The control points.
Returns:
Tensor: The partial representation matrix.
"""
N = input_points.size(0)
M = control_points.size(0)
pairwise_diff = input_points.view(N, 1, 2) - control_points.view(
1, M, 2)
pairwise_diff_square = pairwise_diff * pairwise_diff
pairwise_dist = pairwise_diff_square[:, :,
0] + pairwise_diff_square[:, :, 1]
repr_matrix = 0.5 * pairwise_dist * torch.log(pairwise_dist)
mask = repr_matrix != repr_matrix
repr_matrix.masked_fill_(mask, 0)
return repr_matrix
# output_ctrl_pts are specified, according to our task.
def _build_output_control_points(self, num_control_points: torch.Tensor,
margins: Tuple[float,
float]) -> torch.Tensor:
"""Build the output control points.
The output points will be fix at
top and down side of the image.
Args:
num_control_points (Tensor): The number of control points.
margins (Tuple[float, float]): The margins for control points to
the top and down side of the image.
Returns:
Tensor: The output control points.
"""
margin_x, margin_y = margins
num_ctrl_pts_per_side = num_control_points // 2
ctrl_pts_x = np.linspace(margin_x, 1.0 - margin_x,
num_ctrl_pts_per_side)
ctrl_pts_y_top = np.ones(num_ctrl_pts_per_side) * margin_y
ctrl_pts_y_bottom = np.ones(num_ctrl_pts_per_side) * (1.0 - margin_y)
ctrl_pts_top = np.stack([ctrl_pts_x, ctrl_pts_y_top], axis=1)
ctrl_pts_bottom = np.stack([ctrl_pts_x, ctrl_pts_y_bottom], axis=1)
output_ctrl_pts_arr = np.concatenate([ctrl_pts_top, ctrl_pts_bottom],
axis=0)
output_ctrl_pts = torch.Tensor(output_ctrl_pts_arr)
return output_ctrl_pts
@MODELS.register_module()
class STN(BasePreprocessor):
"""Implement STN module in ASTER: An Attentional Scene Text Recognizer with
Flexible Rectification
(https://ieeexplore.ieee.org/abstract/document/8395027/)
Args:
in_channels (int): The number of input channels.
resized_image_size (Tuple[int, int]): The resized image size. The input
image will be downsampled to have a better recitified result.
output_image_size: The size of the output image for TPS. Defaults to
(32, 100).
num_control_points: The number of control points. Defaults to 20.
margins: The margins for control points to the top and down side of the
image for TPS. Defaults to [0.05, 0.05].
"""
def __init__(self,
in_channels: int,
resized_image_size: Tuple[int, int] = (32, 64),
output_image_size: Tuple[int, int] = (32, 100),
num_control_points: int = 20,
margins: Tuple[float, float] = [0.05, 0.05],
init_cfg: Optional[Union[Dict, List[Dict]]] = [
dict(type='Xavier', layer='Conv2d'),
dict(type='Constant', val=1, layer='BatchNorm2d'),
]):
super().__init__(init_cfg=init_cfg)
self.resized_image_size = resized_image_size
self.num_control_points = num_control_points
self.tps = TPStransform(output_image_size, num_control_points, margins)
self.stn_convnet = nn.Sequential(
ConvModule(in_channels, 32, 3, 1, 1, norm_cfg=dict(type='BN')),
nn.MaxPool2d(kernel_size=2, stride=2),
ConvModule(32, 64, 3, 1, 1, norm_cfg=dict(type='BN')),
nn.MaxPool2d(kernel_size=2, stride=2),
ConvModule(64, 128, 3, 1, 1, norm_cfg=dict(type='BN')),
nn.MaxPool2d(kernel_size=2, stride=2),
ConvModule(128, 256, 3, 1, 1, norm_cfg=dict(type='BN')),
nn.MaxPool2d(kernel_size=2, stride=2),
ConvModule(256, 256, 3, 1, 1, norm_cfg=dict(type='BN')),
nn.MaxPool2d(kernel_size=2, stride=2),
ConvModule(256, 256, 3, 1, 1, norm_cfg=dict(type='BN')),
)
self.stn_fc1 = nn.Sequential(
nn.Linear(2 * 256, 512), nn.BatchNorm1d(512),
nn.ReLU(inplace=True))
self.stn_fc2 = nn.Linear(512, num_control_points * 2)
self.init_stn(self.stn_fc2)
def init_stn(self, stn_fc2: nn.Linear) -> None:
"""Initialize the output linear layer of stn, so that the initial
source point will be at the top and down side of the image, which will
help to optimize.
Args:
stn_fc2 (nn.Linear): The output linear layer of stn.
"""
margin = 0.01
sampling_num_per_side = int(self.num_control_points / 2)
ctrl_pts_x = np.linspace(margin, 1. - margin, sampling_num_per_side)
ctrl_pts_y_top = np.ones(sampling_num_per_side) * margin
ctrl_pts_y_bottom = np.ones(sampling_num_per_side) * (1 - margin)
ctrl_pts_top = np.stack([ctrl_pts_x, ctrl_pts_y_top], axis=1)
ctrl_pts_bottom = np.stack([ctrl_pts_x, ctrl_pts_y_bottom], axis=1)
ctrl_points = np.concatenate([ctrl_pts_top, ctrl_pts_bottom],
axis=0).astype(np.float32)
stn_fc2.weight.data.zero_()
stn_fc2.bias.data = torch.Tensor(ctrl_points).view(-1)
def forward(self, img: torch.Tensor) -> torch.Tensor:
"""Forward function of STN.
Args:
img (Tensor): The input image tensor.
Returns:
Tensor: The rectified image tensor.
"""
resize_img = F.interpolate(
img, self.resized_image_size, mode='bilinear', align_corners=True)
points = self.stn_convnet(resize_img)
batch_size, _, _, _ = points.size()
points = points.view(batch_size, -1)
img_feat = self.stn_fc1(points)
points = self.stn_fc2(0.1 * img_feat)
points = points.view(-1, self.num_control_points, 2)
transformd_image = self.tps(img, points)
return transformd_image
|