import torch import torch.nn as nn from openrec.modeling.decoders.nrtr_decoder import PositionalEncoding, TransformerBlock class Transformer_Encoder(nn.Module): def __init__( self, n_layers=3, n_head=8, d_model=512, d_inner=2048, dropout=0.1, n_position=256, ): super(Transformer_Encoder, self).__init__() self.pe = PositionalEncoding(dropout=dropout, dim=d_model, max_len=n_position) self.layer_stack = nn.ModuleList([ TransformerBlock(d_model, n_head, d_inner) for _ in range(n_layers) ]) self.layer_norm = nn.LayerNorm(d_model, eps=1e-6) def forward(self, enc_output, src_mask): enc_output = self.pe(enc_output) # position embeding for enc_layer in self.layer_stack: enc_output = enc_layer(enc_output, self_mask=src_mask) enc_output = self.layer_norm(enc_output) return enc_output class PP_layer(nn.Module): def __init__(self, n_dim=512, N_max_character=25, n_position=256): super(PP_layer, self).__init__() self.character_len = N_max_character self.f0_embedding = nn.Embedding(N_max_character, n_dim) self.w0 = nn.Linear(N_max_character, n_position) self.wv = nn.Linear(n_dim, n_dim) self.we = nn.Linear(n_dim, N_max_character) self.active = nn.Tanh() self.softmax = nn.Softmax(dim=2) def forward(self, enc_output): reading_order = torch.arange(self.character_len, dtype=torch.long, device=enc_output.device) reading_order = reading_order.unsqueeze(0).expand( enc_output.shape[0], -1) # (S,) -> (B, S) reading_order = self.f0_embedding(reading_order) # b,25,512 # calculate attention t = self.w0(reading_order.transpose(1, 2)) # b,512,256 t = self.active(t.transpose(1, 2) + self.wv(enc_output)) # b,256,512 t = self.we(t) # b,256,25 t = self.softmax(t.transpose(1, 2)) # b,25,256 g_output = torch.bmm(t, enc_output) # b,25,512 return g_output class Prediction(nn.Module): def __init__( self, n_dim=512, n_class=37, N_max_character=25, n_position=256, ): super(Prediction, self).__init__() self.pp = PP_layer(n_dim=n_dim, N_max_character=N_max_character, n_position=n_position) self.pp_share = PP_layer(n_dim=n_dim, N_max_character=N_max_character, n_position=n_position) self.w_vrm = nn.Linear(n_dim, n_class) # output layer self.w_share = nn.Linear(n_dim, n_class) # output layer self.nclass = n_class def forward(self, cnn_feature, f_res, f_sub, is_Train=False, use_mlm=True): if is_Train: if not use_mlm: g_output = self.pp(cnn_feature) # b,25,512 g_output = self.w_vrm(g_output) f_res = 0 f_sub = 0 return g_output, f_res, f_sub g_output = self.pp(cnn_feature) # b,25,512 f_res = self.pp_share(f_res) f_sub = self.pp_share(f_sub) g_output = self.w_vrm(g_output) f_res = self.w_share(f_res) f_sub = self.w_share(f_sub) return g_output, f_res, f_sub else: g_output = self.pp(cnn_feature) # b,25,512 g_output = self.w_vrm(g_output) return g_output class MLM(nn.Module): """Architecture of MLM.""" def __init__( self, n_dim=512, n_position=256, n_head=8, dim_feedforward=2048, max_text_length=25, ): super(MLM, self).__init__() self.MLM_SequenceModeling_mask = Transformer_Encoder( n_layers=2, n_head=n_head, d_model=n_dim, d_inner=dim_feedforward, n_position=n_position, ) self.MLM_SequenceModeling_WCL = Transformer_Encoder( n_layers=1, n_head=n_head, d_model=n_dim, d_inner=dim_feedforward, n_position=n_position, ) self.pos_embedding = nn.Embedding(max_text_length, n_dim) self.w0_linear = nn.Linear(1, n_position) self.wv = nn.Linear(n_dim, n_dim) self.active = nn.Tanh() self.we = nn.Linear(n_dim, 1) self.sigmoid = nn.Sigmoid() def forward(self, input, label_pos): # transformer unit for generating mask_c feature_v_seq = self.MLM_SequenceModeling_mask(input, src_mask=None) # position embedding layer pos_emb = self.pos_embedding(label_pos.long()) pos_emb = self.w0_linear(torch.unsqueeze(pos_emb, dim=2)).transpose(1, 2) # fusion position embedding with features V & generate mask_c att_map_sub = self.active(pos_emb + self.wv(feature_v_seq)) att_map_sub = self.we(att_map_sub) # b,256,1 att_map_sub = self.sigmoid(att_map_sub.transpose(1, 2)) # b,1,256 # WCL # generate inputs for WCL f_res = input * (1 - att_map_sub.transpose(1, 2) ) # second path with remaining string f_sub = input * (att_map_sub.transpose(1, 2) ) # first path with occluded character # transformer units in WCL f_res = self.MLM_SequenceModeling_WCL(f_res, src_mask=None) f_sub = self.MLM_SequenceModeling_WCL(f_sub, src_mask=None) return f_res, f_sub, att_map_sub class MLM_VRM(nn.Module): def __init__( self, n_layers=3, n_position=256, n_dim=512, n_head=8, dim_feedforward=2048, max_text_length=25, nclass=37, ): super(MLM_VRM, self).__init__() self.MLM = MLM( n_dim=n_dim, n_position=n_position, n_head=n_head, dim_feedforward=dim_feedforward, max_text_length=max_text_length, ) self.SequenceModeling = Transformer_Encoder( n_layers=n_layers, n_head=n_head, d_model=n_dim, d_inner=dim_feedforward, n_position=n_position, ) self.Prediction = Prediction( n_dim=n_dim, n_position=n_position, N_max_character=max_text_length + 1, n_class=nclass, ) # N_max_character = 1 eos + 25 characters self.nclass = nclass self.max_text_length = max_text_length def forward(self, input, label_pos, training_step, is_Train=False): nT = self.max_text_length b, c, h, w = input.shape input = input.reshape(b, c, -1) input = input.transpose(1, 2) if is_Train: if training_step == 'LF_1': f_res = 0 f_sub = 0 input = self.SequenceModeling(input, src_mask=None) text_pre, text_rem, text_mas = self.Prediction(input, f_res, f_sub, is_Train=True, use_mlm=False) return text_pre, text_pre, text_pre elif training_step == 'LF_2': # MLM f_res, f_sub, mask_c = self.MLM(input, label_pos) input = self.SequenceModeling(input, src_mask=None) text_pre, text_rem, text_mas = self.Prediction(input, f_res, f_sub, is_Train=True) return text_pre, text_rem, text_mas elif training_step == 'LA': # MLM f_res, f_sub, mask_c = self.MLM(input, label_pos) # use the mask_c (1 for occluded character and 0 for remaining characters) to occlude input # ratio controls the occluded number in a batch ratio = 2 character_mask = torch.zeros_like(mask_c) character_mask[0:b // ratio, :, :] = mask_c[0:b // ratio, :, :] input = input * (1 - character_mask.transpose(1, 2)) # VRM # transformer unit for VRM input = self.SequenceModeling(input, src_mask=None) # prediction layer for MLM and VSR text_pre, text_rem, text_mas = self.Prediction(input, f_res, f_sub, is_Train=True) return text_pre, text_rem, text_mas else: # VRM is only used in the testing stage f_res = 0 f_sub = 0 contextual_feature = self.SequenceModeling(input, src_mask=None) C = self.Prediction(contextual_feature, f_res, f_sub, is_Train=False, use_mlm=False) C = C.transpose(1, 0) # (25, b, 38)) out_res = torch.zeros(nT, b, self.nclass).type_as(input.data) out_length = torch.zeros(b).type_as(input.data) now_step = 0 while 0 in out_length and now_step < nT: tmp_result = C[now_step, :, :] out_res[now_step] = tmp_result tmp_result = tmp_result.topk(1)[1].squeeze(dim=1) for j in range(b): if out_length[j] == 0 and tmp_result[j] == 0: out_length[j] = now_step + 1 now_step += 1 for j in range(0, b): if int(out_length[j]) == 0: out_length[j] = nT start = 0 output = torch.zeros(int(out_length.sum()), self.nclass).type_as(input.data) for i in range(0, b): cur_length = int(out_length[i]) output[start:start + cur_length] = out_res[0:cur_length, i, :] start += cur_length return output, out_length class VisionLANDecoder(nn.Module): def __init__( self, in_channels, out_channels, n_head=None, training_step='LA', n_layers=3, n_position=256, max_text_length=25, ): super(VisionLANDecoder, self).__init__() self.training_step = training_step n_dim = in_channels dim_feedforward = n_dim * 4 n_head = n_head if n_head is not None else n_dim // 32 self.MLM_VRM = MLM_VRM( n_layers=n_layers, n_position=n_position, n_dim=n_dim, n_head=n_head, dim_feedforward=dim_feedforward, max_text_length=max_text_length, nclass=out_channels + 1, ) def forward(self, x, data=None): # MLM + VRM if self.training: label_pos = data[-2] text_pre, text_rem, text_mas = self.MLM_VRM(x, label_pos, self.training_step, is_Train=True) return text_pre, text_rem, text_mas else: output, out_length = self.MLM_VRM(x, None, self.training_step, is_Train=False) return output, out_length