|
import torch |
|
import torch.nn as nn |
|
import torchvision.models as models |
|
|
|
|
|
class VisualExtractor(nn.Module): |
|
def __init__(self, args): |
|
super(VisualExtractor, self).__init__() |
|
self.cov1x1 = nn.Conv2d(in_channels=2048, out_channels=args.nhidden, kernel_size=(1, 1)) |
|
self.visual_extractor = args.visual_extractor |
|
self.pretrained = args.visual_extractor_pretrained |
|
|
|
model = getattr(models, self.visual_extractor)(pretrained=self.pretrained) |
|
modules = list(model.children())[:-2] |
|
self.model = nn.Sequential(*modules) |
|
|
|
self.avg_fnt = torch.nn.AvgPool2d(kernel_size=7, stride=1, padding=0) |
|
if self.pretrained is True: print('first init the imagenet pretrained!') |
|
|
|
def forward(self, images): |
|
|
|
patch_feats = self.model(images) |
|
|
|
att_feat_it = self.cov1x1(patch_feats) |
|
avg_feat_it = self.avg_fnt(att_feat_it).squeeze().reshape(-1, att_feat_it.size(1)) |
|
|
|
avg_feats = self.avg_fnt(patch_feats).squeeze().reshape(-1, patch_feats.size(1)) |
|
batch_size, feat_size, _, _ = patch_feats.shape |
|
patch_feats = patch_feats.reshape(batch_size, feat_size, -1).permute(0, 2, 1) |
|
return patch_feats, avg_feats, att_feat_it, avg_feat_it |
|
|
|
|
|
|