File size: 1,241 Bytes
3b2b066 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 |
import torch
import torch.nn as nn
import torchvision.models as models
class VisualExtractor(nn.Module):
def __init__(self, args):
super(VisualExtractor, self).__init__()
self.cov1x1 = nn.Conv2d(in_channels=2048, out_channels=args.nhidden, kernel_size=(1, 1))
self.visual_extractor = args.visual_extractor
self.pretrained = args.visual_extractor_pretrained
model = getattr(models, self.visual_extractor)(pretrained=self.pretrained)
modules = list(model.children())[:-2]
self.model = nn.Sequential(*modules)
self.avg_fnt = torch.nn.AvgPool2d(kernel_size=7, stride=1, padding=0)
if self.pretrained is True: print('first init the imagenet pretrained!')
def forward(self, images):
patch_feats = self.model(images)
att_feat_it = self.cov1x1(patch_feats)
avg_feat_it = self.avg_fnt(att_feat_it).squeeze().reshape(-1, att_feat_it.size(1))
avg_feats = self.avg_fnt(patch_feats).squeeze().reshape(-1, patch_feats.size(1))
batch_size, feat_size, _, _ = patch_feats.shape
patch_feats = patch_feats.reshape(batch_size, feat_size, -1).permute(0, 2, 1)
return patch_feats, avg_feats, att_feat_it, avg_feat_it
|