WwYc's picture
Upload 61 files
08d7644 verified
raw
history blame
1.53 kB
# coding=utf-8
# Copyleft 2019 project LXRT.
import torch.nn as nn
from ..param import args
from ..lxrt.entry import LXRTEncoder
from ..lxrt.modeling import BertLayerNorm, GeLU
from transformers import AutoTokenizer, AutoModelForQuestionAnswering
# Max length including <bos> and <eos>
MAX_VQA_LENGTH = 20
class VQAModel(nn.Module):
def __init__(self, num_answers):
super().__init__()
# # Build LXRT encoder
# self.lxrt_encoder = LXRTEncoder(
# args,
# max_seq_length=MAX_VQA_LENGTH
# )
# hid_dim = self.lxrt_encoder.dim
#
# # VQA Answer heads
# self.logit_fc = nn.Sequential(
# nn.Linear(hid_dim, hid_dim * 2),
# GeLU(),
# BertLayerNorm(hid_dim * 2, eps=1e-12),
# nn.Linear(hid_dim * 2, num_answers)
# )
# self.logit_fc.apply(self.lxrt_encoder.model.init_bert_weights)
self.tokenizer = AutoTokenizer.from_pretrained("unc-nlp/lxmert-vqa-uncased")
self.model = AutoModelForQuestionAnswering.from_pretrained("unc-nlp/lxmert-vqa-uncased")
def forward(self, feat, pos, sent):
"""
b -- batch_size, o -- object_number, f -- visual_feature_size
:param feat: (b, o, f)
:param pos: (b, o, 4)
:param sent: (b,) Type -- list of string
:param leng: (b,) Type -- int numpy array
:return: (b, num_answer) The logit of each answers.
"""
return self.model(sent, feat, pos)