Spaces:
Runtime error
Runtime error
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
""" | |
This code is refer from: | |
https://github.com/open-mmlab/mmocr/blob/main/mmocr/models/textrecog/encoders/sar_encoder.py | |
https://github.com/open-mmlab/mmocr/blob/main/mmocr/models/textrecog/decoders/sar_decoder.py | |
""" | |
from __future__ import absolute_import | |
from __future__ import division | |
from __future__ import print_function | |
import math | |
import paddle | |
from paddle import ParamAttr | |
import paddle.nn as nn | |
import paddle.nn.functional as F | |
class SAREncoder(nn.Layer): | |
""" | |
Args: | |
enc_bi_rnn (bool): If True, use bidirectional RNN in encoder. | |
enc_drop_rnn (float): Dropout probability of RNN layer in encoder. | |
enc_gru (bool): If True, use GRU, else LSTM in encoder. | |
d_model (int): Dim of channels from backbone. | |
d_enc (int): Dim of encoder RNN layer. | |
mask (bool): If True, mask padding in RNN sequence. | |
""" | |
def __init__(self, | |
enc_bi_rnn=False, | |
enc_drop_rnn=0.1, | |
enc_gru=False, | |
d_model=512, | |
d_enc=512, | |
mask=True, | |
**kwargs): | |
super().__init__() | |
assert isinstance(enc_bi_rnn, bool) | |
assert isinstance(enc_drop_rnn, (int, float)) | |
assert 0 <= enc_drop_rnn < 1.0 | |
assert isinstance(enc_gru, bool) | |
assert isinstance(d_model, int) | |
assert isinstance(d_enc, int) | |
assert isinstance(mask, bool) | |
self.enc_bi_rnn = enc_bi_rnn | |
self.enc_drop_rnn = enc_drop_rnn | |
self.mask = mask | |
# LSTM Encoder | |
if enc_bi_rnn: | |
direction = 'bidirectional' | |
else: | |
direction = 'forward' | |
kwargs = dict( | |
input_size=d_model, | |
hidden_size=d_enc, | |
num_layers=2, | |
time_major=False, | |
dropout=enc_drop_rnn, | |
direction=direction) | |
if enc_gru: | |
self.rnn_encoder = nn.GRU(**kwargs) | |
else: | |
self.rnn_encoder = nn.LSTM(**kwargs) | |
# global feature transformation | |
encoder_rnn_out_size = d_enc * (int(enc_bi_rnn) + 1) | |
self.linear = nn.Linear(encoder_rnn_out_size, encoder_rnn_out_size) | |
def forward(self, feat, img_metas=None): | |
if img_metas is not None: | |
assert len(img_metas[0]) == paddle.shape(feat)[0] | |
valid_ratios = None | |
if img_metas is not None and self.mask: | |
valid_ratios = img_metas[-1] | |
h_feat = feat.shape[2] # bsz c h w | |
feat_v = F.max_pool2d( | |
feat, kernel_size=(h_feat, 1), stride=1, padding=0) | |
feat_v = feat_v.squeeze(2) # bsz * C * W | |
feat_v = paddle.transpose(feat_v, perm=[0, 2, 1]) # bsz * W * C | |
holistic_feat = self.rnn_encoder(feat_v)[0] # bsz * T * C | |
if valid_ratios is not None: | |
valid_hf = [] | |
T = paddle.shape(holistic_feat)[1] | |
for i in range(paddle.shape(valid_ratios)[0]): | |
valid_step = paddle.minimum( | |
T, paddle.ceil(valid_ratios[i] * T).astype('int32')) - 1 | |
valid_hf.append(holistic_feat[i, valid_step, :]) | |
valid_hf = paddle.stack(valid_hf, axis=0) | |
else: | |
valid_hf = holistic_feat[:, -1, :] # bsz * C | |
holistic_feat = self.linear(valid_hf) # bsz * C | |
return holistic_feat | |
class BaseDecoder(nn.Layer): | |
def __init__(self, **kwargs): | |
super().__init__() | |
def forward_train(self, feat, out_enc, targets, img_metas): | |
raise NotImplementedError | |
def forward_test(self, feat, out_enc, img_metas): | |
raise NotImplementedError | |
def forward(self, | |
feat, | |
out_enc, | |
label=None, | |
img_metas=None, | |
train_mode=True): | |
self.train_mode = train_mode | |
if train_mode: | |
return self.forward_train(feat, out_enc, label, img_metas) | |
return self.forward_test(feat, out_enc, img_metas) | |
class ParallelSARDecoder(BaseDecoder): | |
""" | |
Args: | |
out_channels (int): Output class number. | |
enc_bi_rnn (bool): If True, use bidirectional RNN in encoder. | |
dec_bi_rnn (bool): If True, use bidirectional RNN in decoder. | |
dec_drop_rnn (float): Dropout of RNN layer in decoder. | |
dec_gru (bool): If True, use GRU, else LSTM in decoder. | |
d_model (int): Dim of channels from backbone. | |
d_enc (int): Dim of encoder RNN layer. | |
d_k (int): Dim of channels of attention module. | |
pred_dropout (float): Dropout probability of prediction layer. | |
max_seq_len (int): Maximum sequence length for decoding. | |
mask (bool): If True, mask padding in feature map. | |
start_idx (int): Index of start token. | |
padding_idx (int): Index of padding token. | |
pred_concat (bool): If True, concat glimpse feature from | |
attention with holistic feature and hidden state. | |
""" | |
def __init__( | |
self, | |
out_channels, # 90 + unknown + start + padding | |
enc_bi_rnn=False, | |
dec_bi_rnn=False, | |
dec_drop_rnn=0.0, | |
dec_gru=False, | |
d_model=512, | |
d_enc=512, | |
d_k=64, | |
pred_dropout=0.1, | |
max_text_length=30, | |
mask=True, | |
pred_concat=True, | |
**kwargs): | |
super().__init__() | |
self.num_classes = out_channels | |
self.enc_bi_rnn = enc_bi_rnn | |
self.d_k = d_k | |
self.start_idx = out_channels - 2 | |
self.padding_idx = out_channels - 1 | |
self.max_seq_len = max_text_length | |
self.mask = mask | |
self.pred_concat = pred_concat | |
encoder_rnn_out_size = d_enc * (int(enc_bi_rnn) + 1) | |
decoder_rnn_out_size = encoder_rnn_out_size * (int(dec_bi_rnn) + 1) | |
# 2D attention layer | |
self.conv1x1_1 = nn.Linear(decoder_rnn_out_size, d_k) | |
self.conv3x3_1 = nn.Conv2D( | |
d_model, d_k, kernel_size=3, stride=1, padding=1) | |
self.conv1x1_2 = nn.Linear(d_k, 1) | |
# Decoder RNN layer | |
if dec_bi_rnn: | |
direction = 'bidirectional' | |
else: | |
direction = 'forward' | |
kwargs = dict( | |
input_size=encoder_rnn_out_size, | |
hidden_size=encoder_rnn_out_size, | |
num_layers=2, | |
time_major=False, | |
dropout=dec_drop_rnn, | |
direction=direction) | |
if dec_gru: | |
self.rnn_decoder = nn.GRU(**kwargs) | |
else: | |
self.rnn_decoder = nn.LSTM(**kwargs) | |
# Decoder input embedding | |
self.embedding = nn.Embedding( | |
self.num_classes, | |
encoder_rnn_out_size, | |
padding_idx=self.padding_idx) | |
# Prediction layer | |
self.pred_dropout = nn.Dropout(pred_dropout) | |
pred_num_classes = self.num_classes - 1 | |
if pred_concat: | |
fc_in_channel = decoder_rnn_out_size + d_model + encoder_rnn_out_size | |
else: | |
fc_in_channel = d_model | |
self.prediction = nn.Linear(fc_in_channel, pred_num_classes) | |
def _2d_attention(self, | |
decoder_input, | |
feat, | |
holistic_feat, | |
valid_ratios=None): | |
y = self.rnn_decoder(decoder_input)[0] | |
# y: bsz * (seq_len + 1) * hidden_size | |
attn_query = self.conv1x1_1(y) # bsz * (seq_len + 1) * attn_size | |
bsz, seq_len, attn_size = attn_query.shape | |
attn_query = paddle.unsqueeze(attn_query, axis=[3, 4]) | |
# (bsz, seq_len + 1, attn_size, 1, 1) | |
attn_key = self.conv3x3_1(feat) | |
# bsz * attn_size * h * w | |
attn_key = attn_key.unsqueeze(1) | |
# bsz * 1 * attn_size * h * w | |
attn_weight = paddle.tanh(paddle.add(attn_key, attn_query)) | |
# bsz * (seq_len + 1) * attn_size * h * w | |
attn_weight = paddle.transpose(attn_weight, perm=[0, 1, 3, 4, 2]) | |
# bsz * (seq_len + 1) * h * w * attn_size | |
attn_weight = self.conv1x1_2(attn_weight) | |
# bsz * (seq_len + 1) * h * w * 1 | |
bsz, T, h, w, c = paddle.shape(attn_weight) | |
assert c == 1 | |
if valid_ratios is not None: | |
# cal mask of attention weight | |
for i in range(paddle.shape(valid_ratios)[0]): | |
valid_width = paddle.minimum( | |
w, paddle.ceil(valid_ratios[i] * w).astype("int32")) | |
if valid_width < w: | |
attn_weight[i, :, :, valid_width:, :] = float('-inf') | |
attn_weight = paddle.reshape(attn_weight, [bsz, T, -1]) | |
attn_weight = F.softmax(attn_weight, axis=-1) | |
attn_weight = paddle.reshape(attn_weight, [bsz, T, h, w, c]) | |
attn_weight = paddle.transpose(attn_weight, perm=[0, 1, 4, 2, 3]) | |
# attn_weight: bsz * T * c * h * w | |
# feat: bsz * c * h * w | |
attn_feat = paddle.sum(paddle.multiply(feat.unsqueeze(1), attn_weight), | |
(3, 4), | |
keepdim=False) | |
# bsz * (seq_len + 1) * C | |
# Linear transformation | |
if self.pred_concat: | |
hf_c = holistic_feat.shape[-1] | |
holistic_feat = paddle.expand( | |
holistic_feat, shape=[bsz, seq_len, hf_c]) | |
y = self.prediction( | |
paddle.concat((y, attn_feat.astype(y.dtype), | |
holistic_feat.astype(y.dtype)), 2)) | |
else: | |
y = self.prediction(attn_feat) | |
# bsz * (seq_len + 1) * num_classes | |
if self.train_mode: | |
y = self.pred_dropout(y) | |
return y | |
def forward_train(self, feat, out_enc, label, img_metas): | |
''' | |
img_metas: [label, valid_ratio] | |
''' | |
if img_metas is not None: | |
assert paddle.shape(img_metas[0])[0] == paddle.shape(feat)[0] | |
valid_ratios = None | |
if img_metas is not None and self.mask: | |
valid_ratios = img_metas[-1] | |
lab_embedding = self.embedding(label) | |
# bsz * seq_len * emb_dim | |
out_enc = out_enc.unsqueeze(1).astype(lab_embedding.dtype) | |
# bsz * 1 * emb_dim | |
in_dec = paddle.concat((out_enc, lab_embedding), axis=1) | |
# bsz * (seq_len + 1) * C | |
out_dec = self._2d_attention( | |
in_dec, feat, out_enc, valid_ratios=valid_ratios) | |
return out_dec[:, 1:, :] # bsz * seq_len * num_classes | |
def forward_test(self, feat, out_enc, img_metas): | |
if img_metas is not None: | |
assert len(img_metas[0]) == feat.shape[0] | |
valid_ratios = None | |
if img_metas is not None and self.mask: | |
valid_ratios = img_metas[-1] | |
seq_len = self.max_seq_len | |
bsz = feat.shape[0] | |
start_token = paddle.full( | |
(bsz, ), fill_value=self.start_idx, dtype='int64') | |
# bsz | |
start_token = self.embedding(start_token) | |
# bsz * emb_dim | |
emb_dim = start_token.shape[1] | |
start_token = start_token.unsqueeze(1) | |
start_token = paddle.expand(start_token, shape=[bsz, seq_len, emb_dim]) | |
# bsz * seq_len * emb_dim | |
out_enc = out_enc.unsqueeze(1) | |
# bsz * 1 * emb_dim | |
decoder_input = paddle.concat((out_enc, start_token), axis=1) | |
# bsz * (seq_len + 1) * emb_dim | |
outputs = [] | |
for i in range(1, seq_len + 1): | |
decoder_output = self._2d_attention( | |
decoder_input, feat, out_enc, valid_ratios=valid_ratios) | |
char_output = decoder_output[:, i, :] # bsz * num_classes | |
char_output = F.softmax(char_output, -1) | |
outputs.append(char_output) | |
max_idx = paddle.argmax(char_output, axis=1, keepdim=False) | |
char_embedding = self.embedding(max_idx) # bsz * emb_dim | |
if i < seq_len: | |
decoder_input[:, i + 1, :] = char_embedding | |
outputs = paddle.stack(outputs, 1) # bsz * seq_len * num_classes | |
return outputs | |
class SARHead(nn.Layer): | |
def __init__(self, | |
in_channels, | |
out_channels, | |
enc_dim=512, | |
max_text_length=30, | |
enc_bi_rnn=False, | |
enc_drop_rnn=0.1, | |
enc_gru=False, | |
dec_bi_rnn=False, | |
dec_drop_rnn=0.0, | |
dec_gru=False, | |
d_k=512, | |
pred_dropout=0.1, | |
pred_concat=True, | |
**kwargs): | |
super(SARHead, self).__init__() | |
# encoder module | |
self.encoder = SAREncoder( | |
enc_bi_rnn=enc_bi_rnn, | |
enc_drop_rnn=enc_drop_rnn, | |
enc_gru=enc_gru, | |
d_model=in_channels, | |
d_enc=enc_dim) | |
# decoder module | |
self.decoder = ParallelSARDecoder( | |
out_channels=out_channels, | |
enc_bi_rnn=enc_bi_rnn, | |
dec_bi_rnn=dec_bi_rnn, | |
dec_drop_rnn=dec_drop_rnn, | |
dec_gru=dec_gru, | |
d_model=in_channels, | |
d_enc=enc_dim, | |
d_k=d_k, | |
pred_dropout=pred_dropout, | |
max_text_length=max_text_length, | |
pred_concat=pred_concat) | |
def forward(self, feat, targets=None): | |
''' | |
img_metas: [label, valid_ratio] | |
''' | |
holistic_feat = self.encoder(feat, targets) # bsz c | |
if self.training: | |
label = targets[0] # label | |
final_out = self.decoder( | |
feat, holistic_feat, label, img_metas=targets) | |
else: | |
final_out = self.decoder( | |
feat, | |
holistic_feat, | |
label=None, | |
img_metas=targets, | |
train_mode=False) | |
# (bsz, seq_len, num_classes) | |
return final_out | |