Spaces:
Runtime error
Runtime error
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
""" | |
This code is refer from: | |
https://github.com/FudanVI/FudanOCR/blob/main/text-gestalt/model/tsrn.py | |
""" | |
import math | |
import paddle | |
import paddle.nn.functional as F | |
from paddle import nn | |
from collections import OrderedDict | |
import sys | |
import numpy as np | |
import warnings | |
import math, copy | |
import cv2 | |
warnings.filterwarnings("ignore") | |
from .tps_spatial_transformer import TPSSpatialTransformer | |
from .stn import STN as STN_model | |
from ppocr.modeling.heads.sr_rensnet_transformer import Transformer | |
class TSRN(nn.Layer): | |
def __init__(self, | |
in_channels, | |
scale_factor=2, | |
width=128, | |
height=32, | |
STN=False, | |
srb_nums=5, | |
mask=False, | |
hidden_units=32, | |
infer_mode=False, | |
**kwargs): | |
super(TSRN, self).__init__() | |
in_planes = 3 | |
if mask: | |
in_planes = 4 | |
assert math.log(scale_factor, 2) % 1 == 0 | |
upsample_block_num = int(math.log(scale_factor, 2)) | |
self.block1 = nn.Sequential( | |
nn.Conv2D( | |
in_planes, 2 * hidden_units, kernel_size=9, padding=4), | |
nn.PReLU()) | |
self.srb_nums = srb_nums | |
for i in range(srb_nums): | |
setattr(self, 'block%d' % (i + 2), | |
RecurrentResidualBlock(2 * hidden_units)) | |
setattr( | |
self, | |
'block%d' % (srb_nums + 2), | |
nn.Sequential( | |
nn.Conv2D( | |
2 * hidden_units, | |
2 * hidden_units, | |
kernel_size=3, | |
padding=1), | |
nn.BatchNorm2D(2 * hidden_units))) | |
block_ = [ | |
UpsampleBLock(2 * hidden_units, 2) | |
for _ in range(upsample_block_num) | |
] | |
block_.append( | |
nn.Conv2D( | |
2 * hidden_units, in_planes, kernel_size=9, padding=4)) | |
setattr(self, 'block%d' % (srb_nums + 3), nn.Sequential(*block_)) | |
self.tps_inputsize = [height // scale_factor, width // scale_factor] | |
tps_outputsize = [height // scale_factor, width // scale_factor] | |
num_control_points = 20 | |
tps_margins = [0.05, 0.05] | |
self.stn = STN | |
if self.stn: | |
self.tps = TPSSpatialTransformer( | |
output_image_size=tuple(tps_outputsize), | |
num_control_points=num_control_points, | |
margins=tuple(tps_margins)) | |
self.stn_head = STN_model( | |
in_channels=in_planes, | |
num_ctrlpoints=num_control_points, | |
activation='none') | |
self.out_channels = in_channels | |
self.r34_transformer = Transformer() | |
for param in self.r34_transformer.parameters(): | |
param.trainable = False | |
self.infer_mode = infer_mode | |
def forward(self, x): | |
output = {} | |
if self.infer_mode: | |
output["lr_img"] = x | |
y = x | |
else: | |
output["lr_img"] = x[0] | |
output["hr_img"] = x[1] | |
y = x[0] | |
if self.stn and self.training: | |
_, ctrl_points_x = self.stn_head(y) | |
y, _ = self.tps(y, ctrl_points_x) | |
block = {'1': self.block1(y)} | |
for i in range(self.srb_nums + 1): | |
block[str(i + 2)] = getattr(self, | |
'block%d' % (i + 2))(block[str(i + 1)]) | |
block[str(self.srb_nums + 3)] = getattr(self, 'block%d' % (self.srb_nums + 3)) \ | |
((block['1'] + block[str(self.srb_nums + 2)])) | |
sr_img = paddle.tanh(block[str(self.srb_nums + 3)]) | |
output["sr_img"] = sr_img | |
if self.training: | |
hr_img = x[1] | |
length = x[2] | |
input_tensor = x[3] | |
# add transformer | |
sr_pred, word_attention_map_pred, _ = self.r34_transformer( | |
sr_img, length, input_tensor) | |
hr_pred, word_attention_map_gt, _ = self.r34_transformer( | |
hr_img, length, input_tensor) | |
output["hr_img"] = hr_img | |
output["hr_pred"] = hr_pred | |
output["word_attention_map_gt"] = word_attention_map_gt | |
output["sr_pred"] = sr_pred | |
output["word_attention_map_pred"] = word_attention_map_pred | |
return output | |
class RecurrentResidualBlock(nn.Layer): | |
def __init__(self, channels): | |
super(RecurrentResidualBlock, self).__init__() | |
self.conv1 = nn.Conv2D(channels, channels, kernel_size=3, padding=1) | |
self.bn1 = nn.BatchNorm2D(channels) | |
self.gru1 = GruBlock(channels, channels) | |
self.prelu = mish() | |
self.conv2 = nn.Conv2D(channels, channels, kernel_size=3, padding=1) | |
self.bn2 = nn.BatchNorm2D(channels) | |
self.gru2 = GruBlock(channels, channels) | |
def forward(self, x): | |
residual = self.conv1(x) | |
residual = self.bn1(residual) | |
residual = self.prelu(residual) | |
residual = self.conv2(residual) | |
residual = self.bn2(residual) | |
residual = self.gru1(residual.transpose([0, 1, 3, 2])).transpose( | |
[0, 1, 3, 2]) | |
return self.gru2(x + residual) | |
class UpsampleBLock(nn.Layer): | |
def __init__(self, in_channels, up_scale): | |
super(UpsampleBLock, self).__init__() | |
self.conv = nn.Conv2D( | |
in_channels, in_channels * up_scale**2, kernel_size=3, padding=1) | |
self.pixel_shuffle = nn.PixelShuffle(up_scale) | |
self.prelu = mish() | |
def forward(self, x): | |
x = self.conv(x) | |
x = self.pixel_shuffle(x) | |
x = self.prelu(x) | |
return x | |
class mish(nn.Layer): | |
def __init__(self, ): | |
super(mish, self).__init__() | |
self.activated = True | |
def forward(self, x): | |
if self.activated: | |
x = x * (paddle.tanh(F.softplus(x))) | |
return x | |
class GruBlock(nn.Layer): | |
def __init__(self, in_channels, out_channels): | |
super(GruBlock, self).__init__() | |
assert out_channels % 2 == 0 | |
self.conv1 = nn.Conv2D( | |
in_channels, out_channels, kernel_size=1, padding=0) | |
self.gru = nn.GRU(out_channels, | |
out_channels // 2, | |
direction='bidirectional') | |
def forward(self, x): | |
# x: b, c, w, h | |
x = self.conv1(x) | |
x = x.transpose([0, 2, 3, 1]) # b, w, h, c | |
batch_size, w, h, c = x.shape | |
x = x.reshape([-1, h, c]) # b*w, h, c | |
x, _ = self.gru(x) | |
x = x.reshape([-1, w, h, c]) | |
x = x.transpose([0, 3, 1, 2]) | |
return x | |