# coding=utf-8 from __future__ import absolute_import from __future__ import division from __future__ import print_function import copy import logging import math from os.path import join as pjoin import torch import torch.nn as nn import numpy as np from torch.nn import BCEWithLogitsLoss,CrossEntropyLoss, Dropout, Softmax, Linear, Conv2d, LayerNorm from torch.nn.modules.utils import _pair from scipy import ndimage import models.configs as configs from models.attention import Attention from models.embed import Embeddings from models.mlp import Mlp from models.block import Block class Encoder(nn.Module): def __init__(self, config, vis,mm): super(Encoder, self).__init__() self.vis = vis self.layer = nn.ModuleList() self.encoder_norm = LayerNorm(config.hidden_size, eps=1e-6) for i in range(config.transformer["num_layers"]): if i < 2: layer = Block(config, vis, mm) else: layer = Block(config, vis,mm=False) self.layer.append(copy.deepcopy(layer)) self.img_adaptive_avg_pool = nn.AdaptiveAvgPool2d(output_size=(1, 512)) self.txt_adaptive_avg_pool = nn.AdaptiveAvgPool2d(output_size=(1, 512)) def forward(self, hidden_states, text=None): for (i, layer_block) in enumerate(self.layer): if i == 2: if text is not None: hidden_states = self.img_adaptive_avg_pool(hidden_states) text = self.txt_adaptive_avg_pool(text) hidden_states = torch.cat((hidden_states, text), 1) hidden_states,text ,weights = layer_block(hidden_states) else: hidden_states, text, weights = layer_block(hidden_states) elif i < 2: hidden_states, text, weights = layer_block(hidden_states, text) else: hidden_states,text, weights = layer_block(hidden_states) encoded = self.encoder_norm(hidden_states) return encoded