|
import torch |
|
from torch import nn |
|
|
|
from .RecSVTR import Block |
|
|
|
|
|
class Swish(nn.Module): |
|
def __int__(self): |
|
super(Swish, self).__int__() |
|
|
|
def forward(self, x): |
|
return x * torch.sigmoid(x) |
|
|
|
|
|
class Im2Im(nn.Module): |
|
def __init__(self, in_channels, **kwargs): |
|
super().__init__() |
|
self.out_channels = in_channels |
|
|
|
def forward(self, x): |
|
return x |
|
|
|
|
|
class Im2Seq(nn.Module): |
|
def __init__(self, in_channels, **kwargs): |
|
super().__init__() |
|
self.out_channels = in_channels |
|
|
|
def forward(self, x): |
|
B, C, H, W = x.shape |
|
|
|
x = x.reshape(B, C, H * W) |
|
x = x.permute((0, 2, 1)) |
|
return x |
|
|
|
|
|
class EncoderWithRNN(nn.Module): |
|
def __init__(self, in_channels, **kwargs): |
|
super(EncoderWithRNN, self).__init__() |
|
hidden_size = kwargs.get("hidden_size", 256) |
|
self.out_channels = hidden_size * 2 |
|
self.lstm = nn.LSTM(in_channels, hidden_size, bidirectional=True, num_layers=2, batch_first=True) |
|
|
|
def forward(self, x): |
|
self.lstm.flatten_parameters() |
|
x, _ = self.lstm(x) |
|
return x |
|
|
|
|
|
class SequenceEncoder(nn.Module): |
|
def __init__(self, in_channels, encoder_type="rnn", **kwargs): |
|
super(SequenceEncoder, self).__init__() |
|
self.encoder_reshape = Im2Seq(in_channels) |
|
self.out_channels = self.encoder_reshape.out_channels |
|
self.encoder_type = encoder_type |
|
if encoder_type == "reshape": |
|
self.only_reshape = True |
|
else: |
|
support_encoder_dict = {"reshape": Im2Seq, "rnn": EncoderWithRNN, "svtr": EncoderWithSVTR} |
|
assert encoder_type in support_encoder_dict, "{} must in {}".format( |
|
encoder_type, support_encoder_dict.keys() |
|
) |
|
|
|
self.encoder = support_encoder_dict[encoder_type](self.encoder_reshape.out_channels, **kwargs) |
|
self.out_channels = self.encoder.out_channels |
|
self.only_reshape = False |
|
|
|
def forward(self, x): |
|
if self.encoder_type != "svtr": |
|
x = self.encoder_reshape(x) |
|
if not self.only_reshape: |
|
x = self.encoder(x) |
|
return x |
|
else: |
|
x = self.encoder(x) |
|
x = self.encoder_reshape(x) |
|
return x |
|
|
|
|
|
class ConvBNLayer(nn.Module): |
|
def __init__( |
|
self, in_channels, out_channels, kernel_size=3, stride=1, padding=0, bias_attr=False, groups=1, act=nn.GELU |
|
): |
|
super().__init__() |
|
self.conv = nn.Conv2d( |
|
in_channels=in_channels, |
|
out_channels=out_channels, |
|
kernel_size=kernel_size, |
|
stride=stride, |
|
padding=padding, |
|
groups=groups, |
|
|
|
bias=bias_attr, |
|
) |
|
self.norm = nn.BatchNorm2d(out_channels) |
|
self.act = Swish() |
|
|
|
def forward(self, inputs): |
|
out = self.conv(inputs) |
|
out = self.norm(out) |
|
out = self.act(out) |
|
return out |
|
|
|
|
|
class EncoderWithSVTR(nn.Module): |
|
def __init__( |
|
self, |
|
in_channels, |
|
dims=64, |
|
depth=2, |
|
hidden_dims=120, |
|
use_guide=False, |
|
num_heads=8, |
|
qkv_bias=True, |
|
mlp_ratio=2.0, |
|
drop_rate=0.1, |
|
attn_drop_rate=0.1, |
|
drop_path=0.0, |
|
qk_scale=None, |
|
): |
|
super(EncoderWithSVTR, self).__init__() |
|
self.depth = depth |
|
self.use_guide = use_guide |
|
self.conv1 = ConvBNLayer(in_channels, in_channels // 8, padding=1, act="swish") |
|
self.conv2 = ConvBNLayer(in_channels // 8, hidden_dims, kernel_size=1, act="swish") |
|
|
|
self.svtr_block = nn.ModuleList( |
|
[ |
|
Block( |
|
dim=hidden_dims, |
|
num_heads=num_heads, |
|
mixer="Global", |
|
HW=None, |
|
mlp_ratio=mlp_ratio, |
|
qkv_bias=qkv_bias, |
|
qk_scale=qk_scale, |
|
drop=drop_rate, |
|
act_layer="swish", |
|
attn_drop=attn_drop_rate, |
|
drop_path=drop_path, |
|
norm_layer="nn.LayerNorm", |
|
epsilon=1e-05, |
|
prenorm=False, |
|
) |
|
for i in range(depth) |
|
] |
|
) |
|
self.norm = nn.LayerNorm(hidden_dims, eps=1e-6) |
|
self.conv3 = ConvBNLayer(hidden_dims, in_channels, kernel_size=1, act="swish") |
|
|
|
self.conv4 = ConvBNLayer(2 * in_channels, in_channels // 8, padding=1, act="swish") |
|
|
|
self.conv1x1 = ConvBNLayer(in_channels // 8, dims, kernel_size=1, act="swish") |
|
self.out_channels = dims |
|
self.apply(self._init_weights) |
|
|
|
def _init_weights(self, m): |
|
|
|
if isinstance(m, nn.Conv2d): |
|
nn.init.kaiming_normal_(m.weight, mode="fan_out") |
|
if m.bias is not None: |
|
nn.init.zeros_(m.bias) |
|
elif isinstance(m, nn.BatchNorm2d): |
|
nn.init.ones_(m.weight) |
|
nn.init.zeros_(m.bias) |
|
elif isinstance(m, nn.Linear): |
|
nn.init.normal_(m.weight, 0, 0.01) |
|
if m.bias is not None: |
|
nn.init.zeros_(m.bias) |
|
elif isinstance(m, nn.ConvTranspose2d): |
|
nn.init.kaiming_normal_(m.weight, mode="fan_out") |
|
if m.bias is not None: |
|
nn.init.zeros_(m.bias) |
|
elif isinstance(m, nn.LayerNorm): |
|
nn.init.ones_(m.weight) |
|
nn.init.zeros_(m.bias) |
|
|
|
def forward(self, x): |
|
|
|
if self.use_guide: |
|
z = x.clone() |
|
z.stop_gradient = True |
|
else: |
|
z = x |
|
|
|
h = z |
|
|
|
z = self.conv1(z) |
|
z = self.conv2(z) |
|
|
|
B, C, H, W = z.shape |
|
z = z.flatten(2).permute(0, 2, 1) |
|
|
|
for blk in self.svtr_block: |
|
z = blk(z) |
|
|
|
z = self.norm(z) |
|
|
|
z = z.reshape([-1, H, W, C]).permute(0, 3, 1, 2) |
|
z = self.conv3(z) |
|
z = torch.cat((h, z), dim=1) |
|
z = self.conv1x1(self.conv4(z)) |
|
|
|
return z |
|
|
|
|
|
if __name__ == "__main__": |
|
svtrRNN = EncoderWithSVTR(56) |
|
print(svtrRNN) |
|
|