topdu's picture
openocr demo
29f689c
"""This code is refer from:
https://github.com/MelosY/CAM
"""
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.init import trunc_normal_
from .convnextv2 import ConvNeXtV2, Block, LayerNorm
from .svtrv2_lnconv_two33 import SVTRv2LNConvTwo33
class Swish(nn.Module):
def __init__(self) -> None:
super().__init__()
def forward(self, x):
return x * torch.sigmoid(x)
class UNetBlock(nn.Module):
def __init__(self, cin, cout, bn2d, stride, deformable=False):
"""
a UNet block with 2x up sampling
"""
super().__init__()
stride_h, stride_w = stride
if stride_h == 1:
kernel_h = 1
padding_h = 0
elif stride_h == 2:
kernel_h = 4
padding_h = 1
elif stride_h == 4:
kernel_h = 4
padding_h = 0
if stride_w == 1:
kernel_w = 1
padding_w = 0
elif stride_w == 2:
kernel_w = 4
padding_w = 1
elif stride_w == 4:
kernel_w = 4
padding_w = 0
conv = nn.Conv2d
self.up_sample = nn.ConvTranspose2d(cin,
cin,
kernel_size=(kernel_h, kernel_w),
stride=(stride_h, stride_w),
padding=(padding_h, padding_w),
bias=True)
self.conv = nn.Sequential(
conv(cin, cin, kernel_size=3, stride=1, padding=1, bias=False),
bn2d(cin),
nn.ReLU6(inplace=True),
conv(cin, cout, kernel_size=3, stride=1, padding=1, bias=False),
bn2d(cout),
)
def forward(self, x):
x = self.up_sample(x)
return self.conv(x)
class DepthWiseUNetBlock(nn.Module):
def __init__(self, cin, cout, bn2d, stride, deformable=False):
"""
a UNet block with 2x up sampling
"""
super().__init__()
stride_h, stride_w = stride
if stride_h == 1:
kernel_h = 1
padding_h = 0
elif stride_h == 2:
kernel_h = 4
padding_h = 1
elif stride_h == 4:
kernel_h = 4
padding_h = 0
if stride_w == 1:
kernel_w = 1
padding_w = 0
elif stride_w == 2:
kernel_w = 4
padding_w = 1
elif stride_w == 4:
kernel_w = 4
padding_w = 0
self.up_sample = nn.ConvTranspose2d(cin,
cin,
kernel_size=(kernel_h, kernel_w),
stride=(stride_h, stride_w),
padding=(padding_h, padding_w),
bias=True)
self.conv = nn.Sequential(
nn.Conv2d(cin,
cin,
kernel_size=3,
stride=1,
padding=1,
bias=False,
groups=cin),
nn.Conv2d(cin, cin, kernel_size=1, stride=1, padding=0,
bias=False),
bn2d(cin),
nn.ReLU6(inplace=True),
nn.Conv2d(cin,
cin,
kernel_size=3,
stride=1,
padding=1,
bias=False,
groups=cin),
nn.Conv2d(cin,
cout,
kernel_size=1,
stride=1,
padding=0,
bias=False),
bn2d(cout),
)
def forward(self, x):
x = self.up_sample(x)
return self.conv(x)
class SFTLayer(nn.Module):
def __init__(self, dim_in, dim_out):
super(SFTLayer, self).__init__()
self.SFT_scale_conv0 = nn.Linear(
dim_in,
dim_in,
)
self.SFT_scale_conv1 = nn.Linear(
dim_in,
dim_out,
)
self.SFT_shift_conv0 = nn.Linear(
dim_in,
dim_in,
)
self.SFT_shift_conv1 = nn.Linear(
dim_in,
dim_out,
)
def forward(self, x):
# x[0]: fea; x[1]: cond
scale = self.SFT_scale_conv1(
F.leaky_relu(self.SFT_scale_conv0(x[1]), 0.1, inplace=True))
shift = self.SFT_shift_conv1(
F.leaky_relu(self.SFT_shift_conv0(x[1]), 0.1, inplace=True))
return x[0] * (scale + 1) + shift
class MoreUNetBlock(nn.Module):
def __init__(self, cin, cout, bn2d, stride, deformable=False):
"""
a UNet block with 2x up sampling
"""
super().__init__()
stride_h, stride_w = stride
if stride_h == 1:
kernel_h = 1
padding_h = 0
elif stride_h == 2:
kernel_h = 4
padding_h = 1
elif stride_h == 4:
kernel_h = 4
padding_h = 0
if stride_w == 1:
kernel_w = 1
padding_w = 0
elif stride_w == 2:
kernel_w = 4
padding_w = 1
elif stride_w == 4:
kernel_w = 4
padding_w = 0
self.up_sample = nn.ConvTranspose2d(cin,
cin,
kernel_size=(kernel_h, kernel_w),
stride=(stride_h, stride_w),
padding=(padding_h, padding_w),
bias=True)
self.conv = nn.Sequential(
nn.Conv2d(cin,
cin,
kernel_size=3,
stride=1,
padding=1,
bias=False,
groups=cin),
nn.Conv2d(cin, cin, kernel_size=1, stride=1, padding=0,
bias=False), bn2d(cin), nn.ReLU6(inplace=True),
nn.Conv2d(cin,
cin,
kernel_size=3,
stride=1,
padding=1,
bias=False,
groups=cin),
nn.Conv2d(cin, cin, kernel_size=1, stride=1, padding=0,
bias=False), bn2d(cin), nn.ReLU6(inplace=True),
nn.Conv2d(cin,
cin,
kernel_size=3,
stride=1,
padding=1,
bias=False,
groups=cin),
nn.Conv2d(cin,
cout,
kernel_size=1,
stride=1,
padding=0,
bias=False), bn2d(cout), nn.ReLU6(inplace=True),
nn.Conv2d(cout,
cout,
kernel_size=3,
stride=1,
padding=1,
bias=False,
groups=cout),
nn.Conv2d(cout,
cout,
kernel_size=1,
stride=1,
padding=0,
bias=False), bn2d(cout))
def forward(self, x):
x = self.up_sample(x)
return self.conv(x)
class BinaryDecoder(nn.Module):
def __init__(self,
dim,
num_classes,
strides,
use_depthwise_unet=False,
use_more_unet=False,
binary_loss_type='DiceLoss') -> None:
super().__init__()
channels = [dim // 2**i for i in range(4)]
self.linear_enc2binary = nn.Sequential(
nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1),
nn.SyncBatchNorm(dim),
)
self.strides = strides
self.use_deformable = False
self.binary_decoder = nn.ModuleList()
unet = DepthWiseUNetBlock if use_depthwise_unet else UNetBlock
unet = MoreUNetBlock if use_more_unet else unet
for i in range(3):
up_sample_stride = self.strides[::-1][i]
cin, cout = channels[i], channels[i + 1]
self.binary_decoder.append(
unet(cin, cout, nn.SyncBatchNorm, up_sample_stride,
self.use_deformable))
last_stride = (self.strides[0][0] // 2, self.strides[0][1] // 2)
self.binary_decoder.append(
unet(cout, cout, nn.SyncBatchNorm, last_stride,
self.use_deformable))
if binary_loss_type == 'CrossEntropyDiceLoss' or binary_loss_type == 'BanlanceMultiClassCrossEntropyLoss':
segm_num_cls = num_classes - 2
else:
segm_num_cls = num_classes - 3
self.binary_pred = nn.Conv2d(channels[-1],
segm_num_cls,
kernel_size=1,
stride=1,
bias=True)
def patchify(self, imgs):
"""
imgs: (N, 3, H, W)
x: (N, L, patch_size**2 *3)
"""
p_h, p_w = self.strides[0]
p_h = p_h // 2
p_w = p_w // 2
h = imgs.shape[2] // p_h
w = imgs.shape[3] // p_w
x = imgs.reshape(shape=(imgs.shape[0], 3, h, p_h, w, p_w))
x = torch.einsum('nchpwq->nhwpqc', x)
x = x.reshape(shape=(imgs.shape[0], h * w, p_h * p_w * 3))
return x
def unpatchify(self, x):
"""
x: (N, patch_size**2, h, w)
imgs: (N, 3, H, W)
"""
p_h, p_w = self.strides[0]
p_h = p_h // 2
p_w = p_w // 2
_, _, h, w = x.shape
assert p_h * p_w == x.shape[1]
x = x.permute(0, 2, 3, 1) # [N, h, w, 4*4]
x = x.reshape(shape=(x.shape[0], h, w, p_h, p_w))
x = torch.einsum('nhwpq->nhpwq', x)
imgs = x.reshape(shape=(x.shape[0], h * p_h, w * p_w))
return imgs
def forward(self, x, time=None):
"""
x: the encoder feat to init the query for binary prediction, usually this is equal to the `img`.
img: the encoder feat.
txt: the unnormmed text to get the length of predicted words.
txt_feat: the text feat before character prediction.
xs: the encoder feat from different stages
"""
binary_feats = []
x = self.linear_enc2binary(x)
binary_feats.append(x.clone())
for i, d in enumerate(self.binary_decoder):
x = d(x)
binary_feats.append(x.clone())
#return None,binary_feats
x = self.binary_pred(x)
if self.training:
return x, binary_feats
else:
# return torch.sigmoid(x), binary_feat
return x.softmax(1), binary_feats
class LayerNormProxy(nn.Module):
def __init__(self, dim):
super().__init__()
self.norm = nn.LayerNorm(dim)
def forward(self, x):
x = x.permute(0, 2, 3, 1)
x = self.norm(x)
return x.permute(0, 3, 1, 2)
class DAttentionFuse(nn.Module):
def __init__(
self,
q_size=(4, 32),
kv_size=(4, 32),
n_heads=8,
n_head_channels=80,
n_groups=4,
attn_drop=0.0,
proj_drop=0.0,
stride=2,
offset_range_factor=2,
use_pe=True,
stage_idx=0,
):
'''
stage_idx from 2 to 3
'''
super().__init__()
self.n_head_channels = n_head_channels
self.scale = self.n_head_channels**-0.5
self.n_heads = n_heads
self.q_h, self.q_w = q_size
self.kv_h, self.kv_w = kv_size
self.nc = n_head_channels * n_heads
self.n_groups = n_groups
self.n_group_channels = self.nc // self.n_groups
self.n_group_heads = self.n_heads // self.n_groups
self.use_pe = use_pe
self.offset_range_factor = offset_range_factor
ksizes = [9, 7, 5, 3]
kk = ksizes[stage_idx]
self.conv_offset = nn.Sequential(
nn.Conv2d(2 * self.n_group_channels,
2 * self.n_group_channels,
kk,
stride,
kk // 2,
groups=self.n_group_channels),
LayerNormProxy(2 * self.n_group_channels), nn.GELU(),
nn.Conv2d(2 * self.n_group_channels, 2, 1, 1, 0, bias=False))
self.proj_q = nn.Conv2d(self.nc,
self.nc,
kernel_size=1,
stride=1,
padding=0)
self.proj_k = nn.Conv2d(self.nc,
self.nc,
kernel_size=1,
stride=1,
padding=0)
self.proj_v = nn.Conv2d(self.nc,
self.nc,
kernel_size=1,
stride=1,
padding=0)
self.proj_out = nn.Conv2d(self.nc,
self.nc,
kernel_size=1,
stride=1,
padding=0)
self.proj_drop = nn.Dropout(proj_drop, inplace=True)
self.attn_drop = nn.Dropout(attn_drop, inplace=True)
if self.use_pe:
self.rpe_table = nn.Parameter(
torch.zeros(self.n_heads, self.kv_h * 2 - 1,
self.kv_w * 2 - 1))
trunc_normal_(self.rpe_table, std=0.01)
else:
self.rpe_table = None
@torch.no_grad()
def _get_ref_points(self, H_key, W_key, B, dtype, device):
ref_y, ref_x = torch.meshgrid(
torch.linspace(0.5, H_key - 0.5, H_key, dtype=dtype,
device=device),
torch.linspace(0.5, W_key - 0.5, W_key, dtype=dtype,
device=device))
ref = torch.stack((ref_y, ref_x), -1)
ref[..., 1].div_(W_key).mul_(2).sub_(1)
ref[..., 0].div_(H_key).mul_(2).sub_(1)
ref = ref[None, ...].expand(B * self.n_groups, -1, -1,
-1) # B * g H W 2
return ref
def forward(self, x, y):
B, C, H, W = x.size()
dtype, device = x.dtype, x.device
q_off = torch.cat(
(x, y), dim=1
).reshape(B, self.n_groups, 2 * self.n_group_channels, H, W).flatten(
0, 1
) #einops.rearrange(torch.cat((x,y),dim=1), 'b (g c) h w -> (b g) c h w', g=self.n_groups, c=2*self.n_group_channels)
offset = self.conv_offset(q_off) # B * g 2 Hg Wg
Hk, Wk = offset.size(2), offset.size(3)
n_sample = Hk * Wk
if self.offset_range_factor > 0:
offset_range = torch.tensor([1.0 / Hk, 1.0 / Wk],
device=device).reshape(1, 2, 1, 1)
offset = offset.tanh().mul(offset_range).mul(
self.offset_range_factor)
offset = offset.permute(
0, 2, 3, 1) #einops.rearrange(offset, 'b p h w -> b h w p')
reference = self._get_ref_points(Hk, Wk, B, dtype, device)
if self.offset_range_factor >= 0:
pos = offset + reference
else:
pos = (offset + reference).tanh()
q = self.proj_q(y)
x_sampled = F.grid_sample(
input=x.reshape(B * self.n_groups, self.n_group_channels, H, W),
grid=pos[..., (1, 0)], # y, x -> x, y
mode='bilinear',
align_corners=False) # B * g, Cg, Hg, Wg
x_sampled = x_sampled.reshape(B, C, 1, n_sample)
q = q.reshape(B * self.n_heads, self.n_head_channels, H * W)
k = self.proj_k(x_sampled).reshape(B * self.n_heads,
self.n_head_channels, n_sample)
v = self.proj_v(x_sampled).reshape(B * self.n_heads,
self.n_head_channels, n_sample)
attn = torch.einsum('b c m, b c n -> b m n', q, k) # B * h, HW, Ns
attn = attn.mul(self.scale)
if self.use_pe:
rpe_table = self.rpe_table
rpe_bias = rpe_table[None, ...].expand(B, -1, -1, -1)
q_grid = self._get_ref_points(H, W, B, dtype, device)
displacement = (
q_grid.reshape(B * self.n_groups, H * W, 2).unsqueeze(2) -
pos.reshape(B * self.n_groups, n_sample,
2).unsqueeze(1)).mul(0.5)
attn_bias = F.grid_sample(input=rpe_bias.reshape(
B * self.n_groups, self.n_group_heads, 2 * H - 1, 2 * W - 1),
grid=displacement[..., (1, 0)],
mode='bilinear',
align_corners=False)
attn_bias = attn_bias.reshape(B * self.n_heads, H * W, n_sample)
attn = attn + attn_bias
attn = F.softmax(attn, dim=2)
attn = self.attn_drop(attn)
out = torch.einsum('b m n, b c n -> b c m', attn, v)
out = out.reshape(B, C, H, W)
out = self.proj_drop(self.proj_out(out))
return out, pos.reshape(B, self.n_groups, Hk, Wk,
2), reference.reshape(B, self.n_groups, Hk, Wk,
2)
class FuseModel(nn.Module):
def __init__(self,
dim,
deform_stride=2,
stage_idx=2,
k_size=[(2, 2), (2, 1), (2, 1), (1, 1)],
q_size=(2, 32)):
super().__init__()
channels = [dim // 2**i for i in range(4)]
refine_conv = nn.Conv2d
self.deform_stride = deform_stride
in_out_ch = [(-1, -2), (-2, -3), (-3, -4), (-4, -4)]
self.binary_condition_layer = DAttentionFuse(q_size=q_size,
kv_size=q_size,
stride=self.deform_stride,
n_head_channels=dim // 8,
stage_idx=stage_idx)
self.binary2refine_linear_norm = nn.ModuleList()
for i in range(len(k_size)):
self.binary2refine_linear_norm.append(
nn.Sequential(
Block(dim=channels[in_out_ch[i][0]]),
LayerNorm(channels[in_out_ch[i][0]],
eps=1e-6,
data_format='channels_first'),
refine_conv(channels[in_out_ch[i][0]],
channels[in_out_ch[i][1]],
kernel_size=k_size[i],
stride=k_size[i])), # [8, 32]
)
def forward(self, recog_feat, binary_feats, dec_in=None):
multi_feat = []
binary_feat = binary_feats[-1]
for i in range(len(self.binary2refine_linear_norm)):
binary_feat = self.binary2refine_linear_norm[i](binary_feat)
multi_feat.append(binary_feat)
binary_feat = binary_feat + binary_feats[0]
multi_feat[3] += binary_feats[0]
binary_refined_feat, pos, _ = self.binary_condition_layer(
recog_feat, binary_feat)
binary_refined_feat = binary_refined_feat + binary_feat
return binary_refined_feat, binary_feat
class CAMEncoder(nn.Module):
"""
Args:
in_chans (int): Number of input image channels. Default: 3
depths (tuple(int)): Number of blocks at each stage. Default: [3, 3, 9, 3]
dims (int): Feature dimension at each stage. Default: [96, 192, 384, 768]
drop_path_rate (float): Stochastic depth rate. Default: 0.
head_init_scale (float): Init scaling value for classifier weights and biases. Default: 1.
"""
def __init__(self,
in_channels=3,
encoder_config={'name': 'ConvNeXtV2'},
nb_classes=71,
strides=[(4, 4), (2, 1), (2, 1), (1, 1)],
k_size=[(2, 2), (2, 1), (2, 1), (1, 1)],
q_size=[2, 32],
deform_stride=2,
stage_idx=2,
use_depthwise_unet=True,
use_more_unet=False,
binary_loss_type='BanlanceMultiClassCrossEntropyLoss',
mid_size=True,
d_embedding=384):
super().__init__()
encoder_name = encoder_config.pop('name')
encoder_config['in_channels'] = in_channels
self.backbone = eval(encoder_name)(**encoder_config)
dim = self.backbone.out_channels
self.mid_size = mid_size
if self.mid_size:
self.enc_downsample = nn.Sequential(
nn.Conv2d(dim, dim // 2, kernel_size=1, stride=1),
nn.SyncBatchNorm(dim // 2),
#nn.ReLU6(inplace=True),
nn.Conv2d(dim // 2,
dim // 2,
kernel_size=3,
stride=1,
padding=1,
bias=False,
groups=dim // 2),
nn.Conv2d(dim // 2,
dim // 2,
kernel_size=1,
stride=1,
padding=0,
bias=False),
nn.SyncBatchNorm(dim // 2),
)
dim = dim // 2
# recognition decoder
self.linear_enc2recog = nn.Sequential(
nn.Conv2d(
dim,
dim,
kernel_size=1,
stride=1,
),
nn.SyncBatchNorm(dim),
#nn.ReLU6(inplace=True),
nn.Conv2d(dim,
dim,
kernel_size=3,
stride=1,
padding=1,
bias=False,
groups=dim),
nn.Conv2d(dim,
dim,
kernel_size=1,
stride=1,
padding=0,
bias=False),
nn.SyncBatchNorm(dim),
)
else:
self.linear_enc2recog = nn.Sequential(
nn.Conv2d(dim, dim // 2, kernel_size=1, stride=1),
nn.SyncBatchNorm(dim // 2),
#nn.ReLU6(inplace=True),
nn.Conv2d(dim // 2, dim, kernel_size=3, stride=1, padding=1),
nn.SyncBatchNorm(dim),
)
self.linear_norm = nn.Sequential(
nn.Linear(dim, d_embedding),
nn.LayerNorm(d_embedding, eps=1e-6),
)
self.out_channels = d_embedding
self.binary_decoder = BinaryDecoder(
dim,
nb_classes,
strides,
use_depthwise_unet=use_depthwise_unet,
use_more_unet=use_more_unet,
binary_loss_type=binary_loss_type)
self.fuse_model = FuseModel(dim,
deform_stride=deform_stride,
stage_idx=stage_idx,
k_size=k_size,
q_size=q_size)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, (nn.Conv2d, nn.Linear)):
trunc_normal_(m.weight, std=.02)
if isinstance(m, (nn.Conv2d, nn.Linear)) and m.bias is not None:
nn.init.constant_(m.bias, 0)
if isinstance(m, nn.ConvTranspose2d):
nn.init.kaiming_normal_(m.weight,
mode='fan_out',
nonlinearity='relu')
if m.bias is not None:
nn.init.constant_(m.bias, 0.)
elif isinstance(m, nn.LayerNorm):
if m.bias is not None:
nn.init.constant_(m.bias, 0)
if m.weight is not None:
nn.init.constant_(m.weight, 1.0)
elif isinstance(m, nn.SyncBatchNorm):
if m.bias is not None:
nn.init.constant_(m.bias, 0)
if m.weight is not None:
nn.init.constant_(m.weight, 1.0)
elif isinstance(m, nn.BatchNorm2d):
if m.bias is not None:
nn.init.constant_(m.bias, 0)
if m.weight is not None:
nn.init.constant_(m.weight, 1.0)
def no_weight_decay(self):
return {}
def forward(self, x):
output = {}
enc_feat = self.backbone(x)
if self.mid_size:
enc_feat = self.enc_downsample(enc_feat)
output['enc_feat'] = enc_feat
# binary mask
pred_binary, binary_feats = self.binary_decoder(enc_feat)
output['pred_binary'] = pred_binary
reg_feat = self.linear_enc2recog(enc_feat)
B, C, H, W = reg_feat.shape
last_feat, binary_feat = self.fuse_model(reg_feat, binary_feats)
dec_in = last_feat.reshape(B, C, H * W).permute(0, 2, 1)
dec_in = self.linear_norm(dec_in)
output['refined_feat'] = dec_in
output['binary_feat'] = binary_feats[-1]
return output