Spaces:
Running
Running
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from openrec.modeling.decoders.nrtr_decoder import Embeddings, PositionalEncoding, TransformerBlock # , Beam | |
from openrec.modeling.decoders.visionlan_decoder import Transformer_Encoder | |
def generate_square_subsequent_mask(sz): | |
r"""Generate a square mask for the sequence. The masked positions are filled with float('-inf'). | |
Unmasked positions are filled with float(0.0). | |
""" | |
mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1) | |
mask = (mask.float().masked_fill(mask == 0, float('-inf')).masked_fill( | |
mask == 1, float(0.0))) | |
return mask | |
class SEM_Pre(nn.Module): | |
def __init__( | |
self, | |
d_model=512, | |
dst_vocab_size=40, | |
residual_dropout_rate=0.1, | |
): | |
super(SEM_Pre, self).__init__() | |
self.embedding = Embeddings(d_model=d_model, vocab=dst_vocab_size) | |
self.positional_encoding = PositionalEncoding( | |
dropout=residual_dropout_rate, | |
dim=d_model, | |
) | |
def forward(self, tgt): | |
tgt = self.embedding(tgt) | |
tgt = self.positional_encoding(tgt) | |
tgt_mask = generate_square_subsequent_mask(tgt.shape[1]).to(tgt.device) | |
return tgt, tgt_mask | |
class POS_Pre(nn.Module): | |
def __init__( | |
self, | |
d_model=512, | |
): | |
super(POS_Pre, self).__init__() | |
self.pos_encoding = PositionalEncoding( | |
dropout=0.1, | |
dim=d_model, | |
) | |
self.linear1 = nn.Linear(d_model, d_model) | |
self.linear2 = nn.Linear(d_model, d_model) | |
self.norm2 = nn.LayerNorm(d_model) | |
def forward(self, tgt): | |
pos = tgt.new_zeros(*tgt.shape) | |
pos = self.pos_encoding(pos) | |
pos2 = self.linear2(F.relu(self.linear1(pos))) | |
pos = self.norm2(pos + pos2) | |
return pos | |
class DSF(nn.Module): | |
def __init__(self, d_model, fusion_num): | |
super(DSF, self).__init__() | |
self.w_att = nn.Linear(fusion_num * d_model, d_model) | |
def forward(self, l_feature, v_feature): | |
""" | |
Args: | |
l_feature: (N, T, E) where T is length, N is batch size and d is dim of model | |
v_feature: (N, T, E) shape the same as l_feature | |
l_lengths: (N,) | |
v_lengths: (N,) | |
""" | |
f = torch.cat((l_feature, v_feature), dim=2) | |
f_att = torch.sigmoid(self.w_att(f)) | |
output = f_att * v_feature + (1 - f_att) * l_feature | |
return output | |
class MDCDP(nn.Module): | |
r""" | |
Multi-Domain CharacterDistance Perception | |
""" | |
def __init__(self, d_model, n_head, d_inner, num_layers): | |
super(MDCDP, self).__init__() | |
self.num_layers = num_layers | |
# step 1 SAE | |
self.layers_pos = nn.ModuleList([ | |
TransformerBlock(d_model, n_head, d_inner) | |
for _ in range(num_layers) | |
]) | |
# step 2 CBI: | |
self.layers2 = nn.ModuleList([ | |
TransformerBlock( | |
d_model, | |
n_head, | |
d_inner, | |
with_self_attn=False, | |
with_cross_attn=True, | |
) for _ in range(num_layers) | |
]) | |
self.layers3 = nn.ModuleList([ | |
TransformerBlock( | |
d_model, | |
n_head, | |
d_inner, | |
with_self_attn=False, | |
with_cross_attn=True, | |
) for _ in range(num_layers) | |
]) | |
# step 3 :DSF | |
self.dynamic_shared_fusion = DSF(d_model, 2) | |
def forward( | |
self, | |
sem, | |
vis, | |
pos, | |
tgt_mask=None, | |
memory_mask=None, | |
): | |
for i in range(self.num_layers): | |
# ----------step 1 -----------: SAE: Self-Attention Enhancement | |
pos = self.layers_pos[i](pos, self_mask=tgt_mask) | |
# ----------step 2 -----------: CBI: Cross-Branch Interaction | |
# CBI-V | |
pos_vis = self.layers2[i]( | |
pos, | |
vis, | |
cross_mask=memory_mask, | |
) | |
# CBI-S | |
pos_sem = self.layers3[i]( | |
pos, | |
sem, | |
cross_mask=tgt_mask, | |
) | |
# ----------step 3 -----------: DSF: Dynamic Shared Fusion | |
pos = self.dynamic_shared_fusion(pos_vis, pos_sem) | |
output = pos | |
return output | |
class ConvBnRelu(nn.Module): | |
# adapt padding for kernel_size change | |
def __init__( | |
self, | |
in_channels, | |
out_channels, | |
kernel_size, | |
conv=nn.Conv2d, | |
stride=2, | |
inplace=True, | |
): | |
super().__init__() | |
p_size = [int(k // 2) for k in kernel_size] | |
# p_size = int(kernel_size//2) | |
self.conv = conv( | |
in_channels, | |
out_channels, | |
kernel_size=kernel_size, | |
stride=stride, | |
padding=p_size, | |
) | |
self.bn = nn.BatchNorm2d(out_channels) | |
self.relu = nn.ReLU(inplace=inplace) | |
def forward(self, x): | |
x = self.conv(x) | |
x = self.bn(x) | |
x = self.relu(x) | |
return x | |
class CDistNetDecoder(nn.Module): | |
def __init__(self, | |
in_channels, | |
out_channels, | |
n_head=None, | |
num_encoder_blocks=3, | |
num_decoder_blocks=3, | |
beam_size=0, | |
max_len=25, | |
residual_dropout_rate=0.1, | |
add_conv=False, | |
**kwargs): | |
super(CDistNetDecoder, self).__init__() | |
dst_vocab_size = out_channels | |
self.ignore_index = dst_vocab_size - 1 | |
self.bos = dst_vocab_size - 2 | |
self.eos = 0 | |
self.beam_size = beam_size | |
self.max_len = max_len | |
self.add_conv = add_conv | |
d_model = in_channels | |
dim_feedforward = d_model * 4 | |
n_head = n_head if n_head is not None else d_model // 32 | |
if add_conv: | |
self.convbnrelu = ConvBnRelu( | |
in_channels=in_channels, | |
out_channels=in_channels, | |
kernel_size=(1, 3), | |
stride=(1, 2), | |
) | |
if num_encoder_blocks > 0: | |
self.positional_encoding = PositionalEncoding( | |
dropout=0.1, | |
dim=d_model, | |
) | |
self.trans_encoder = Transformer_Encoder( | |
n_layers=num_encoder_blocks, | |
n_head=n_head, | |
d_model=d_model, | |
d_inner=dim_feedforward, | |
) | |
else: | |
self.trans_encoder = None | |
self.semantic_branch = SEM_Pre( | |
d_model=d_model, | |
dst_vocab_size=dst_vocab_size, | |
residual_dropout_rate=residual_dropout_rate, | |
) | |
self.positional_branch = POS_Pre(d_model=d_model) | |
self.mdcdp = MDCDP(d_model, n_head, dim_feedforward // 2, | |
num_decoder_blocks) | |
self._reset_parameters() | |
self.tgt_word_prj = nn.Linear( | |
d_model, dst_vocab_size - 2, | |
bias=False) # We don't predict <bos> nor <pad> | |
self.tgt_word_prj.weight.data.normal_(mean=0.0, std=d_model**-0.5) | |
def forward(self, x, data=None): | |
if self.add_conv: | |
x = self.convbnrelu(x) | |
# x = rearrange(x, "b c h w -> b (w h) c") | |
x = x.flatten(2).transpose(1, 2) | |
if self.trans_encoder is not None: | |
x = self.positional_encoding(x) | |
vis_feat = self.trans_encoder(x, src_mask=None) | |
else: | |
vis_feat = x | |
if self.training: | |
max_len = data[1].max() | |
tgt = data[0][:, :1 + max_len] | |
res = self.forward_train(vis_feat, tgt) | |
else: | |
if self.beam_size > 0: | |
res = self.forward_beam(vis_feat) | |
else: | |
res = self.forward_test(vis_feat) | |
return res | |
def forward_train(self, vis_feat, tgt): | |
sem_feat, sem_mask = self.semantic_branch(tgt) | |
pos_feat = self.positional_branch(sem_feat) | |
output = self.mdcdp( | |
sem_feat, | |
vis_feat, | |
pos_feat, | |
tgt_mask=sem_mask, | |
memory_mask=None, | |
) | |
logit = self.tgt_word_prj(output) | |
return logit | |
def forward_test(self, vis_feat): | |
bs = vis_feat.size(0) | |
dec_seq = torch.full( | |
(bs, self.max_len + 1), | |
self.ignore_index, | |
dtype=torch.int64, | |
device=vis_feat.device, | |
) | |
dec_seq[:, 0] = self.bos | |
logits = [] | |
for len_dec_seq in range(0, self.max_len): | |
sem_feat, sem_mask = self.semantic_branch(dec_seq[:, :len_dec_seq + | |
1]) | |
pos_feat = self.positional_branch(sem_feat) | |
output = self.mdcdp( | |
sem_feat, | |
vis_feat, | |
pos_feat, | |
tgt_mask=sem_mask, | |
memory_mask=None, | |
) | |
dec_output = output[:, -1:, :] | |
word_prob = F.softmax(self.tgt_word_prj(dec_output), dim=-1) | |
logits.append(word_prob) | |
if len_dec_seq < self.max_len: | |
# greedy decode. add the next token index to the target input | |
dec_seq[:, len_dec_seq + 1] = word_prob.squeeze(1).argmax(-1) | |
# Efficient batch decoding: If all output words have at least one EOS token, end decoding. | |
if (dec_seq == self.eos).any(dim=-1).all(): | |
break | |
logits = torch.cat(logits, dim=1) | |
return logits | |
def forward_beam(self, x): | |
"""Translation work in one batch.""" | |
# to do | |
def _reset_parameters(self): | |
r"""Initiate parameters in the transformer model.""" | |
for p in self.parameters(): | |
if p.dim() > 1: | |
nn.init.xavier_uniform_(p) | |