Spaces:
Running
Running
import numpy as np | |
import torch | |
from torch import nn | |
from torch.nn import functional as F | |
from torch.nn.init import ones_, trunc_normal_, zeros_ | |
from openrec.modeling.common import DropPath, Identity | |
from openrec.modeling.decoders.cppd_decoder import DecoderLayer | |
from openrec.modeling.decoders.nrtr_decoder import Embeddings | |
class CrossAttention(nn.Module): | |
def __init__( | |
self, | |
dim, | |
num_heads=8, | |
qkv_bias=False, | |
qk_scale=None, | |
attn_drop=0.0, | |
proj_drop=0.0, | |
): | |
super().__init__() | |
self.num_heads = num_heads | |
head_dim = dim // num_heads | |
self.scale = qk_scale or head_dim**-0.5 | |
self.q = nn.Linear(dim, dim, bias=qkv_bias) | |
self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias) | |
self.attn_drop = nn.Dropout(attn_drop) | |
self.proj = nn.Linear(dim, dim) | |
self.proj_drop = nn.Dropout(proj_drop) | |
def forward(self, q, kv, key_mask=None): | |
N, C = kv.shape[1:] | |
QN = q.shape[1] | |
q = self.q(q).reshape([-1, QN, self.num_heads, | |
C // self.num_heads]).transpose(1, 2) | |
q = q * self.scale | |
k, v = self.kv(kv).reshape( | |
[-1, N, 2, self.num_heads, | |
C // self.num_heads]).permute(2, 0, 3, 1, 4) | |
attn = q.matmul(k.transpose(2, 3)) | |
if key_mask is not None: | |
attn = attn + key_mask.unsqueeze(1) | |
attn = F.softmax(attn, -1) | |
if not self.training: | |
self.attn_map = attn | |
attn = self.attn_drop(attn) | |
x = (attn.matmul(v)).transpose(1, 2).reshape((-1, QN, C)) | |
x = self.proj(x) | |
x = self.proj_drop(x) | |
return x | |
class SSMatchLayer(nn.Module): | |
def __init__( | |
self, | |
dim, | |
nextq2subs_head2=None, | |
dynq2img_heads=2, | |
mlp_ratio=4.0, | |
qkv_bias=False, | |
qk_scale=None, | |
drop=0.0, | |
attn_drop=0.0, | |
drop_path=0.0, | |
act_layer=nn.GELU, | |
num_layer=2, | |
epsilon=1e-6, | |
): | |
super().__init__() | |
self.dim = dim | |
if nextq2subs_head2 is None: | |
nextq2subs_head2 = dim // 32 | |
self.normq1 = nn.LayerNorm(dim, eps=epsilon) | |
self.normkv1 = nn.LayerNorm(dim, eps=epsilon) | |
self.images_to_question_cross_attn = CrossAttention( | |
dim, | |
num_heads=nextq2subs_head2, | |
qkv_bias=qkv_bias, | |
qk_scale=qk_scale, | |
attn_drop=attn_drop, | |
proj_drop=drop) | |
self.normq2 = nn.LayerNorm(dim, eps=epsilon) | |
# self.normkv2 = nn.LayerNorm(dim, eps=epsilon) | |
dpr = np.linspace(0, drop_path, num_layer) | |
self.question_to_images_cross_attn = nn.ModuleList([ | |
DecoderLayer( | |
dim=dim, | |
num_heads=dynq2img_heads, | |
mlp_ratio=4.0, | |
qkv_bias=True, | |
drop_path=dpr[i], | |
act_layer=act_layer, | |
) for i in range(num_layer) | |
]) | |
# CrossAttention( | |
# dim, | |
# num_heads=dynq2img_heads, | |
# qkv_bias=qkv_bias, | |
# qk_scale=qk_scale, | |
# attn_drop=attn_drop, | |
# proj_drop=drop) | |
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here | |
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else Identity() | |
def forward(self, question_f, prompt_f, visual_f, mask=None): | |
question_f = question_f + self.drop_path( | |
self.images_to_question_cross_attn(self.normq1(question_f), | |
self.normkv1(prompt_f), mask)) | |
question_f = question_f.reshape(visual_f.shape[0], -1, self.dim) | |
question_f = self.normq2(question_f) | |
# kv = self.normkv2(visual_f) | |
for layer in self.question_to_images_cross_attn: | |
question_f = layer(question_f, visual_f) | |
return question_f | |
class SMTRDecoderNumAttn(nn.Module): | |
def __init__(self, | |
in_channels, | |
out_channels, | |
num_layer=2, | |
nextq2subs_head2=None, | |
dynq2img_heads=2, | |
drop_path_rate=0.1, | |
max_len=25, | |
vis_seq=50, | |
ds=False, | |
pos2d=False, | |
max_size=[8, 32], | |
sub_str_len=5, | |
next_mode=True, | |
infer_aug=False, | |
**kwargs): | |
super(SMTRDecoderNumAttn, self).__init__() | |
self.out_channels = out_channels | |
dim = in_channels | |
self.dim = dim | |
self.max_len = max_len + 3 # max_len + eos + bos | |
self.char_embed = Embeddings(d_model=dim, | |
vocab=self.out_channels, | |
scale_embedding=True) | |
self.ignore_index = out_channels - 1 | |
self.sub_str_len = sub_str_len | |
self.bos_next = out_channels - 3 | |
self.bos_pre = out_channels - 2 | |
self.eos = 0 | |
self.next_mode = next_mode | |
self.infer_aug = infer_aug | |
self.cmff_decoder = SSMatchLayer(dim=dim, | |
nextq2subs_head2=nextq2subs_head2, | |
dynq2img_heads=dynq2img_heads, | |
mlp_ratio=4.0, | |
qkv_bias=True, | |
drop_path=drop_path_rate, | |
num_layer=num_layer) | |
self.ds = ds | |
self.pos2d = pos2d | |
if not ds: | |
self.vis_pos_embed = nn.Parameter(torch.zeros([1, vis_seq, dim], | |
dtype=torch.float32), | |
requires_grad=True) | |
trunc_normal_(self.vis_pos_embed, std=0.02) | |
elif pos2d: | |
pos_embed = torch.zeros([1, max_size[0] * max_size[1], dim], | |
dtype=torch.float32) | |
trunc_normal_(pos_embed, mean=0, std=0.02) | |
self.vis_pos_embed = nn.Parameter(pos_embed.transpose( | |
1, 2).reshape(1, dim, max_size[0], max_size[1]), | |
requires_grad=True) | |
self.next_token = nn.Parameter(torch.zeros([1, 1, 1, dim], | |
dtype=torch.float32), | |
requires_grad=True) | |
self.pre_token = nn.Parameter(torch.zeros([1, 1, 1, dim], | |
dtype=torch.float32), | |
requires_grad=True) | |
self.prompt_next_embed = nn.Parameter(torch.zeros( | |
[1, 1, self.sub_str_len + 1, dim], dtype=torch.float32), | |
requires_grad=True) | |
self.prompt_pre_embed = nn.Parameter(torch.zeros( | |
[1, 1, self.sub_str_len + 1, dim], dtype=torch.float32), | |
requires_grad=True) | |
self.norm_pred = nn.LayerNorm(dim, eps=1e-6) | |
self.ques1_head = nn.Linear(dim, self.out_channels - 3) | |
trunc_normal_(self.next_token, std=0.02) | |
trunc_normal_(self.pre_token, std=0.02) | |
trunc_normal_(self.prompt_pre_embed, std=0.02) | |
trunc_normal_(self.prompt_next_embed, std=0.02) | |
self.apply(self._init_weights) | |
def _init_weights(self, m): | |
if isinstance(m, nn.Linear): | |
trunc_normal_(m.weight, std=0.02) | |
if isinstance(m, nn.Linear) and m.bias is not None: | |
zeros_(m.bias) | |
elif isinstance(m, nn.LayerNorm): | |
zeros_(m.bias) | |
ones_(m.weight) | |
def no_weight_decay(self): | |
return {'vis_pos_embed', 'pre_token', 'next_token', 'char_embed'} | |
def forward(self, x, data=None): | |
if self.training: | |
return self.forward_train(x, data) | |
else: | |
if self.infer_aug: | |
return self.forward_test_bi(x) | |
return self.forward_test(x) | |
def forward_test_bi(self, x): | |
# self.attn_maps = [] | |
if not self.ds: | |
visual_f = x + self.vis_pos_embed | |
elif self.pos2d: | |
visual_f = x + self.vis_pos_embed[:, :, :x.shape[2], :x.shape[3]] | |
visual_f = x.flatten(2).transpose(1, 2) | |
else: | |
visual_f = x | |
bs = 2 | |
if 1: | |
next = self.next_token | |
pre = self.pre_token | |
next_pre = torch.concat([next, pre], 0) | |
next_pre = next_pre.squeeze(1) #2, 1, dim | |
prompt_next_embed = self.prompt_next_embed.squeeze(1) | |
prompt_pre_embed = self.prompt_pre_embed.squeeze(1) | |
next_id = torch.full([1, self.sub_str_len], | |
self.bos_next, | |
dtype=torch.long, | |
device=x.get_device()) | |
pre_id = torch.full([1, self.sub_str_len], | |
self.bos_pre, | |
dtype=torch.long, | |
device=x.get_device()) | |
# prompt_next_bos = self.char_embed(prompt_id) | |
# pred_prob_list = torch.full([bs, self.sub_str_len], self.ignore_index, dtype=torch.long, device=x.get_device()) | |
next_pred_id_list = torch.full([1, self.max_len], | |
self.ignore_index, | |
dtype=torch.long, | |
device=x.get_device()) | |
pre_pred_id_list = torch.full([1, self.max_len], | |
self.ignore_index, | |
dtype=torch.long, | |
device=x.get_device()) | |
next_logits_all = [] | |
pre_logits_all = [] | |
mask_pad = torch.zeros([bs, 1], | |
dtype=torch.float32, | |
device=x.get_device()) | |
for j in range(0, min(70, self.max_len - 1)): | |
prompt_char_next = torch.concat([ | |
prompt_next_embed[:, :1, :], | |
prompt_next_embed[:, 1:, :] + self.char_embed(next_id) | |
], 1) # b, sub_l, dim | |
prompt_char_pre = torch.concat([ | |
prompt_pre_embed[:, :1, :], | |
prompt_pre_embed[:, 1:, :] + self.char_embed(pre_id) | |
], 1) # b, sub_l, dim | |
prompt_char = torch.concat([prompt_char_next, prompt_char_pre], | |
0) #2, 6, dim | |
# prompt_char = prompt_char.flatten(0, 1) | |
mask_next = torch.where(next_id == self.bos_next, | |
float('-inf'), 0) # b, subs_l | |
mask_pre = torch.where(pre_id == self.bos_pre, float('-inf'), | |
0) # b, subs_l | |
mask = torch.concat([mask_next, mask_pre], 0) #2, 5 | |
mask = torch.concat([mask_pad, mask], 1) # 2, 6 | |
pred_token = next_pre | |
visual_f_i = visual_f[:2] # 2 l dim | |
pred_token = self.cmff_decoder(pred_token, prompt_char, | |
visual_f_i, mask.unsqueeze(1)) | |
logits_next_i = self.ques1_head(self.norm_pred(pred_token)) | |
logits = F.softmax(logits_next_i, -1) | |
pred_id_i = logits.argmax(-1) #2, 1 | |
# print(pred_id_i.shape) | |
next_pred_id_list[:, j:j + 1] = pred_id_i[:1] | |
pre_pred_id_list[:, j:j + 1] = pred_id_i[1:2] | |
if not (next_pred_id_list == self.eos).any(dim=-1).all(): | |
next_logits_all.append(logits[:1]) | |
next_id = torch.concat([next_id[:, 1:], pred_id_i[:1]], 1) | |
if not (pre_pred_id_list == self.eos).any(dim=-1).all(): | |
pre_logits_all.append(logits[1:2]) | |
pre_id = torch.concat([pred_id_i[1:2], pre_id[:, :-1]], 1) | |
if (next_pred_id_list == self.eos).any(dim=-1).all() and ( | |
pre_pred_id_list == self.eos).any(dim=-1).all(): | |
break | |
# print(next_id, pre_id) | |
# exit(0) | |
if len(next_logits_all) > self.sub_str_len and len( | |
pre_logits_all) > self.sub_str_len: | |
next_logits_all_ = torch.concat(next_logits_all[:-1], | |
1) # 1, l | |
pre_logits_all_ = torch.concat(pre_logits_all[:-1][::-1], | |
1) #1, l | |
next_id = next_logits_all_.argmax(-1)[:, -self.sub_str_len:] | |
pre_id = pre_logits_all_.argmax(-1)[:, :self.sub_str_len] | |
next_logits_all = [] | |
ques_next = self.next_token.tile([1, 1, 1, 1]).squeeze(1) | |
mask_pad = torch.zeros([1, 1], | |
dtype=torch.float32, | |
device=x.get_device()) | |
for j in range(0, min(70, self.max_len - 1)): | |
prompt_next = torch.concat([ | |
prompt_next_embed[:, :1, :], | |
prompt_next_embed[:, 1:, :] + self.char_embed(next_id) | |
], 1) # b, sub_l, dim | |
mask_next = torch.where(next_id == self.bos_next, | |
float('-inf'), 0) # b, subs_l | |
mask = torch.concat([mask_pad, mask_next], 1) | |
# prompt_next = self.char_embed(prompt_id) | |
ques_next_i = ques_next | |
visual_f_i = visual_f[2:3] | |
ques_next_i = self.cmff_decoder(ques_next_i, prompt_next, | |
visual_f_i, | |
mask.unsqueeze(1)) | |
logits_next_i = self.ques1_head( | |
self.norm_pred(ques_next_i)) | |
logits = F.softmax(logits_next_i, -1) | |
pred_id_i = logits.argmax(-1) | |
next_logits_all.append(logits) | |
next_id = torch.concat([next_id[:, 1:, ], pred_id_i], 1) | |
if next_id.equal(pre_id): | |
break | |
next_logits_all = torch.concat(next_logits_all, 1) | |
next_logits_all_ = torch.concat( | |
[next_logits_all_, next_logits_all], 1) | |
return torch.concat( | |
[next_logits_all_, pre_logits_all_[:, self.sub_str_len:]], | |
1) | |
else: | |
return torch.concat(next_logits_all + pre_logits_all[::-1], 1) | |
def forward_test(self, x): | |
# self.attn_maps = [] | |
if not self.ds: | |
visual_f = x + self.vis_pos_embed | |
elif self.pos2d: | |
visual_f = x + self.vis_pos_embed[:, :, :x.shape[2], :x.shape[3]] | |
visual_f = x.flatten(2).transpose(1, 2) | |
else: | |
visual_f = x | |
bs = x.shape[0] | |
if self.next_mode: | |
ques_next = self.next_token.tile([bs, 1, 1, 1]).squeeze(1) | |
prompt_next_embed = self.prompt_next_embed.tile([bs, 1, 1, | |
1]).squeeze(1) | |
prompt_id = torch.full([bs, self.sub_str_len], | |
self.bos_next, | |
dtype=torch.long, | |
device=x.get_device()) | |
pred_id_list = torch.full([bs, self.max_len], | |
self.ignore_index, | |
dtype=torch.long, | |
device=x.get_device()) | |
logits_all = [] | |
mask_pad = torch.zeros([bs, 1], | |
dtype=torch.float32, | |
device=x.get_device()) | |
for j in range(0, self.max_len - 1): | |
prompt_next = torch.concat([ | |
prompt_next_embed[:, :1, :], | |
prompt_next_embed[:, 1:, :] + self.char_embed(prompt_id) | |
], 1) # b, sub_l, dim | |
mask_next = torch.where(prompt_id == self.bos_next, | |
float('-inf'), 0) # b, subs_l | |
mask = torch.concat([mask_pad, mask_next], 1) | |
ques_next_i = ques_next | |
visual_f_i = visual_f | |
ques_next_i = self.cmff_decoder(ques_next_i, prompt_next, | |
visual_f_i, mask.unsqueeze(1)) | |
# self.attn_maps.append( | |
# self.cmff_decoder[-1].question_to_images_cross_attn. | |
# attn_map[0]) | |
logits_next_i = self.ques1_head(self.norm_pred(ques_next_i)) | |
logits = F.softmax(logits_next_i, -1) | |
pred_id_i = logits.argmax(-1) | |
logits_all.append(logits) | |
pred_id_list[:, j:j + 1] = pred_id_i | |
if (pred_id_list == self.eos).any(dim=-1).all(): | |
break | |
prompt_id = torch.concat( | |
[ | |
prompt_id[:, 1:, ], | |
pred_id_i, | |
], | |
1, | |
) | |
return torch.concat(logits_all, 1) | |
else: | |
ques_next = self.pre_token.tile([bs, 1, 1, 1]).squeeze(1) | |
prompt_pre_embed = self.prompt_pre_embed.tile([bs, 1, 1, | |
1]).squeeze(1) | |
prompt_id = torch.full([bs, self.sub_str_len], | |
self.bos_pre, | |
dtype=torch.long, | |
device=x.get_device()) | |
pred_id_list = torch.full([bs, self.max_len], | |
self.ignore_index, | |
dtype=torch.long, | |
device=x.get_device()) | |
logits_all = [] | |
mask_pad = torch.zeros([bs, 1], | |
dtype=torch.float32, | |
device=x.get_device()) | |
for j in range(0, self.max_len - 1): | |
prompt_next = torch.concat([ | |
prompt_pre_embed[:, :1, :], | |
prompt_pre_embed[:, 1:, :] + self.char_embed(prompt_id) | |
], 1) # b, sub_l, dim | |
mask_next = torch.where(prompt_id == self.bos_pre, | |
float('-inf'), 0) # b, subs_l | |
mask = torch.concat([mask_pad, mask_next], 1) | |
ques_next_i = ques_next | |
visual_f_i = visual_f | |
ques_next_i = self.cmff_decoder(ques_next_i, prompt_next, | |
visual_f_i, mask.unsqueeze(1)) | |
logits_next_i = self.ques1_head(self.norm_pred(ques_next_i)) | |
logits = F.softmax(logits_next_i, -1) | |
pred_id_i = logits.argmax(-1) | |
logits_all.append(logits) | |
pred_id_list[:, j:j + 1] = pred_id_i | |
if (pred_id_list == self.eos).any(dim=-1).all(): | |
break | |
prompt_id = torch.concat( | |
[ | |
pred_id_i, | |
prompt_id[:, :-1, ], | |
], | |
1, | |
) | |
return torch.concat(logits_all, 1) | |
def forward_train(self, x, targets=None): | |
bs = x.shape[0] | |
if not self.ds: | |
visual_f = x + self.vis_pos_embed | |
elif self.pos2d: | |
visual_f = x + self.vis_pos_embed[:, :, :x.shape[2], :x.shape[3]] | |
else: | |
visual_f = x | |
max_len_curr = targets[3].max() | |
subs = targets[1][:, :max_len_curr, :] # b, n, subs_l | |
mask_next = torch.where(subs == self.bos_next, float('-inf'), | |
0) # b, n, subs_l | |
prompt_next_embed = self.prompt_next_embed.tile( | |
[bs, max_len_curr, 1, 1]) | |
prompt_char_next = torch.concat([ | |
prompt_next_embed[:, :, :1, :], | |
prompt_next_embed[:, :, 1:, :] + self.char_embed(subs) | |
], 2) # b, n, sub_l, dim | |
next = self.next_token.tile([bs, max_len_curr, 1, 1]) | |
max_len_curr_pre = targets[6].max() | |
subs = targets[4][:, :max_len_curr_pre, :] # b, n, subs_l | |
mask_pre = torch.where(subs == self.bos_pre, float('-inf'), | |
0) # b, n, subs_l | |
prompt_pre_embed = self.prompt_pre_embed.tile( | |
[bs, max_len_curr_pre, 1, 1]) | |
prompt_char_pre = torch.concat([ | |
prompt_pre_embed[:, :, :1, :], | |
prompt_pre_embed[:, :, 1:, :] + self.char_embed(subs) | |
], 2) # b, n, sub_l, dim | |
pre = self.pre_token.tile([bs, max_len_curr_pre, 1, 1]) # b, n, 1, dim | |
prompt_char = torch.concat([prompt_char_next, prompt_char_pre], 1) | |
next_pre = torch.concat([next, pre], 1) | |
mask_pad = torch.zeros([bs * (max_len_curr + max_len_curr_pre), 1], | |
dtype=torch.float32, | |
device=x.get_device()) | |
mask = torch.concat([mask_next, mask_pre], 1).flatten(0, 1) | |
mask = torch.concat([mask_pad, mask], 1) | |
next_pre = next_pre.flatten(0, 1) | |
prompt_char = prompt_char.flatten(0, 1) | |
next_pre = self.cmff_decoder(next_pre, prompt_char, visual_f, | |
mask.unsqueeze(1)) | |
answer1_pred = self.ques1_head(self.norm_pred(next_pre)) | |
logits = answer1_pred[:, :max_len_curr] | |
label = torch.concat( | |
[targets[2][:, :max_len_curr], targets[5][:, :max_len_curr_pre]], | |
1) | |
loss1 = F.cross_entropy(answer1_pred.flatten(0, 1), | |
label.flatten(0, 1), | |
ignore_index=self.ignore_index, | |
reduction='mean') | |
loss = {'loss': loss1} | |
return [loss, logits] | |