topdu's picture
openocr demo
29f689c
raw
history blame
33.9 kB
import numpy as np
import torch
from torch import nn
from torch.nn import functional as F
from torch.nn.init import ones_, trunc_normal_, zeros_
from openrec.modeling.common import DropPath, Identity, Mlp
from openrec.modeling.decoders.nrtr_decoder import Embeddings
class CrossAttention(nn.Module):
def __init__(
self,
dim,
num_heads=8,
qkv_bias=False,
qk_scale=None,
attn_drop=0.0,
proj_drop=0.0,
):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = qk_scale or head_dim**-0.5
self.q = nn.Linear(dim, dim, bias=qkv_bias)
self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, q, kv, key_mask=None):
N, C = kv.shape[1:]
QN = q.shape[1]
q = self.q(q).reshape([-1, QN, self.num_heads,
C // self.num_heads]).transpose(1, 2)
q = q * self.scale
k, v = self.kv(kv).reshape(
[-1, N, 2, self.num_heads,
C // self.num_heads]).permute(2, 0, 3, 1, 4)
attn = q.matmul(k.transpose(2, 3))
if key_mask is not None:
attn = attn + key_mask.unsqueeze(1)
attn = F.softmax(attn, -1)
if not self.training:
self.attn_map = attn
attn = self.attn_drop(attn)
x = (attn.matmul(v)).transpose(1, 2).reshape((-1, QN, C))
x = self.proj(x)
x = self.proj_drop(x)
return x
class EdgeDecoderLayer(nn.Module):
def __init__(
self,
dim,
num_heads,
mlp_ratio=4.0,
qkv_bias=False,
qk_scale=None,
drop=0.0,
attn_drop=0.0,
drop_path=[0.0, 0.0],
act_layer=nn.GELU,
norm_layer='nn.LayerNorm',
epsilon=1e-6,
):
super().__init__()
self.head_dim = dim // num_heads
self.scale = qk_scale or self.head_dim**-0.5
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
self.drop_path1 = DropPath(
drop_path[0]) if drop_path[0] > 0.0 else Identity()
self.norm1 = eval(norm_layer)(dim, eps=epsilon)
self.norm2 = eval(norm_layer)(dim, eps=epsilon)
# self.c = nn.Linear(dim, dim*2)
self.p = nn.Linear(dim, dim)
self.cv = nn.Linear(dim, dim)
self.pv = nn.Linear(dim, dim)
self.dim = dim
self.num_heads = num_heads
self.p_proj = nn.Linear(dim, dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp_ratio = mlp_ratio
self.mlp = Mlp(
in_features=dim,
hidden_features=mlp_hidden_dim,
act_layer=act_layer,
drop=drop,
)
def forward(self, p, cv, pv):
pN = p.shape[1]
vN = cv.shape[1]
p_shortcut = p
p1 = self.p(p).reshape(
[-1, pN, self.num_heads,
self.dim // self.num_heads]).transpose(1, 2)
cv1 = self.cv(cv).reshape(
[-1, vN, self.num_heads,
self.dim // self.num_heads]).transpose(1, 2)
pv1 = self.pv(pv).reshape(
[-1, vN, self.num_heads,
self.dim // self.num_heads]).transpose(1, 2)
edge = F.softmax(p1.matmul(pv1.transpose(2, 3)), -1) # B h N N
p_c = (edge @ cv1).transpose(1, 2).reshape((-1, pN, self.dim))
x1 = self.norm1(p_shortcut + self.drop_path1(self.p_proj(p_c)))
x = self.norm2(x1 + self.drop_path1(self.mlp(x1)))
return x
class DecoderLayer(nn.Module):
def __init__(
self,
dim,
num_heads,
mlp_ratio=4.0,
qkv_bias=False,
qk_scale=None,
drop=0.0,
attn_drop=0.0,
drop_path=0.0,
act_layer=nn.GELU,
norm_layer='nn.LayerNorm',
epsilon=1e-6,
):
super().__init__()
self.norm1 = eval(norm_layer)(dim, eps=epsilon)
self.normkv = eval(norm_layer)(dim, eps=epsilon)
self.mixer = CrossAttention(
dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
attn_drop=attn_drop,
proj_drop=drop,
)
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else Identity()
self.norm2 = eval(norm_layer)(dim, eps=epsilon)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp_ratio = mlp_ratio
self.mlp = Mlp(
in_features=dim,
hidden_features=mlp_hidden_dim,
act_layer=act_layer,
drop=drop,
)
def forward(self, q, kv, key_mask=None):
x1 = q + self.drop_path(
self.mixer(self.norm1(q), self.normkv(kv), key_mask))
x = x1 + self.drop_path(self.mlp(self.norm2(x1)))
return x
class CMFFLayer(nn.Module):
def __init__(
self,
dim,
num_heads,
mlp_ratio=4.0,
qkv_bias=False,
qk_scale=None,
drop=0.0,
attn_drop=0.0,
drop_path=0.0,
act_layer=nn.GELU,
epsilon=1e-6,
):
super().__init__()
self.normq1 = nn.LayerNorm(dim, eps=epsilon)
self.normkv1 = nn.LayerNorm(dim, eps=epsilon)
self.images_to_question_cross_attn = CrossAttention(
dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
attn_drop=attn_drop,
proj_drop=drop,
)
self.normq2 = nn.LayerNorm(dim, eps=epsilon)
self.normkv2 = nn.LayerNorm(dim, eps=epsilon)
self.question_to_images_cross_attn = CrossAttention(
dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
attn_drop=attn_drop,
proj_drop=drop,
)
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else Identity()
self.normmlp = nn.LayerNorm(dim, eps=epsilon)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(
in_features=dim,
hidden_features=mlp_hidden_dim,
act_layer=act_layer,
drop=drop,
)
def forward(self, question_f, prompt_f, visual_f, mask=None):
query_add = torch.concat([question_f, prompt_f, visual_f], 1)
query_add = query_add + self.drop_path(
self.images_to_question_cross_attn(self.normq1(query_add),
self.normkv1(prompt_f), mask))
query_add = query_add + self.drop_path(
self.question_to_images_cross_attn(
self.normq2(query_add),
self.normkv2(query_add[:, -visual_f.shape[1]:, :])))
query_updated = query_add + self.drop_path(
self.mlp(self.normmlp(query_add)))
question_f_updated = query_updated[:, :question_f.shape[1], :]
prompt_f_updated = query_updated[:, question_f.
shape[1]:-visual_f.shape[1], :]
visual_f_updated = query_updated[:, -visual_f.shape[1]:, :]
return question_f_updated, prompt_f_updated, visual_f_updated
class IGTRDecoder(nn.Module):
def __init__(self,
in_channels,
dim,
out_channels,
num_layer=2,
drop_path_rate=0.1,
max_len=25,
vis_seq=50,
ch=False,
ar=False,
refine_iter=0,
quesall=True,
next_pred=False,
ds=False,
pos2d=False,
check_search=False,
max_size=[8, 32],
**kwargs):
super(IGTRDecoder, self).__init__()
self.out_channels = out_channels
self.dim = dim
self.max_len = max_len + 3 # max_len + eos + bos
self.ch = ch
self.char_embed = Embeddings(d_model=dim,
vocab=self.out_channels,
scale_embedding=True)
self.ignore_index = out_channels - 1
self.ar = ar
self.refine_iter = refine_iter
self.bos = self.out_channels - 2
self.eos = 0
self.next_pred = next_pred
self.quesall = quesall
self.check_search = check_search
dpr = np.linspace(0, drop_path_rate, num_layer + 2)
self.cmff_decoder = nn.ModuleList([
CMFFLayer(dim=dim,
num_heads=dim // 32,
mlp_ratio=4.0,
qkv_bias=True,
drop_path=dpr[i]) for i in range(num_layer)
])
self.answer_to_question_layer = DecoderLayer(dim=dim,
num_heads=dim // 32,
mlp_ratio=4.0,
qkv_bias=True,
drop_path=dpr[-2])
self.answer_to_image_layer = DecoderLayer(dim=dim,
num_heads=dim // 32,
mlp_ratio=4.0,
qkv_bias=True,
drop_path=dpr[-1])
self.char_pos_embed = nn.Parameter(torch.zeros([self.max_len, dim],
dtype=torch.float32),
requires_grad=True)
self.appear_num_embed = nn.Parameter(torch.zeros([self.max_len, dim],
dtype=torch.float32),
requires_grad=True)
self.ds = ds
self.pos2d = pos2d
if not ds:
self.vis_pos_embed = nn.Parameter(torch.zeros([1, vis_seq, dim],
dtype=torch.float32),
requires_grad=True)
trunc_normal_(self.vis_pos_embed, std=0.02)
elif pos2d:
pos_embed = torch.zeros([1, max_size[0] * max_size[1], dim],
dtype=torch.float32)
trunc_normal_(pos_embed, mean=0, std=0.02)
self.vis_pos_embed = nn.Parameter(
pos_embed.transpose(1, 2).reshape(1, dim, max_size[0],
max_size[1]),
requires_grad=True,
)
self.prompt_pos_embed = nn.Parameter(torch.zeros([1, 6, dim],
dtype=torch.float32),
requires_grad=True)
self.answer_query = nn.Parameter(torch.zeros([1, 1, dim],
dtype=torch.float32),
requires_grad=True)
self.norm_pred = nn.LayerNorm(dim, eps=1e-6)
self.ques1_head = nn.Linear(dim, self.out_channels - 2)
self.ques2_head = nn.Linear(dim, self.max_len, bias=False)
self.ques3_head = nn.Linear(dim, self.max_len - 1)
self.ques4_head = nn.Linear(dim, self.max_len - 1)
trunc_normal_(self.char_pos_embed, std=0.02)
trunc_normal_(self.appear_num_embed, std=0.02)
trunc_normal_(self.answer_query, std=0.02)
trunc_normal_(self.prompt_pos_embed, std=0.02)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=0.02)
if isinstance(m, nn.Linear) and m.bias is not None:
zeros_(m.bias)
elif isinstance(m, nn.LayerNorm):
zeros_(m.bias)
ones_(m.weight)
@torch.jit.ignore
def no_weight_decay(self):
return {
'char_pos_embed', 'vis_pos_embed', 'appear_num_embed',
'answer_query', 'char_embed'
}
def question_encoder(self, targets, train_i):
(
prompt_pos_idx,
prompt_char_idx,
ques_pos_idx,
ques1_answer,
ques2_char_idx,
ques2_answer,
ques4_char_num,
ques_len,
ques2_len,
prompt_len,
) = targets
max_ques_len = torch.max(ques_len)
max_ques2_len = torch.max(ques2_len)
max_prompt_len = torch.max(prompt_len)
if self.next_pred and (train_i == 2 or train_i == 3):
prompt_pos = self.prompt_pos_embed
prompt_char_idx = prompt_char_idx[:, :max_prompt_len]
else:
prompt_pos = F.embedding(
prompt_pos_idx[:, :max_prompt_len], self.char_pos_embed
) # bs lp [ 0, 4, 3, 12, 12, 12, 12, 12, 12, 12, 12]
prompt_char_idx = prompt_char_idx[:, :max_prompt_len]
prompt_char = self.char_embed(prompt_char_idx) # bs lp
prompt = prompt_pos + prompt_char
mask_1234 = torch.where(prompt_char_idx == self.ignore_index,
float('-inf'), 0)
ques1 = F.embedding(ques_pos_idx[:, :max_ques_len],
self.char_pos_embed) # bs lq1 dim
ques1_answer = ques1_answer[:, :max_ques_len]
if self.quesall or train_i == 0:
ques2_char = self.char_embed(ques2_char_idx[:, :max_ques2_len, 1])
ques2 = ques2_char + F.embedding(ques2_char_idx[:, :max_ques2_len,
0],
self.char_pos_embed) # bs lq2 dim
ques2_answer = ques2_answer[:, :max_ques2_len]
ques2_head = F.embedding(ques2_char_idx[:, :max_ques2_len, 0],
self.ques2_head.weight)
ques4_char = self.char_embed(ques1_answer)
ques4_ap_num = F.embedding(ques4_char_num[:, :max_ques_len],
self.appear_num_embed)
ques4 = ques4_char + ques4_ap_num
ques4_answer = ques_pos_idx[:, :max_ques_len]
return (
prompt,
ques1,
ques2,
ques2_head,
ques4,
ques1_answer,
ques2_answer,
ques4_answer,
mask_1234.unsqueeze(1),
)
else:
return prompt, ques1, ques1_answer, mask_1234.unsqueeze(1)
def forward(self, x, data=None):
if self.training:
return self.forward_train(x, data)
else:
return self.forward_test(x)
def forward_test(self, x):
if not self.ds:
visual_f = x + self.vis_pos_embed
elif self.pos2d:
x = x + self.vis_pos_embed[:, :, :x.shape[2], :x.shape[3]]
visual_f = x.flatten(2).transpose(1, 2)
else:
visual_f = x
bs = x.shape[0]
prompt_bos = self.char_embed(
torch.full(
[bs, 1], self.bos, dtype=torch.long,
device=x.get_device())) + self.char_pos_embed[:1, :].unsqueeze(
0) # BOS prompt
ques_all = torch.tile(self.char_pos_embed.unsqueeze(0), (bs, 1, 1))
if not self.ar:
if self.check_search:
tgt_in = torch.full((bs, self.max_len),
self.ignore_index,
dtype=torch.long,
device=x.get_device())
tgt_in[:, 0] = self.bos
logits = []
for j in range(1, self.max_len):
visual_f_check = visual_f
ques_check_i = ques_all[:, j:j + 1, :] + self.char_embed(
torch.arange(self.out_channels - 2,
device=x.get_device())).unsqueeze(0)
prompt_check = ques_all[:, :j] + self.char_embed(
tgt_in[:, :j])
# prompt_check = prompt_bos
mask = torch.where(
(tgt_in[:, :j] == self.eos).int().cumsum(-1) > 0,
float('-inf'), 0)
for layer in self.cmff_decoder:
ques_check_i, prompt_check, visual_f_check = layer(
ques_check_i, prompt_check, visual_f_check,
mask.unsqueeze(1))
answer_query_i = self.answer_to_question_layer(
ques_check_i, prompt_check, mask.unsqueeze(1))
answer_pred_i = self.norm_pred(
self.answer_to_image_layer(
answer_query_i, visual_f_check)) # B, 26, 37
# the next token probability is in the output's ith token position
fc_2 = self.ques2_head.weight[j:j + 1].unsqueeze(0)
fc_2 = fc_2.tile([bs, 1, 1])
p_i = fc_2 @ answer_pred_i.transpose(1, 2)
# p_i = p_i[:, 0, :]
logits.append(p_i)
if j < self.max_len - 1:
# greedy decode. add the next token index to the target input
tgt_in[:, j] = p_i.squeeze().argmax(-1)
# Efficient batch decoding: If all output words have at least one EOS token, end decoding.
if (tgt_in == self.eos).any(dim=-1).all():
break
logits = torch.cat(logits, dim=1)
else:
ques_pd = ques_all[:, 1:, :]
prompt_pd = prompt_bos
visual_f_pd = visual_f
for layer in self.cmff_decoder:
ques_pd, prompt_pd, visual_f_pd = layer(
ques_pd, prompt_pd, visual_f_pd)
answer_query_pd = self.answer_to_question_layer(
ques_pd, prompt_pd)
answer_feats_pd = self.norm_pred(
self.answer_to_image_layer(answer_query_pd,
visual_f_pd)) # B, 26, 37
logits = self.ques1_head(answer_feats_pd)
elif self.next_pred:
ques_pd_1 = ques_all[:, 1:2, :]
prompt_pd = prompt_bos
visual_f_pd = visual_f
for layer in self.cmff_decoder:
ques_pd_1, prompt_pd, visual_f_pd = layer(
ques_pd_1, prompt_pd, visual_f_pd)
answer_query_pd = self.answer_to_question_layer(
ques_pd_1, prompt_pd)
answer_feats_pd = self.norm_pred(
self.answer_to_image_layer(answer_query_pd,
visual_f_pd)) # B, 26, 37
logits_pd_1 = self.ques1_head(answer_feats_pd)
ques_next = self.char_pos_embed[-2:-1, :].unsqueeze(0).tile(
[bs, 1, 1])
prompt_next_bos = (self.char_embed(
torch.full(
[bs, 1], self.bos, dtype=torch.long,
device=x.get_device())) + self.prompt_pos_embed[:, :1, :])
pred_prob, pred_id = F.softmax(logits_pd_1, -1).max(-1)
pred_prob_list = [pred_prob]
pred_id_list = [pred_id]
for j in range(1, 70):
prompt_next_1 = self.char_embed(
pred_id) + self.prompt_pos_embed[:,
-1 * pred_id.shape[1]:, :]
prompt_next = torch.concat([prompt_next_bos, prompt_next_1], 1)
ques_next_i = ques_next
visual_f_i = visual_f
for layer in self.cmff_decoder:
ques_next_i, prompt_next, visual_f_pd = layer(
ques_next_i, prompt_next, visual_f_i)
answer_query_next_i = self.answer_to_question_layer(
ques_next_i, prompt_next)
answer_feats_next_i = self.norm_pred(
self.answer_to_image_layer(answer_query_next_i,
visual_f_i)) # B, 26, 37
logits_next_i = self.ques1_head(answer_feats_next_i)
# pred_id = logits_next_i.argmax(-1)
pred_prob_i, pred_id_i = F.softmax(logits_next_i, -1).max(-1)
pred_prob_list.append(pred_prob_i)
pred_id_list.append(pred_id_i)
if (torch.concat(pred_id_list,
1) == self.eos).any(dim=-1).all():
break
if pred_id.shape[1] >= 5:
pred_id = torch.concat([pred_id[:, 1:], pred_id_i], 1)
else:
pred_id = torch.concat([pred_id, pred_id_i], 1)
return [
torch.concat(pred_id_list, 1),
torch.concat(pred_prob_list, 1)
]
else:
tgt_in = torch.full((bs, self.max_len),
self.ignore_index,
dtype=torch.long,
device=x.get_device())
tgt_in[:, 0] = self.bos
logits = []
for j in range(1, self.max_len):
visual_f_ar = visual_f
ques_i = ques_all[:, j:j + 1, :]
prompt_ar = ques_all[:, :j] + self.char_embed(tgt_in[:, :j])
mask = torch.where(
(tgt_in[:, :j] == self.eos).int().cumsum(-1) > 0,
float('-inf'), 0)
for layer in self.cmff_decoder:
ques_i, prompt_ar, visual_f_ar = layer(
ques_i, prompt_ar, visual_f_ar, mask.unsqueeze(1))
answer_query_i = self.answer_to_question_layer(
ques_i, prompt_ar, mask.unsqueeze(1))
answer_pred_i = self.norm_pred(
self.answer_to_image_layer(answer_query_i,
visual_f_ar)) # B, 26, 37
# the next token probability is in the output's ith token position
p_i = self.ques1_head(answer_pred_i)
logits.append(p_i)
if j < self.max_len - 1:
# greedy decode. add the next token index to the target input
tgt_in[:, j] = p_i.squeeze().argmax(-1)
# Efficient batch decoding: If all output words have at least one EOS token, end decoding.
if (tgt_in == self.eos).any(dim=-1).all():
break
logits = torch.cat(logits, dim=1)
if self.refine_iter > 0:
pred_probs, pred_idxs = F.softmax(logits, -1).max(-1)
for i in range(self.refine_iter):
mask_check = (pred_idxs == self.eos).int().cumsum(-1) <= 1
ques_check_all = self.char_embed(
pred_idxs) + ques_all[:, 1:pred_idxs.shape[1] + 1, :]
prompt_check = prompt_bos
visual_f_check = visual_f
ques_check = ques_check_all
for layer in self.cmff_decoder:
ques_check, prompt_check, visual_f_check = layer(
ques_check, prompt_check, visual_f_check)
answer_query_check = self.answer_to_question_layer(
ques_check, prompt_check)
answer_pred_check = self.norm_pred(
self.answer_to_image_layer(answer_query_check,
visual_f_check)) # B, 26, 37
ques2_head = self.ques2_head.weight[1:pred_idxs.shape[1] +
1, :]
ques2_head = torch.tile(ques2_head.unsqueeze(0), [bs, 1, 1])
answer2_pred = answer_pred_check.matmul(
ques2_head.transpose(1, 2))
diag_mask = torch.eye(answer2_pred.shape[1],
device=x.get_device()).unsqueeze(0).tile(
[bs, 1, 1])
answer2_pred = F.sigmoid(
(answer2_pred * diag_mask).sum(-1)) * mask_check
check_result = answer2_pred < 0.9 # pred_probs < 0.99
prompt_refine = torch.concat([prompt_bos, ques_check_all], 1)
mask_refine = torch.where(
check_result, float('-inf'), 0) + torch.where(
(pred_idxs == self.eos).int().cumsum(-1) < 1, 0,
float('-inf'))
mask_refine = torch.concat(
[torch.zeros([bs, 1], device=x.get_device()), mask_refine],
1).unsqueeze(1)
ques_refine = ques_all[:, 1:pred_idxs.shape[1] + 1, :]
visual_f_refine = visual_f
for layer in self.cmff_decoder:
ques_refine, prompt_refine, visual_f_refine = layer(
ques_refine, prompt_refine, visual_f_refine,
mask_refine)
answer_query_refine = self.answer_to_question_layer(
ques_refine, prompt_refine, mask_refine)
answer_pred_refine = self.norm_pred(
self.answer_to_image_layer(answer_query_refine,
visual_f_refine)) # B, 26, 37
answer_refine = self.ques1_head(answer_pred_refine)
refine_probs, refine_idxs = F.softmax(answer_refine,
-1).max(-1)
pred_idxs_refine = torch.where(check_result, refine_idxs,
pred_idxs)
pred_idxs = torch.where(mask_check, pred_idxs_refine,
pred_idxs)
pred_probs_refine = torch.where(check_result, refine_probs,
pred_probs)
pred_probs = torch.where(mask_check, pred_probs_refine,
pred_probs)
return [pred_idxs, pred_probs]
return F.softmax(logits, -1)
def forward_train(self, x, targets=None):
bs = x.shape[0]
answer_token = torch.tile(self.answer_query, (bs, 1, 1))
if self.ch:
ques3 = self.char_embed(targets[7][:, :,
0]) + answer_token # bs nc dim
ques3_answer = targets[7][:, :, 1]
else:
ques3 = self.char_embed(
torch.arange(self.out_channels - 2, device=x.get_device())
).unsqueeze(0) + answer_token # bs nc dim
ques3_answer = targets[7]
loss1_list = []
loss2_list = []
loss3_list = []
loss4_list = []
sampler1_num = 0
sampler2_num = 0
sampler3_num = 0
sampler4_num = 0
if not self.ds:
visual_f = x + self.vis_pos_embed
elif self.pos2d:
x = x + self.vis_pos_embed[:, :, :x.shape[2], :x.shape[3]]
visual_f = x.flatten(2).transpose(1, 2)
else:
visual_f = x
train_i = 0
for target_ in zip(
targets[1].transpose(0, 1),
targets[2].transpose(0, 1),
targets[3].transpose(0, 1),
targets[4].transpose(0, 1),
targets[5].transpose(0, 1),
targets[6].transpose(0, 1),
targets[8].transpose(0, 1),
targets[9].transpose(0, 1),
targets[10].transpose(0, 1),
targets[11].transpose(0, 1),
):
# target_ = [prompt_pos_idx, prompt_char_idx, ques_pos_idx, ques1_answer, \
# ques2_char_idx, ques2_answer, ques4_char_num, ques_len, prompt_len]
visual_f_1234 = visual_f
if self.quesall or train_i == 0:
(
prompt,
ques1,
ques2,
ques2_head,
ques4,
ques1_answer,
ques2_answer,
ques4_answer,
mask_1234,
) = self.question_encoder(target_, train_i)
prompt_1234 = prompt
ques_1234 = torch.concat([ques1, ques2, ques3, ques4], 1)
for layer in self.cmff_decoder:
ques_1234, prompt_1234, visual_f_1234 = layer(
ques_1234, prompt_1234, visual_f_1234, mask_1234)
answer_query_1234 = self.answer_to_question_layer(
ques_1234, prompt_1234, mask_1234)
answer_feats_1234 = self.norm_pred(
self.answer_to_image_layer(answer_query_1234,
visual_f_1234)) # B, 26, 37
answer_feats_1 = answer_feats_1234[:, :ques1.shape[1], :]
answer_feats_2 = answer_feats_1234[:, ques1.shape[1]:(
ques1.shape[1] + ques2.shape[1]), :]
answer_feats_3 = answer_feats_1234[:, (
ques1.shape[1] + ques2.shape[1]):-ques4.shape[1], :]
answer_feats_4 = answer_feats_1234[:, -ques4.shape[1]:, :]
answer1_pred = self.ques1_head(answer_feats_1)
if train_i == 0:
logits = answer1_pred
n = (ques1_answer != self.ignore_index).sum().item()
loss1 = n * F.cross_entropy(
answer1_pred.flatten(0, 1),
ques1_answer.flatten(0, 1),
ignore_index=self.ignore_index,
reduction='mean',
)
sampler1_num += n
loss1_list.append(loss1)
answer2_pred = answer_feats_2.matmul(ques2_head.transpose(
1, 2))
diag_mask = torch.eye(answer2_pred.shape[1],
device=x.get_device()).unsqueeze(0).tile(
[bs, 1, 1])
answer2_pred = (answer2_pred * diag_mask).sum(-1)
ques2_answer = ques2_answer.flatten(0, 1)
non_pad_mask = torch.not_equal(ques2_answer, self.ignore_index)
n = non_pad_mask.sum().item()
ques2_answer = torch.where(ques2_answer == self.ignore_index,
0, ques2_answer)
loss2_none = F.binary_cross_entropy_with_logits(
answer2_pred.flatten(0, 1), ques2_answer, reduction='none')
loss2 = n * loss2_none.masked_select(non_pad_mask).mean()
sampler2_num += n
loss2_list.append(loss2)
answer3_pred = self.ques3_head(answer_feats_3)
n = (ques3_answer != self.ignore_index).sum().item()
loss3 = n * F.cross_entropy(answer3_pred.flatten(0, 1),
ques3_answer.flatten(0, 1),
reduction='mean')
sampler3_num += n
loss3_list.append(loss3)
answer4_pred = self.ques4_head(answer_feats_4)
n = (ques4_answer != self.max_len - 1).sum().item()
loss4 = n * F.cross_entropy(
answer4_pred.flatten(0, 1),
ques4_answer.flatten(0, 1),
ignore_index=self.max_len - 1,
reduction='mean',
)
sampler4_num += n
loss4_list.append(loss4)
else:
prompt, ques1, ques1_answer, mask_1234 = self.question_encoder(
target_, train_i)
prompt_1234 = prompt
for layer in self.cmff_decoder:
ques1, prompt_1234, visual_f_1234 = layer(
ques1, prompt_1234, visual_f_1234, mask_1234)
answer_query_1 = self.answer_to_question_layer(
ques1, prompt_1234, mask_1234)
answer_feats_1 = self.norm_pred(
self.answer_to_image_layer(answer_query_1,
visual_f_1234)) # B, 26, 37
answer1_pred = self.ques1_head(answer_feats_1)
n = (ques1_answer != self.ignore_index).sum().item()
loss1 = n * F.cross_entropy(
answer1_pred.flatten(0, 1),
ques1_answer.flatten(0, 1),
ignore_index=self.ignore_index,
reduction='mean',
)
sampler1_num += n
loss1_list.append(loss1)
train_i += 1
loss_list = [
sum(loss1_list) / sampler1_num,
sum(loss2_list) / sampler2_num,
sum(loss3_list) / sampler3_num,
sum(loss4_list) / sampler4_num,
]
loss = {
'loss': sum(loss_list),
'loss1': loss_list[0],
'loss2': loss_list[1],
'loss3': loss_list[2],
'loss4': loss_list[3],
}
return [loss, logits]