v-doc_abstractive_mac / extract_feature.py
Last commit not found
raw
history blame
1.39 kB
import argparse, os, json
import numpy as np
from imageio import imread
from PIL import Image
import torch
import torchvision
import ssl
ssl._create_default_https_context = ssl._create_unverified_context
def build_model(model='resnet101', model_stage=3):
cnn = getattr(torchvision.models, model)(pretrained=True)
layers = [
cnn.conv1,
cnn.bn1,
cnn.relu,
cnn.maxpool,
]
for i in range(model_stage):
name = 'layer%d' % (i + 1)
layers.append(getattr(cnn, name))
model = torch.nn.Sequential(*layers)
# model.cuda()
model.eval()
return model
def run_image(img, model):
mean = np.array([0.485, 0.456, 0.406]).reshape(1, 3, 1, 1)
std = np.array([0.229, 0.224, 0.224]).reshape(1, 3, 1, 1)
image = np.concatenate([img], 0).astype(np.float32)
image = (image / 255.0 - mean) / std
image = torch.FloatTensor(image)
image = torch.autograd.Variable(image, volatile=True)
feats = model(image)
feats = feats.data.cpu().clone().numpy()
return feats
def get_img_feat(cnn_model, img, image_height=224, image_width=224):
img_size = (image_height, image_width)
img = np.array(Image.fromarray(np.uint8(img)).resize(img_size))
img = img.transpose(2, 0, 1)[None]
feats = run_image(img, cnn_model)
_, C, H, W = feats.shape
feat_dset = feats.reshape(1, C, H, W)
return feat_dset