Spaces:
Running
Running
File size: 5,236 Bytes
29f689c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 |
"""This code is refer from:
https://github.com/jjwei66/BUSNet
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from .nrtr_decoder import PositionalEncoding, TransformerBlock
from .abinet_decoder import _get_mask, _get_length
class BUSDecoder(nn.Module):
def __init__(self,
in_channels,
out_channels,
nhead=8,
num_layers=4,
dim_feedforward=2048,
dropout=0.1,
max_length=25,
ignore_index=100,
pretraining=False,
detach=True):
super().__init__()
d_model = in_channels
self.ignore_index = ignore_index
self.pretraining = pretraining
self.d_model = d_model
self.detach = detach
self.max_length = max_length + 1 # additional stop token
self.out_channels = out_channels
# --------------------------------------------------------------------------
# decoder specifics
self.proj = nn.Linear(out_channels, d_model, False)
self.token_encoder = PositionalEncoding(dropout=0.1,
dim=d_model,
max_len=self.max_length)
self.pos_encoder = PositionalEncoding(dropout=0.1,
dim=d_model,
max_len=self.max_length)
self.decoder = nn.ModuleList([
TransformerBlock(
d_model=d_model,
nhead=nhead,
dim_feedforward=dim_feedforward,
attention_dropout_rate=dropout,
residual_dropout_rate=dropout,
with_self_attn=False,
with_cross_attn=True,
) for i in range(num_layers)
])
v_mask = torch.empty((1, 1, d_model))
l_mask = torch.empty((1, 1, d_model))
self.v_mask = nn.Parameter(v_mask)
self.l_mask = nn.Parameter(l_mask)
torch.nn.init.uniform_(self.v_mask, -0.001, 0.001)
torch.nn.init.uniform_(self.l_mask, -0.001, 0.001)
v_embeding = torch.empty((1, 1, d_model))
l_embeding = torch.empty((1, 1, d_model))
self.v_embeding = nn.Parameter(v_embeding)
self.l_embeding = nn.Parameter(l_embeding)
torch.nn.init.uniform_(self.v_embeding, -0.001, 0.001)
torch.nn.init.uniform_(self.l_embeding, -0.001, 0.001)
self.cls = nn.Linear(d_model, out_channels)
def forward_decoder(self, q, x, mask=None):
for decoder_layer in self.decoder:
q = decoder_layer(q, x, cross_mask=mask)
output = q # (N, T, E)
logits = self.cls(output) # (N, T, C)
return logits
def forward(self, img_feat, data=None):
"""
Args:
tokens: (N, T, C) where T is length, N is batch size and C is classes number
lengths: (N,)
"""
img_feat = img_feat + self.v_embeding
B, L, C = img_feat.shape
# --------------------------------------------------------------------------
# decoder procedure
T = self.max_length
zeros = img_feat.new_zeros((B, T, C))
zeros_len = img_feat.new_zeros(B)
query = self.pos_encoder(zeros)
# 1. vision decode
v_embed = torch.cat((img_feat, self.l_mask.repeat(B, T, 1)),
dim=1) # v
padding_mask = _get_mask(
self.max_length + zeros_len,
self.max_length) # 对tokens长度以外的padding # B, maxlen maxlen
v_mask = torch.zeros((1, 1, self.max_length, L),
device=img_feat.device).tile([B, 1, 1,
1]) # maxlen L
mask = torch.cat((v_mask, padding_mask), 3)
v_logits = self.forward_decoder(query, v_embed, mask=mask)
# 2. language decode
if self.training and self.pretraining:
tgt = torch.where(data[0] == self.ignore_index, 0, data[0])
tokens = F.one_hot(tgt, num_classes=self.out_channels)
tokens = tokens.float()
lengths = data[-1]
else:
tokens = torch.softmax(v_logits, dim=-1)
lengths = _get_length(v_logits)
tokens = tokens.detach()
token_embed = self.proj(tokens) # (N, T, E)
token_embed = self.token_encoder(token_embed) # (T, N, E)
token_embed = token_embed + self.l_embeding
padding_mask = _get_mask(lengths,
self.max_length) # 对tokens长度以外的padding
mask = torch.cat((v_mask, padding_mask), 3)
l_embed = torch.cat((self.v_mask.repeat(B, L, 1), token_embed), dim=1)
l_logits = self.forward_decoder(query, l_embed, mask=mask)
# 3. vision language decode
vl_embed = torch.cat((img_feat, token_embed), dim=1)
vl_logits = self.forward_decoder(query, vl_embed, mask=mask)
if self.training:
return {'align': [vl_logits], 'lang': l_logits, 'vision': v_logits}
else:
return F.softmax(vl_logits, -1)
|