least1924's picture
Upload 10 files
6775edf verified
raw
history blame
2.07 kB
# coding=utf-8
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import copy
import logging
import math
from os.path import join as pjoin
import torch
import torch.nn as nn
import numpy as np
from torch.nn import BCEWithLogitsLoss,CrossEntropyLoss, Dropout, Softmax, Linear, Conv2d, LayerNorm
from torch.nn.modules.utils import _pair
from scipy import ndimage
import models.configs as configs
from models.attention import Attention
from models.embed import Embeddings
from models.mlp import Mlp
from models.block import Block
class Encoder(nn.Module):
def __init__(self, config, vis,mm):
super(Encoder, self).__init__()
self.vis = vis
self.layer = nn.ModuleList()
self.encoder_norm = LayerNorm(config.hidden_size, eps=1e-6)
for i in range(config.transformer["num_layers"]):
if i < 2:
layer = Block(config, vis, mm)
else:
layer = Block(config, vis,mm=False)
self.layer.append(copy.deepcopy(layer))
self.img_adaptive_avg_pool = nn.AdaptiveAvgPool2d(output_size=(1, 512))
self.txt_adaptive_avg_pool = nn.AdaptiveAvgPool2d(output_size=(1, 512))
def forward(self, hidden_states, text=None):
for (i, layer_block) in enumerate(self.layer):
if i == 2:
if text is not None:
hidden_states = self.img_adaptive_avg_pool(hidden_states)
text = self.txt_adaptive_avg_pool(text)
hidden_states = torch.cat((hidden_states, text), 1)
hidden_states,text ,weights = layer_block(hidden_states)
else:
hidden_states, text, weights = layer_block(hidden_states)
elif i < 2:
hidden_states, text, weights = layer_block(hidden_states, text)
else:
hidden_states,text, weights = layer_block(hidden_states)
encoded = self.encoder_norm(hidden_states)
return encoded