Spaces:
Sleeping
Sleeping
File size: 10,253 Bytes
9bf9e42 5d45228 9bf9e42 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 |
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import json, math
import numpy as np
import os, sys
from six.moves import cPickle
from sys import path
sys.path.insert(0, os.getcwd())
sys.path.insert(0, 'captioning/')
# print('relative captioning is called')
import captioning.utils.opts as opts
import captioning.models as models
from captioning.data.dataloader import *
from captioning.data.dataloaderraw import *
import argparse
import captioning.utils.misc as utils
import torch
import skimage.io
from torch.autograd import Variable
from torchvision import transforms as trn
preprocess = trn.Compose([
# trn.ToTensor(),
trn.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
from captioning.utils.resnet_utils import myResnet
from captioning.utils.resnet_utils import ResNetBatch
import captioning.utils.resnet as resnet
import wget
import tempfile
class object:
def __init__(self):
self.input_fc_dir = ''
self.input_json = ''
self.batch_size = ''
self.id = ''
self.sample_max = 1
self.cnn_model = 'resnet101'
self.model = ''
self.language_eval = 0
self.beam_size = 1
self.temperature = 1.0
return
class Captioner():
def __init__(self, is_relative=True, model_path=None, image_feat_params=None, data_type=None, load_resnet=True, diff_feat=None):
opt = object()
if image_feat_params==None:
image_feat_params = {}
image_feat_params['model'] = 'resnet101'
image_feat_params['model_root'] = ''
image_feat_params['att_size'] = 7
# inputs specific to shoe dataset
infos_path = os.path.join(model_path, 'infos_best.pkl')
model_path = os.path.join(model_path, 'model_best.pth')
opt.infos_path = infos_path
opt.model_path = model_path
opt.beam_size = 1
opt.load_resnet = load_resnet
# load pre-trained model, adjusting if URL
if opt.infos_path.startswith("http:") or opt.infos_path.startswith("https:"):
# create a folder to store the checkpoints for downloading
if not os.path.exists('./checkpoints_usersim'):
os.mkdir('./checkpoints_usersim')
checkpoint_path = os.path.join('./checkpoints_usersim', data_type)
if not os.path.exists(checkpoint_path):
os.mkdir(checkpoint_path)
# set the location for infos
infos_loc = os.path.join(checkpoint_path, 'infos_best.pkl')
if not os.path.exists(infos_loc):
try:
wget.download(opt.infos_path, infos_loc)
except Exception as err:
print(f"[{err}]")
else:
infos_loc = infos_path
if opt.model_path.startswith("http:") or opt.model_path.startswith("https:"):
# create a folder to store the checkpoints for downloading
if not os.path.exists('./checkpoints_usersim'):
os.mkdir('./checkpoints_usersim')
checkpoint_path = os.path.join('./checkpoints_usersim', data_type)
if not os.path.exists(checkpoint_path):
os.mkdir(checkpoint_path)
# set the location for models
model_loc = os.path.join(checkpoint_path, 'model_best.pth')
if not os.path.exists(model_loc):
try:
wget.download(opt.model_path, model_loc)
except Exception as err:
print(f"[{err}]")
opt.model = model_loc
else:
opt.model = model_path
if os.path.exists(infos_loc):
# load existing infos
with open(infos_loc, 'rb') as f:
infos = cPickle.load(f)
self.caption_model = infos["opt"].caption_model
# override and collect parameters
if len(opt.input_fc_dir) == 0:
opt.input_fc_dir = infos['opt'].input_fc_dir
opt.input_att_dir = infos['opt'].input_att_dir
opt.input_label_h5 = infos['opt'].input_label_h5
if len(opt.input_json) == 0:
opt.input_json = infos['opt'].input_json
if opt.batch_size == 0:
opt.batch_size = infos['opt'].batch_size
if len(opt.id) == 0:
opt.id = infos['opt'].id
ignore = ["id", "batch_size", "beam_size", "start_from", "language_eval", "model"]
for k in vars(infos['opt']).keys():
if k not in ignore:
if k in vars(opt):
assert vars(opt)[k] == vars(infos['opt'])[k], k + ' option not consistent'
else:
vars(opt).update({k: vars(infos['opt'])[k]}) # copy over options from model
vocab = infos['vocab'] # ix -> word mapping
# print('opt:', opt)
# Setup the model
opt.vocab = vocab
model = models.setup(opt)
del opt.vocab
if torch.cuda.is_available():
model.load_state_dict(torch.load(opt.model))
model.cuda()
else:
model.load_state_dict(torch.load(opt.model, map_location={'cuda:0': 'cpu'}))
model.eval()
self.is_relative = is_relative
self.model = model
self.vocab = vocab
self.opt = vars(opt)
# Load ResNet for processing images
if opt.load_resnet:
if image_feat_params['model_root']=='':
net = getattr(resnet, image_feat_params['model'])(pretrained=True)
else:
net = getattr(resnet, image_feat_params['model'])()
net.load_state_dict(
torch.load(os.path.join(image_feat_params['model_root'], image_feat_params['model'] + '.pth')))
my_resnet = myResnet(net)
if torch.cuda.is_available():
my_resnet.cuda()
my_resnet.eval()
my_resnet_batch = ResNetBatch(net)
if torch.cuda.is_available():
my_resnet_batch.cuda()
self.my_resnet_batch = my_resnet_batch
self.my_resnet = my_resnet
self.att_size = image_feat_params['att_size']
# Control the input features of the model
if diff_feat == None:
if self.caption_model == "show_attend_tell":
self.diff_feat = True
else:
self.diff_feat = False
else:
self.diff_feat = diff_feat
def gen_caption_from_feat(self, feat_target, feat_reference=None):
if self.is_relative and feat_reference == None:
return None, None
if not self.is_relative and not feat_reference == None:
return None, None
if self.is_relative:
if self.diff_feat:
fc_feat = torch.cat((feat_target[0], feat_target[0] - feat_reference[0]), dim=-1)
att_feat = torch.cat((feat_target[1], feat_target[1] - feat_reference[1]), dim=-1)
else:
fc_feat = torch.cat((feat_target[0], feat_reference[0]), dim=-1)
att_feat = torch.cat((feat_target[1], feat_reference[1]), dim=-1)
else:
fc_feat = feat_target[0]
att_feat = feat_target[1]
# Reshape to B x K x C (128,14,14,4096) --> (128,196,4096)
att_feat = att_feat.view(att_feat.shape[0], att_feat.shape[1] * att_feat.shape[2], att_feat.shape[-1])
att_masks = np.zeros(att_feat.shape[:2], dtype='float32')
for i in range(len(att_feat)):
att_masks[i, :att_feat[i].shape[0]] = 1
# set att_masks to None if attention features have same length
if att_masks.sum() == att_masks.size:
att_masks = None
if self.caption_model == 'show_attend_tell':
seq, _ = self.model.sample(fc_feat, att_feat, self.opt)
else:
seq, _ = self.model(fc_feat, att_feat, att_masks=att_masks, opt=self.opt, mode='sample')
sents = utils.decode_sequence(self.vocab, seq)
return seq, sents
def get_vocab_size(self):
return len(self.vocab)
def get_img_feat(self, img_name):
# load the image
I = skimage.io.imread(img_name)
if len(I.shape) == 2:
I = I[:, :, np.newaxis]
I = np.concatenate((I, I, I), axis=2)
I = I.astype('float32') / 255.0
I = torch.from_numpy(I.transpose([2, 0, 1]))
if torch.cuda.is_available(): I = I.cuda()
# I = Variable(preprocess(I), volatile=True)
with torch.no_grad():
I = preprocess(I)
fc, att = self.my_resnet(I, self.att_size)
return fc, att
def get_img_feat_batch(self, img_names, batchsize=32):
if not isinstance(img_names, list):
img_names = [img_names]
num_images = len(img_names)
num_batches = math.ceil(np.float(num_images) / np.float(batchsize))
feature_fc = []
feature_att = []
for id in range(num_batches):
startInd = id * batchsize
endInd = min((id + 1) * batchsize, num_images)
img_names_current_batch = img_names[startInd:endInd]
I_current_batch = []
for img_name in img_names_current_batch:
I = skimage.io.imread(img_name)
if len(I.shape) == 2:
I = I[:, :, np.newaxis]
I = np.concatenate((I, I, I), axis=2)
I = I.astype('float32') / 255.0
I = torch.from_numpy(I.transpose([2, 0, 1]))
I_current_batch.append(preprocess(I))
I_current_batch = torch.stack(I_current_batch, dim=0)
if torch.cuda.is_available(): I_current_batch = I_current_batch.cuda()
# I_current_batch = Variable(I_current_batch, volatile=True)
with torch.no_grad():
fc, att = self.my_resnet_batch(I_current_batch, self.att_size)
feature_fc.append(fc)
feature_att.append(att)
feature_fc = torch.cat(feature_fc, dim=0)
feature_att = torch.cat(feature_att, dim=0)
return feature_fc, feature_att
|