Spaces:
Running
Running
import numpy as np | |
import torch | |
from torch import nn | |
from torch.nn import functional as F | |
from torch.nn.init import ones_, trunc_normal_, zeros_ | |
from openrec.modeling.common import DropPath, Identity, Mlp | |
from openrec.modeling.decoders.nrtr_decoder import Embeddings | |
class CrossAttention(nn.Module): | |
def __init__( | |
self, | |
dim, | |
num_heads=8, | |
qkv_bias=False, | |
qk_scale=None, | |
attn_drop=0.0, | |
proj_drop=0.0, | |
): | |
super().__init__() | |
self.num_heads = num_heads | |
head_dim = dim // num_heads | |
self.scale = qk_scale or head_dim**-0.5 | |
self.q = nn.Linear(dim, dim, bias=qkv_bias) | |
self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias) | |
self.attn_drop = nn.Dropout(attn_drop) | |
self.proj = nn.Linear(dim, dim) | |
self.proj_drop = nn.Dropout(proj_drop) | |
def forward(self, q, kv, key_mask=None): | |
N, C = kv.shape[1:] | |
QN = q.shape[1] | |
q = self.q(q).reshape([-1, QN, self.num_heads, | |
C // self.num_heads]).transpose(1, 2) | |
q = q * self.scale | |
k, v = self.kv(kv).reshape( | |
[-1, N, 2, self.num_heads, | |
C // self.num_heads]).permute(2, 0, 3, 1, 4) | |
attn = q.matmul(k.transpose(2, 3)) | |
if key_mask is not None: | |
attn = attn + key_mask.unsqueeze(1) | |
attn = F.softmax(attn, -1) | |
if not self.training: | |
self.attn_map = attn | |
attn = self.attn_drop(attn) | |
x = (attn.matmul(v)).transpose(1, 2).reshape((-1, QN, C)) | |
x = self.proj(x) | |
x = self.proj_drop(x) | |
return x | |
class EdgeDecoderLayer(nn.Module): | |
def __init__( | |
self, | |
dim, | |
num_heads, | |
mlp_ratio=4.0, | |
qkv_bias=False, | |
qk_scale=None, | |
drop=0.0, | |
attn_drop=0.0, | |
drop_path=[0.0, 0.0], | |
act_layer=nn.GELU, | |
norm_layer='nn.LayerNorm', | |
epsilon=1e-6, | |
): | |
super().__init__() | |
self.head_dim = dim // num_heads | |
self.scale = qk_scale or self.head_dim**-0.5 | |
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here | |
self.drop_path1 = DropPath( | |
drop_path[0]) if drop_path[0] > 0.0 else Identity() | |
self.norm1 = eval(norm_layer)(dim, eps=epsilon) | |
self.norm2 = eval(norm_layer)(dim, eps=epsilon) | |
# self.c = nn.Linear(dim, dim*2) | |
self.p = nn.Linear(dim, dim) | |
self.cv = nn.Linear(dim, dim) | |
self.pv = nn.Linear(dim, dim) | |
self.dim = dim | |
self.num_heads = num_heads | |
self.p_proj = nn.Linear(dim, dim) | |
mlp_hidden_dim = int(dim * mlp_ratio) | |
self.mlp_ratio = mlp_ratio | |
self.mlp = Mlp( | |
in_features=dim, | |
hidden_features=mlp_hidden_dim, | |
act_layer=act_layer, | |
drop=drop, | |
) | |
def forward(self, p, cv, pv): | |
pN = p.shape[1] | |
vN = cv.shape[1] | |
p_shortcut = p | |
p1 = self.p(p).reshape( | |
[-1, pN, self.num_heads, | |
self.dim // self.num_heads]).transpose(1, 2) | |
cv1 = self.cv(cv).reshape( | |
[-1, vN, self.num_heads, | |
self.dim // self.num_heads]).transpose(1, 2) | |
pv1 = self.pv(pv).reshape( | |
[-1, vN, self.num_heads, | |
self.dim // self.num_heads]).transpose(1, 2) | |
edge = F.softmax(p1.matmul(pv1.transpose(2, 3)), -1) # B h N N | |
p_c = (edge @ cv1).transpose(1, 2).reshape((-1, pN, self.dim)) | |
x1 = self.norm1(p_shortcut + self.drop_path1(self.p_proj(p_c))) | |
x = self.norm2(x1 + self.drop_path1(self.mlp(x1))) | |
return x | |
class DecoderLayer(nn.Module): | |
def __init__( | |
self, | |
dim, | |
num_heads, | |
mlp_ratio=4.0, | |
qkv_bias=False, | |
qk_scale=None, | |
drop=0.0, | |
attn_drop=0.0, | |
drop_path=0.0, | |
act_layer=nn.GELU, | |
norm_layer='nn.LayerNorm', | |
epsilon=1e-6, | |
): | |
super().__init__() | |
self.norm1 = eval(norm_layer)(dim, eps=epsilon) | |
self.normkv = eval(norm_layer)(dim, eps=epsilon) | |
self.mixer = CrossAttention( | |
dim, | |
num_heads=num_heads, | |
qkv_bias=qkv_bias, | |
qk_scale=qk_scale, | |
attn_drop=attn_drop, | |
proj_drop=drop, | |
) | |
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here | |
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else Identity() | |
self.norm2 = eval(norm_layer)(dim, eps=epsilon) | |
mlp_hidden_dim = int(dim * mlp_ratio) | |
self.mlp_ratio = mlp_ratio | |
self.mlp = Mlp( | |
in_features=dim, | |
hidden_features=mlp_hidden_dim, | |
act_layer=act_layer, | |
drop=drop, | |
) | |
def forward(self, q, kv, key_mask=None): | |
x1 = q + self.drop_path( | |
self.mixer(self.norm1(q), self.normkv(kv), key_mask)) | |
x = x1 + self.drop_path(self.mlp(self.norm2(x1))) | |
return x | |
class CMFFLayer(nn.Module): | |
def __init__( | |
self, | |
dim, | |
num_heads, | |
mlp_ratio=4.0, | |
qkv_bias=False, | |
qk_scale=None, | |
drop=0.0, | |
attn_drop=0.0, | |
drop_path=0.0, | |
act_layer=nn.GELU, | |
epsilon=1e-6, | |
): | |
super().__init__() | |
self.normq1 = nn.LayerNorm(dim, eps=epsilon) | |
self.normkv1 = nn.LayerNorm(dim, eps=epsilon) | |
self.images_to_question_cross_attn = CrossAttention( | |
dim, | |
num_heads=num_heads, | |
qkv_bias=qkv_bias, | |
qk_scale=qk_scale, | |
attn_drop=attn_drop, | |
proj_drop=drop, | |
) | |
self.normq2 = nn.LayerNorm(dim, eps=epsilon) | |
self.normkv2 = nn.LayerNorm(dim, eps=epsilon) | |
self.question_to_images_cross_attn = CrossAttention( | |
dim, | |
num_heads=num_heads, | |
qkv_bias=qkv_bias, | |
qk_scale=qk_scale, | |
attn_drop=attn_drop, | |
proj_drop=drop, | |
) | |
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here | |
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else Identity() | |
self.normmlp = nn.LayerNorm(dim, eps=epsilon) | |
mlp_hidden_dim = int(dim * mlp_ratio) | |
self.mlp = Mlp( | |
in_features=dim, | |
hidden_features=mlp_hidden_dim, | |
act_layer=act_layer, | |
drop=drop, | |
) | |
def forward(self, question_f, prompt_f, visual_f, mask=None): | |
query_add = torch.concat([question_f, prompt_f, visual_f], 1) | |
query_add = query_add + self.drop_path( | |
self.images_to_question_cross_attn(self.normq1(query_add), | |
self.normkv1(prompt_f), mask)) | |
query_add = query_add + self.drop_path( | |
self.question_to_images_cross_attn( | |
self.normq2(query_add), | |
self.normkv2(query_add[:, -visual_f.shape[1]:, :]))) | |
query_updated = query_add + self.drop_path( | |
self.mlp(self.normmlp(query_add))) | |
question_f_updated = query_updated[:, :question_f.shape[1], :] | |
prompt_f_updated = query_updated[:, question_f. | |
shape[1]:-visual_f.shape[1], :] | |
visual_f_updated = query_updated[:, -visual_f.shape[1]:, :] | |
return question_f_updated, prompt_f_updated, visual_f_updated | |
class IGTRDecoder(nn.Module): | |
def __init__(self, | |
in_channels, | |
dim, | |
out_channels, | |
num_layer=2, | |
drop_path_rate=0.1, | |
max_len=25, | |
vis_seq=50, | |
ch=False, | |
ar=False, | |
refine_iter=0, | |
quesall=True, | |
next_pred=False, | |
ds=False, | |
pos2d=False, | |
check_search=False, | |
max_size=[8, 32], | |
**kwargs): | |
super(IGTRDecoder, self).__init__() | |
self.out_channels = out_channels | |
self.dim = dim | |
self.max_len = max_len + 3 # max_len + eos + bos | |
self.ch = ch | |
self.char_embed = Embeddings(d_model=dim, | |
vocab=self.out_channels, | |
scale_embedding=True) | |
self.ignore_index = out_channels - 1 | |
self.ar = ar | |
self.refine_iter = refine_iter | |
self.bos = self.out_channels - 2 | |
self.eos = 0 | |
self.next_pred = next_pred | |
self.quesall = quesall | |
self.check_search = check_search | |
dpr = np.linspace(0, drop_path_rate, num_layer + 2) | |
self.cmff_decoder = nn.ModuleList([ | |
CMFFLayer(dim=dim, | |
num_heads=dim // 32, | |
mlp_ratio=4.0, | |
qkv_bias=True, | |
drop_path=dpr[i]) for i in range(num_layer) | |
]) | |
self.answer_to_question_layer = DecoderLayer(dim=dim, | |
num_heads=dim // 32, | |
mlp_ratio=4.0, | |
qkv_bias=True, | |
drop_path=dpr[-2]) | |
self.answer_to_image_layer = DecoderLayer(dim=dim, | |
num_heads=dim // 32, | |
mlp_ratio=4.0, | |
qkv_bias=True, | |
drop_path=dpr[-1]) | |
self.char_pos_embed = nn.Parameter(torch.zeros([self.max_len, dim], | |
dtype=torch.float32), | |
requires_grad=True) | |
self.appear_num_embed = nn.Parameter(torch.zeros([self.max_len, dim], | |
dtype=torch.float32), | |
requires_grad=True) | |
self.ds = ds | |
self.pos2d = pos2d | |
if not ds: | |
self.vis_pos_embed = nn.Parameter(torch.zeros([1, vis_seq, dim], | |
dtype=torch.float32), | |
requires_grad=True) | |
trunc_normal_(self.vis_pos_embed, std=0.02) | |
elif pos2d: | |
pos_embed = torch.zeros([1, max_size[0] * max_size[1], dim], | |
dtype=torch.float32) | |
trunc_normal_(pos_embed, mean=0, std=0.02) | |
self.vis_pos_embed = nn.Parameter( | |
pos_embed.transpose(1, 2).reshape(1, dim, max_size[0], | |
max_size[1]), | |
requires_grad=True, | |
) | |
self.prompt_pos_embed = nn.Parameter(torch.zeros([1, 6, dim], | |
dtype=torch.float32), | |
requires_grad=True) | |
self.answer_query = nn.Parameter(torch.zeros([1, 1, dim], | |
dtype=torch.float32), | |
requires_grad=True) | |
self.norm_pred = nn.LayerNorm(dim, eps=1e-6) | |
self.ques1_head = nn.Linear(dim, self.out_channels - 2) | |
self.ques2_head = nn.Linear(dim, self.max_len, bias=False) | |
self.ques3_head = nn.Linear(dim, self.max_len - 1) | |
self.ques4_head = nn.Linear(dim, self.max_len - 1) | |
trunc_normal_(self.char_pos_embed, std=0.02) | |
trunc_normal_(self.appear_num_embed, std=0.02) | |
trunc_normal_(self.answer_query, std=0.02) | |
trunc_normal_(self.prompt_pos_embed, std=0.02) | |
self.apply(self._init_weights) | |
def _init_weights(self, m): | |
if isinstance(m, nn.Linear): | |
trunc_normal_(m.weight, std=0.02) | |
if isinstance(m, nn.Linear) and m.bias is not None: | |
zeros_(m.bias) | |
elif isinstance(m, nn.LayerNorm): | |
zeros_(m.bias) | |
ones_(m.weight) | |
def no_weight_decay(self): | |
return { | |
'char_pos_embed', 'vis_pos_embed', 'appear_num_embed', | |
'answer_query', 'char_embed' | |
} | |
def question_encoder(self, targets, train_i): | |
( | |
prompt_pos_idx, | |
prompt_char_idx, | |
ques_pos_idx, | |
ques1_answer, | |
ques2_char_idx, | |
ques2_answer, | |
ques4_char_num, | |
ques_len, | |
ques2_len, | |
prompt_len, | |
) = targets | |
max_ques_len = torch.max(ques_len) | |
max_ques2_len = torch.max(ques2_len) | |
max_prompt_len = torch.max(prompt_len) | |
if self.next_pred and (train_i == 2 or train_i == 3): | |
prompt_pos = self.prompt_pos_embed | |
prompt_char_idx = prompt_char_idx[:, :max_prompt_len] | |
else: | |
prompt_pos = F.embedding( | |
prompt_pos_idx[:, :max_prompt_len], self.char_pos_embed | |
) # bs lp [ 0, 4, 3, 12, 12, 12, 12, 12, 12, 12, 12] | |
prompt_char_idx = prompt_char_idx[:, :max_prompt_len] | |
prompt_char = self.char_embed(prompt_char_idx) # bs lp | |
prompt = prompt_pos + prompt_char | |
mask_1234 = torch.where(prompt_char_idx == self.ignore_index, | |
float('-inf'), 0) | |
ques1 = F.embedding(ques_pos_idx[:, :max_ques_len], | |
self.char_pos_embed) # bs lq1 dim | |
ques1_answer = ques1_answer[:, :max_ques_len] | |
if self.quesall or train_i == 0: | |
ques2_char = self.char_embed(ques2_char_idx[:, :max_ques2_len, 1]) | |
ques2 = ques2_char + F.embedding(ques2_char_idx[:, :max_ques2_len, | |
0], | |
self.char_pos_embed) # bs lq2 dim | |
ques2_answer = ques2_answer[:, :max_ques2_len] | |
ques2_head = F.embedding(ques2_char_idx[:, :max_ques2_len, 0], | |
self.ques2_head.weight) | |
ques4_char = self.char_embed(ques1_answer) | |
ques4_ap_num = F.embedding(ques4_char_num[:, :max_ques_len], | |
self.appear_num_embed) | |
ques4 = ques4_char + ques4_ap_num | |
ques4_answer = ques_pos_idx[:, :max_ques_len] | |
return ( | |
prompt, | |
ques1, | |
ques2, | |
ques2_head, | |
ques4, | |
ques1_answer, | |
ques2_answer, | |
ques4_answer, | |
mask_1234.unsqueeze(1), | |
) | |
else: | |
return prompt, ques1, ques1_answer, mask_1234.unsqueeze(1) | |
def forward(self, x, data=None): | |
if self.training: | |
return self.forward_train(x, data) | |
else: | |
return self.forward_test(x) | |
def forward_test(self, x): | |
if not self.ds: | |
visual_f = x + self.vis_pos_embed | |
elif self.pos2d: | |
x = x + self.vis_pos_embed[:, :, :x.shape[2], :x.shape[3]] | |
visual_f = x.flatten(2).transpose(1, 2) | |
else: | |
visual_f = x | |
bs = x.shape[0] | |
prompt_bos = self.char_embed( | |
torch.full( | |
[bs, 1], self.bos, dtype=torch.long, | |
device=x.get_device())) + self.char_pos_embed[:1, :].unsqueeze( | |
0) # BOS prompt | |
ques_all = torch.tile(self.char_pos_embed.unsqueeze(0), (bs, 1, 1)) | |
if not self.ar: | |
if self.check_search: | |
tgt_in = torch.full((bs, self.max_len), | |
self.ignore_index, | |
dtype=torch.long, | |
device=x.get_device()) | |
tgt_in[:, 0] = self.bos | |
logits = [] | |
for j in range(1, self.max_len): | |
visual_f_check = visual_f | |
ques_check_i = ques_all[:, j:j + 1, :] + self.char_embed( | |
torch.arange(self.out_channels - 2, | |
device=x.get_device())).unsqueeze(0) | |
prompt_check = ques_all[:, :j] + self.char_embed( | |
tgt_in[:, :j]) | |
# prompt_check = prompt_bos | |
mask = torch.where( | |
(tgt_in[:, :j] == self.eos).int().cumsum(-1) > 0, | |
float('-inf'), 0) | |
for layer in self.cmff_decoder: | |
ques_check_i, prompt_check, visual_f_check = layer( | |
ques_check_i, prompt_check, visual_f_check, | |
mask.unsqueeze(1)) | |
answer_query_i = self.answer_to_question_layer( | |
ques_check_i, prompt_check, mask.unsqueeze(1)) | |
answer_pred_i = self.norm_pred( | |
self.answer_to_image_layer( | |
answer_query_i, visual_f_check)) # B, 26, 37 | |
# the next token probability is in the output's ith token position | |
fc_2 = self.ques2_head.weight[j:j + 1].unsqueeze(0) | |
fc_2 = fc_2.tile([bs, 1, 1]) | |
p_i = fc_2 @ answer_pred_i.transpose(1, 2) | |
# p_i = p_i[:, 0, :] | |
logits.append(p_i) | |
if j < self.max_len - 1: | |
# greedy decode. add the next token index to the target input | |
tgt_in[:, j] = p_i.squeeze().argmax(-1) | |
# Efficient batch decoding: If all output words have at least one EOS token, end decoding. | |
if (tgt_in == self.eos).any(dim=-1).all(): | |
break | |
logits = torch.cat(logits, dim=1) | |
else: | |
ques_pd = ques_all[:, 1:, :] | |
prompt_pd = prompt_bos | |
visual_f_pd = visual_f | |
for layer in self.cmff_decoder: | |
ques_pd, prompt_pd, visual_f_pd = layer( | |
ques_pd, prompt_pd, visual_f_pd) | |
answer_query_pd = self.answer_to_question_layer( | |
ques_pd, prompt_pd) | |
answer_feats_pd = self.norm_pred( | |
self.answer_to_image_layer(answer_query_pd, | |
visual_f_pd)) # B, 26, 37 | |
logits = self.ques1_head(answer_feats_pd) | |
elif self.next_pred: | |
ques_pd_1 = ques_all[:, 1:2, :] | |
prompt_pd = prompt_bos | |
visual_f_pd = visual_f | |
for layer in self.cmff_decoder: | |
ques_pd_1, prompt_pd, visual_f_pd = layer( | |
ques_pd_1, prompt_pd, visual_f_pd) | |
answer_query_pd = self.answer_to_question_layer( | |
ques_pd_1, prompt_pd) | |
answer_feats_pd = self.norm_pred( | |
self.answer_to_image_layer(answer_query_pd, | |
visual_f_pd)) # B, 26, 37 | |
logits_pd_1 = self.ques1_head(answer_feats_pd) | |
ques_next = self.char_pos_embed[-2:-1, :].unsqueeze(0).tile( | |
[bs, 1, 1]) | |
prompt_next_bos = (self.char_embed( | |
torch.full( | |
[bs, 1], self.bos, dtype=torch.long, | |
device=x.get_device())) + self.prompt_pos_embed[:, :1, :]) | |
pred_prob, pred_id = F.softmax(logits_pd_1, -1).max(-1) | |
pred_prob_list = [pred_prob] | |
pred_id_list = [pred_id] | |
for j in range(1, 70): | |
prompt_next_1 = self.char_embed( | |
pred_id) + self.prompt_pos_embed[:, | |
-1 * pred_id.shape[1]:, :] | |
prompt_next = torch.concat([prompt_next_bos, prompt_next_1], 1) | |
ques_next_i = ques_next | |
visual_f_i = visual_f | |
for layer in self.cmff_decoder: | |
ques_next_i, prompt_next, visual_f_pd = layer( | |
ques_next_i, prompt_next, visual_f_i) | |
answer_query_next_i = self.answer_to_question_layer( | |
ques_next_i, prompt_next) | |
answer_feats_next_i = self.norm_pred( | |
self.answer_to_image_layer(answer_query_next_i, | |
visual_f_i)) # B, 26, 37 | |
logits_next_i = self.ques1_head(answer_feats_next_i) | |
# pred_id = logits_next_i.argmax(-1) | |
pred_prob_i, pred_id_i = F.softmax(logits_next_i, -1).max(-1) | |
pred_prob_list.append(pred_prob_i) | |
pred_id_list.append(pred_id_i) | |
if (torch.concat(pred_id_list, | |
1) == self.eos).any(dim=-1).all(): | |
break | |
if pred_id.shape[1] >= 5: | |
pred_id = torch.concat([pred_id[:, 1:], pred_id_i], 1) | |
else: | |
pred_id = torch.concat([pred_id, pred_id_i], 1) | |
return [ | |
torch.concat(pred_id_list, 1), | |
torch.concat(pred_prob_list, 1) | |
] | |
else: | |
tgt_in = torch.full((bs, self.max_len), | |
self.ignore_index, | |
dtype=torch.long, | |
device=x.get_device()) | |
tgt_in[:, 0] = self.bos | |
logits = [] | |
for j in range(1, self.max_len): | |
visual_f_ar = visual_f | |
ques_i = ques_all[:, j:j + 1, :] | |
prompt_ar = ques_all[:, :j] + self.char_embed(tgt_in[:, :j]) | |
mask = torch.where( | |
(tgt_in[:, :j] == self.eos).int().cumsum(-1) > 0, | |
float('-inf'), 0) | |
for layer in self.cmff_decoder: | |
ques_i, prompt_ar, visual_f_ar = layer( | |
ques_i, prompt_ar, visual_f_ar, mask.unsqueeze(1)) | |
answer_query_i = self.answer_to_question_layer( | |
ques_i, prompt_ar, mask.unsqueeze(1)) | |
answer_pred_i = self.norm_pred( | |
self.answer_to_image_layer(answer_query_i, | |
visual_f_ar)) # B, 26, 37 | |
# the next token probability is in the output's ith token position | |
p_i = self.ques1_head(answer_pred_i) | |
logits.append(p_i) | |
if j < self.max_len - 1: | |
# greedy decode. add the next token index to the target input | |
tgt_in[:, j] = p_i.squeeze().argmax(-1) | |
# Efficient batch decoding: If all output words have at least one EOS token, end decoding. | |
if (tgt_in == self.eos).any(dim=-1).all(): | |
break | |
logits = torch.cat(logits, dim=1) | |
if self.refine_iter > 0: | |
pred_probs, pred_idxs = F.softmax(logits, -1).max(-1) | |
for i in range(self.refine_iter): | |
mask_check = (pred_idxs == self.eos).int().cumsum(-1) <= 1 | |
ques_check_all = self.char_embed( | |
pred_idxs) + ques_all[:, 1:pred_idxs.shape[1] + 1, :] | |
prompt_check = prompt_bos | |
visual_f_check = visual_f | |
ques_check = ques_check_all | |
for layer in self.cmff_decoder: | |
ques_check, prompt_check, visual_f_check = layer( | |
ques_check, prompt_check, visual_f_check) | |
answer_query_check = self.answer_to_question_layer( | |
ques_check, prompt_check) | |
answer_pred_check = self.norm_pred( | |
self.answer_to_image_layer(answer_query_check, | |
visual_f_check)) # B, 26, 37 | |
ques2_head = self.ques2_head.weight[1:pred_idxs.shape[1] + | |
1, :] | |
ques2_head = torch.tile(ques2_head.unsqueeze(0), [bs, 1, 1]) | |
answer2_pred = answer_pred_check.matmul( | |
ques2_head.transpose(1, 2)) | |
diag_mask = torch.eye(answer2_pred.shape[1], | |
device=x.get_device()).unsqueeze(0).tile( | |
[bs, 1, 1]) | |
answer2_pred = F.sigmoid( | |
(answer2_pred * diag_mask).sum(-1)) * mask_check | |
check_result = answer2_pred < 0.9 # pred_probs < 0.99 | |
prompt_refine = torch.concat([prompt_bos, ques_check_all], 1) | |
mask_refine = torch.where( | |
check_result, float('-inf'), 0) + torch.where( | |
(pred_idxs == self.eos).int().cumsum(-1) < 1, 0, | |
float('-inf')) | |
mask_refine = torch.concat( | |
[torch.zeros([bs, 1], device=x.get_device()), mask_refine], | |
1).unsqueeze(1) | |
ques_refine = ques_all[:, 1:pred_idxs.shape[1] + 1, :] | |
visual_f_refine = visual_f | |
for layer in self.cmff_decoder: | |
ques_refine, prompt_refine, visual_f_refine = layer( | |
ques_refine, prompt_refine, visual_f_refine, | |
mask_refine) | |
answer_query_refine = self.answer_to_question_layer( | |
ques_refine, prompt_refine, mask_refine) | |
answer_pred_refine = self.norm_pred( | |
self.answer_to_image_layer(answer_query_refine, | |
visual_f_refine)) # B, 26, 37 | |
answer_refine = self.ques1_head(answer_pred_refine) | |
refine_probs, refine_idxs = F.softmax(answer_refine, | |
-1).max(-1) | |
pred_idxs_refine = torch.where(check_result, refine_idxs, | |
pred_idxs) | |
pred_idxs = torch.where(mask_check, pred_idxs_refine, | |
pred_idxs) | |
pred_probs_refine = torch.where(check_result, refine_probs, | |
pred_probs) | |
pred_probs = torch.where(mask_check, pred_probs_refine, | |
pred_probs) | |
return [pred_idxs, pred_probs] | |
return F.softmax(logits, -1) | |
def forward_train(self, x, targets=None): | |
bs = x.shape[0] | |
answer_token = torch.tile(self.answer_query, (bs, 1, 1)) | |
if self.ch: | |
ques3 = self.char_embed(targets[7][:, :, | |
0]) + answer_token # bs nc dim | |
ques3_answer = targets[7][:, :, 1] | |
else: | |
ques3 = self.char_embed( | |
torch.arange(self.out_channels - 2, device=x.get_device()) | |
).unsqueeze(0) + answer_token # bs nc dim | |
ques3_answer = targets[7] | |
loss1_list = [] | |
loss2_list = [] | |
loss3_list = [] | |
loss4_list = [] | |
sampler1_num = 0 | |
sampler2_num = 0 | |
sampler3_num = 0 | |
sampler4_num = 0 | |
if not self.ds: | |
visual_f = x + self.vis_pos_embed | |
elif self.pos2d: | |
x = x + self.vis_pos_embed[:, :, :x.shape[2], :x.shape[3]] | |
visual_f = x.flatten(2).transpose(1, 2) | |
else: | |
visual_f = x | |
train_i = 0 | |
for target_ in zip( | |
targets[1].transpose(0, 1), | |
targets[2].transpose(0, 1), | |
targets[3].transpose(0, 1), | |
targets[4].transpose(0, 1), | |
targets[5].transpose(0, 1), | |
targets[6].transpose(0, 1), | |
targets[8].transpose(0, 1), | |
targets[9].transpose(0, 1), | |
targets[10].transpose(0, 1), | |
targets[11].transpose(0, 1), | |
): | |
# target_ = [prompt_pos_idx, prompt_char_idx, ques_pos_idx, ques1_answer, \ | |
# ques2_char_idx, ques2_answer, ques4_char_num, ques_len, prompt_len] | |
visual_f_1234 = visual_f | |
if self.quesall or train_i == 0: | |
( | |
prompt, | |
ques1, | |
ques2, | |
ques2_head, | |
ques4, | |
ques1_answer, | |
ques2_answer, | |
ques4_answer, | |
mask_1234, | |
) = self.question_encoder(target_, train_i) | |
prompt_1234 = prompt | |
ques_1234 = torch.concat([ques1, ques2, ques3, ques4], 1) | |
for layer in self.cmff_decoder: | |
ques_1234, prompt_1234, visual_f_1234 = layer( | |
ques_1234, prompt_1234, visual_f_1234, mask_1234) | |
answer_query_1234 = self.answer_to_question_layer( | |
ques_1234, prompt_1234, mask_1234) | |
answer_feats_1234 = self.norm_pred( | |
self.answer_to_image_layer(answer_query_1234, | |
visual_f_1234)) # B, 26, 37 | |
answer_feats_1 = answer_feats_1234[:, :ques1.shape[1], :] | |
answer_feats_2 = answer_feats_1234[:, ques1.shape[1]:( | |
ques1.shape[1] + ques2.shape[1]), :] | |
answer_feats_3 = answer_feats_1234[:, ( | |
ques1.shape[1] + ques2.shape[1]):-ques4.shape[1], :] | |
answer_feats_4 = answer_feats_1234[:, -ques4.shape[1]:, :] | |
answer1_pred = self.ques1_head(answer_feats_1) | |
if train_i == 0: | |
logits = answer1_pred | |
n = (ques1_answer != self.ignore_index).sum().item() | |
loss1 = n * F.cross_entropy( | |
answer1_pred.flatten(0, 1), | |
ques1_answer.flatten(0, 1), | |
ignore_index=self.ignore_index, | |
reduction='mean', | |
) | |
sampler1_num += n | |
loss1_list.append(loss1) | |
answer2_pred = answer_feats_2.matmul(ques2_head.transpose( | |
1, 2)) | |
diag_mask = torch.eye(answer2_pred.shape[1], | |
device=x.get_device()).unsqueeze(0).tile( | |
[bs, 1, 1]) | |
answer2_pred = (answer2_pred * diag_mask).sum(-1) | |
ques2_answer = ques2_answer.flatten(0, 1) | |
non_pad_mask = torch.not_equal(ques2_answer, self.ignore_index) | |
n = non_pad_mask.sum().item() | |
ques2_answer = torch.where(ques2_answer == self.ignore_index, | |
0, ques2_answer) | |
loss2_none = F.binary_cross_entropy_with_logits( | |
answer2_pred.flatten(0, 1), ques2_answer, reduction='none') | |
loss2 = n * loss2_none.masked_select(non_pad_mask).mean() | |
sampler2_num += n | |
loss2_list.append(loss2) | |
answer3_pred = self.ques3_head(answer_feats_3) | |
n = (ques3_answer != self.ignore_index).sum().item() | |
loss3 = n * F.cross_entropy(answer3_pred.flatten(0, 1), | |
ques3_answer.flatten(0, 1), | |
reduction='mean') | |
sampler3_num += n | |
loss3_list.append(loss3) | |
answer4_pred = self.ques4_head(answer_feats_4) | |
n = (ques4_answer != self.max_len - 1).sum().item() | |
loss4 = n * F.cross_entropy( | |
answer4_pred.flatten(0, 1), | |
ques4_answer.flatten(0, 1), | |
ignore_index=self.max_len - 1, | |
reduction='mean', | |
) | |
sampler4_num += n | |
loss4_list.append(loss4) | |
else: | |
prompt, ques1, ques1_answer, mask_1234 = self.question_encoder( | |
target_, train_i) | |
prompt_1234 = prompt | |
for layer in self.cmff_decoder: | |
ques1, prompt_1234, visual_f_1234 = layer( | |
ques1, prompt_1234, visual_f_1234, mask_1234) | |
answer_query_1 = self.answer_to_question_layer( | |
ques1, prompt_1234, mask_1234) | |
answer_feats_1 = self.norm_pred( | |
self.answer_to_image_layer(answer_query_1, | |
visual_f_1234)) # B, 26, 37 | |
answer1_pred = self.ques1_head(answer_feats_1) | |
n = (ques1_answer != self.ignore_index).sum().item() | |
loss1 = n * F.cross_entropy( | |
answer1_pred.flatten(0, 1), | |
ques1_answer.flatten(0, 1), | |
ignore_index=self.ignore_index, | |
reduction='mean', | |
) | |
sampler1_num += n | |
loss1_list.append(loss1) | |
train_i += 1 | |
loss_list = [ | |
sum(loss1_list) / sampler1_num, | |
sum(loss2_list) / sampler2_num, | |
sum(loss3_list) / sampler3_num, | |
sum(loss4_list) / sampler4_num, | |
] | |
loss = { | |
'loss': sum(loss_list), | |
'loss1': loss_list[0], | |
'loss2': loss_list[1], | |
'loss3': loss_list[2], | |
'loss4': loss_list[3], | |
} | |
return [loss, logits] | |