import torch import torch.nn as nn from torch.nn import functional as F import numpy as np from fastai.vision import * from modules_matrn.attention import * from modules_matrn.model import Model, _default_tfmer_cfg from modules_matrn.transformer import (PositionalEncoding, TransformerEncoder, TransformerEncoderLayer) class BaseSemanticVisual_backbone_feature(Model): def __init__(self, config): super().__init__(config) d_model = ifnone(config.model_alignment_d_model, _default_tfmer_cfg['d_model']) nhead = ifnone(config.model_alignment_nhead, _default_tfmer_cfg['nhead']) d_inner = ifnone(config.model_alignment_d_inner, _default_tfmer_cfg['d_inner']) dropout = ifnone(config.model_alignmentl_dropout, _default_tfmer_cfg['dropout']) activation = ifnone(config.model_alignment_activation, _default_tfmer_cfg['activation']) num_layers = ifnone(config.model_alignment_num_layers, 2) self.mask_example_prob = ifnone(config.model_alignment_mask_example_prob, 0.9) self.mask_candidate_prob = ifnone(config.model_alignment_mask_candidate_prob, 0.9) self.num_vis_mask = ifnone(config.model_alignment_num_vis_mask, 10) self.nhead = nhead self.d_model = d_model self.use_self_attn = ifnone(config.model_alignment_use_self_attn, False) self.loss_weight = ifnone(config.model_alignment_loss_weight, 1.0) self.max_length = config.dataset_max_length + 1 # additional stop token self.debug = ifnone(config.global_debug, False) encoder_layer = TransformerEncoderLayer(d_model=d_model, nhead=nhead, dim_feedforward=d_inner, dropout=dropout, activation=activation) self.model1 = TransformerEncoder(encoder_layer, num_layers) self.pos_encoder_tfm = PositionalEncoding(d_model, dropout=0, max_len=8*32) mode = ifnone(config.model_alignment_attention_mode, 'nearest') self.model2_vis = PositionAttention( max_length=config.dataset_max_length + 1, # additional stop token mode=mode ) self.cls_vis = nn.Linear(d_model, self.charset.num_classes) self.cls_sem = nn.Linear(d_model, self.charset.num_classes) self.w_att = nn.Linear(2 * d_model, d_model) v_token = torch.empty((1, d_model)) self.v_token = nn.Parameter(v_token) torch.nn.init.uniform_(self.v_token, -0.001, 0.001) self.cls = nn.Linear(d_model, self.charset.num_classes) def forward(self, l_feature, v_feature, lengths_l=None, v_attn=None, l_logits=None, texts=None, training=True): """ Args: l_feature: (N, T, E) where T is length, N is batch size and d is dim of model v_feature: (N, E, H, W) lengths_l: (N,) v_attn: (N, T, H, W) l_logits: (N, T, C) texts: (N, T, C) """ padding_mask = self._get_padding_mask(lengths_l, self.max_length) l_feature = l_feature.permute(1, 0, 2) # (T, N, E) N, E, H, W = v_feature.size() v_feature = v_feature.view(N, E, H*W).contiguous().permute(2, 0, 1) # (H*W, N, E) if training: n, t, h, w = v_attn.shape v_attn = v_attn.view(n, t, -1) # (N, T, H*W) for idx, length in enumerate(lengths_l): if np.random.random() <= self.mask_example_prob: l_idx = np.random.randint(int(length)) v_random_idx = v_attn[idx, l_idx].argsort(descending=True).cpu().numpy()[:self.num_vis_mask,] v_random_idx = v_random_idx[np.random.random(v_random_idx.shape) <= self.mask_candidate_prob] v_feature[v_random_idx, idx] = self.v_token if len(v_attn.shape) == 4: n, t, h, w = v_attn.shape v_attn = v_attn.view(n, t, -1) # (N, T, H*W) zeros = v_feature.new_zeros((h*w, n, E)) # (H*W, N, E) base_pos = self.pos_encoder_tfm(zeros) # (H*W, N, E) base_pos = base_pos.permute(1, 0, 2) # (N, H*W, E) base_pos = torch.bmm(v_attn, base_pos) # (N, T, E) base_pos = base_pos.permute(1, 0, 2) # (T, N, E) l_feature = l_feature + base_pos sv_feature = torch.cat((v_feature, l_feature), dim=0) # (H*W+T, N, E) sv_feature = self.model1(sv_feature) # (H*W+T, N, E) sv_to_v_feature = sv_feature[:H*W] # (H*W, N, E) sv_to_s_feature = sv_feature[H*W:] # (T, N, E) sv_to_v_feature = sv_to_v_feature.permute(1, 2, 0).view(N, E, H, W) sv_to_v_feature, _ = self.model2_vis(sv_to_v_feature) # (N, T, E) sv_to_v_logits = self.cls_vis(sv_to_v_feature) # (N, T, C) pt_v_lengths = self._get_length(sv_to_v_logits) # (N,) sv_to_s_feature = sv_to_s_feature.permute(1, 0, 2) # (N, T, E) sv_to_s_logits = self.cls_sem(sv_to_s_feature) # (N, T, C) pt_s_lengths = self._get_length(sv_to_s_logits) # (N,) f = torch.cat((sv_to_v_feature, sv_to_s_feature), dim=2) f_att = torch.sigmoid(self.w_att(f)) output = f_att * sv_to_v_feature + (1 - f_att) * sv_to_s_feature logits = self.cls(output) # (N, T, C) pt_lengths = self._get_length(logits) return {'logits': logits, 'pt_lengths': pt_lengths, 'loss_weight':self.loss_weight*3, 'v_logits': sv_to_v_logits, 'pt_v_lengths': pt_v_lengths, 's_logits': sv_to_s_logits, 'pt_s_lengths': pt_s_lengths, 'name': 'alignment'}