Spaces:
Running
Running
import torch | |
import torch.nn as nn | |
from openrec.modeling.decoders.nrtr_decoder import PositionalEncoding, TransformerBlock | |
class Transformer_Encoder(nn.Module): | |
def __init__( | |
self, | |
n_layers=3, | |
n_head=8, | |
d_model=512, | |
d_inner=2048, | |
dropout=0.1, | |
n_position=256, | |
): | |
super(Transformer_Encoder, self).__init__() | |
self.pe = PositionalEncoding(dropout=dropout, | |
dim=d_model, | |
max_len=n_position) | |
self.layer_stack = nn.ModuleList([ | |
TransformerBlock(d_model, n_head, d_inner) for _ in range(n_layers) | |
]) | |
self.layer_norm = nn.LayerNorm(d_model, eps=1e-6) | |
def forward(self, enc_output, src_mask): | |
enc_output = self.pe(enc_output) # position embeding | |
for enc_layer in self.layer_stack: | |
enc_output = enc_layer(enc_output, self_mask=src_mask) | |
enc_output = self.layer_norm(enc_output) | |
return enc_output | |
class PP_layer(nn.Module): | |
def __init__(self, n_dim=512, N_max_character=25, n_position=256): | |
super(PP_layer, self).__init__() | |
self.character_len = N_max_character | |
self.f0_embedding = nn.Embedding(N_max_character, n_dim) | |
self.w0 = nn.Linear(N_max_character, n_position) | |
self.wv = nn.Linear(n_dim, n_dim) | |
self.we = nn.Linear(n_dim, N_max_character) | |
self.active = nn.Tanh() | |
self.softmax = nn.Softmax(dim=2) | |
def forward(self, enc_output): | |
reading_order = torch.arange(self.character_len, | |
dtype=torch.long, | |
device=enc_output.device) | |
reading_order = reading_order.unsqueeze(0).expand( | |
enc_output.shape[0], -1) # (S,) -> (B, S) | |
reading_order = self.f0_embedding(reading_order) # b,25,512 | |
# calculate attention | |
t = self.w0(reading_order.transpose(1, 2)) # b,512,256 | |
t = self.active(t.transpose(1, 2) + self.wv(enc_output)) # b,256,512 | |
t = self.we(t) # b,256,25 | |
t = self.softmax(t.transpose(1, 2)) # b,25,256 | |
g_output = torch.bmm(t, enc_output) # b,25,512 | |
return g_output | |
class Prediction(nn.Module): | |
def __init__( | |
self, | |
n_dim=512, | |
n_class=37, | |
N_max_character=25, | |
n_position=256, | |
): | |
super(Prediction, self).__init__() | |
self.pp = PP_layer(n_dim=n_dim, | |
N_max_character=N_max_character, | |
n_position=n_position) | |
self.pp_share = PP_layer(n_dim=n_dim, | |
N_max_character=N_max_character, | |
n_position=n_position) | |
self.w_vrm = nn.Linear(n_dim, n_class) # output layer | |
self.w_share = nn.Linear(n_dim, n_class) # output layer | |
self.nclass = n_class | |
def forward(self, cnn_feature, f_res, f_sub, is_Train=False, use_mlm=True): | |
if is_Train: | |
if not use_mlm: | |
g_output = self.pp(cnn_feature) # b,25,512 | |
g_output = self.w_vrm(g_output) | |
f_res = 0 | |
f_sub = 0 | |
return g_output, f_res, f_sub | |
g_output = self.pp(cnn_feature) # b,25,512 | |
f_res = self.pp_share(f_res) | |
f_sub = self.pp_share(f_sub) | |
g_output = self.w_vrm(g_output) | |
f_res = self.w_share(f_res) | |
f_sub = self.w_share(f_sub) | |
return g_output, f_res, f_sub | |
else: | |
g_output = self.pp(cnn_feature) # b,25,512 | |
g_output = self.w_vrm(g_output) | |
return g_output | |
class MLM(nn.Module): | |
"""Architecture of MLM.""" | |
def __init__( | |
self, | |
n_dim=512, | |
n_position=256, | |
n_head=8, | |
dim_feedforward=2048, | |
max_text_length=25, | |
): | |
super(MLM, self).__init__() | |
self.MLM_SequenceModeling_mask = Transformer_Encoder( | |
n_layers=2, | |
n_head=n_head, | |
d_model=n_dim, | |
d_inner=dim_feedforward, | |
n_position=n_position, | |
) | |
self.MLM_SequenceModeling_WCL = Transformer_Encoder( | |
n_layers=1, | |
n_head=n_head, | |
d_model=n_dim, | |
d_inner=dim_feedforward, | |
n_position=n_position, | |
) | |
self.pos_embedding = nn.Embedding(max_text_length, n_dim) | |
self.w0_linear = nn.Linear(1, n_position) | |
self.wv = nn.Linear(n_dim, n_dim) | |
self.active = nn.Tanh() | |
self.we = nn.Linear(n_dim, 1) | |
self.sigmoid = nn.Sigmoid() | |
def forward(self, input, label_pos): | |
# transformer unit for generating mask_c | |
feature_v_seq = self.MLM_SequenceModeling_mask(input, src_mask=None) | |
# position embedding layer | |
pos_emb = self.pos_embedding(label_pos.long()) | |
pos_emb = self.w0_linear(torch.unsqueeze(pos_emb, | |
dim=2)).transpose(1, 2) | |
# fusion position embedding with features V & generate mask_c | |
att_map_sub = self.active(pos_emb + self.wv(feature_v_seq)) | |
att_map_sub = self.we(att_map_sub) # b,256,1 | |
att_map_sub = self.sigmoid(att_map_sub.transpose(1, 2)) # b,1,256 | |
# WCL | |
# generate inputs for WCL | |
f_res = input * (1 - att_map_sub.transpose(1, 2) | |
) # second path with remaining string | |
f_sub = input * (att_map_sub.transpose(1, 2) | |
) # first path with occluded character | |
# transformer units in WCL | |
f_res = self.MLM_SequenceModeling_WCL(f_res, src_mask=None) | |
f_sub = self.MLM_SequenceModeling_WCL(f_sub, src_mask=None) | |
return f_res, f_sub, att_map_sub | |
class MLM_VRM(nn.Module): | |
def __init__( | |
self, | |
n_layers=3, | |
n_position=256, | |
n_dim=512, | |
n_head=8, | |
dim_feedforward=2048, | |
max_text_length=25, | |
nclass=37, | |
): | |
super(MLM_VRM, self).__init__() | |
self.MLM = MLM( | |
n_dim=n_dim, | |
n_position=n_position, | |
n_head=n_head, | |
dim_feedforward=dim_feedforward, | |
max_text_length=max_text_length, | |
) | |
self.SequenceModeling = Transformer_Encoder( | |
n_layers=n_layers, | |
n_head=n_head, | |
d_model=n_dim, | |
d_inner=dim_feedforward, | |
n_position=n_position, | |
) | |
self.Prediction = Prediction( | |
n_dim=n_dim, | |
n_position=n_position, | |
N_max_character=max_text_length + 1, | |
n_class=nclass, | |
) # N_max_character = 1 eos + 25 characters | |
self.nclass = nclass | |
self.max_text_length = max_text_length | |
def forward(self, input, label_pos, training_step, is_Train=False): | |
nT = self.max_text_length | |
b, c, h, w = input.shape | |
input = input.reshape(b, c, -1) | |
input = input.transpose(1, 2) | |
if is_Train: | |
if training_step == 'LF_1': | |
f_res = 0 | |
f_sub = 0 | |
input = self.SequenceModeling(input, src_mask=None) | |
text_pre, text_rem, text_mas = self.Prediction(input, | |
f_res, | |
f_sub, | |
is_Train=True, | |
use_mlm=False) | |
return text_pre, text_pre, text_pre | |
elif training_step == 'LF_2': | |
# MLM | |
f_res, f_sub, mask_c = self.MLM(input, label_pos) | |
input = self.SequenceModeling(input, src_mask=None) | |
text_pre, text_rem, text_mas = self.Prediction(input, | |
f_res, | |
f_sub, | |
is_Train=True) | |
return text_pre, text_rem, text_mas | |
elif training_step == 'LA': | |
# MLM | |
f_res, f_sub, mask_c = self.MLM(input, label_pos) | |
# use the mask_c (1 for occluded character and 0 for remaining characters) to occlude input | |
# ratio controls the occluded number in a batch | |
ratio = 2 | |
character_mask = torch.zeros_like(mask_c) | |
character_mask[0:b // ratio, :, :] = mask_c[0:b // ratio, :, :] | |
input = input * (1 - character_mask.transpose(1, 2)) | |
# VRM | |
# transformer unit for VRM | |
input = self.SequenceModeling(input, src_mask=None) | |
# prediction layer for MLM and VSR | |
text_pre, text_rem, text_mas = self.Prediction(input, | |
f_res, | |
f_sub, | |
is_Train=True) | |
return text_pre, text_rem, text_mas | |
else: # VRM is only used in the testing stage | |
f_res = 0 | |
f_sub = 0 | |
contextual_feature = self.SequenceModeling(input, src_mask=None) | |
C = self.Prediction(contextual_feature, | |
f_res, | |
f_sub, | |
is_Train=False, | |
use_mlm=False) | |
C = C.transpose(1, 0) # (25, b, 38)) | |
out_res = torch.zeros(nT, b, self.nclass).type_as(input.data) | |
out_length = torch.zeros(b).type_as(input.data) | |
now_step = 0 | |
while 0 in out_length and now_step < nT: | |
tmp_result = C[now_step, :, :] | |
out_res[now_step] = tmp_result | |
tmp_result = tmp_result.topk(1)[1].squeeze(dim=1) | |
for j in range(b): | |
if out_length[j] == 0 and tmp_result[j] == 0: | |
out_length[j] = now_step + 1 | |
now_step += 1 | |
for j in range(0, b): | |
if int(out_length[j]) == 0: | |
out_length[j] = nT | |
start = 0 | |
output = torch.zeros(int(out_length.sum()), | |
self.nclass).type_as(input.data) | |
for i in range(0, b): | |
cur_length = int(out_length[i]) | |
output[start:start + cur_length] = out_res[0:cur_length, i, :] | |
start += cur_length | |
return output, out_length | |
class VisionLANDecoder(nn.Module): | |
def __init__( | |
self, | |
in_channels, | |
out_channels, | |
n_head=None, | |
training_step='LA', | |
n_layers=3, | |
n_position=256, | |
max_text_length=25, | |
): | |
super(VisionLANDecoder, self).__init__() | |
self.training_step = training_step | |
n_dim = in_channels | |
dim_feedforward = n_dim * 4 | |
n_head = n_head if n_head is not None else n_dim // 32 | |
self.MLM_VRM = MLM_VRM( | |
n_layers=n_layers, | |
n_position=n_position, | |
n_dim=n_dim, | |
n_head=n_head, | |
dim_feedforward=dim_feedforward, | |
max_text_length=max_text_length, | |
nclass=out_channels + 1, | |
) | |
def forward(self, x, data=None): | |
# MLM + VRM | |
if self.training: | |
label_pos = data[-2] | |
text_pre, text_rem, text_mas = self.MLM_VRM(x, | |
label_pos, | |
self.training_step, | |
is_Train=True) | |
return text_pre, text_rem, text_mas | |
else: | |
output, out_length = self.MLM_VRM(x, | |
None, | |
self.training_step, | |
is_Train=False) | |
return output, out_length | |