topdu's picture
openocr demo
29f689c
raw
history blame
8.23 kB
import torch
import torch.nn as nn
import torch.nn.functional as F
class CAM(nn.Module):
'''
Convolutional Alignment Module
'''
# Current version only supports input whose size is a power of 2, such as 32, 64, 128 etc.
# You can adapt it to any input size by changing the padding or stride.
def __init__(self,
channels_list=[64, 128, 256, 512],
strides_list=[[2, 2], [1, 1], [1, 1]],
in_shape=[8, 32],
maxT=25,
depth=4,
num_channels=128):
super(CAM, self).__init__()
# cascade multiscale features
fpn = []
for i in range(1, len(channels_list)):
fpn.append(
nn.Sequential(
nn.Conv2d(channels_list[i - 1], channels_list[i], (3, 3),
(strides_list[i - 1][0], strides_list[i - 1][1]),
1), nn.BatchNorm2d(channels_list[i]),
nn.ReLU(True)))
self.fpn = nn.Sequential(*fpn)
# convolutional alignment
# convs
assert depth % 2 == 0, 'the depth of CAM must be a even number.'
# in_shape = scales[-1]
strides = []
conv_ksizes = []
deconv_ksizes = []
h, w = in_shape[0], in_shape[1]
for i in range(0, int(depth / 2)):
stride = [2] if 2**(depth / 2 - i) <= h else [1]
stride = stride + [2] if 2**(depth / 2 - i) <= w else stride + [1]
strides.append(stride)
conv_ksizes.append([3, 3])
deconv_ksizes.append([_**2 for _ in stride])
convs = [
nn.Sequential(
nn.Conv2d(channels_list[-1], num_channels,
tuple(conv_ksizes[0]), tuple(strides[0]),
(int((conv_ksizes[0][0] - 1) / 2),
int((conv_ksizes[0][1] - 1) / 2))),
nn.BatchNorm2d(num_channels), nn.ReLU(True))
]
for i in range(1, int(depth / 2)):
convs.append(
nn.Sequential(
nn.Conv2d(num_channels, num_channels,
tuple(conv_ksizes[i]), tuple(strides[i]),
(int((conv_ksizes[i][0] - 1) / 2),
int((conv_ksizes[i][1] - 1) / 2))),
nn.BatchNorm2d(num_channels), nn.ReLU(True)))
self.convs = nn.Sequential(*convs)
# deconvs
deconvs = []
for i in range(1, int(depth / 2)):
deconvs.append(
nn.Sequential(
nn.ConvTranspose2d(
num_channels, num_channels,
tuple(deconv_ksizes[int(depth / 2) - i]),
tuple(strides[int(depth / 2) - i]),
(int(deconv_ksizes[int(depth / 2) - i][0] / 4.),
int(deconv_ksizes[int(depth / 2) - i][1] / 4.))),
nn.BatchNorm2d(num_channels), nn.ReLU(True)))
deconvs.append(
nn.Sequential(
nn.ConvTranspose2d(num_channels, maxT, tuple(deconv_ksizes[0]),
tuple(strides[0]),
(int(deconv_ksizes[0][0] / 4.),
int(deconv_ksizes[0][1] / 4.))),
nn.Sigmoid()))
self.deconvs = nn.Sequential(*deconvs)
def forward(self, input):
x = input[0]
for i in range(0, len(self.fpn)):
# print(self.fpn[i](x).shape, input[i+1].shape)
x = self.fpn[i](x) + input[i + 1]
conv_feats = []
for i in range(0, len(self.convs)):
x = self.convs[i](x)
conv_feats.append(x)
for i in range(0, len(self.deconvs) - 1):
x = self.deconvs[i](x)
x = x + conv_feats[len(conv_feats) - 2 - i]
x = self.deconvs[-1](x)
return x
class CAMSimp(nn.Module):
def __init__(self, maxT=25, num_channels=128):
super(CAMSimp, self).__init__()
self.conv = nn.Sequential(nn.Conv2d(num_channels, maxT, 1, 1, 0),
nn.Sigmoid())
def forward(self, x):
x = self.conv(x)
return x
class DANDecoder(nn.Module):
'''
Decoupled Text Decoder
'''
def __init__(self,
out_channels,
in_channels,
use_cam=True,
max_len=25,
channels_list=[64, 128, 256, 512],
strides_list=[[2, 2], [1, 1], [1, 1]],
in_shape=[8, 32],
depth=4,
dropout=0.3,
**kwargs):
super(DANDecoder, self).__init__()
self.eos = 0
self.bos = out_channels - 2
self.ignore_index = out_channels - 1
nchannel = in_channels
self.nchannel = in_channels
self.use_cam = use_cam
if use_cam:
self.cam = CAM(channels_list=channels_list,
strides_list=strides_list,
in_shape=in_shape,
maxT=max_len + 1,
depth=depth,
num_channels=nchannel)
else:
self.cam = CAMSimp(maxT=max_len + 1, num_channels=nchannel)
self.pre_lstm = nn.LSTM(nchannel,
int(nchannel / 2),
bidirectional=True)
self.rnn = nn.GRUCell(nchannel * 2, nchannel)
self.generator = nn.Sequential(nn.Dropout(p=dropout),
nn.Linear(nchannel, out_channels - 2))
self.char_embeddings = nn.Embedding(out_channels,
embedding_dim=in_channels,
padding_idx=out_channels - 1)
def forward(self, inputs, data=None):
A = self.cam(inputs)
if isinstance(inputs, list):
feature = inputs[-1]
else:
feature = inputs
nB, nC, nH, nW = feature.shape
nT = A.shape[1]
# Normalize
A = A / A.view(nB, nT, -1).sum(2).view(nB, nT, 1, 1)
# weighted sum
C = feature.view(nB, 1, nC, nH, nW) * A.view(nB, nT, 1, nH, nW)
C = C.view(nB, nT, nC, -1).sum(3).transpose(1, 0) # T, B, C
C, _ = self.pre_lstm(C) # T, B, C
C = F.dropout(C, p=0.3, training=self.training)
if self.training:
text = data[0]
text_length = data[-1]
nsteps = int(text_length.max())
gru_res = torch.zeros_like(C)
hidden = torch.zeros(nB, self.nchannel).type_as(C.data)
prev_emb = self.char_embeddings(text[:, 0])
for i in range(0, nsteps + 1):
hidden = self.rnn(torch.cat((C[i, :, :], prev_emb), dim=1),
hidden)
gru_res[i, :, :] = hidden
prev_emb = self.char_embeddings(text[:, i + 1])
gru_res = self.generator(gru_res)
return gru_res[:nsteps + 1, :, :].transpose(1, 0)
else:
gru_res = torch.zeros_like(C)
hidden = torch.zeros(nB, self.nchannel).type_as(C.data)
prev_emb = self.char_embeddings(
torch.zeros(nB, dtype=torch.int64, device=feature.device) +
self.bos)
dec_seq = torch.full((nB, nT),
self.ignore_index,
dtype=torch.int64,
device=feature.get_device())
for i in range(0, nT):
hidden = self.rnn(torch.cat((C[i, :, :], prev_emb), dim=1),
hidden)
gru_res[i, :, :] = hidden
mid_res = self.generator(hidden).argmax(-1)
dec_seq[:, i] = mid_res.squeeze(0)
if (dec_seq == self.eos).any(dim=-1).all():
break
prev_emb = self.char_embeddings(mid_res)
gru_res = self.generator(gru_res)
return F.softmax(gru_res.transpose(1, 0), -1)