explain-LXMERT / lxmert /src /tasks /nlvr2_model.py
WwYc's picture
Upload 61 files
08d7644 verified
raw
history blame
1.77 kB
# coding=utf-8
# Copyleft 2019 project LXRT.
import torch.nn as nn
from lxrt.modeling import GeLU, BertLayerNorm
from lxrt.entry import LXRTEncoder
from param import args
class NLVR2Model(nn.Module):
def __init__(self):
super().__init__()
self.lxrt_encoder = LXRTEncoder(
args,
max_seq_length=20
)
self.hid_dim = hid_dim = self.lxrt_encoder.dim
self.logit_fc = nn.Sequential(
nn.Linear(hid_dim * 2, hid_dim * 2),
GeLU(),
BertLayerNorm(hid_dim * 2, eps=1e-12),
nn.Linear(hid_dim * 2, 2)
)
self.logit_fc.apply(self.lxrt_encoder.model.init_bert_weights)
def forward(self, feat, pos, sent):
"""
:param feat: b, 2, o, f
:param pos: b, 2, o, 4
:param sent: b, (string)
:param leng: b, (numpy, int)
:return:
"""
# Pairing images and sentences:
# The input of NLVR2 is two images and one sentence. In batch level, they are saved as
# [ [img0_0, img0_1], [img1_0, img1_1], ...] and [sent0, sent1, ...]
# Here, we flat them to
# feat/pos = [ img0_0, img0_1, img1_0, img1_1, ...]
# sent = [ sent0, sent0, sent1, sent1, ...]
sent = sum(zip(sent, sent), ())
batch_size, img_num, obj_num, feat_size = feat.size()
assert img_num == 2 and obj_num == 36 and feat_size == 2048
feat = feat.view(batch_size * 2, obj_num, feat_size)
pos = pos.view(batch_size * 2, obj_num, 4)
# Extract feature --> Concat
x = self.lxrt_encoder(sent, (feat, pos))
x = x.view(-1, self.hid_dim*2)
# Compute logit of answers
logit = self.logit_fc(x)
return logit