topdu's picture
openocr demo
29f689c
raw
history blame
5.77 kB
import torch
import torch.nn as nn
import torch.nn.functional as F
from openrec.modeling.encoders.svtrnet import (
Block,
ConvBNLayer,
kaiming_normal_,
trunc_normal_,
zeros_,
ones_,
)
class Swish(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
return x * F.sigmoid(x)
class EncoderWithSVTR(nn.Module):
def __init__(
self,
in_channels,
dims=64, # XS
depth=2,
hidden_dims=120,
use_guide=False,
num_heads=8,
qkv_bias=True,
mlp_ratio=2.0,
drop_rate=0.1,
attn_drop_rate=0.1,
drop_path=0.0,
kernel_size=[3, 3],
qk_scale=None,
use_pool=True,
):
super(EncoderWithSVTR, self).__init__()
self.depth = depth
self.use_guide = use_guide
self.use_pool = use_pool
self.conv1 = ConvBNLayer(
in_channels,
in_channels // 8,
kernel_size=kernel_size,
padding=[kernel_size[0] // 2, kernel_size[1] // 2],
act=Swish,
bias=False)
self.conv2 = ConvBNLayer(in_channels // 8,
hidden_dims,
kernel_size=1,
act=Swish,
bias=False)
self.svtr_block = nn.ModuleList([
Block(
dim=hidden_dims,
num_heads=num_heads,
mixer='Global',
HW=None,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
drop=drop_rate,
act_layer=Swish,
attn_drop=attn_drop_rate,
drop_path=drop_path,
norm_layer='nn.LayerNorm',
eps=1e-05,
prenorm=False,
) for i in range(depth)
])
self.norm = nn.LayerNorm(hidden_dims, eps=1e-6)
self.conv3 = ConvBNLayer(hidden_dims,
in_channels,
kernel_size=1,
act=Swish,
bias=False)
# last conv-nxn, the input is concat of input tensor and conv3 output tensor
self.conv4 = ConvBNLayer(
2 * in_channels,
in_channels // 8,
kernel_size=kernel_size,
padding=[kernel_size[0] // 2, kernel_size[1] // 2],
act=Swish,
bias=False)
self.conv1x1 = ConvBNLayer(in_channels // 8,
dims,
kernel_size=1,
act=Swish,
bias=False)
self.out_channels = dims
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, mean=0, std=0.02)
if isinstance(m, nn.Linear) and m.bias is not None:
zeros_(m.bias)
if isinstance(m, nn.LayerNorm):
zeros_(m.bias)
ones_(m.weight)
if isinstance(m, nn.Conv2d):
kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
def pool_h_2(self, x):
# x: B, C, H, W
x = x.mean(dim=2, keepdim=True)
x = F.avg_pool2d(x, kernel_size=(1, 2))
return x # B, C, 1, W//2
def forward(self, x):
if self.use_pool:
x = self.pool_h_2(x)
# for use guide
if self.use_guide:
z = x.detach()
else:
z = x
# for short cut
h = z
# reduce dim
z = self.conv1(z)
z = self.conv2(z)
# SVTR global block
B, C, H, W = z.shape
z = z.flatten(2).transpose(1, 2)
for blk in self.svtr_block:
z = blk(z)
z = self.norm(z)
# last stage
z = z.reshape(-1, H, W, C).permute(0, 3, 1, 2)
z = self.conv3(z)
z = torch.concat((h, z), dim=1)
z = self.conv1x1(self.conv4(z))
return z
class CTCDecoder(nn.Module):
def __init__(self,
in_channels,
out_channels=6625,
mid_channels=None,
return_feats=False,
svtr_encoder=None,
**kwargs):
super(CTCDecoder, self).__init__()
if svtr_encoder is not None:
svtr_encoder['in_channels'] = in_channels
self.svtr_encoder = EncoderWithSVTR(**svtr_encoder)
in_channels = self.svtr_encoder.out_channels
else:
self.svtr_encoder = None
if mid_channels is None:
self.fc = nn.Linear(
in_channels,
out_channels,
)
else:
self.fc1 = nn.Linear(
in_channels,
mid_channels,
)
self.fc2 = nn.Linear(
mid_channels,
out_channels,
)
self.out_channels = out_channels
self.mid_channels = mid_channels
self.return_feats = return_feats
def forward(self, x, data=None):
if self.svtr_encoder is not None:
x = self.svtr_encoder(x)
x = x.flatten(2).transpose(1, 2)
if self.mid_channels is None:
predicts = self.fc(x)
else:
x = self.fc1(x)
predicts = self.fc2(x)
if self.return_feats:
result = (x, predicts)
else:
result = predicts
if not self.training:
predicts = F.softmax(predicts, dim=2)
result = predicts
return result