topdu's picture
openocr demo
29f689c
raw
history blame
13.6 kB
import numpy as np
import torch
from torch import nn
from torch.nn import functional as F
from torch.nn.init import ones_, trunc_normal_, zeros_
from openrec.modeling.common import DropPath, Identity, Mlp
from openrec.modeling.decoders.nrtr_decoder import Embeddings
class Attention(nn.Module):
def __init__(
self,
dim,
num_heads=8,
qkv_bias=False,
qk_scale=None,
attn_drop=0.0,
proj_drop=0.0,
):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = qk_scale or head_dim**-0.5
self.q = nn.Linear(dim, dim, bias=qkv_bias)
self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, q, kv, key_mask=None):
N, C = kv.shape[1:]
QN = q.shape[1]
q = self.q(q).reshape([-1, QN, self.num_heads,
C // self.num_heads]).transpose(1, 2)
q = q * self.scale
k, v = self.kv(kv).reshape(
[-1, N, 2, self.num_heads,
C // self.num_heads]).permute(2, 0, 3, 1, 4)
attn = q.matmul(k.transpose(2, 3))
if key_mask is not None:
attn = attn + key_mask.unsqueeze(1)
attn = F.softmax(attn, -1)
# if not self.training:
# self.attn_map = attn
attn = self.attn_drop(attn)
x = (attn.matmul(v)).transpose(1, 2).reshape((-1, QN, C))
x = self.proj(x)
x = self.proj_drop(x)
return x
class EdgeDecoderLayer(nn.Module):
def __init__(
self,
dim,
num_heads,
mlp_ratio=4.0,
qkv_bias=False,
qk_scale=None,
drop=0.0,
attn_drop=0.0,
drop_path=[0.0, 0.0],
act_layer=nn.GELU,
norm_layer='nn.LayerNorm',
epsilon=1e-6,
):
super().__init__()
self.head_dim = dim // num_heads
self.scale = qk_scale or self.head_dim**-0.5
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
self.drop_path1 = DropPath(
drop_path[0]) if drop_path[0] > 0.0 else Identity()
self.norm1 = eval(norm_layer)(dim, epsilon=epsilon)
self.norm2 = eval(norm_layer)(dim, epsilon=epsilon)
self.p = nn.Linear(dim, dim)
self.cv = nn.Linear(dim, dim)
self.pv = nn.Linear(dim, dim)
self.dim = dim
self.num_heads = num_heads
self.p_proj = nn.Linear(dim, dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp_ratio = mlp_ratio
self.mlp = Mlp(
in_features=dim,
hidden_features=mlp_hidden_dim,
act_layer=act_layer,
drop=drop,
)
def forward(self, p, cv, pv):
pN = p.shape[1]
vN = cv.shape[1]
p_shortcut = p
p1 = self.p(p).reshape(
[-1, pN, self.num_heads,
self.dim // self.num_heads]).transpose(1, 2)
cv1 = self.cv(cv).reshape(
[-1, vN, self.num_heads,
self.dim // self.num_heads]).transpose(1, 2)
pv1 = self.pv(pv).reshape(
[-1, vN, self.num_heads,
self.dim // self.num_heads]).transpose(1, 2)
edge = F.softmax(p1.matmul(pv1.transpose(2, 3)), -1) # B h N N
p_c = (edge @ cv1).transpose(1, 2).reshape((-1, pN, self.dim))
x1 = self.norm1(p_shortcut + self.drop_path1(self.p_proj(p_c)))
x = self.norm2(x1 + self.drop_path1(self.mlp(x1)))
return x
class DecoderLayer(nn.Module):
def __init__(
self,
dim,
num_heads,
mlp_ratio=4.0,
qkv_bias=False,
qk_scale=None,
drop=0.0,
attn_drop=0.0,
drop_path=0.0,
act_layer=nn.GELU,
norm_layer=nn.LayerNorm,
epsilon=1e-6,
):
super().__init__()
self.norm1 = norm_layer(dim, eps=epsilon)
self.mixer = Attention(
dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
attn_drop=attn_drop,
proj_drop=drop,
)
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else Identity()
self.norm2 = norm_layer(dim, eps=epsilon)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp_ratio = mlp_ratio
self.mlp = Mlp(
in_features=dim,
hidden_features=mlp_hidden_dim,
act_layer=act_layer,
drop=drop,
)
def forward(self, q, kv, key_mask=None):
x1 = self.norm1(q + self.drop_path(self.mixer(q, kv, key_mask)))
x = self.norm2(x1 + self.drop_path(self.mlp(x1)))
return x
class CPPDDecoder(nn.Module):
def __init__(self,
in_channels,
out_channels,
num_layer=2,
drop_path_rate=0.1,
max_len=25,
vis_seq=50,
iters=1,
pos_len=False,
ch=False,
rec_layer=1,
num_heads=None,
ds=False,
**kwargs):
super(CPPDDecoder, self).__init__()
self.out_channels = out_channels # none + 26 + 10
dim = in_channels
self.dim = dim
self.iters = iters
self.max_len = max_len + 1 # max_len + eos
self.pos_len = pos_len
self.ch = ch
self.char_node_embed = Embeddings(d_model=dim,
vocab=self.out_channels,
scale_embedding=True)
self.pos_node_embed = Embeddings(d_model=dim,
vocab=self.max_len,
scale_embedding=True)
dpr = np.linspace(0, drop_path_rate, num_layer + rec_layer)
self.char_node_decoder = nn.ModuleList([
DecoderLayer(
dim=dim,
num_heads=dim // 32 if num_heads is None else num_heads,
mlp_ratio=4.0,
qkv_bias=True,
drop_path=dpr[i],
) for i in range(num_layer)
])
self.pos_node_decoder = nn.ModuleList([
DecoderLayer(
dim=dim,
num_heads=dim // 32 if num_heads is None else num_heads,
mlp_ratio=4.0,
qkv_bias=True,
drop_path=dpr[i],
) for i in range(num_layer)
])
self.edge_decoder = nn.ModuleList([
DecoderLayer(
dim=dim,
num_heads=dim // 32 if num_heads is None else num_heads,
mlp_ratio=4.0,
qkv_bias=True,
qk_scale=1.0 if (rec_layer + i) % 2 != 0 else None,
drop_path=dpr[num_layer + i],
) for i in range(rec_layer)
])
self.rec_layer_num = rec_layer
self_mask = torch.tril(
torch.ones([self.max_len, self.max_len], dtype=torch.float32))
self_mask = torch.where(
self_mask > 0,
torch.zeros_like(self_mask, dtype=torch.float32),
torch.full([self.max_len, self.max_len],
float('-inf'),
dtype=torch.float32),
)
self.self_mask = self_mask.unsqueeze(0)
self.char_pos_embed = nn.Parameter(torch.zeros([1, self.max_len, dim],
dtype=torch.float32),
requires_grad=True)
self.ds = ds
if not self.ds:
self.vis_pos_embed = nn.Parameter(torch.zeros([1, vis_seq, dim],
dtype=torch.float32),
requires_grad=True)
trunc_normal_(self.vis_pos_embed, std=0.02)
self.char_node_fc1 = nn.Linear(dim, max_len)
self.pos_node_fc1 = nn.Linear(dim, self.max_len)
self.edge_fc = nn.Linear(dim, self.out_channels)
trunc_normal_(self.char_pos_embed, std=0.02)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=0.02)
if isinstance(m, nn.Linear) and m.bias is not None:
zeros_(m.bias)
elif isinstance(m, nn.LayerNorm):
zeros_(m.bias)
ones_(m.weight)
@torch.jit.ignore
def no_weight_decay(self):
return {
'char_pos_embed', 'vis_pos_embed', 'char_node_embed',
'pos_node_embed'
}
def forward(self, x, data=None):
if self.training:
return self.forward_train(x, data)
else:
return self.forward_test(x)
def forward_test(self, x):
if not self.ds:
visual_feats = x + self.vis_pos_embed
else:
visual_feats = x
bs = visual_feats.shape[0]
pos_node_embed = self.pos_node_embed(
torch.arange(self.max_len).cuda(
x.get_device())).unsqueeze(0) + self.char_pos_embed
pos_node_embed = torch.tile(pos_node_embed, [bs, 1, 1])
char_vis_node_query = visual_feats
pos_vis_node_query = torch.concat([pos_node_embed, visual_feats], 1)
for char_decoder_layer, pos_decoder_layer in zip(
self.char_node_decoder, self.pos_node_decoder):
char_vis_node_query = char_decoder_layer(char_vis_node_query,
char_vis_node_query)
pos_vis_node_query = pos_decoder_layer(
pos_vis_node_query, pos_vis_node_query[:, self.max_len:, :])
pos_node_query = pos_vis_node_query[:, :self.max_len, :]
char_vis_feats = char_vis_node_query
# pos_vis_feats = pos_vis_node_query[:, self.max_len :, :]
# pos_node_feats = self.edge_decoder(
# pos_node_query, char_vis_feats, pos_vis_feats
# ) # B, 26, dim
pos_node_feats = pos_node_query
for layer_i in range(self.rec_layer_num):
rec_layer = self.edge_decoder[layer_i]
if (self.rec_layer_num + layer_i) % 2 == 0:
pos_node_feats = rec_layer(pos_node_feats, pos_node_feats,
self.self_mask)
else:
pos_node_feats = rec_layer(pos_node_feats, char_vis_feats)
edge_feats = self.edge_fc(pos_node_feats) # B, 26, 37
edge_logits = F.softmax(
edge_feats,
-1) # * F.sigmoid(pos_node_feats1.unsqueeze(-1)) # B, 26, 37
return edge_logits
def forward_train(self, x, targets=None):
if not self.ds:
visual_feats = x + self.vis_pos_embed
else:
visual_feats = x
bs = visual_feats.shape[0]
if self.ch:
char_node_embed = self.char_node_embed(targets[-2])
else:
char_node_embed = self.char_node_embed(
torch.arange(self.out_channels).cuda(
x.get_device())).unsqueeze(0)
char_node_embed = torch.tile(char_node_embed, [bs, 1, 1])
counting_char_num = char_node_embed.shape[1]
pos_node_embed = self.pos_node_embed(
torch.arange(self.max_len).cuda(
x.get_device())).unsqueeze(0) + self.char_pos_embed
pos_node_embed = torch.tile(pos_node_embed, [bs, 1, 1])
node_feats = []
char_vis_node_query = torch.concat([char_node_embed, visual_feats], 1)
pos_vis_node_query = torch.concat([pos_node_embed, visual_feats], 1)
for char_decoder_layer, pos_decoder_layer in zip(
self.char_node_decoder, self.pos_node_decoder):
char_vis_node_query = char_decoder_layer(
char_vis_node_query,
char_vis_node_query[:, counting_char_num:, :])
pos_vis_node_query = pos_decoder_layer(
pos_vis_node_query, pos_vis_node_query[:, self.max_len:, :])
char_node_query = char_vis_node_query[:, :counting_char_num, :]
pos_node_query = pos_vis_node_query[:, :self.max_len, :]
char_vis_feats = char_vis_node_query[:, counting_char_num:, :]
char_node_feats1 = self.char_node_fc1(char_node_query)
pos_node_feats1 = self.pos_node_fc1(pos_node_query)
if not self.pos_len:
diag_mask = torch.eye(pos_node_feats1.shape[1]).unsqueeze(0).tile(
[pos_node_feats1.shape[0], 1, 1])
pos_node_feats1 = (
pos_node_feats1 *
diag_mask.cuda(pos_node_feats1.get_device())).sum(-1)
node_feats.append(char_node_feats1)
node_feats.append(pos_node_feats1)
pos_node_feats = pos_node_query
for layer_i in range(self.rec_layer_num):
rec_layer = self.edge_decoder[layer_i]
if (self.rec_layer_num + layer_i) % 2 == 0:
pos_node_feats = rec_layer(pos_node_feats, pos_node_feats,
self.self_mask)
else:
pos_node_feats = rec_layer(pos_node_feats, char_vis_feats)
edge_feats = self.edge_fc(pos_node_feats) # B, 26, 37
return node_feats, edge_feats