Spaces:
Running
Running
import numpy as np | |
import torch | |
from torch import nn | |
from torch.nn import functional as F | |
from torch.nn.init import ones_, trunc_normal_, zeros_ | |
from openrec.modeling.common import DropPath, Identity, Mlp | |
from openrec.modeling.decoders.nrtr_decoder import Embeddings | |
class Attention(nn.Module): | |
def __init__( | |
self, | |
dim, | |
num_heads=8, | |
qkv_bias=False, | |
qk_scale=None, | |
attn_drop=0.0, | |
proj_drop=0.0, | |
): | |
super().__init__() | |
self.num_heads = num_heads | |
head_dim = dim // num_heads | |
self.scale = qk_scale or head_dim**-0.5 | |
self.q = nn.Linear(dim, dim, bias=qkv_bias) | |
self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias) | |
self.attn_drop = nn.Dropout(attn_drop) | |
self.proj = nn.Linear(dim, dim) | |
self.proj_drop = nn.Dropout(proj_drop) | |
def forward(self, q, kv, key_mask=None): | |
N, C = kv.shape[1:] | |
QN = q.shape[1] | |
q = self.q(q).reshape([-1, QN, self.num_heads, | |
C // self.num_heads]).transpose(1, 2) | |
q = q * self.scale | |
k, v = self.kv(kv).reshape( | |
[-1, N, 2, self.num_heads, | |
C // self.num_heads]).permute(2, 0, 3, 1, 4) | |
attn = q.matmul(k.transpose(2, 3)) | |
if key_mask is not None: | |
attn = attn + key_mask.unsqueeze(1) | |
attn = F.softmax(attn, -1) | |
# if not self.training: | |
# self.attn_map = attn | |
attn = self.attn_drop(attn) | |
x = (attn.matmul(v)).transpose(1, 2).reshape((-1, QN, C)) | |
x = self.proj(x) | |
x = self.proj_drop(x) | |
return x | |
class EdgeDecoderLayer(nn.Module): | |
def __init__( | |
self, | |
dim, | |
num_heads, | |
mlp_ratio=4.0, | |
qkv_bias=False, | |
qk_scale=None, | |
drop=0.0, | |
attn_drop=0.0, | |
drop_path=[0.0, 0.0], | |
act_layer=nn.GELU, | |
norm_layer='nn.LayerNorm', | |
epsilon=1e-6, | |
): | |
super().__init__() | |
self.head_dim = dim // num_heads | |
self.scale = qk_scale or self.head_dim**-0.5 | |
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here | |
self.drop_path1 = DropPath( | |
drop_path[0]) if drop_path[0] > 0.0 else Identity() | |
self.norm1 = eval(norm_layer)(dim, epsilon=epsilon) | |
self.norm2 = eval(norm_layer)(dim, epsilon=epsilon) | |
self.p = nn.Linear(dim, dim) | |
self.cv = nn.Linear(dim, dim) | |
self.pv = nn.Linear(dim, dim) | |
self.dim = dim | |
self.num_heads = num_heads | |
self.p_proj = nn.Linear(dim, dim) | |
mlp_hidden_dim = int(dim * mlp_ratio) | |
self.mlp_ratio = mlp_ratio | |
self.mlp = Mlp( | |
in_features=dim, | |
hidden_features=mlp_hidden_dim, | |
act_layer=act_layer, | |
drop=drop, | |
) | |
def forward(self, p, cv, pv): | |
pN = p.shape[1] | |
vN = cv.shape[1] | |
p_shortcut = p | |
p1 = self.p(p).reshape( | |
[-1, pN, self.num_heads, | |
self.dim // self.num_heads]).transpose(1, 2) | |
cv1 = self.cv(cv).reshape( | |
[-1, vN, self.num_heads, | |
self.dim // self.num_heads]).transpose(1, 2) | |
pv1 = self.pv(pv).reshape( | |
[-1, vN, self.num_heads, | |
self.dim // self.num_heads]).transpose(1, 2) | |
edge = F.softmax(p1.matmul(pv1.transpose(2, 3)), -1) # B h N N | |
p_c = (edge @ cv1).transpose(1, 2).reshape((-1, pN, self.dim)) | |
x1 = self.norm1(p_shortcut + self.drop_path1(self.p_proj(p_c))) | |
x = self.norm2(x1 + self.drop_path1(self.mlp(x1))) | |
return x | |
class DecoderLayer(nn.Module): | |
def __init__( | |
self, | |
dim, | |
num_heads, | |
mlp_ratio=4.0, | |
qkv_bias=False, | |
qk_scale=None, | |
drop=0.0, | |
attn_drop=0.0, | |
drop_path=0.0, | |
act_layer=nn.GELU, | |
norm_layer=nn.LayerNorm, | |
epsilon=1e-6, | |
): | |
super().__init__() | |
self.norm1 = norm_layer(dim, eps=epsilon) | |
self.mixer = Attention( | |
dim, | |
num_heads=num_heads, | |
qkv_bias=qkv_bias, | |
qk_scale=qk_scale, | |
attn_drop=attn_drop, | |
proj_drop=drop, | |
) | |
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here | |
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else Identity() | |
self.norm2 = norm_layer(dim, eps=epsilon) | |
mlp_hidden_dim = int(dim * mlp_ratio) | |
self.mlp_ratio = mlp_ratio | |
self.mlp = Mlp( | |
in_features=dim, | |
hidden_features=mlp_hidden_dim, | |
act_layer=act_layer, | |
drop=drop, | |
) | |
def forward(self, q, kv, key_mask=None): | |
x1 = self.norm1(q + self.drop_path(self.mixer(q, kv, key_mask))) | |
x = self.norm2(x1 + self.drop_path(self.mlp(x1))) | |
return x | |
class CPPDDecoder(nn.Module): | |
def __init__(self, | |
in_channels, | |
out_channels, | |
num_layer=2, | |
drop_path_rate=0.1, | |
max_len=25, | |
vis_seq=50, | |
iters=1, | |
pos_len=False, | |
ch=False, | |
rec_layer=1, | |
num_heads=None, | |
ds=False, | |
**kwargs): | |
super(CPPDDecoder, self).__init__() | |
self.out_channels = out_channels # none + 26 + 10 | |
dim = in_channels | |
self.dim = dim | |
self.iters = iters | |
self.max_len = max_len + 1 # max_len + eos | |
self.pos_len = pos_len | |
self.ch = ch | |
self.char_node_embed = Embeddings(d_model=dim, | |
vocab=self.out_channels, | |
scale_embedding=True) | |
self.pos_node_embed = Embeddings(d_model=dim, | |
vocab=self.max_len, | |
scale_embedding=True) | |
dpr = np.linspace(0, drop_path_rate, num_layer + rec_layer) | |
self.char_node_decoder = nn.ModuleList([ | |
DecoderLayer( | |
dim=dim, | |
num_heads=dim // 32 if num_heads is None else num_heads, | |
mlp_ratio=4.0, | |
qkv_bias=True, | |
drop_path=dpr[i], | |
) for i in range(num_layer) | |
]) | |
self.pos_node_decoder = nn.ModuleList([ | |
DecoderLayer( | |
dim=dim, | |
num_heads=dim // 32 if num_heads is None else num_heads, | |
mlp_ratio=4.0, | |
qkv_bias=True, | |
drop_path=dpr[i], | |
) for i in range(num_layer) | |
]) | |
self.edge_decoder = nn.ModuleList([ | |
DecoderLayer( | |
dim=dim, | |
num_heads=dim // 32 if num_heads is None else num_heads, | |
mlp_ratio=4.0, | |
qkv_bias=True, | |
qk_scale=1.0 if (rec_layer + i) % 2 != 0 else None, | |
drop_path=dpr[num_layer + i], | |
) for i in range(rec_layer) | |
]) | |
self.rec_layer_num = rec_layer | |
self_mask = torch.tril( | |
torch.ones([self.max_len, self.max_len], dtype=torch.float32)) | |
self_mask = torch.where( | |
self_mask > 0, | |
torch.zeros_like(self_mask, dtype=torch.float32), | |
torch.full([self.max_len, self.max_len], | |
float('-inf'), | |
dtype=torch.float32), | |
) | |
self.self_mask = self_mask.unsqueeze(0) | |
self.char_pos_embed = nn.Parameter(torch.zeros([1, self.max_len, dim], | |
dtype=torch.float32), | |
requires_grad=True) | |
self.ds = ds | |
if not self.ds: | |
self.vis_pos_embed = nn.Parameter(torch.zeros([1, vis_seq, dim], | |
dtype=torch.float32), | |
requires_grad=True) | |
trunc_normal_(self.vis_pos_embed, std=0.02) | |
self.char_node_fc1 = nn.Linear(dim, max_len) | |
self.pos_node_fc1 = nn.Linear(dim, self.max_len) | |
self.edge_fc = nn.Linear(dim, self.out_channels) | |
trunc_normal_(self.char_pos_embed, std=0.02) | |
self.apply(self._init_weights) | |
def _init_weights(self, m): | |
if isinstance(m, nn.Linear): | |
trunc_normal_(m.weight, std=0.02) | |
if isinstance(m, nn.Linear) and m.bias is not None: | |
zeros_(m.bias) | |
elif isinstance(m, nn.LayerNorm): | |
zeros_(m.bias) | |
ones_(m.weight) | |
def no_weight_decay(self): | |
return { | |
'char_pos_embed', 'vis_pos_embed', 'char_node_embed', | |
'pos_node_embed' | |
} | |
def forward(self, x, data=None): | |
if self.training: | |
return self.forward_train(x, data) | |
else: | |
return self.forward_test(x) | |
def forward_test(self, x): | |
if not self.ds: | |
visual_feats = x + self.vis_pos_embed | |
else: | |
visual_feats = x | |
bs = visual_feats.shape[0] | |
pos_node_embed = self.pos_node_embed( | |
torch.arange(self.max_len).cuda( | |
x.get_device())).unsqueeze(0) + self.char_pos_embed | |
pos_node_embed = torch.tile(pos_node_embed, [bs, 1, 1]) | |
char_vis_node_query = visual_feats | |
pos_vis_node_query = torch.concat([pos_node_embed, visual_feats], 1) | |
for char_decoder_layer, pos_decoder_layer in zip( | |
self.char_node_decoder, self.pos_node_decoder): | |
char_vis_node_query = char_decoder_layer(char_vis_node_query, | |
char_vis_node_query) | |
pos_vis_node_query = pos_decoder_layer( | |
pos_vis_node_query, pos_vis_node_query[:, self.max_len:, :]) | |
pos_node_query = pos_vis_node_query[:, :self.max_len, :] | |
char_vis_feats = char_vis_node_query | |
# pos_vis_feats = pos_vis_node_query[:, self.max_len :, :] | |
# pos_node_feats = self.edge_decoder( | |
# pos_node_query, char_vis_feats, pos_vis_feats | |
# ) # B, 26, dim | |
pos_node_feats = pos_node_query | |
for layer_i in range(self.rec_layer_num): | |
rec_layer = self.edge_decoder[layer_i] | |
if (self.rec_layer_num + layer_i) % 2 == 0: | |
pos_node_feats = rec_layer(pos_node_feats, pos_node_feats, | |
self.self_mask) | |
else: | |
pos_node_feats = rec_layer(pos_node_feats, char_vis_feats) | |
edge_feats = self.edge_fc(pos_node_feats) # B, 26, 37 | |
edge_logits = F.softmax( | |
edge_feats, | |
-1) # * F.sigmoid(pos_node_feats1.unsqueeze(-1)) # B, 26, 37 | |
return edge_logits | |
def forward_train(self, x, targets=None): | |
if not self.ds: | |
visual_feats = x + self.vis_pos_embed | |
else: | |
visual_feats = x | |
bs = visual_feats.shape[0] | |
if self.ch: | |
char_node_embed = self.char_node_embed(targets[-2]) | |
else: | |
char_node_embed = self.char_node_embed( | |
torch.arange(self.out_channels).cuda( | |
x.get_device())).unsqueeze(0) | |
char_node_embed = torch.tile(char_node_embed, [bs, 1, 1]) | |
counting_char_num = char_node_embed.shape[1] | |
pos_node_embed = self.pos_node_embed( | |
torch.arange(self.max_len).cuda( | |
x.get_device())).unsqueeze(0) + self.char_pos_embed | |
pos_node_embed = torch.tile(pos_node_embed, [bs, 1, 1]) | |
node_feats = [] | |
char_vis_node_query = torch.concat([char_node_embed, visual_feats], 1) | |
pos_vis_node_query = torch.concat([pos_node_embed, visual_feats], 1) | |
for char_decoder_layer, pos_decoder_layer in zip( | |
self.char_node_decoder, self.pos_node_decoder): | |
char_vis_node_query = char_decoder_layer( | |
char_vis_node_query, | |
char_vis_node_query[:, counting_char_num:, :]) | |
pos_vis_node_query = pos_decoder_layer( | |
pos_vis_node_query, pos_vis_node_query[:, self.max_len:, :]) | |
char_node_query = char_vis_node_query[:, :counting_char_num, :] | |
pos_node_query = pos_vis_node_query[:, :self.max_len, :] | |
char_vis_feats = char_vis_node_query[:, counting_char_num:, :] | |
char_node_feats1 = self.char_node_fc1(char_node_query) | |
pos_node_feats1 = self.pos_node_fc1(pos_node_query) | |
if not self.pos_len: | |
diag_mask = torch.eye(pos_node_feats1.shape[1]).unsqueeze(0).tile( | |
[pos_node_feats1.shape[0], 1, 1]) | |
pos_node_feats1 = ( | |
pos_node_feats1 * | |
diag_mask.cuda(pos_node_feats1.get_device())).sum(-1) | |
node_feats.append(char_node_feats1) | |
node_feats.append(pos_node_feats1) | |
pos_node_feats = pos_node_query | |
for layer_i in range(self.rec_layer_num): | |
rec_layer = self.edge_decoder[layer_i] | |
if (self.rec_layer_num + layer_i) % 2 == 0: | |
pos_node_feats = rec_layer(pos_node_feats, pos_node_feats, | |
self.self_mask) | |
else: | |
pos_node_feats = rec_layer(pos_node_feats, char_vis_feats) | |
edge_feats = self.edge_fc(pos_node_feats) # B, 26, 37 | |
return node_feats, edge_feats | |