|
|
|
from __future__ import absolute_import |
|
from __future__ import division |
|
from __future__ import print_function |
|
|
|
import copy |
|
import logging |
|
import math |
|
|
|
from os.path import join as pjoin |
|
|
|
import torch |
|
import torch.nn as nn |
|
import numpy as np |
|
|
|
from torch.nn import BCEWithLogitsLoss,CrossEntropyLoss, Dropout, Softmax, Linear, Conv2d, LayerNorm |
|
from torch.nn.modules.utils import _pair |
|
from scipy import ndimage |
|
|
|
import models.configs as configs |
|
from models.attention import Attention |
|
from models.embed import Embeddings |
|
from models.mlp import Mlp |
|
from models.block import Block |
|
|
|
class Encoder(nn.Module): |
|
def __init__(self, config, vis,mm): |
|
super(Encoder, self).__init__() |
|
self.vis = vis |
|
self.layer = nn.ModuleList() |
|
self.encoder_norm = LayerNorm(config.hidden_size, eps=1e-6) |
|
for i in range(config.transformer["num_layers"]): |
|
if i < 2: |
|
layer = Block(config, vis, mm) |
|
else: |
|
layer = Block(config, vis,mm=False) |
|
self.layer.append(copy.deepcopy(layer)) |
|
self.img_adaptive_avg_pool = nn.AdaptiveAvgPool2d(output_size=(1, 512)) |
|
self.txt_adaptive_avg_pool = nn.AdaptiveAvgPool2d(output_size=(1, 512)) |
|
|
|
def forward(self, hidden_states, text=None): |
|
for (i, layer_block) in enumerate(self.layer): |
|
if i == 2: |
|
if text is not None: |
|
hidden_states = self.img_adaptive_avg_pool(hidden_states) |
|
text = self.txt_adaptive_avg_pool(text) |
|
hidden_states = torch.cat((hidden_states, text), 1) |
|
hidden_states,text ,weights = layer_block(hidden_states) |
|
else: |
|
hidden_states, text, weights = layer_block(hidden_states) |
|
elif i < 2: |
|
hidden_states, text, weights = layer_block(hidden_states, text) |
|
else: |
|
hidden_states,text, weights = layer_block(hidden_states) |
|
encoded = self.encoder_norm(hidden_states) |
|
return encoded |
|
|