Spaces:
Sleeping
Sleeping
yashonwu
commited on
Commit
•
9bf9e42
1
Parent(s):
ed7b5bc
add captioning
Browse files- captioning/.DS_Store +0 -0
- captioning/__init__.py +0 -0
- captioning/captioner.py +296 -0
- captioning/data/.DS_Store +0 -0
- captioning/data/__init__.py +0 -0
- captioning/data/dataloader.py +439 -0
- captioning/data/dataloader_recsys.py +432 -0
- captioning/data/dataloaderraw.py +151 -0
- captioning/data/pth_loader.py +300 -0
- captioning/models/.DS_Store +0 -0
- captioning/models/AoAModel.py +234 -0
- captioning/models/AttEnsemble.py +90 -0
- captioning/models/AttModel.py +977 -0
- captioning/models/BertCapModel.py +103 -0
- captioning/models/CaptionModel.py +411 -0
- captioning/models/FCModel.py +204 -0
- captioning/models/M2Transformer.py +102 -0
- captioning/models/OldModel.py +265 -0
- captioning/models/ShowTellModel.py +368 -0
- captioning/models/TransformerModel.py +367 -0
- captioning/models/__init__.py +76 -0
- captioning/models/cachedTransformer.py +420 -0
- captioning/models/utils.py +25 -0
- captioning/modules/.DS_Store +0 -0
- captioning/modules/loss_wrapper.py +65 -0
- captioning/modules/losses.py +218 -0
- captioning/utils/.DS_Store +0 -0
- captioning/utils/__init__.py +0 -0
- captioning/utils/config.py +153 -0
- captioning/utils/div_utils.py +38 -0
- captioning/utils/eval_multi.py +218 -0
- captioning/utils/eval_utils.py +323 -0
- captioning/utils/misc.py +249 -0
- captioning/utils/opts.py +365 -0
- captioning/utils/resnet.py +84 -0
- captioning/utils/resnet_utils.py +51 -0
- captioning/utils/rewards.py +136 -0
captioning/.DS_Store
ADDED
Binary file (8.2 kB). View file
|
|
captioning/__init__.py
ADDED
File without changes
|
captioning/captioner.py
ADDED
@@ -0,0 +1,296 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import absolute_import
|
2 |
+
from __future__ import division
|
3 |
+
from __future__ import print_function
|
4 |
+
|
5 |
+
import json, math
|
6 |
+
import numpy as np
|
7 |
+
|
8 |
+
import os, sys
|
9 |
+
from six.moves import cPickle
|
10 |
+
|
11 |
+
from sys import path
|
12 |
+
|
13 |
+
sys.path.insert(0, os.getcwd())
|
14 |
+
sys.path.insert(0, 'captioning/')
|
15 |
+
# print('relative captioning is called')
|
16 |
+
|
17 |
+
import captioning.utils.opts as opts
|
18 |
+
import captioning.models as models
|
19 |
+
from captioning.data.dataloader import *
|
20 |
+
from captioning.data.dataloaderraw import *
|
21 |
+
|
22 |
+
import argparse
|
23 |
+
import captioning.utils.misc as utils
|
24 |
+
import torch
|
25 |
+
|
26 |
+
import skimage.io
|
27 |
+
from torch.autograd import Variable
|
28 |
+
from torchvision import transforms as trn
|
29 |
+
|
30 |
+
preprocess = trn.Compose([
|
31 |
+
# trn.ToTensor(),
|
32 |
+
trn.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
33 |
+
])
|
34 |
+
|
35 |
+
from captioning.utils.resnet_utils import myResnet
|
36 |
+
from captioning.utils.resnet_utils import ResNetBatch
|
37 |
+
import captioning.utils.resnet as resnet
|
38 |
+
|
39 |
+
import wget
|
40 |
+
import tempfile
|
41 |
+
|
42 |
+
class object:
|
43 |
+
def __init__(self):
|
44 |
+
self.input_fc_dir = ''
|
45 |
+
self.input_json = ''
|
46 |
+
self.batch_size = ''
|
47 |
+
self.id = ''
|
48 |
+
self.sample_max = 1
|
49 |
+
self.cnn_model = 'resnet101'
|
50 |
+
self.model = ''
|
51 |
+
self.language_eval = 0
|
52 |
+
self.beam_size = 1
|
53 |
+
self.temperature = 1.0
|
54 |
+
return
|
55 |
+
|
56 |
+
|
57 |
+
class Captioner():
|
58 |
+
|
59 |
+
def __init__(self, is_relative=True, model_path=None, image_feat_params=None, data_type=None, load_resnet=True, diff_feat=None):
|
60 |
+
opt = object()
|
61 |
+
|
62 |
+
if image_feat_params==None:
|
63 |
+
image_feat_params = {}
|
64 |
+
image_feat_params['model'] = 'resnet101'
|
65 |
+
image_feat_params['model_root'] = ''
|
66 |
+
image_feat_params['att_size'] = 7
|
67 |
+
|
68 |
+
# inputs specific to shoe dataset
|
69 |
+
infos_path = os.path.join(model_path, 'infos_best.pkl')
|
70 |
+
model_path = os.path.join(model_path, 'model_best.pth')
|
71 |
+
|
72 |
+
opt.infos_path = infos_path
|
73 |
+
opt.model_path = model_path
|
74 |
+
opt.beam_size = 1
|
75 |
+
opt.load_resnet = load_resnet
|
76 |
+
|
77 |
+
# load pre-trained model, adjusting if URL
|
78 |
+
if opt.infos_path.startswith("http:") or opt.infos_path.startswith("https:"):
|
79 |
+
# create a folder to store the checkpoints for downloading
|
80 |
+
if not os.path.exists('./checkpoints_usersim'):
|
81 |
+
os.mkdir('./checkpoints_usersim')
|
82 |
+
|
83 |
+
checkpoint_path = os.path.join('./checkpoints_usersim', data_type)
|
84 |
+
if not os.path.exists(checkpoint_path):
|
85 |
+
os.mkdir(checkpoint_path)
|
86 |
+
|
87 |
+
# set the location for infos
|
88 |
+
infos_loc = os.path.join(checkpoint_path, 'infos_best.pkl')
|
89 |
+
|
90 |
+
if not os.path.exists(infos_loc):
|
91 |
+
try:
|
92 |
+
wget.download(opt.infos_path, infos_loc)
|
93 |
+
except Exception as err:
|
94 |
+
print(f"[{err}]")
|
95 |
+
else:
|
96 |
+
infos_loc = infos_path
|
97 |
+
|
98 |
+
if opt.model_path.startswith("http:") or opt.model_path.startswith("https:"):
|
99 |
+
# create a folder to store the checkpoints for downloading
|
100 |
+
if not os.path.exists('./checkpoints_usersim'):
|
101 |
+
os.mkdir('./checkpoints_usersim')
|
102 |
+
|
103 |
+
checkpoint_path = os.path.join('./checkpoints_usersim', data_type)
|
104 |
+
if not os.path.exists(checkpoint_path):
|
105 |
+
os.mkdir(checkpoint_path)
|
106 |
+
|
107 |
+
# set the location for models
|
108 |
+
model_loc = os.path.join(checkpoint_path, 'model_best.pth')
|
109 |
+
|
110 |
+
if not os.path.exists(model_loc):
|
111 |
+
try:
|
112 |
+
wget.download(opt.model_path, model_loc)
|
113 |
+
except Exception as err:
|
114 |
+
print(f"[{err}]")
|
115 |
+
opt.model = model_loc
|
116 |
+
else:
|
117 |
+
opt.model = model_path
|
118 |
+
|
119 |
+
if os.path.exists(infos_loc):
|
120 |
+
# load existing infos
|
121 |
+
with open(infos_loc, 'rb') as f:
|
122 |
+
infos = cPickle.load(f)
|
123 |
+
|
124 |
+
self.caption_model = infos["opt"].caption_model
|
125 |
+
|
126 |
+
# override and collect parameters
|
127 |
+
if len(opt.input_fc_dir) == 0:
|
128 |
+
opt.input_fc_dir = infos['opt'].input_fc_dir
|
129 |
+
opt.input_att_dir = infos['opt'].input_att_dir
|
130 |
+
opt.input_label_h5 = infos['opt'].input_label_h5
|
131 |
+
if len(opt.input_json) == 0:
|
132 |
+
opt.input_json = infos['opt'].input_json
|
133 |
+
if opt.batch_size == 0:
|
134 |
+
opt.batch_size = infos['opt'].batch_size
|
135 |
+
if len(opt.id) == 0:
|
136 |
+
opt.id = infos['opt'].id
|
137 |
+
ignore = ["id", "batch_size", "beam_size", "start_from", "language_eval", "model"]
|
138 |
+
for k in vars(infos['opt']).keys():
|
139 |
+
if k not in ignore:
|
140 |
+
if k in vars(opt):
|
141 |
+
assert vars(opt)[k] == vars(infos['opt'])[k], k + ' option not consistent'
|
142 |
+
else:
|
143 |
+
vars(opt).update({k: vars(infos['opt'])[k]}) # copy over options from model
|
144 |
+
|
145 |
+
vocab = infos['vocab'] # ix -> word mapping
|
146 |
+
|
147 |
+
# print('opt:', opt)
|
148 |
+
|
149 |
+
# Setup the model
|
150 |
+
opt.vocab = vocab
|
151 |
+
model = models.setup(opt)
|
152 |
+
del opt.vocab
|
153 |
+
if torch.cuda.is_available():
|
154 |
+
model.load_state_dict(torch.load(opt.model))
|
155 |
+
model.cuda()
|
156 |
+
else:
|
157 |
+
model.load_state_dict(torch.load(opt.model, map_location={'cuda:0': 'cpu'}))
|
158 |
+
|
159 |
+
model.eval()
|
160 |
+
|
161 |
+
self.is_relative = is_relative
|
162 |
+
self.model = model
|
163 |
+
self.vocab = vocab
|
164 |
+
self.opt = vars(opt)
|
165 |
+
|
166 |
+
# Load ResNet for processing images
|
167 |
+
if opt.load_resnet:
|
168 |
+
if image_feat_params['model_root']=='':
|
169 |
+
net = getattr(resnet, image_feat_params['model'])(pretrained=True)
|
170 |
+
else:
|
171 |
+
net = getattr(resnet, image_feat_params['model'])()
|
172 |
+
net.load_state_dict(
|
173 |
+
torch.load(os.path.join(image_feat_params['model_root'], image_feat_params['model'] + '.pth')))
|
174 |
+
my_resnet = myResnet(net)
|
175 |
+
if torch.cuda.is_available():
|
176 |
+
my_resnet.cuda()
|
177 |
+
my_resnet.eval()
|
178 |
+
|
179 |
+
my_resnet_batch = ResNetBatch(net)
|
180 |
+
if torch.cuda.is_available():
|
181 |
+
my_resnet_batch.cuda()
|
182 |
+
|
183 |
+
self.my_resnet_batch = my_resnet_batch
|
184 |
+
self.my_resnet = my_resnet
|
185 |
+
self.att_size = image_feat_params['att_size']
|
186 |
+
|
187 |
+
# Control the input features of the model
|
188 |
+
if diff_feat == None:
|
189 |
+
if self.caption_model == "show_attend_tell":
|
190 |
+
self.diff_feat = True
|
191 |
+
else:
|
192 |
+
self.diff_feat = False
|
193 |
+
else:
|
194 |
+
self.diff_feat = diff_feat
|
195 |
+
|
196 |
+
def gen_caption_from_feat(self, feat_target, feat_reference=None):
|
197 |
+
if self.is_relative and feat_reference == None:
|
198 |
+
return None, None
|
199 |
+
|
200 |
+
if not self.is_relative and not feat_reference == None:
|
201 |
+
return None, None
|
202 |
+
|
203 |
+
if self.is_relative:
|
204 |
+
if self.diff_feat:
|
205 |
+
fc_feat = torch.cat((feat_target[0], feat_target[0] - feat_reference[0]), dim=-1)
|
206 |
+
att_feat = torch.cat((feat_target[1], feat_target[1] - feat_reference[1]), dim=-1)
|
207 |
+
else:
|
208 |
+
fc_feat = torch.cat((feat_target[0], feat_reference[0]), dim=-1)
|
209 |
+
att_feat = torch.cat((feat_target[1], feat_reference[1]), dim=-1)
|
210 |
+
else:
|
211 |
+
fc_feat = feat_target[0]
|
212 |
+
att_feat = feat_target[1]
|
213 |
+
|
214 |
+
# Reshape to B x K x C (128,14,14,4096) --> (128,196,4096)
|
215 |
+
att_feat = att_feat.view(att_feat.shape[0], att_feat.shape[1] * att_feat.shape[2], att_feat.shape[-1])
|
216 |
+
|
217 |
+
att_masks = np.zeros(att_feat.shape[:2], dtype='float32')
|
218 |
+
for i in range(len(att_feat)):
|
219 |
+
att_masks[i, :att_feat[i].shape[0]] = 1
|
220 |
+
# set att_masks to None if attention features have same length
|
221 |
+
if att_masks.sum() == att_masks.size:
|
222 |
+
att_masks = None
|
223 |
+
|
224 |
+
if self.caption_model == 'show_attend_tell':
|
225 |
+
seq, _ = self.model.sample(fc_feat, att_feat, self.opt)
|
226 |
+
else:
|
227 |
+
seq, _ = self.model(fc_feat, att_feat, att_masks=att_masks, opt=self.opt, mode='sample')
|
228 |
+
sents = utils.decode_sequence(self.vocab, seq)
|
229 |
+
|
230 |
+
return seq, sents
|
231 |
+
|
232 |
+
def get_vocab_size(self):
|
233 |
+
return len(self.vocab)
|
234 |
+
|
235 |
+
def get_img_feat(self, img_name):
|
236 |
+
# load the image
|
237 |
+
I = skimage.io.imread(img_name)
|
238 |
+
|
239 |
+
if len(I.shape) == 2:
|
240 |
+
I = I[:, :, np.newaxis]
|
241 |
+
I = np.concatenate((I, I, I), axis=2)
|
242 |
+
|
243 |
+
I = I.astype('float32') / 255.0
|
244 |
+
I = torch.from_numpy(I.transpose([2, 0, 1]))
|
245 |
+
if torch.cuda.is_available(): I = I.cuda()
|
246 |
+
# I = Variable(preprocess(I), volatile=True)
|
247 |
+
with torch.no_grad():
|
248 |
+
I = preprocess(I)
|
249 |
+
fc, att = self.my_resnet(I, self.att_size)
|
250 |
+
|
251 |
+
return fc, att
|
252 |
+
|
253 |
+
def get_img_feat_batch(self, img_names, batchsize=32):
|
254 |
+
if not isinstance(img_names, list):
|
255 |
+
img_names = [img_names]
|
256 |
+
|
257 |
+
num_images = len(img_names)
|
258 |
+
num_batches = math.ceil(np.float(num_images) / np.float(batchsize))
|
259 |
+
|
260 |
+
feature_fc = []
|
261 |
+
feature_att = []
|
262 |
+
|
263 |
+
for id in range(num_batches):
|
264 |
+
startInd = id * batchsize
|
265 |
+
endInd = min((id + 1) * batchsize, num_images)
|
266 |
+
|
267 |
+
img_names_current_batch = img_names[startInd:endInd]
|
268 |
+
I_current_batch = []
|
269 |
+
|
270 |
+
for img_name in img_names_current_batch:
|
271 |
+
I = skimage.io.imread(img_name)
|
272 |
+
|
273 |
+
if len(I.shape) == 2:
|
274 |
+
I = I[:, :, np.newaxis]
|
275 |
+
I = np.concatenate((I, I, I), axis=2)
|
276 |
+
|
277 |
+
I = I.astype('float32') / 255.0
|
278 |
+
I = torch.from_numpy(I.transpose([2, 0, 1]))
|
279 |
+
I_current_batch.append(preprocess(I))
|
280 |
+
|
281 |
+
I_current_batch = torch.stack(I_current_batch, dim=0)
|
282 |
+
if torch.cuda.is_available(): I_current_batch = I_current_batch.cuda()
|
283 |
+
# I_current_batch = Variable(I_current_batch, volatile=True)
|
284 |
+
with torch.no_grad():
|
285 |
+
fc, att = self.my_resnet_batch(I_current_batch, self.att_size)
|
286 |
+
|
287 |
+
feature_fc.append(fc)
|
288 |
+
feature_att.append(att)
|
289 |
+
|
290 |
+
feature_fc = torch.cat(feature_fc, dim=0)
|
291 |
+
feature_att = torch.cat(feature_att, dim=0)
|
292 |
+
|
293 |
+
return feature_fc, feature_att
|
294 |
+
|
295 |
+
|
296 |
+
|
captioning/data/.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
captioning/data/__init__.py
ADDED
File without changes
|
captioning/data/dataloader.py
ADDED
@@ -0,0 +1,439 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import absolute_import
|
2 |
+
from __future__ import division
|
3 |
+
from __future__ import print_function
|
4 |
+
|
5 |
+
import json
|
6 |
+
import h5py
|
7 |
+
from lmdbdict import lmdbdict
|
8 |
+
from lmdbdict.methods import DUMPS_FUNC, LOADS_FUNC
|
9 |
+
import os
|
10 |
+
import numpy as np
|
11 |
+
import numpy.random as npr
|
12 |
+
import random
|
13 |
+
from functools import partial
|
14 |
+
|
15 |
+
import torch
|
16 |
+
import torch.utils.data as data
|
17 |
+
|
18 |
+
import multiprocessing
|
19 |
+
import six
|
20 |
+
|
21 |
+
|
22 |
+
class HybridLoader:
|
23 |
+
"""
|
24 |
+
If db_path is a director, then use normal file loading
|
25 |
+
If lmdb, then load from lmdb
|
26 |
+
The loading method depend on extention.
|
27 |
+
|
28 |
+
in_memory: if in_memory is True, we save all the features in memory
|
29 |
+
For individual np(y|z)s, we don't need to do that because the system will do this for us.
|
30 |
+
Should be useful for lmdb or h5.
|
31 |
+
(Copied this idea from vilbert)
|
32 |
+
"""
|
33 |
+
def __init__(self, db_path, ext, in_memory=False):
|
34 |
+
self.db_path = db_path
|
35 |
+
self.ext = ext
|
36 |
+
if self.ext == '.npy':
|
37 |
+
self.loader = lambda x: np.load(six.BytesIO(x))
|
38 |
+
else:
|
39 |
+
def load_npz(x):
|
40 |
+
x = np.load(six.BytesIO(x))
|
41 |
+
return x['feat'] if 'feat' in x else x['z'] # normally it should be 'feat', but under cocotest_bu, the key is saved to be 'z' mistakenly.
|
42 |
+
self.loader = load_npz
|
43 |
+
if db_path.endswith('.lmdb'):
|
44 |
+
self.db_type = 'lmdb'
|
45 |
+
self.lmdb = lmdbdict(db_path, unsafe=True)
|
46 |
+
self.lmdb._key_dumps = DUMPS_FUNC['ascii']
|
47 |
+
self.lmdb._value_loads = LOADS_FUNC['identity']
|
48 |
+
elif db_path.endswith('.pth'): # Assume a key,value dictionary
|
49 |
+
self.db_type = 'pth'
|
50 |
+
self.feat_file = torch.load(db_path)
|
51 |
+
self.loader = lambda x: x
|
52 |
+
print('HybridLoader: ext is ignored')
|
53 |
+
elif db_path.endswith('h5'):
|
54 |
+
self.db_type = 'h5'
|
55 |
+
self.loader = lambda x: np.array(x).astype('float32')
|
56 |
+
else:
|
57 |
+
self.db_type = 'dir'
|
58 |
+
|
59 |
+
self.in_memory = in_memory
|
60 |
+
if self.in_memory:
|
61 |
+
self.features = {}
|
62 |
+
|
63 |
+
def get(self, key):
|
64 |
+
|
65 |
+
if self.in_memory and key in self.features:
|
66 |
+
# We save f_input because we want to save the
|
67 |
+
# compressed bytes to save memory
|
68 |
+
f_input = self.features[key]
|
69 |
+
elif self.db_type == 'lmdb':
|
70 |
+
f_input = self.lmdb[key]
|
71 |
+
elif self.db_type == 'pth':
|
72 |
+
f_input = self.feat_file[key]
|
73 |
+
elif self.db_type == 'h5':
|
74 |
+
f_input = h5py.File(self.db_path, 'r')[key]
|
75 |
+
else:
|
76 |
+
f_input = open(os.path.join(self.db_path, key + self.ext), 'rb').read()
|
77 |
+
|
78 |
+
if self.in_memory and key not in self.features:
|
79 |
+
self.features[key] = f_input
|
80 |
+
|
81 |
+
# load image
|
82 |
+
feat = self.loader(f_input)
|
83 |
+
|
84 |
+
return feat
|
85 |
+
|
86 |
+
class Dataset(data.Dataset):
|
87 |
+
|
88 |
+
def get_vocab_size(self):
|
89 |
+
return self.vocab_size
|
90 |
+
|
91 |
+
def get_vocab(self):
|
92 |
+
return self.ix_to_word
|
93 |
+
|
94 |
+
def get_seq_length(self):
|
95 |
+
return self.seq_length
|
96 |
+
|
97 |
+
def __init__(self, opt):
|
98 |
+
self.opt = opt
|
99 |
+
self.seq_per_img = opt.seq_per_img
|
100 |
+
|
101 |
+
# feature related options
|
102 |
+
self.use_fc = getattr(opt, 'use_fc', True)
|
103 |
+
self.use_att = getattr(opt, 'use_att', True)
|
104 |
+
self.use_box = getattr(opt, 'use_box', 0)
|
105 |
+
self.norm_att_feat = getattr(opt, 'norm_att_feat', 0)
|
106 |
+
self.norm_box_feat = getattr(opt, 'norm_box_feat', 0)
|
107 |
+
|
108 |
+
# load the json file which contains additional information about the dataset
|
109 |
+
print('DataLoader loading json file: ', opt.input_json)
|
110 |
+
self.info = json.load(open(self.opt.input_json))
|
111 |
+
if 'ix_to_word' in self.info:
|
112 |
+
self.ix_to_word = self.info['ix_to_word']
|
113 |
+
self.vocab_size = len(self.ix_to_word)
|
114 |
+
print('vocab size is ', self.vocab_size)
|
115 |
+
|
116 |
+
# open the hdf5 file
|
117 |
+
print('DataLoader loading h5 file: ', opt.input_fc_dir, opt.input_att_dir, opt.input_box_dir, opt.input_label_h5)
|
118 |
+
"""
|
119 |
+
Setting input_label_h5 to none is used when only doing generation.
|
120 |
+
For example, when you need to test on coco test set.
|
121 |
+
"""
|
122 |
+
if self.opt.input_label_h5 != 'none':
|
123 |
+
self.h5_label_file = h5py.File(self.opt.input_label_h5, 'r', driver='core')
|
124 |
+
# load in the sequence data
|
125 |
+
seq_size = self.h5_label_file['labels'].shape
|
126 |
+
self.label = self.h5_label_file['labels'][:]
|
127 |
+
self.seq_length = seq_size[1]
|
128 |
+
print('max sequence length in data is', self.seq_length)
|
129 |
+
# load the pointers in full to RAM (should be small enough)
|
130 |
+
self.label_start_ix = self.h5_label_file['label_start_ix'][:]
|
131 |
+
self.label_end_ix = self.h5_label_file['label_end_ix'][:]
|
132 |
+
else:
|
133 |
+
self.seq_length = 1
|
134 |
+
|
135 |
+
self.data_in_memory = getattr(opt, 'data_in_memory', False)
|
136 |
+
self.fc_loader = HybridLoader(self.opt.input_fc_dir, '.npy', in_memory=self.data_in_memory)
|
137 |
+
self.att_loader = HybridLoader(self.opt.input_att_dir, '.npz', in_memory=self.data_in_memory)
|
138 |
+
self.box_loader = HybridLoader(self.opt.input_box_dir, '.npy', in_memory=self.data_in_memory)
|
139 |
+
|
140 |
+
self.num_images = len(self.info['images']) # self.label_start_ix.shape[0]
|
141 |
+
print('read %d image features' %(self.num_images))
|
142 |
+
|
143 |
+
# separate out indexes for each of the provided splits
|
144 |
+
self.split_ix = {'train': [], 'val': [], 'test': []}
|
145 |
+
for ix in range(len(self.info['images'])):
|
146 |
+
img = self.info['images'][ix]
|
147 |
+
if not 'split' in img:
|
148 |
+
self.split_ix['train'].append(ix)
|
149 |
+
self.split_ix['val'].append(ix)
|
150 |
+
self.split_ix['test'].append(ix)
|
151 |
+
elif img['split'] == 'train':
|
152 |
+
self.split_ix['train'].append(ix)
|
153 |
+
elif img['split'] == 'val':
|
154 |
+
self.split_ix['val'].append(ix)
|
155 |
+
elif img['split'] == 'test':
|
156 |
+
self.split_ix['test'].append(ix)
|
157 |
+
elif opt.train_only == 0: # restval
|
158 |
+
self.split_ix['train'].append(ix)
|
159 |
+
|
160 |
+
print('assigned %d images to split train' %len(self.split_ix['train']))
|
161 |
+
print('assigned %d images to split val' %len(self.split_ix['val']))
|
162 |
+
print('assigned %d images to split test' %len(self.split_ix['test']))
|
163 |
+
|
164 |
+
def get_captions(self, ix, seq_per_img):
|
165 |
+
# fetch the sequence labels
|
166 |
+
ix1 = self.label_start_ix[ix] - 1 #label_start_ix starts from 1
|
167 |
+
ix2 = self.label_end_ix[ix] - 1
|
168 |
+
ncap = ix2 - ix1 + 1 # number of captions available for this image
|
169 |
+
assert ncap > 0, 'an image does not have any label. this can be handled but right now isn\'t'
|
170 |
+
|
171 |
+
random.seed(42)
|
172 |
+
torch.manual_seed(42)
|
173 |
+
if torch.cuda.is_available():
|
174 |
+
torch.cuda.manual_seed(42)
|
175 |
+
|
176 |
+
if ncap < seq_per_img:
|
177 |
+
# we need to subsample (with replacement)
|
178 |
+
seq = np.zeros([seq_per_img, self.seq_length], dtype = 'int')
|
179 |
+
for q in range(seq_per_img):
|
180 |
+
ixl = random.randint(ix1,ix2)
|
181 |
+
seq[q, :] = self.label[ixl, :self.seq_length]
|
182 |
+
else:
|
183 |
+
ixl = random.randint(ix1, ix2 - seq_per_img + 1)
|
184 |
+
seq = self.label[ixl: ixl + seq_per_img, :self.seq_length]
|
185 |
+
|
186 |
+
return seq
|
187 |
+
|
188 |
+
def collate_func(self, batch, split):
|
189 |
+
seq_per_img = self.seq_per_img
|
190 |
+
|
191 |
+
fc_batch = []
|
192 |
+
att_batch = []
|
193 |
+
label_batch = []
|
194 |
+
|
195 |
+
wrapped = False
|
196 |
+
|
197 |
+
infos = []
|
198 |
+
gts = []
|
199 |
+
|
200 |
+
for sample in batch:
|
201 |
+
# fetch image
|
202 |
+
tmp_fc, tmp_att, tmp_seq, \
|
203 |
+
ix, it_pos_now, tmp_wrapped = sample
|
204 |
+
if tmp_wrapped:
|
205 |
+
wrapped = True
|
206 |
+
|
207 |
+
fc_batch.append(tmp_fc)
|
208 |
+
att_batch.append(tmp_att)
|
209 |
+
|
210 |
+
tmp_label = np.zeros([seq_per_img, self.seq_length + 2], dtype = 'int')
|
211 |
+
if hasattr(self, 'h5_label_file'):
|
212 |
+
# if there is ground truth
|
213 |
+
tmp_label[:, 1 : self.seq_length + 1] = tmp_seq
|
214 |
+
label_batch.append(tmp_label)
|
215 |
+
|
216 |
+
# Used for reward evaluation
|
217 |
+
if hasattr(self, 'h5_label_file'):
|
218 |
+
# if there is ground truth
|
219 |
+
gts.append(self.label[self.label_start_ix[ix] - 1: self.label_end_ix[ix]])
|
220 |
+
else:
|
221 |
+
gts.append([])
|
222 |
+
|
223 |
+
# record associated info as well
|
224 |
+
info_dict = {}
|
225 |
+
info_dict['ix'] = ix
|
226 |
+
info_dict['id'] = self.info['images'][ix]['id']
|
227 |
+
info_dict['file_path'] = self.info['images'][ix].get('file_path', '')
|
228 |
+
infos.append(info_dict)
|
229 |
+
|
230 |
+
# #sort by att_feat length
|
231 |
+
# fc_batch, att_batch, label_batch, gts, infos = \
|
232 |
+
# zip(*sorted(zip(fc_batch, att_batch, np.vsplit(label_batch, batch_size), gts, infos), key=lambda x: len(x[1]), reverse=True))
|
233 |
+
fc_batch, att_batch, label_batch, gts, infos = \
|
234 |
+
zip(*sorted(zip(fc_batch, att_batch, label_batch, gts, infos), key=lambda x: 0, reverse=True))
|
235 |
+
|
236 |
+
data = {}
|
237 |
+
data['fc_feats'] = np.stack(fc_batch)
|
238 |
+
# merge att_feats
|
239 |
+
max_att_len = max([_.shape[0] for _ in att_batch])
|
240 |
+
data['att_feats'] = np.zeros([len(att_batch), max_att_len, att_batch[0].shape[1]], dtype = 'float32')
|
241 |
+
|
242 |
+
for i in range(len(att_batch)):
|
243 |
+
data['att_feats'][i, :att_batch[i].shape[0]] = att_batch[i]
|
244 |
+
|
245 |
+
data['att_masks'] = np.zeros(data['att_feats'].shape[:2], dtype='float32')
|
246 |
+
for i in range(len(att_batch)):
|
247 |
+
data['att_masks'][i, :att_batch[i].shape[0]] = 1
|
248 |
+
# set att_masks to None if attention features have same length
|
249 |
+
if data['att_masks'].sum() == data['att_masks'].size:
|
250 |
+
data['att_masks'] = None
|
251 |
+
|
252 |
+
data['labels'] = np.vstack(label_batch)
|
253 |
+
# generate mask
|
254 |
+
nonzeros = np.array(list(map(lambda x: (x != 0).sum()+2, data['labels'])))
|
255 |
+
mask_batch = np.zeros([data['labels'].shape[0], self.seq_length + 2], dtype = 'float32')
|
256 |
+
for ix, row in enumerate(mask_batch):
|
257 |
+
row[:nonzeros[ix]] = 1
|
258 |
+
data['masks'] = mask_batch
|
259 |
+
data['labels'] = data['labels'].reshape(len(batch), seq_per_img, -1)
|
260 |
+
data['masks'] = data['masks'].reshape(len(batch), seq_per_img, -1)
|
261 |
+
|
262 |
+
data['gts'] = gts # all ground truth captions of each images
|
263 |
+
data['bounds'] = {'it_pos_now': it_pos_now, # the it_pos_now of the last sample
|
264 |
+
'it_max': len(self.split_ix[split]), 'wrapped': wrapped}
|
265 |
+
data['infos'] = infos
|
266 |
+
|
267 |
+
data = {k:torch.from_numpy(v) if type(v) is np.ndarray else v for k,v in data.items()} # Turn all ndarray to torch tensor
|
268 |
+
|
269 |
+
return data
|
270 |
+
|
271 |
+
def __getitem__(self, index):
|
272 |
+
"""This function returns a tuple that is further passed to collate_fn
|
273 |
+
"""
|
274 |
+
ix, it_pos_now, wrapped = index #self.split_ix[index]
|
275 |
+
if self.use_att:
|
276 |
+
att_feat = self.att_loader.get(str(self.info['images'][ix]['id']))
|
277 |
+
# shape: (14,14,4096)
|
278 |
+
|
279 |
+
# Reshape to K x C
|
280 |
+
att_feat = att_feat.reshape(-1, att_feat.shape[-1])
|
281 |
+
# shape:(196,4096)
|
282 |
+
|
283 |
+
if self.norm_att_feat:
|
284 |
+
att_feat = att_feat / np.linalg.norm(att_feat, 2, 1, keepdims=True)
|
285 |
+
if self.use_box:
|
286 |
+
box_feat = self.box_loader.get(str(self.info['images'][ix]['id']))
|
287 |
+
# devided by image width and height
|
288 |
+
x1,y1,x2,y2 = np.hsplit(box_feat, 4)
|
289 |
+
h,w = self.info['images'][ix]['height'], self.info['images'][ix]['width']
|
290 |
+
box_feat = np.hstack((x1/w, y1/h, x2/w, y2/h, (x2-x1)*(y2-y1)/(w*h))) # question? x2-x1+1??
|
291 |
+
if self.norm_box_feat:
|
292 |
+
box_feat = box_feat / np.linalg.norm(box_feat, 2, 1, keepdims=True)
|
293 |
+
att_feat = np.hstack([att_feat, box_feat])
|
294 |
+
# sort the features by the size of boxes
|
295 |
+
att_feat = np.stack(sorted(att_feat, key=lambda x:x[-1], reverse=True))
|
296 |
+
else:
|
297 |
+
att_feat = np.zeros((0,0), dtype='float32')
|
298 |
+
if self.use_fc:
|
299 |
+
try:
|
300 |
+
fc_feat = self.fc_loader.get(str(self.info['images'][ix]['id']))
|
301 |
+
except:
|
302 |
+
# Use average of attention when there is no fc provided (For bottomup feature)
|
303 |
+
fc_feat = att_feat.mean(0)
|
304 |
+
else:
|
305 |
+
fc_feat = np.zeros((0), dtype='float32')
|
306 |
+
if hasattr(self, 'h5_label_file'):
|
307 |
+
seq = self.get_captions(ix, self.seq_per_img)
|
308 |
+
else:
|
309 |
+
seq = None
|
310 |
+
return (fc_feat,
|
311 |
+
att_feat, seq,
|
312 |
+
ix, it_pos_now, wrapped)
|
313 |
+
|
314 |
+
def __len__(self):
|
315 |
+
return len(self.info['images'])
|
316 |
+
|
317 |
+
class DataLoader:
|
318 |
+
def __init__(self, opt):
|
319 |
+
self.opt = opt
|
320 |
+
self.batch_size = self.opt.batch_size
|
321 |
+
self.dataset = Dataset(opt)
|
322 |
+
|
323 |
+
# Initialize loaders and iters
|
324 |
+
self.loaders, self.iters = {}, {}
|
325 |
+
for split in ['train', 'val', 'test']:
|
326 |
+
if split == 'train':
|
327 |
+
sampler = MySampler(self.dataset.split_ix[split], shuffle=True, wrap=True)
|
328 |
+
else:
|
329 |
+
sampler = MySampler(self.dataset.split_ix[split], shuffle=False, wrap=False)
|
330 |
+
self.loaders[split] = data.DataLoader(dataset=self.dataset,
|
331 |
+
batch_size=self.batch_size,
|
332 |
+
sampler=sampler,
|
333 |
+
pin_memory=True,
|
334 |
+
num_workers=4, # 4 is usually enough
|
335 |
+
collate_fn=partial(self.dataset.collate_func, split=split),
|
336 |
+
drop_last=False)
|
337 |
+
self.iters[split] = iter(self.loaders[split])
|
338 |
+
|
339 |
+
def get_batch(self, split):
|
340 |
+
try:
|
341 |
+
data = next(self.iters[split])
|
342 |
+
except StopIteration:
|
343 |
+
self.iters[split] = iter(self.loaders[split])
|
344 |
+
data = next(self.iters[split])
|
345 |
+
return data
|
346 |
+
|
347 |
+
def reset_iterator(self, split):
|
348 |
+
self.loaders[split].sampler._reset_iter()
|
349 |
+
self.iters[split] = iter(self.loaders[split])
|
350 |
+
|
351 |
+
def get_vocab_size(self):
|
352 |
+
return self.dataset.get_vocab_size()
|
353 |
+
|
354 |
+
@property
|
355 |
+
def vocab_size(self):
|
356 |
+
return self.get_vocab_size()
|
357 |
+
|
358 |
+
def get_vocab(self):
|
359 |
+
return self.dataset.get_vocab()
|
360 |
+
|
361 |
+
def get_seq_length(self):
|
362 |
+
return self.dataset.get_seq_length()
|
363 |
+
|
364 |
+
@property
|
365 |
+
def seq_length(self):
|
366 |
+
return self.get_seq_length()
|
367 |
+
|
368 |
+
def state_dict(self):
|
369 |
+
def get_prefetch_num(split):
|
370 |
+
if self.loaders[split].num_workers > 0:
|
371 |
+
return (self.iters[split]._send_idx - self.iters[split]._rcvd_idx) * self.batch_size
|
372 |
+
else:
|
373 |
+
return 0
|
374 |
+
return {split: loader.sampler.state_dict(get_prefetch_num(split)) \
|
375 |
+
for split, loader in self.loaders.items()}
|
376 |
+
|
377 |
+
def load_state_dict(self, state_dict=None):
|
378 |
+
if state_dict is None:
|
379 |
+
return
|
380 |
+
for split in self.loaders.keys():
|
381 |
+
self.loaders[split].sampler.load_state_dict(state_dict[split])
|
382 |
+
|
383 |
+
|
384 |
+
class MySampler(data.sampler.Sampler):
|
385 |
+
def __init__(self, index_list, shuffle, wrap):
|
386 |
+
self.index_list = index_list
|
387 |
+
self.shuffle = shuffle
|
388 |
+
self.wrap = wrap
|
389 |
+
# if wrap, there will be not stop iteration called
|
390 |
+
# wrap True used during training, and wrap False used during test.
|
391 |
+
self._reset_iter()
|
392 |
+
|
393 |
+
def __iter__(self):
|
394 |
+
return self
|
395 |
+
|
396 |
+
def __next__(self):
|
397 |
+
wrapped = False
|
398 |
+
if self.iter_counter == len(self._index_list):
|
399 |
+
self._reset_iter()
|
400 |
+
if self.wrap:
|
401 |
+
wrapped = True
|
402 |
+
else:
|
403 |
+
raise StopIteration()
|
404 |
+
if len(self._index_list) == 0: # overflow when 0 samples
|
405 |
+
return None
|
406 |
+
elem = (self._index_list[self.iter_counter], self.iter_counter+1, wrapped)
|
407 |
+
self.iter_counter += 1
|
408 |
+
return elem
|
409 |
+
|
410 |
+
def next(self):
|
411 |
+
return self.__next__()
|
412 |
+
|
413 |
+
def _reset_iter(self):
|
414 |
+
np.random.seed(42)
|
415 |
+
if self.shuffle:
|
416 |
+
rand_perm = npr.permutation(len(self.index_list))
|
417 |
+
self._index_list = [self.index_list[_] for _ in rand_perm]
|
418 |
+
else:
|
419 |
+
self._index_list = self.index_list
|
420 |
+
|
421 |
+
self.iter_counter = 0
|
422 |
+
|
423 |
+
def __len__(self):
|
424 |
+
return len(self.index_list)
|
425 |
+
|
426 |
+
def load_state_dict(self, state_dict=None):
|
427 |
+
if state_dict is None:
|
428 |
+
return
|
429 |
+
self._index_list = state_dict['index_list']
|
430 |
+
self.iter_counter = state_dict['iter_counter']
|
431 |
+
|
432 |
+
def state_dict(self, prefetched_num=None):
|
433 |
+
prefetched_num = prefetched_num or 0
|
434 |
+
return {
|
435 |
+
'index_list': self._index_list,
|
436 |
+
'iter_counter': self.iter_counter - prefetched_num
|
437 |
+
}
|
438 |
+
|
439 |
+
|
captioning/data/dataloader_recsys.py
ADDED
@@ -0,0 +1,432 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import absolute_import
|
2 |
+
from __future__ import division
|
3 |
+
from __future__ import print_function
|
4 |
+
|
5 |
+
import json
|
6 |
+
import h5py
|
7 |
+
from lmdbdict import lmdbdict
|
8 |
+
from lmdbdict.methods import DUMPS_FUNC, LOADS_FUNC
|
9 |
+
import os
|
10 |
+
import numpy as np
|
11 |
+
import numpy.random as npr
|
12 |
+
import random
|
13 |
+
from functools import partial
|
14 |
+
|
15 |
+
import torch
|
16 |
+
import torch.utils.data as data
|
17 |
+
|
18 |
+
import multiprocessing
|
19 |
+
import six
|
20 |
+
|
21 |
+
|
22 |
+
class HybridLoader:
|
23 |
+
"""
|
24 |
+
If db_path is a director, then use normal file loading
|
25 |
+
If lmdb, then load from lmdb
|
26 |
+
The loading method depend on extention.
|
27 |
+
|
28 |
+
in_memory: if in_memory is True, we save all the features in memory
|
29 |
+
For individual np(y|z)s, we don't need to do that because the system will do this for us.
|
30 |
+
Should be useful for lmdb or h5.
|
31 |
+
(Copied this idea from vilbert)
|
32 |
+
"""
|
33 |
+
def __init__(self, db_path, ext, in_memory=False):
|
34 |
+
self.db_path = db_path
|
35 |
+
self.ext = ext
|
36 |
+
if self.ext == '.npy':
|
37 |
+
self.loader = lambda x: np.load(six.BytesIO(x))
|
38 |
+
else:
|
39 |
+
def load_npz(x):
|
40 |
+
x = np.load(six.BytesIO(x))
|
41 |
+
return x['feat'] if 'feat' in x else x['z'] # normally it should be 'feat', but under cocotest_bu, the key is saved to be 'z' mistakenly.
|
42 |
+
self.loader = load_npz
|
43 |
+
if db_path.endswith('.lmdb'):
|
44 |
+
self.db_type = 'lmdb'
|
45 |
+
self.lmdb = lmdbdict(db_path, unsafe=True)
|
46 |
+
self.lmdb._key_dumps = DUMPS_FUNC['ascii']
|
47 |
+
self.lmdb._value_loads = LOADS_FUNC['identity']
|
48 |
+
elif db_path.endswith('.pth'): # Assume a key,value dictionary
|
49 |
+
self.db_type = 'pth'
|
50 |
+
self.feat_file = torch.load(db_path)
|
51 |
+
self.loader = lambda x: x
|
52 |
+
print('HybridLoader: ext is ignored')
|
53 |
+
elif db_path.endswith('h5'):
|
54 |
+
self.db_type = 'h5'
|
55 |
+
self.loader = lambda x: np.array(x).astype('float32')
|
56 |
+
else:
|
57 |
+
self.db_type = 'dir'
|
58 |
+
|
59 |
+
self.in_memory = in_memory
|
60 |
+
if self.in_memory:
|
61 |
+
self.features = {}
|
62 |
+
|
63 |
+
def get(self, key):
|
64 |
+
|
65 |
+
if self.in_memory and key in self.features:
|
66 |
+
# We save f_input because we want to save the
|
67 |
+
# compressed bytes to save memory
|
68 |
+
f_input = self.features[key]
|
69 |
+
elif self.db_type == 'lmdb':
|
70 |
+
f_input = self.lmdb[key]
|
71 |
+
elif self.db_type == 'pth':
|
72 |
+
f_input = self.feat_file[key]
|
73 |
+
elif self.db_type == 'h5':
|
74 |
+
f_input = h5py.File(self.db_path, 'r')[key]
|
75 |
+
else:
|
76 |
+
f_input = open(os.path.join(self.db_path, key + self.ext), 'rb').read()
|
77 |
+
|
78 |
+
if self.in_memory and key not in self.features:
|
79 |
+
self.features[key] = f_input
|
80 |
+
|
81 |
+
# load image
|
82 |
+
feat = self.loader(f_input)
|
83 |
+
|
84 |
+
return feat
|
85 |
+
|
86 |
+
class Dataset(data.Dataset):
|
87 |
+
|
88 |
+
def get_vocab_size(self):
|
89 |
+
return self.vocab_size
|
90 |
+
|
91 |
+
def get_vocab(self):
|
92 |
+
return self.ix_to_word
|
93 |
+
|
94 |
+
def get_seq_length(self):
|
95 |
+
return self.seq_length
|
96 |
+
|
97 |
+
def __init__(self, opt):
|
98 |
+
self.opt = opt
|
99 |
+
self.seq_per_img = opt.seq_per_img
|
100 |
+
|
101 |
+
# feature related options
|
102 |
+
self.use_fc = getattr(opt, 'use_fc', True)
|
103 |
+
self.use_att = getattr(opt, 'use_att', True)
|
104 |
+
self.use_box = getattr(opt, 'use_box', 0)
|
105 |
+
self.norm_att_feat = getattr(opt, 'norm_att_feat', 0)
|
106 |
+
self.norm_box_feat = getattr(opt, 'norm_box_feat', 0)
|
107 |
+
|
108 |
+
# load the json file which contains additional information about the dataset
|
109 |
+
print('DataLoader loading json file: ', opt.input_json)
|
110 |
+
self.info = json.load(open(self.opt.input_json))
|
111 |
+
if 'ix_to_word' in self.info:
|
112 |
+
self.ix_to_word = self.info['ix_to_word']
|
113 |
+
self.vocab_size = len(self.ix_to_word)
|
114 |
+
print('vocab size is ', self.vocab_size)
|
115 |
+
|
116 |
+
# open the hdf5 file
|
117 |
+
print('DataLoader loading h5 file: ', opt.input_fc_dir, opt.input_att_dir, opt.input_box_dir, opt.input_label_h5)
|
118 |
+
"""
|
119 |
+
Setting input_label_h5 to none is used when only doing generation.
|
120 |
+
For example, when you need to test on coco test set.
|
121 |
+
"""
|
122 |
+
if self.opt.input_label_h5 != 'none':
|
123 |
+
self.h5_label_file = h5py.File(self.opt.input_label_h5, 'r', driver='core')
|
124 |
+
# load in the sequence data
|
125 |
+
seq_size = self.h5_label_file['labels'].shape
|
126 |
+
self.label = self.h5_label_file['labels'][:]
|
127 |
+
self.seq_length = seq_size[1]
|
128 |
+
print('max sequence length in data is', self.seq_length)
|
129 |
+
# load the pointers in full to RAM (should be small enough)
|
130 |
+
self.label_start_ix = self.h5_label_file['label_start_ix'][:]
|
131 |
+
self.label_end_ix = self.h5_label_file['label_end_ix'][:]
|
132 |
+
else:
|
133 |
+
self.seq_length = 1
|
134 |
+
|
135 |
+
self.data_in_memory = getattr(opt, 'data_in_memory', False)
|
136 |
+
self.fc_loader = HybridLoader(self.opt.input_fc_dir, '.npy', in_memory=self.data_in_memory)
|
137 |
+
self.att_loader = HybridLoader(self.opt.input_att_dir, '.npz', in_memory=self.data_in_memory)
|
138 |
+
self.box_loader = HybridLoader(self.opt.input_box_dir, '.npy', in_memory=self.data_in_memory)
|
139 |
+
|
140 |
+
self.num_images = len(self.info['images']) # self.label_start_ix.shape[0]
|
141 |
+
print('read %d image features' %(self.num_images))
|
142 |
+
|
143 |
+
# separate out indexes for each of the provided splits
|
144 |
+
self.split_ix = {'train': [], 'val': [], 'test': []}
|
145 |
+
for ix in range(len(self.info['images'])):
|
146 |
+
img = self.info['images'][ix]
|
147 |
+
if not 'split' in img:
|
148 |
+
self.split_ix['train'].append(ix)
|
149 |
+
self.split_ix['val'].append(ix)
|
150 |
+
self.split_ix['test'].append(ix)
|
151 |
+
elif img['split'] == 'train':
|
152 |
+
self.split_ix['train'].append(ix)
|
153 |
+
elif img['split'] == 'val':
|
154 |
+
self.split_ix['val'].append(ix)
|
155 |
+
elif img['split'] == 'test':
|
156 |
+
self.split_ix['test'].append(ix)
|
157 |
+
elif opt.train_only == 0: # restval
|
158 |
+
self.split_ix['train'].append(ix)
|
159 |
+
|
160 |
+
print('assigned %d images to split train' %len(self.split_ix['train']))
|
161 |
+
print('assigned %d images to split val' %len(self.split_ix['val']))
|
162 |
+
print('assigned %d images to split test' %len(self.split_ix['test']))
|
163 |
+
|
164 |
+
def get_captions(self, ix, seq_per_img):
|
165 |
+
# fetch the sequence labels
|
166 |
+
ix1 = self.label_start_ix[ix] - 1 #label_start_ix starts from 1
|
167 |
+
ix2 = self.label_end_ix[ix] - 1
|
168 |
+
ncap = ix2 - ix1 + 1 # number of captions available for this image
|
169 |
+
assert ncap > 0, 'an image does not have any label. this can be handled but right now isn\'t'
|
170 |
+
|
171 |
+
random.seed(42)
|
172 |
+
torch.manual_seed(42)
|
173 |
+
if torch.cuda.is_available():
|
174 |
+
torch.cuda.manual_seed(42)
|
175 |
+
|
176 |
+
if ncap < seq_per_img:
|
177 |
+
# we need to subsample (with replacement)
|
178 |
+
seq = np.zeros([seq_per_img, self.seq_length], dtype = 'int')
|
179 |
+
for q in range(seq_per_img):
|
180 |
+
ixl = random.randint(ix1,ix2)
|
181 |
+
seq[q, :] = self.label[ixl, :self.seq_length]
|
182 |
+
else:
|
183 |
+
ixl = random.randint(ix1, ix2 - seq_per_img + 1)
|
184 |
+
seq = self.label[ixl: ixl + seq_per_img, :self.seq_length]
|
185 |
+
|
186 |
+
return seq
|
187 |
+
|
188 |
+
def collate_func(self, batch, split):
|
189 |
+
seq_per_img = self.seq_per_img
|
190 |
+
|
191 |
+
fc_batch = []
|
192 |
+
att_batch = []
|
193 |
+
label_batch = []
|
194 |
+
|
195 |
+
wrapped = False
|
196 |
+
|
197 |
+
infos = []
|
198 |
+
gts = []
|
199 |
+
|
200 |
+
for sample in batch:
|
201 |
+
# fetch image
|
202 |
+
tmp_fc, tmp_att, tmp_seq, \
|
203 |
+
ix, it_pos_now, tmp_wrapped = sample
|
204 |
+
if tmp_wrapped:
|
205 |
+
wrapped = True
|
206 |
+
|
207 |
+
fc_batch.append(tmp_fc)
|
208 |
+
att_batch.append(tmp_att)
|
209 |
+
|
210 |
+
tmp_label = np.zeros([seq_per_img, self.seq_length + 2], dtype = 'int')
|
211 |
+
if hasattr(self, 'h5_label_file'):
|
212 |
+
# if there is ground truth
|
213 |
+
tmp_label[:, 1 : self.seq_length + 1] = tmp_seq
|
214 |
+
label_batch.append(tmp_label)
|
215 |
+
|
216 |
+
# Used for reward evaluation
|
217 |
+
if hasattr(self, 'h5_label_file'):
|
218 |
+
# if there is ground truth
|
219 |
+
gts.append(self.label[self.label_start_ix[ix] - 1: self.label_end_ix[ix]])
|
220 |
+
else:
|
221 |
+
gts.append([])
|
222 |
+
|
223 |
+
# record associated info as well
|
224 |
+
info_dict = {}
|
225 |
+
info_dict['ix'] = ix
|
226 |
+
info_dict['id'] = self.info['images'][ix]['id']
|
227 |
+
info_dict['file_path'] = self.info['images'][ix].get('file_path', '')
|
228 |
+
infos.append(info_dict)
|
229 |
+
|
230 |
+
# #sort by att_feat length
|
231 |
+
# fc_batch, att_batch, label_batch, gts, infos = \
|
232 |
+
# zip(*sorted(zip(fc_batch, att_batch, np.vsplit(label_batch, batch_size), gts, infos), key=lambda x: len(x[1]), reverse=True))
|
233 |
+
fc_batch, att_batch, label_batch, gts, infos = \
|
234 |
+
zip(*sorted(zip(fc_batch, att_batch, label_batch, gts, infos), key=lambda x: 0, reverse=True))
|
235 |
+
data = {}
|
236 |
+
data['fc_feats'] = np.stack(fc_batch)
|
237 |
+
# merge att_feats
|
238 |
+
max_att_len = max([_.shape[0] for _ in att_batch])
|
239 |
+
data['att_feats'] = np.zeros([len(att_batch), max_att_len, att_batch[0].shape[1]], dtype = 'float32')
|
240 |
+
for i in range(len(att_batch)):
|
241 |
+
data['att_feats'][i, :att_batch[i].shape[0]] = att_batch[i]
|
242 |
+
data['att_masks'] = np.zeros(data['att_feats'].shape[:2], dtype='float32')
|
243 |
+
for i in range(len(att_batch)):
|
244 |
+
data['att_masks'][i, :att_batch[i].shape[0]] = 1
|
245 |
+
# set att_masks to None if attention features have same length
|
246 |
+
if data['att_masks'].sum() == data['att_masks'].size:
|
247 |
+
data['att_masks'] = None
|
248 |
+
|
249 |
+
data['labels'] = np.vstack(label_batch)
|
250 |
+
# generate mask
|
251 |
+
nonzeros = np.array(list(map(lambda x: (x != 0).sum()+2, data['labels'])))
|
252 |
+
mask_batch = np.zeros([data['labels'].shape[0], self.seq_length + 2], dtype = 'float32')
|
253 |
+
for ix, row in enumerate(mask_batch):
|
254 |
+
row[:nonzeros[ix]] = 1
|
255 |
+
data['masks'] = mask_batch
|
256 |
+
data['labels'] = data['labels'].reshape(len(batch), seq_per_img, -1)
|
257 |
+
data['masks'] = data['masks'].reshape(len(batch), seq_per_img, -1)
|
258 |
+
|
259 |
+
data['gts'] = gts # all ground truth captions of each images
|
260 |
+
data['bounds'] = {'it_pos_now': it_pos_now, # the it_pos_now of the last sample
|
261 |
+
'it_max': len(self.split_ix[split]), 'wrapped': wrapped}
|
262 |
+
data['infos'] = infos
|
263 |
+
|
264 |
+
data = {k:torch.from_numpy(v) if type(v) is np.ndarray else v for k,v in data.items()} # Turn all ndarray to torch tensor
|
265 |
+
|
266 |
+
return data
|
267 |
+
|
268 |
+
def __getitem__(self, index):
|
269 |
+
"""This function returns a tuple that is further passed to collate_fn
|
270 |
+
"""
|
271 |
+
ix, it_pos_now, wrapped = index #self.split_ix[index]
|
272 |
+
if self.use_att:
|
273 |
+
att_feat = self.att_loader.get(str(self.info['images'][ix]['id']))
|
274 |
+
# Reshape to K x C
|
275 |
+
att_feat = att_feat.reshape(-1, att_feat.shape[-1])
|
276 |
+
if self.norm_att_feat:
|
277 |
+
att_feat = att_feat / np.linalg.norm(att_feat, 2, 1, keepdims=True)
|
278 |
+
if self.use_box:
|
279 |
+
box_feat = self.box_loader.get(str(self.info['images'][ix]['id']))
|
280 |
+
# devided by image width and height
|
281 |
+
x1,y1,x2,y2 = np.hsplit(box_feat, 4)
|
282 |
+
h,w = self.info['images'][ix]['height'], self.info['images'][ix]['width']
|
283 |
+
box_feat = np.hstack((x1/w, y1/h, x2/w, y2/h, (x2-x1)*(y2-y1)/(w*h))) # question? x2-x1+1??
|
284 |
+
if self.norm_box_feat:
|
285 |
+
box_feat = box_feat / np.linalg.norm(box_feat, 2, 1, keepdims=True)
|
286 |
+
att_feat = np.hstack([att_feat, box_feat])
|
287 |
+
# sort the features by the size of boxes
|
288 |
+
att_feat = np.stack(sorted(att_feat, key=lambda x:x[-1], reverse=True))
|
289 |
+
else:
|
290 |
+
att_feat = np.zeros((0,0), dtype='float32')
|
291 |
+
if self.use_fc:
|
292 |
+
try:
|
293 |
+
fc_feat = self.fc_loader.get(str(self.info['images'][ix]['id']))
|
294 |
+
except:
|
295 |
+
# Use average of attention when there is no fc provided (For bottomup feature)
|
296 |
+
fc_feat = att_feat.mean(0)
|
297 |
+
else:
|
298 |
+
fc_feat = np.zeros((0), dtype='float32')
|
299 |
+
if hasattr(self, 'h5_label_file'):
|
300 |
+
seq = self.get_captions(ix, self.seq_per_img)
|
301 |
+
else:
|
302 |
+
seq = None
|
303 |
+
return (fc_feat,
|
304 |
+
att_feat, seq,
|
305 |
+
ix, it_pos_now, wrapped)
|
306 |
+
|
307 |
+
def __len__(self):
|
308 |
+
return len(self.info['images'])
|
309 |
+
|
310 |
+
class DataLoader:
|
311 |
+
def __init__(self, opt):
|
312 |
+
self.opt = opt
|
313 |
+
self.batch_size = self.opt.batch_size
|
314 |
+
self.dataset = Dataset(opt)
|
315 |
+
|
316 |
+
# Initialize loaders and iters
|
317 |
+
self.loaders, self.iters = {}, {}
|
318 |
+
for split in ['train', 'val', 'test']:
|
319 |
+
if split == 'train':
|
320 |
+
sampler = MySampler(self.dataset.split_ix[split], shuffle=True, wrap=True)
|
321 |
+
else:
|
322 |
+
sampler = MySampler(self.dataset.split_ix[split], shuffle=False, wrap=False)
|
323 |
+
self.loaders[split] = data.DataLoader(dataset=self.dataset,
|
324 |
+
batch_size=self.batch_size,
|
325 |
+
sampler=sampler,
|
326 |
+
pin_memory=True,
|
327 |
+
num_workers=4, # 4 is usually enough
|
328 |
+
collate_fn=partial(self.dataset.collate_func, split=split),
|
329 |
+
drop_last=False)
|
330 |
+
self.iters[split] = iter(self.loaders[split])
|
331 |
+
|
332 |
+
def get_batch(self, split):
|
333 |
+
try:
|
334 |
+
data = next(self.iters[split])
|
335 |
+
except StopIteration:
|
336 |
+
self.iters[split] = iter(self.loaders[split])
|
337 |
+
data = next(self.iters[split])
|
338 |
+
return data
|
339 |
+
|
340 |
+
def reset_iterator(self, split):
|
341 |
+
self.loaders[split].sampler._reset_iter()
|
342 |
+
self.iters[split] = iter(self.loaders[split])
|
343 |
+
|
344 |
+
def get_vocab_size(self):
|
345 |
+
return self.dataset.get_vocab_size()
|
346 |
+
|
347 |
+
@property
|
348 |
+
def vocab_size(self):
|
349 |
+
return self.get_vocab_size()
|
350 |
+
|
351 |
+
def get_vocab(self):
|
352 |
+
return self.dataset.get_vocab()
|
353 |
+
|
354 |
+
def get_seq_length(self):
|
355 |
+
return self.dataset.get_seq_length()
|
356 |
+
|
357 |
+
@property
|
358 |
+
def seq_length(self):
|
359 |
+
return self.get_seq_length()
|
360 |
+
|
361 |
+
def state_dict(self):
|
362 |
+
def get_prefetch_num(split):
|
363 |
+
if self.loaders[split].num_workers > 0:
|
364 |
+
return (self.iters[split]._send_idx - self.iters[split]._rcvd_idx) * self.batch_size
|
365 |
+
else:
|
366 |
+
return 0
|
367 |
+
return {split: loader.sampler.state_dict(get_prefetch_num(split)) \
|
368 |
+
for split, loader in self.loaders.items()}
|
369 |
+
|
370 |
+
def load_state_dict(self, state_dict=None):
|
371 |
+
if state_dict is None:
|
372 |
+
return
|
373 |
+
for split in self.loaders.keys():
|
374 |
+
self.loaders[split].sampler.load_state_dict(state_dict[split])
|
375 |
+
|
376 |
+
|
377 |
+
class MySampler(data.sampler.Sampler):
|
378 |
+
def __init__(self, index_list, shuffle, wrap):
|
379 |
+
self.index_list = index_list
|
380 |
+
self.shuffle = shuffle
|
381 |
+
self.wrap = wrap
|
382 |
+
# if wrap, there will be not stop iteration called
|
383 |
+
# wrap True used during training, and wrap False used during test.
|
384 |
+
self._reset_iter()
|
385 |
+
|
386 |
+
def __iter__(self):
|
387 |
+
return self
|
388 |
+
|
389 |
+
def __next__(self):
|
390 |
+
wrapped = False
|
391 |
+
if self.iter_counter == len(self._index_list):
|
392 |
+
self._reset_iter()
|
393 |
+
if self.wrap:
|
394 |
+
wrapped = True
|
395 |
+
else:
|
396 |
+
raise StopIteration()
|
397 |
+
if len(self._index_list) == 0: # overflow when 0 samples
|
398 |
+
return None
|
399 |
+
elem = (self._index_list[self.iter_counter], self.iter_counter+1, wrapped)
|
400 |
+
self.iter_counter += 1
|
401 |
+
return elem
|
402 |
+
|
403 |
+
def next(self):
|
404 |
+
return self.__next__()
|
405 |
+
|
406 |
+
def _reset_iter(self):
|
407 |
+
np.random.seed(0)
|
408 |
+
if self.shuffle:
|
409 |
+
rand_perm = npr.permutation(len(self.index_list))
|
410 |
+
self._index_list = [self.index_list[_] for _ in rand_perm]
|
411 |
+
else:
|
412 |
+
self._index_list = self.index_list
|
413 |
+
|
414 |
+
self.iter_counter = 0
|
415 |
+
|
416 |
+
def __len__(self):
|
417 |
+
return len(self.index_list)
|
418 |
+
|
419 |
+
def load_state_dict(self, state_dict=None):
|
420 |
+
if state_dict is None:
|
421 |
+
return
|
422 |
+
self._index_list = state_dict['index_list']
|
423 |
+
self.iter_counter = state_dict['iter_counter']
|
424 |
+
|
425 |
+
def state_dict(self, prefetched_num=None):
|
426 |
+
prefetched_num = prefetched_num or 0
|
427 |
+
return {
|
428 |
+
'index_list': self._index_list,
|
429 |
+
'iter_counter': self.iter_counter - prefetched_num
|
430 |
+
}
|
431 |
+
|
432 |
+
|
captioning/data/dataloaderraw.py
ADDED
@@ -0,0 +1,151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import absolute_import
|
2 |
+
from __future__ import division
|
3 |
+
from __future__ import print_function
|
4 |
+
|
5 |
+
import json
|
6 |
+
import h5py
|
7 |
+
import os
|
8 |
+
import numpy as np
|
9 |
+
import random
|
10 |
+
import torch
|
11 |
+
import skimage
|
12 |
+
import skimage.io
|
13 |
+
import scipy.misc
|
14 |
+
|
15 |
+
from torchvision import transforms as trn
|
16 |
+
preprocess = trn.Compose([
|
17 |
+
#trn.ToTensor(),
|
18 |
+
trn.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
19 |
+
])
|
20 |
+
|
21 |
+
# from ..utils.resnet_utils import myResnet
|
22 |
+
# from ..utils import resnet
|
23 |
+
|
24 |
+
from captioning.utils.resnet_utils import myResnet
|
25 |
+
from captioning.utils import resnet
|
26 |
+
|
27 |
+
|
28 |
+
class DataLoaderRaw():
|
29 |
+
|
30 |
+
def __init__(self, opt):
|
31 |
+
self.opt = opt
|
32 |
+
self.coco_json = opt.get('coco_json', '')
|
33 |
+
self.folder_path = opt.get('folder_path', '')
|
34 |
+
|
35 |
+
self.batch_size = opt.get('batch_size', 1)
|
36 |
+
self.seq_per_img = 1
|
37 |
+
|
38 |
+
# Load resnet
|
39 |
+
self.cnn_model = opt.get('cnn_model', 'resnet101')
|
40 |
+
self.my_resnet = getattr(resnet, self.cnn_model)()
|
41 |
+
self.my_resnet.load_state_dict(torch.load('./data/imagenet_weights/'+self.cnn_model+'.pth'))
|
42 |
+
self.my_resnet = myResnet(self.my_resnet)
|
43 |
+
self.my_resnet.cuda()
|
44 |
+
self.my_resnet.eval()
|
45 |
+
|
46 |
+
|
47 |
+
|
48 |
+
# load the json file which contains additional information about the dataset
|
49 |
+
print('DataLoaderRaw loading images from folder: ', self.folder_path)
|
50 |
+
|
51 |
+
self.files = []
|
52 |
+
self.ids = []
|
53 |
+
|
54 |
+
print(len(self.coco_json))
|
55 |
+
if len(self.coco_json) > 0:
|
56 |
+
print('reading from ' + opt.coco_json)
|
57 |
+
# read in filenames from the coco-style json file
|
58 |
+
self.coco_annotation = json.load(open(self.coco_json))
|
59 |
+
for k,v in enumerate(self.coco_annotation['images']):
|
60 |
+
fullpath = os.path.join(self.folder_path, v['file_name'])
|
61 |
+
self.files.append(fullpath)
|
62 |
+
self.ids.append(v['id'])
|
63 |
+
else:
|
64 |
+
# read in all the filenames from the folder
|
65 |
+
print('listing all images in directory ' + self.folder_path)
|
66 |
+
def isImage(f):
|
67 |
+
supportedExt = ['.jpg','.JPG','.jpeg','.JPEG','.png','.PNG','.ppm','.PPM']
|
68 |
+
for ext in supportedExt:
|
69 |
+
start_idx = f.rfind(ext)
|
70 |
+
if start_idx >= 0 and start_idx + len(ext) == len(f):
|
71 |
+
return True
|
72 |
+
return False
|
73 |
+
|
74 |
+
n = 1
|
75 |
+
for root, dirs, files in os.walk(self.folder_path, topdown=False):
|
76 |
+
for file in files:
|
77 |
+
fullpath = os.path.join(self.folder_path, file)
|
78 |
+
if isImage(fullpath):
|
79 |
+
self.files.append(fullpath)
|
80 |
+
self.ids.append(str(n)) # just order them sequentially
|
81 |
+
n = n + 1
|
82 |
+
|
83 |
+
self.N = len(self.files)
|
84 |
+
print('DataLoaderRaw found ', self.N, ' images')
|
85 |
+
|
86 |
+
self.iterator = 0
|
87 |
+
|
88 |
+
# Nasty
|
89 |
+
self.dataset = self # to fix the bug in eval
|
90 |
+
|
91 |
+
def get_batch(self, split, batch_size=None):
|
92 |
+
batch_size = batch_size or self.batch_size
|
93 |
+
|
94 |
+
# pick an index of the datapoint to load next
|
95 |
+
fc_batch = np.ndarray((batch_size, 2048), dtype = 'float32')
|
96 |
+
att_batch = np.ndarray((batch_size, 14, 14, 2048), dtype = 'float32')
|
97 |
+
max_index = self.N
|
98 |
+
wrapped = False
|
99 |
+
infos = []
|
100 |
+
|
101 |
+
for i in range(batch_size):
|
102 |
+
ri = self.iterator
|
103 |
+
ri_next = ri + 1
|
104 |
+
if ri_next >= max_index:
|
105 |
+
ri_next = 0
|
106 |
+
wrapped = True
|
107 |
+
# wrap back around
|
108 |
+
self.iterator = ri_next
|
109 |
+
|
110 |
+
img = skimage.io.imread(self.files[ri])
|
111 |
+
|
112 |
+
if len(img.shape) == 2:
|
113 |
+
img = img[:,:,np.newaxis]
|
114 |
+
img = np.concatenate((img, img, img), axis=2)
|
115 |
+
|
116 |
+
img = img[:,:,:3].astype('float32')/255.0
|
117 |
+
img = torch.from_numpy(img.transpose([2,0,1])).cuda()
|
118 |
+
img = preprocess(img)
|
119 |
+
with torch.no_grad():
|
120 |
+
tmp_fc, tmp_att = self.my_resnet(img)
|
121 |
+
|
122 |
+
fc_batch[i] = tmp_fc.data.cpu().float().numpy()
|
123 |
+
att_batch[i] = tmp_att.data.cpu().float().numpy()
|
124 |
+
|
125 |
+
info_struct = {}
|
126 |
+
info_struct['id'] = self.ids[ri]
|
127 |
+
info_struct['file_path'] = self.files[ri]
|
128 |
+
infos.append(info_struct)
|
129 |
+
|
130 |
+
data = {}
|
131 |
+
data['fc_feats'] = fc_batch
|
132 |
+
data['att_feats'] = att_batch.reshape(batch_size, -1, 2048)
|
133 |
+
data['labels'] = np.zeros([batch_size, 0])
|
134 |
+
data['masks'] = None
|
135 |
+
data['att_masks'] = None
|
136 |
+
data['bounds'] = {'it_pos_now': self.iterator, 'it_max': self.N, 'wrapped': wrapped}
|
137 |
+
data['infos'] = infos
|
138 |
+
|
139 |
+
data = {k:torch.from_numpy(v) if type(v) is np.ndarray else v for k,v in data.items()} # Turn all ndarray to torch tensor
|
140 |
+
|
141 |
+
return data
|
142 |
+
|
143 |
+
def reset_iterator(self, split):
|
144 |
+
self.iterator = 0
|
145 |
+
|
146 |
+
def get_vocab_size(self):
|
147 |
+
return len(self.ix_to_word)
|
148 |
+
|
149 |
+
def get_vocab(self):
|
150 |
+
return self.ix_to_word
|
151 |
+
|
captioning/data/pth_loader.py
ADDED
@@ -0,0 +1,300 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import absolute_import
|
2 |
+
from __future__ import division
|
3 |
+
from __future__ import print_function
|
4 |
+
|
5 |
+
import json
|
6 |
+
import h5py
|
7 |
+
from lmdbdict import lmdbdict
|
8 |
+
from lmdbdict.methods import DUMPS_FUNC, LOADS_FUNC
|
9 |
+
import os
|
10 |
+
import numpy as np
|
11 |
+
import numpy.random as npr
|
12 |
+
import random
|
13 |
+
|
14 |
+
import torch
|
15 |
+
import torch.utils.data as data
|
16 |
+
|
17 |
+
import multiprocessing
|
18 |
+
import six
|
19 |
+
|
20 |
+
# random.seed(42)
|
21 |
+
# torch.manual_seed(42)
|
22 |
+
# if torch.cuda.is_available():
|
23 |
+
# torch.cuda.manual_seed(42)
|
24 |
+
|
25 |
+
class HybridLoader:
|
26 |
+
"""
|
27 |
+
If db_path is a director, then use normal file loading
|
28 |
+
If lmdb, then load from lmdb
|
29 |
+
The loading method depend on extention.
|
30 |
+
|
31 |
+
in_memory: if in_memory is True, we save all the features in memory
|
32 |
+
For individual np(y|z)s, we don't need to do that because the system will do this for us.
|
33 |
+
Should be useful for lmdb or h5.
|
34 |
+
(Copied this idea from vilbert)
|
35 |
+
"""
|
36 |
+
def __init__(self, db_path, ext, in_memory=False):
|
37 |
+
self.db_path = db_path
|
38 |
+
self.ext = ext
|
39 |
+
if self.ext == '.npy':
|
40 |
+
self.loader = lambda x: np.load(six.BytesIO(x))
|
41 |
+
else:
|
42 |
+
self.loader = lambda x: np.load(six.BytesIO(x))['feat']
|
43 |
+
if db_path.endswith('.lmdb'):
|
44 |
+
self.db_type = 'lmdb'
|
45 |
+
self.lmdb = lmdbdict(db_path, unsafe=True)
|
46 |
+
self.lmdb._key_dumps = DUMPS_FUNC['ascii']
|
47 |
+
self.lmdb._value_loads = LOADS_FUNC['identity']
|
48 |
+
elif db_path.endswith('.pth'): # Assume a key,value dictionary
|
49 |
+
self.db_type = 'pth'
|
50 |
+
self.feat_file = torch.load(db_path)
|
51 |
+
self.loader = lambda x: x
|
52 |
+
print('HybridLoader: ext is ignored')
|
53 |
+
elif db_path.endswith('h5'):
|
54 |
+
self.db_type = 'h5'
|
55 |
+
self.loader = lambda x: np.array(x).astype('float32')
|
56 |
+
else:
|
57 |
+
self.db_type = 'dir'
|
58 |
+
|
59 |
+
self.in_memory = in_memory
|
60 |
+
if self.in_memory:
|
61 |
+
self.features = {}
|
62 |
+
|
63 |
+
def get(self, key):
|
64 |
+
|
65 |
+
if self.in_memory and key in self.features:
|
66 |
+
# We save f_input because we want to save the
|
67 |
+
# compressed bytes to save memory
|
68 |
+
f_input = self.features[key]
|
69 |
+
elif self.db_type == 'lmdb':
|
70 |
+
f_input = self.lmdb[key]
|
71 |
+
elif self.db_type == 'pth':
|
72 |
+
f_input = self.feat_file[key]
|
73 |
+
elif self.db_type == 'h5':
|
74 |
+
f_input = h5py.File(self.db_path, 'r')[key]
|
75 |
+
else:
|
76 |
+
f_input = open(os.path.join(self.db_path, key + self.ext), 'rb').read()
|
77 |
+
|
78 |
+
if self.in_memory and key not in self.features:
|
79 |
+
self.features[key] = f_input
|
80 |
+
|
81 |
+
# load image
|
82 |
+
feat = self.loader(f_input)
|
83 |
+
|
84 |
+
return feat
|
85 |
+
|
86 |
+
class CaptionDataset(data.Dataset):
|
87 |
+
|
88 |
+
def get_vocab_size(self):
|
89 |
+
return self.vocab_size
|
90 |
+
|
91 |
+
def get_vocab(self):
|
92 |
+
return self.ix_to_word
|
93 |
+
|
94 |
+
def get_seq_length(self):
|
95 |
+
return self.seq_length
|
96 |
+
|
97 |
+
def __init__(self, opt):
|
98 |
+
self.opt = opt
|
99 |
+
self.seq_per_img = opt.seq_per_img
|
100 |
+
|
101 |
+
# feature related options
|
102 |
+
self.use_fc = getattr(opt, 'use_fc', True)
|
103 |
+
self.use_att = getattr(opt, 'use_att', True)
|
104 |
+
self.use_box = getattr(opt, 'use_box', 0)
|
105 |
+
self.norm_att_feat = getattr(opt, 'norm_att_feat', 0)
|
106 |
+
self.norm_box_feat = getattr(opt, 'norm_box_feat', 0)
|
107 |
+
|
108 |
+
# load the json file which contains additional information about the dataset
|
109 |
+
print('DataLoader loading json file: ', opt.input_json)
|
110 |
+
self.info = json.load(open(self.opt.input_json))
|
111 |
+
if 'ix_to_word' in self.info:
|
112 |
+
self.ix_to_word = self.info['ix_to_word']
|
113 |
+
self.vocab_size = len(self.ix_to_word)
|
114 |
+
print('vocab size is ', self.vocab_size)
|
115 |
+
|
116 |
+
# open the hdf5 file
|
117 |
+
print('DataLoader loading h5 file: ', opt.input_fc_dir, opt.input_att_dir, opt.input_box_dir, opt.input_label_h5)
|
118 |
+
"""
|
119 |
+
Setting input_label_h5 to none is used when only doing generation.
|
120 |
+
For example, when you need to test on coco test set.
|
121 |
+
"""
|
122 |
+
if self.opt.input_label_h5 != 'none':
|
123 |
+
self.h5_label_file = h5py.File(self.opt.input_label_h5, 'r', driver='core')
|
124 |
+
# load in the sequence data
|
125 |
+
seq_size = self.h5_label_file['labels'].shape
|
126 |
+
self.label = self.h5_label_file['labels'][:]
|
127 |
+
self.seq_length = seq_size[1]
|
128 |
+
print('max sequence length in data is', self.seq_length)
|
129 |
+
# load the pointers in full to RAM (should be small enough)
|
130 |
+
self.label_start_ix = self.h5_label_file['label_start_ix'][:]
|
131 |
+
self.label_end_ix = self.h5_label_file['label_end_ix'][:]
|
132 |
+
else:
|
133 |
+
self.seq_length = 1
|
134 |
+
|
135 |
+
self.data_in_memory = getattr(opt, 'data_in_memory', False)
|
136 |
+
self.fc_loader = HybridLoader(self.opt.input_fc_dir, '.npy', in_memory=self.data_in_memory)
|
137 |
+
self.att_loader = HybridLoader(self.opt.input_att_dir, '.npz', in_memory=self.data_in_memory)
|
138 |
+
self.box_loader = HybridLoader(self.opt.input_box_dir, '.npy', in_memory=self.data_in_memory)
|
139 |
+
|
140 |
+
self.num_images = len(self.info['images']) # self.label_start_ix.shape[0]
|
141 |
+
print('read %d image features' %(self.num_images))
|
142 |
+
|
143 |
+
# separate out indexes for each of the provided splits
|
144 |
+
self.split_ix = {'train': [], 'val': [], 'test': []}
|
145 |
+
for ix in range(len(self.info['images'])):
|
146 |
+
img = self.info['images'][ix]
|
147 |
+
if not 'split' in img:
|
148 |
+
self.split_ix['train'].append(ix)
|
149 |
+
self.split_ix['val'].append(ix)
|
150 |
+
self.split_ix['test'].append(ix)
|
151 |
+
elif img['split'] == 'train':
|
152 |
+
self.split_ix['train'].append(ix)
|
153 |
+
elif img['split'] == 'val':
|
154 |
+
self.split_ix['val'].append(ix)
|
155 |
+
elif img['split'] == 'test':
|
156 |
+
self.split_ix['test'].append(ix)
|
157 |
+
elif opt.train_only == 0: # restval
|
158 |
+
self.split_ix['train'].append(ix)
|
159 |
+
|
160 |
+
print('assigned %d images to split train' %len(self.split_ix['train']))
|
161 |
+
print('assigned %d images to split val' %len(self.split_ix['val']))
|
162 |
+
print('assigned %d images to split test' %len(self.split_ix['test']))
|
163 |
+
|
164 |
+
def get_captions(self, ix, seq_per_img):
|
165 |
+
# fetch the sequence labels
|
166 |
+
ix1 = self.label_start_ix[ix] - 1 #label_start_ix starts from 1
|
167 |
+
ix2 = self.label_end_ix[ix] - 1
|
168 |
+
ncap = ix2 - ix1 + 1 # number of captions available for this image
|
169 |
+
assert ncap > 0, 'an image does not have any label. this can be handled but right now isn\'t'
|
170 |
+
|
171 |
+
random.seed(42)
|
172 |
+
|
173 |
+
if ncap < seq_per_img:
|
174 |
+
# we need to subsample (with replacement)
|
175 |
+
seq = np.zeros([seq_per_img, self.seq_length], dtype = 'int')
|
176 |
+
for q in range(seq_per_img):
|
177 |
+
ixl = random.randint(ix1,ix2)
|
178 |
+
seq[q, :] = self.label[ixl, :self.seq_length]
|
179 |
+
else:
|
180 |
+
ixl = random.randint(ix1, ix2 - seq_per_img + 1)
|
181 |
+
seq = self.label[ixl: ixl + seq_per_img, :self.seq_length]
|
182 |
+
|
183 |
+
return seq
|
184 |
+
|
185 |
+
def collate_func(self, batch):
|
186 |
+
seq_per_img = self.seq_per_img
|
187 |
+
|
188 |
+
fc_batch = []
|
189 |
+
att_batch = []
|
190 |
+
label_batch = []
|
191 |
+
|
192 |
+
wrapped = False
|
193 |
+
|
194 |
+
infos = []
|
195 |
+
gts = []
|
196 |
+
|
197 |
+
for sample in batch:
|
198 |
+
# fetch image
|
199 |
+
tmp_fc, tmp_att, tmp_seq, \
|
200 |
+
ix = sample
|
201 |
+
|
202 |
+
fc_batch.append(tmp_fc)
|
203 |
+
att_batch.append(tmp_att)
|
204 |
+
|
205 |
+
tmp_label = np.zeros([seq_per_img, self.seq_length + 2], dtype = 'int')
|
206 |
+
if hasattr(self, 'h5_label_file'):
|
207 |
+
# if there is ground truth
|
208 |
+
tmp_label[:, 1 : self.seq_length + 1] = tmp_seq
|
209 |
+
label_batch.append(tmp_label)
|
210 |
+
|
211 |
+
# Used for reward evaluation
|
212 |
+
if hasattr(self, 'h5_label_file'):
|
213 |
+
# if there is ground truth
|
214 |
+
gts.append(self.label[self.label_start_ix[ix] - 1: self.label_end_ix[ix]])
|
215 |
+
else:
|
216 |
+
gts.append([])
|
217 |
+
|
218 |
+
# record associated info as well
|
219 |
+
info_dict = {}
|
220 |
+
info_dict['ix'] = ix
|
221 |
+
info_dict['id'] = self.info['images'][ix]['id']
|
222 |
+
info_dict['file_path'] = self.info['images'][ix].get('file_path', '')
|
223 |
+
infos.append(info_dict)
|
224 |
+
|
225 |
+
# #sort by att_feat length
|
226 |
+
# fc_batch, att_batch, label_batch, gts, infos = \
|
227 |
+
# zip(*sorted(zip(fc_batch, att_batch, np.vsplit(label_batch, batch_size), gts, infos), key=lambda x: len(x[1]), reverse=True))
|
228 |
+
fc_batch, att_batch, label_batch, gts, infos = \
|
229 |
+
zip(*sorted(zip(fc_batch, att_batch, label_batch, gts, infos), key=lambda x: 0, reverse=True))
|
230 |
+
data = {}
|
231 |
+
data['fc_feats'] = np.stack(fc_batch)
|
232 |
+
# merge att_feats
|
233 |
+
max_att_len = max([_.shape[0] for _ in att_batch])
|
234 |
+
data['att_feats'] = np.zeros([len(att_batch), max_att_len, att_batch[0].shape[1]], dtype = 'float32')
|
235 |
+
for i in range(len(att_batch)):
|
236 |
+
data['att_feats'][i, :att_batch[i].shape[0]] = att_batch[i]
|
237 |
+
data['att_masks'] = np.zeros(data['att_feats'].shape[:2], dtype='float32')
|
238 |
+
for i in range(len(att_batch)):
|
239 |
+
data['att_masks'][i, :att_batch[i].shape[0]] = 1
|
240 |
+
# set att_masks to None if attention features have same length
|
241 |
+
if data['att_masks'].sum() == data['att_masks'].size:
|
242 |
+
data['att_masks'] = None
|
243 |
+
|
244 |
+
data['labels'] = np.vstack(label_batch)
|
245 |
+
# generate mask
|
246 |
+
nonzeros = np.array(list(map(lambda x: (x != 0).sum()+2, data['labels'])))
|
247 |
+
mask_batch = np.zeros([data['labels'].shape[0], self.seq_length + 2], dtype = 'float32')
|
248 |
+
for ix, row in enumerate(mask_batch):
|
249 |
+
row[:nonzeros[ix]] = 1
|
250 |
+
data['masks'] = mask_batch
|
251 |
+
data['labels'] = data['labels'].reshape(len(batch), seq_per_img, -1)
|
252 |
+
data['masks'] = data['masks'].reshape(len(batch), seq_per_img, -1)
|
253 |
+
|
254 |
+
data['gts'] = gts # all ground truth captions of each images
|
255 |
+
data['infos'] = infos
|
256 |
+
|
257 |
+
data = {k:torch.from_numpy(v) if type(v) is np.ndarray else v for k,v in data.items()} # Turn all ndarray to torch tensor
|
258 |
+
|
259 |
+
return data
|
260 |
+
|
261 |
+
def __getitem__(self, ix):
|
262 |
+
"""This function returns a tuple that is further passed to collate_fn
|
263 |
+
"""
|
264 |
+
if self.use_att:
|
265 |
+
att_feat = self.att_loader.get(str(self.info['images'][ix]['id']))
|
266 |
+
# Reshape to K x C
|
267 |
+
att_feat = att_feat.reshape(-1, att_feat.shape[-1])
|
268 |
+
if self.norm_att_feat:
|
269 |
+
att_feat = att_feat / np.linalg.norm(att_feat, 2, 1, keepdims=True)
|
270 |
+
if self.use_box:
|
271 |
+
box_feat = self.box_loader.get(str(self.info['images'][ix]['id']))
|
272 |
+
# devided by image width and height
|
273 |
+
x1,y1,x2,y2 = np.hsplit(box_feat, 4)
|
274 |
+
h,w = self.info['images'][ix]['height'], self.info['images'][ix]['width']
|
275 |
+
box_feat = np.hstack((x1/w, y1/h, x2/w, y2/h, (x2-x1)*(y2-y1)/(w*h))) # question? x2-x1+1??
|
276 |
+
if self.norm_box_feat:
|
277 |
+
box_feat = box_feat / np.linalg.norm(box_feat, 2, 1, keepdims=True)
|
278 |
+
att_feat = np.hstack([att_feat, box_feat])
|
279 |
+
# sort the features by the size of boxes
|
280 |
+
att_feat = np.stack(sorted(att_feat, key=lambda x:x[-1], reverse=True))
|
281 |
+
else:
|
282 |
+
att_feat = np.zeros((0,0), dtype='float32')
|
283 |
+
if self.use_fc:
|
284 |
+
try:
|
285 |
+
fc_feat = self.fc_loader.get(str(self.info['images'][ix]['id']))
|
286 |
+
except:
|
287 |
+
# Use average of attention when there is no fc provided (For bottomup feature)
|
288 |
+
fc_feat = att_feat.mean(0)
|
289 |
+
else:
|
290 |
+
fc_feat = np.zeros((0), dtype='float32')
|
291 |
+
if hasattr(self, 'h5_label_file'):
|
292 |
+
seq = self.get_captions(ix, self.seq_per_img)
|
293 |
+
else:
|
294 |
+
seq = None
|
295 |
+
return (fc_feat,
|
296 |
+
att_feat, seq,
|
297 |
+
ix)
|
298 |
+
|
299 |
+
def __len__(self):
|
300 |
+
return len(self.info['images'])
|
captioning/models/.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
captioning/models/AoAModel.py
ADDED
@@ -0,0 +1,234 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Implementation for paper 'Attention on Attention for Image Captioning'
|
2 |
+
# https://arxiv.org/abs/1908.06954
|
3 |
+
|
4 |
+
# RT: Code from original author's repo: https://github.com/husthuaan/AoANet/
|
5 |
+
|
6 |
+
from __future__ import absolute_import
|
7 |
+
from __future__ import division
|
8 |
+
from __future__ import print_function
|
9 |
+
|
10 |
+
import torch
|
11 |
+
import torch.nn as nn
|
12 |
+
import torch.nn.functional as F
|
13 |
+
|
14 |
+
from .AttModel import pack_wrapper, AttModel, Attention
|
15 |
+
from .TransformerModel import LayerNorm, attention, clones, SublayerConnection, PositionwiseFeedForward
|
16 |
+
|
17 |
+
class MultiHeadedDotAttention(nn.Module):
|
18 |
+
def __init__(self, h, d_model, dropout=0.1, scale=1, project_k_v=1, use_output_layer=1, do_aoa=0, norm_q=0, dropout_aoa=0.3):
|
19 |
+
super(MultiHeadedDotAttention, self).__init__()
|
20 |
+
assert d_model * scale % h == 0
|
21 |
+
# We assume d_v always equals d_k
|
22 |
+
self.d_k = d_model * scale // h
|
23 |
+
self.h = h
|
24 |
+
|
25 |
+
# Do we need to do linear projections on K and V?
|
26 |
+
self.project_k_v = project_k_v
|
27 |
+
|
28 |
+
# normalize the query?
|
29 |
+
if norm_q:
|
30 |
+
self.norm = LayerNorm(d_model)
|
31 |
+
else:
|
32 |
+
self.norm = lambda x:x
|
33 |
+
self.linears = clones(nn.Linear(d_model, d_model * scale), 1 + 2 * project_k_v)
|
34 |
+
|
35 |
+
# output linear layer after the multi-head attention?
|
36 |
+
self.output_layer = nn.Linear(d_model * scale, d_model)
|
37 |
+
|
38 |
+
# apply aoa after attention?
|
39 |
+
self.use_aoa = do_aoa
|
40 |
+
if self.use_aoa:
|
41 |
+
self.aoa_layer = nn.Sequential(nn.Linear((1 + scale) * d_model, 2 * d_model), nn.GLU())
|
42 |
+
# dropout to the input of AoA layer
|
43 |
+
if dropout_aoa > 0:
|
44 |
+
self.dropout_aoa = nn.Dropout(p=dropout_aoa)
|
45 |
+
else:
|
46 |
+
self.dropout_aoa = lambda x:x
|
47 |
+
|
48 |
+
if self.use_aoa or not use_output_layer:
|
49 |
+
# AoA doesn't need the output linear layer
|
50 |
+
del self.output_layer
|
51 |
+
self.output_layer = lambda x:x
|
52 |
+
|
53 |
+
self.attn = None
|
54 |
+
self.dropout = nn.Dropout(p=dropout)
|
55 |
+
|
56 |
+
def forward(self, query, value, key, mask=None):
|
57 |
+
if mask is not None:
|
58 |
+
if len(mask.size()) == 2:
|
59 |
+
mask = mask.unsqueeze(-2)
|
60 |
+
# Same mask applied to all h heads.
|
61 |
+
mask = mask.unsqueeze(1)
|
62 |
+
|
63 |
+
single_query = 0
|
64 |
+
if len(query.size()) == 2:
|
65 |
+
single_query = 1
|
66 |
+
query = query.unsqueeze(1)
|
67 |
+
|
68 |
+
nbatches = query.size(0)
|
69 |
+
|
70 |
+
query = self.norm(query)
|
71 |
+
|
72 |
+
# Do all the linear projections in batch from d_model => h x d_k
|
73 |
+
if self.project_k_v == 0:
|
74 |
+
query_ = self.linears[0](query).view(nbatches, -1, self.h, self.d_k).transpose(1, 2)
|
75 |
+
key_ = key.view(nbatches, -1, self.h, self.d_k).transpose(1, 2)
|
76 |
+
value_ = value.view(nbatches, -1, self.h, self.d_k).transpose(1, 2)
|
77 |
+
else:
|
78 |
+
query_, key_, value_ = \
|
79 |
+
[l(x).view(nbatches, -1, self.h, self.d_k).transpose(1, 2)
|
80 |
+
for l, x in zip(self.linears, (query, key, value))]
|
81 |
+
|
82 |
+
|
83 |
+
# Apply attention on all the projected vectors in batch.
|
84 |
+
x, self.attn = attention(query_, key_, value_, mask=mask,
|
85 |
+
dropout=self.dropout)
|
86 |
+
|
87 |
+
# "Concat" using a view
|
88 |
+
x = x.transpose(1, 2).contiguous() \
|
89 |
+
.view(nbatches, -1, self.h * self.d_k)
|
90 |
+
|
91 |
+
if self.use_aoa:
|
92 |
+
# Apply AoA
|
93 |
+
x = self.aoa_layer(self.dropout_aoa(torch.cat([x, query], -1)))
|
94 |
+
# try:
|
95 |
+
# x = self.aoa_layer(self.dropout_aoa(torch.cat([x, query], -1)))
|
96 |
+
# except:
|
97 |
+
# x = self.aoa_layer(self.dropout_aoa(torch.cat([x.view(query.shape), query], -1)))
|
98 |
+
# x = self.aoa_layer(self.dropout_aoa(torch.cat([x, query.view(x.shape)], -1)))
|
99 |
+
|
100 |
+
x = self.output_layer(x)
|
101 |
+
|
102 |
+
if single_query:
|
103 |
+
query = query.squeeze(1)
|
104 |
+
x = x.squeeze(1)
|
105 |
+
return x
|
106 |
+
|
107 |
+
class AoA_Refiner_Layer(nn.Module):
|
108 |
+
def __init__(self, size, self_attn, feed_forward, dropout):
|
109 |
+
super(AoA_Refiner_Layer, self).__init__()
|
110 |
+
self.self_attn = self_attn
|
111 |
+
self.feed_forward = feed_forward
|
112 |
+
self.use_ff = 0
|
113 |
+
if self.feed_forward is not None:
|
114 |
+
self.use_ff = 1
|
115 |
+
self.sublayer = clones(SublayerConnection(size, dropout), 1+self.use_ff)
|
116 |
+
self.size = size
|
117 |
+
|
118 |
+
def forward(self, x, mask):
|
119 |
+
x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, mask))
|
120 |
+
return self.sublayer[-1](x, self.feed_forward) if self.use_ff else x
|
121 |
+
|
122 |
+
class AoA_Refiner_Core(nn.Module):
|
123 |
+
def __init__(self, opt):
|
124 |
+
super(AoA_Refiner_Core, self).__init__()
|
125 |
+
attn = MultiHeadedDotAttention(opt.num_heads, opt.rnn_size, project_k_v=1, scale=opt.multi_head_scale, do_aoa=opt.refine_aoa, norm_q=0, dropout_aoa=getattr(opt, 'dropout_aoa', 0.3))
|
126 |
+
layer = AoA_Refiner_Layer(opt.rnn_size, attn, PositionwiseFeedForward(opt.rnn_size, 2048, 0.1) if opt.use_ff else None, 0.1)
|
127 |
+
self.layers = clones(layer, 6)
|
128 |
+
self.norm = LayerNorm(layer.size)
|
129 |
+
|
130 |
+
def forward(self, x, mask):
|
131 |
+
for layer in self.layers:
|
132 |
+
x = layer(x, mask)
|
133 |
+
return self.norm(x)
|
134 |
+
|
135 |
+
class AoA_Decoder_Core(nn.Module):
|
136 |
+
def __init__(self, opt):
|
137 |
+
super(AoA_Decoder_Core, self).__init__()
|
138 |
+
self.drop_prob_lm = opt.drop_prob_lm
|
139 |
+
self.d_model = opt.rnn_size
|
140 |
+
self.use_multi_head = opt.use_multi_head
|
141 |
+
self.multi_head_scale = opt.multi_head_scale
|
142 |
+
self.use_ctx_drop = getattr(opt, 'ctx_drop', 0)
|
143 |
+
self.out_res = getattr(opt, 'out_res', 0)
|
144 |
+
self.decoder_type = getattr(opt, 'decoder_type', 'AoA')
|
145 |
+
self.att_lstm = nn.LSTMCell(opt.input_encoding_size + opt.rnn_size, opt.rnn_size) # we, fc, h^2_t-1
|
146 |
+
self.out_drop = nn.Dropout(self.drop_prob_lm)
|
147 |
+
|
148 |
+
if self.decoder_type == 'AoA':
|
149 |
+
# AoA layer
|
150 |
+
self.att2ctx = nn.Sequential(nn.Linear(self.d_model * opt.multi_head_scale + opt.rnn_size, 2 * opt.rnn_size), nn.GLU())
|
151 |
+
elif self.decoder_type == 'LSTM':
|
152 |
+
# LSTM layer
|
153 |
+
self.att2ctx = nn.LSTMCell(self.d_model * opt.multi_head_scale + opt.rnn_size, opt.rnn_size)
|
154 |
+
else:
|
155 |
+
# Base linear layer
|
156 |
+
self.att2ctx = nn.Sequential(nn.Linear(self.d_model * opt.multi_head_scale + opt.rnn_size, opt.rnn_size), nn.ReLU())
|
157 |
+
|
158 |
+
# if opt.use_multi_head == 1: # TODO, not implemented for now
|
159 |
+
# self.attention = MultiHeadedAddAttention(opt.num_heads, opt.d_model, scale=opt.multi_head_scale)
|
160 |
+
if opt.use_multi_head == 2:
|
161 |
+
self.attention = MultiHeadedDotAttention(opt.num_heads, opt.rnn_size, project_k_v=0, scale=opt.multi_head_scale, use_output_layer=0, do_aoa=0, norm_q=1)
|
162 |
+
else:
|
163 |
+
self.attention = Attention(opt)
|
164 |
+
|
165 |
+
if self.use_ctx_drop:
|
166 |
+
self.ctx_drop = nn.Dropout(self.drop_prob_lm)
|
167 |
+
else:
|
168 |
+
self.ctx_drop = lambda x :x
|
169 |
+
|
170 |
+
def forward(self, xt, mean_feats, att_feats, p_att_feats, state, att_masks=None):
|
171 |
+
|
172 |
+
# state[0][1] is the context vector at the last step
|
173 |
+
h_att, c_att = self.att_lstm(torch.cat([xt, mean_feats + self.ctx_drop(state[0][1])], 1), (state[0][0], state[1][0]))
|
174 |
+
|
175 |
+
if self.use_multi_head == 2:
|
176 |
+
att = self.attention(h_att, p_att_feats.narrow(2, 0, self.multi_head_scale * self.d_model), p_att_feats.narrow(2, self.multi_head_scale * self.d_model, self.multi_head_scale * self.d_model), att_masks)
|
177 |
+
else:
|
178 |
+
att = self.attention(h_att, att_feats, p_att_feats, att_masks)
|
179 |
+
|
180 |
+
ctx_input = torch.cat([att, h_att], 1)
|
181 |
+
if self.decoder_type == 'LSTM':
|
182 |
+
output, c_logic = self.att2ctx(ctx_input, (state[0][1], state[1][1]))
|
183 |
+
state = (torch.stack((h_att, output)), torch.stack((c_att, c_logic)))
|
184 |
+
else:
|
185 |
+
output = self.att2ctx(ctx_input)
|
186 |
+
# save the context vector to state[0][1]
|
187 |
+
state = (torch.stack((h_att, output)), torch.stack((c_att, state[1][1])))
|
188 |
+
|
189 |
+
if self.out_res:
|
190 |
+
# add residual connection
|
191 |
+
output = output + h_att
|
192 |
+
|
193 |
+
output = self.out_drop(output)
|
194 |
+
return output, state
|
195 |
+
|
196 |
+
class AoAModel(AttModel):
|
197 |
+
def __init__(self, opt):
|
198 |
+
super(AoAModel, self).__init__(opt)
|
199 |
+
self.num_layers = 2
|
200 |
+
# mean pooling
|
201 |
+
self.use_mean_feats = getattr(opt, 'mean_feats', 1)
|
202 |
+
if opt.use_multi_head == 2:
|
203 |
+
del self.ctx2att
|
204 |
+
self.ctx2att = nn.Linear(opt.rnn_size, 2 * opt.multi_head_scale * opt.rnn_size)
|
205 |
+
|
206 |
+
if self.use_mean_feats:
|
207 |
+
del self.fc_embed
|
208 |
+
if opt.refine:
|
209 |
+
self.refiner = AoA_Refiner_Core(opt)
|
210 |
+
else:
|
211 |
+
self.refiner = lambda x,y : x
|
212 |
+
self.core = AoA_Decoder_Core(opt)
|
213 |
+
|
214 |
+
|
215 |
+
def _prepare_feature(self, fc_feats, att_feats, att_masks):
|
216 |
+
att_feats, att_masks = self.clip_att(att_feats, att_masks)
|
217 |
+
|
218 |
+
# embed att feats
|
219 |
+
att_feats = pack_wrapper(self.att_embed, att_feats, att_masks)
|
220 |
+
att_feats = self.refiner(att_feats, att_masks)
|
221 |
+
|
222 |
+
if self.use_mean_feats:
|
223 |
+
# meaning pooling
|
224 |
+
if att_masks is None:
|
225 |
+
mean_feats = torch.mean(att_feats, dim=1)
|
226 |
+
else:
|
227 |
+
mean_feats = (torch.sum(att_feats * att_masks.unsqueeze(-1), 1) / torch.sum(att_masks.unsqueeze(-1), 1))
|
228 |
+
else:
|
229 |
+
mean_feats = self.fc_embed(fc_feats)
|
230 |
+
|
231 |
+
# Project the attention feats first to reduce memory and computation.
|
232 |
+
p_att_feats = self.ctx2att(att_feats)
|
233 |
+
|
234 |
+
return mean_feats, att_feats, p_att_feats, att_masks
|
captioning/models/AttEnsemble.py
ADDED
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# This file is the implementation for ensemble evaluation.
|
2 |
+
|
3 |
+
from __future__ import absolute_import
|
4 |
+
from __future__ import division
|
5 |
+
from __future__ import print_function
|
6 |
+
|
7 |
+
import numpy as np
|
8 |
+
import torch
|
9 |
+
import torch.nn as nn
|
10 |
+
import torch.nn.functional as F
|
11 |
+
from torch.autograd import *
|
12 |
+
|
13 |
+
from .CaptionModel import CaptionModel
|
14 |
+
from .AttModel import pack_wrapper, AttModel
|
15 |
+
|
16 |
+
class AttEnsemble(AttModel):
|
17 |
+
def __init__(self, models, weights=None):
|
18 |
+
CaptionModel.__init__(self)
|
19 |
+
# super(AttEnsemble, self).__init__()
|
20 |
+
|
21 |
+
self.models = nn.ModuleList(models)
|
22 |
+
self.vocab_size = models[0].vocab_size
|
23 |
+
self.seq_length = models[0].seq_length
|
24 |
+
self.bad_endings_ix = models[0].bad_endings_ix
|
25 |
+
self.ss_prob = 0
|
26 |
+
weights = weights or [1.0] * len(self.models)
|
27 |
+
self.register_buffer('weights', torch.tensor(weights))
|
28 |
+
|
29 |
+
def init_hidden(self, batch_size):
|
30 |
+
state = [m.init_hidden(batch_size) for m in self.models]
|
31 |
+
return self.pack_state(state)
|
32 |
+
|
33 |
+
def pack_state(self, state):
|
34 |
+
self.state_lengths = [len(_) for _ in state]
|
35 |
+
return sum([list(_) for _ in state], [])
|
36 |
+
|
37 |
+
def unpack_state(self, state):
|
38 |
+
out = []
|
39 |
+
for l in self.state_lengths:
|
40 |
+
out.append(state[:l])
|
41 |
+
state = state[l:]
|
42 |
+
return out
|
43 |
+
|
44 |
+
def embed(self, it):
|
45 |
+
return [m.embed(it) for m in self.models]
|
46 |
+
|
47 |
+
def core(self, *args):
|
48 |
+
return zip(*[m.core(*_) for m, _ in zip(self.models, zip(*args))])
|
49 |
+
|
50 |
+
def get_logprobs_state(self, it, tmp_fc_feats, tmp_att_feats, tmp_p_att_feats, tmp_att_masks, state, output_logsoftmax=1):
|
51 |
+
# 'it' contains a word index
|
52 |
+
xt = self.embed(it)
|
53 |
+
|
54 |
+
state = self.unpack_state(state)
|
55 |
+
output, state = self.core(xt, tmp_fc_feats, tmp_att_feats, tmp_p_att_feats, state, tmp_att_masks)
|
56 |
+
logprobs = torch.stack([F.softmax(m.logit(output[i]), dim=1) for i,m in enumerate(self.models)], 2).mul(self.weights).div(self.weights.sum()).sum(-1).log()
|
57 |
+
|
58 |
+
return logprobs, self.pack_state(state)
|
59 |
+
|
60 |
+
def _prepare_feature(self, *args):
|
61 |
+
return tuple(zip(*[m._prepare_feature(*args) for m in self.models]))
|
62 |
+
|
63 |
+
def _old_sample_beam(self, fc_feats, att_feats, att_masks=None, opt={}):
|
64 |
+
beam_size = opt.get('beam_size', 10)
|
65 |
+
batch_size = fc_feats.size(0)
|
66 |
+
|
67 |
+
fc_feats, att_feats, p_att_feats, att_masks = self._prepare_feature(fc_feats, att_feats, att_masks)
|
68 |
+
|
69 |
+
assert beam_size <= self.vocab_size + 1, 'lets assume this for now, otherwise this corner case causes a few headaches down the road. can be dealt with in future if needed'
|
70 |
+
seq = torch.LongTensor(self.seq_length, batch_size).zero_()
|
71 |
+
seqLogprobs = torch.FloatTensor(self.seq_length, batch_size, self.vocab_size + 1)
|
72 |
+
# lets process every image independently for now, for simplicity
|
73 |
+
|
74 |
+
self.done_beams = [[] for _ in range(batch_size)]
|
75 |
+
for k in range(batch_size):
|
76 |
+
state = self.init_hidden(beam_size)
|
77 |
+
tmp_fc_feats = [fc_feats[i][k:k+1].expand(beam_size, fc_feats[i].size(1)) for i,m in enumerate(self.models)]
|
78 |
+
tmp_att_feats = [att_feats[i][k:k+1].expand(*((beam_size,)+att_feats[i].size()[1:])).contiguous() for i,m in enumerate(self.models)]
|
79 |
+
tmp_p_att_feats = [p_att_feats[i][k:k+1].expand(*((beam_size,)+p_att_feats[i].size()[1:])).contiguous() for i,m in enumerate(self.models)]
|
80 |
+
tmp_att_masks = [att_masks[i][k:k+1].expand(*((beam_size,)+att_masks[i].size()[1:])).contiguous() if att_masks[i] is not None else att_masks[i] for i,m in enumerate(self.models)]
|
81 |
+
|
82 |
+
it = fc_feats[0].data.new(beam_size).long().zero_()
|
83 |
+
logprobs, state = self.get_logprobs_state(it, tmp_fc_feats, tmp_att_feats, tmp_p_att_feats, tmp_att_masks, state)
|
84 |
+
|
85 |
+
self.done_beams[k] = self.old_beam_search(state, logprobs, tmp_fc_feats, tmp_att_feats, tmp_p_att_feats, tmp_att_masks, opt=opt)
|
86 |
+
seq[:, k] = self.done_beams[k][0]['seq'] # the first beam has highest cumulative score
|
87 |
+
seqLogprobs[:, k] = self.done_beams[k][0]['logps']
|
88 |
+
# return the samples and their log likelihoods
|
89 |
+
return seq.transpose(0, 1), seqLogprobs.transpose(0, 1)
|
90 |
+
# return the samples and their log likelihoods
|
captioning/models/AttModel.py
ADDED
@@ -0,0 +1,977 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# This file contains Att2in2, AdaAtt, AdaAttMO, UpDown model
|
2 |
+
|
3 |
+
# AdaAtt is from Knowing When to Look: Adaptive Attention via A Visual Sentinel for Image Captioning
|
4 |
+
# https://arxiv.org/abs/1612.01887
|
5 |
+
# AdaAttMO is a modified version with maxout lstm
|
6 |
+
|
7 |
+
# Att2in is from Self-critical Sequence Training for Image Captioning
|
8 |
+
# https://arxiv.org/abs/1612.00563
|
9 |
+
# In this file we only have Att2in2, which is a slightly different version of att2in,
|
10 |
+
# in which the img feature embedding and word embedding is the same as what in adaatt.
|
11 |
+
|
12 |
+
# UpDown is from Bottom-Up and Top-Down Attention for Image Captioning and VQA
|
13 |
+
# https://arxiv.org/abs/1707.07998
|
14 |
+
# However, it may not be identical to the author's architecture.
|
15 |
+
|
16 |
+
from __future__ import absolute_import
|
17 |
+
from __future__ import division
|
18 |
+
from __future__ import print_function
|
19 |
+
|
20 |
+
import numpy as np
|
21 |
+
import torch
|
22 |
+
import torch.nn as nn
|
23 |
+
import torch.nn.functional as F
|
24 |
+
from . import utils
|
25 |
+
from torch.nn.utils.rnn import PackedSequence, pack_padded_sequence, pad_packed_sequence
|
26 |
+
|
27 |
+
from .CaptionModel import CaptionModel
|
28 |
+
|
29 |
+
bad_endings = ['a','an','the','in','for','at','of','with','before','after','on','upon','near','to','is','are','am']
|
30 |
+
bad_endings += ['UNK', 'has', 'and', 'more']
|
31 |
+
|
32 |
+
def sort_pack_padded_sequence(input, lengths):
|
33 |
+
sorted_lengths, indices = torch.sort(lengths, descending=True)
|
34 |
+
tmp = pack_padded_sequence(input[indices], sorted_lengths, batch_first=True)
|
35 |
+
inv_ix = indices.clone()
|
36 |
+
inv_ix[indices] = torch.arange(0,len(indices)).type_as(inv_ix)
|
37 |
+
return tmp, inv_ix
|
38 |
+
|
39 |
+
def pad_unsort_packed_sequence(input, inv_ix):
|
40 |
+
tmp, _ = pad_packed_sequence(input, batch_first=True)
|
41 |
+
tmp = tmp[inv_ix]
|
42 |
+
return tmp
|
43 |
+
|
44 |
+
def pack_wrapper(module, att_feats, att_masks):
|
45 |
+
if att_masks is not None:
|
46 |
+
packed, inv_ix = sort_pack_padded_sequence(att_feats, att_masks.data.long().sum(1))
|
47 |
+
return pad_unsort_packed_sequence(PackedSequence(module(packed[0]), packed[1]), inv_ix)
|
48 |
+
else:
|
49 |
+
return module(att_feats)
|
50 |
+
|
51 |
+
class AttModel(CaptionModel):
|
52 |
+
def __init__(self, opt):
|
53 |
+
super(AttModel, self).__init__()
|
54 |
+
self.vocab_size = opt.vocab_size
|
55 |
+
self.input_encoding_size = opt.input_encoding_size
|
56 |
+
#self.rnn_type = opt.rnn_type
|
57 |
+
self.rnn_size = opt.rnn_size
|
58 |
+
self.num_layers = opt.num_layers
|
59 |
+
self.drop_prob_lm = opt.drop_prob_lm
|
60 |
+
self.seq_length = getattr(opt, 'max_length', 16) or opt.seq_length # maximum sample length
|
61 |
+
self.fc_feat_size = opt.fc_feat_size
|
62 |
+
self.att_feat_size = opt.att_feat_size
|
63 |
+
self.att_hid_size = opt.att_hid_size
|
64 |
+
|
65 |
+
self.bos_idx = getattr(opt, 'bos_idx', 0)
|
66 |
+
self.eos_idx = getattr(opt, 'eos_idx', 0)
|
67 |
+
self.pad_idx = getattr(opt, 'pad_idx', 0)
|
68 |
+
|
69 |
+
self.use_bn = getattr(opt, 'use_bn', 0)
|
70 |
+
|
71 |
+
self.ss_prob = 0.0 # Schedule sampling probability
|
72 |
+
|
73 |
+
self.embed = nn.Sequential(nn.Embedding(self.vocab_size + 1, self.input_encoding_size),
|
74 |
+
nn.ReLU(),
|
75 |
+
nn.Dropout(self.drop_prob_lm))
|
76 |
+
self.fc_embed = nn.Sequential(nn.Linear(self.fc_feat_size, self.rnn_size),
|
77 |
+
nn.ReLU(),
|
78 |
+
nn.Dropout(self.drop_prob_lm))
|
79 |
+
self.att_embed = nn.Sequential(*(
|
80 |
+
((nn.BatchNorm1d(self.att_feat_size),) if self.use_bn else ())+
|
81 |
+
(nn.Linear(self.att_feat_size, self.rnn_size),
|
82 |
+
nn.ReLU(),
|
83 |
+
nn.Dropout(self.drop_prob_lm))+
|
84 |
+
((nn.BatchNorm1d(self.rnn_size),) if self.use_bn==2 else ())))
|
85 |
+
|
86 |
+
self.logit_layers = getattr(opt, 'logit_layers', 1)
|
87 |
+
if self.logit_layers == 1:
|
88 |
+
self.logit = nn.Linear(self.rnn_size, self.vocab_size + 1)
|
89 |
+
else:
|
90 |
+
self.logit = [[nn.Linear(self.rnn_size, self.rnn_size), nn.ReLU(), nn.Dropout(0.5)] for _ in range(opt.logit_layers - 1)]
|
91 |
+
self.logit = nn.Sequential(*(reduce(lambda x,y:x+y, self.logit) + [nn.Linear(self.rnn_size, self.vocab_size + 1)]))
|
92 |
+
self.ctx2att = nn.Linear(self.rnn_size, self.att_hid_size)
|
93 |
+
|
94 |
+
# For remove bad endding
|
95 |
+
self.vocab = opt.vocab
|
96 |
+
self.bad_endings_ix = [int(k) for k,v in self.vocab.items() if v in bad_endings]
|
97 |
+
|
98 |
+
def init_hidden(self, bsz):
|
99 |
+
weight = self.logit.weight \
|
100 |
+
if hasattr(self.logit, "weight") \
|
101 |
+
else self.logit[0].weight
|
102 |
+
return (weight.new_zeros(self.num_layers, bsz, self.rnn_size),
|
103 |
+
weight.new_zeros(self.num_layers, bsz, self.rnn_size))
|
104 |
+
|
105 |
+
def clip_att(self, att_feats, att_masks):
|
106 |
+
# Clip the length of att_masks and att_feats to the maximum length
|
107 |
+
if att_masks is not None:
|
108 |
+
max_len = att_masks.data.long().sum(1).max()
|
109 |
+
att_feats = att_feats[:, :max_len].contiguous()
|
110 |
+
att_masks = att_masks[:, :max_len].contiguous()
|
111 |
+
return att_feats, att_masks
|
112 |
+
|
113 |
+
def _prepare_feature(self, fc_feats, att_feats, att_masks):
|
114 |
+
att_feats, att_masks = self.clip_att(att_feats, att_masks)
|
115 |
+
|
116 |
+
# embed fc and att feats
|
117 |
+
fc_feats = self.fc_embed(fc_feats)
|
118 |
+
att_feats = pack_wrapper(self.att_embed, att_feats, att_masks)
|
119 |
+
|
120 |
+
# Project the attention feats first to reduce memory and computation comsumptions.
|
121 |
+
p_att_feats = self.ctx2att(att_feats)
|
122 |
+
|
123 |
+
return fc_feats, att_feats, p_att_feats, att_masks
|
124 |
+
|
125 |
+
def _forward(self, fc_feats, att_feats, seq, att_masks=None):
|
126 |
+
|
127 |
+
batch_size = fc_feats.size(0)
|
128 |
+
if seq.ndim == 3: # B * seq_per_img * seq_len
|
129 |
+
seq = seq.reshape(-1, seq.shape[2])
|
130 |
+
seq_per_img = seq.shape[0] // batch_size
|
131 |
+
state = self.init_hidden(batch_size*seq_per_img)
|
132 |
+
|
133 |
+
outputs = fc_feats.new_zeros(batch_size*seq_per_img, seq.size(1), self.vocab_size+1)
|
134 |
+
|
135 |
+
# Prepare the features
|
136 |
+
p_fc_feats, p_att_feats, pp_att_feats, p_att_masks = self._prepare_feature(fc_feats, att_feats, att_masks)
|
137 |
+
# pp_att_feats is used for attention, we cache it in advance to reduce computation cost
|
138 |
+
|
139 |
+
if seq_per_img > 1:
|
140 |
+
p_fc_feats, p_att_feats, pp_att_feats, p_att_masks = utils.repeat_tensors(seq_per_img,
|
141 |
+
[p_fc_feats, p_att_feats, pp_att_feats, p_att_masks]
|
142 |
+
)
|
143 |
+
|
144 |
+
for i in range(seq.size(1)):
|
145 |
+
if self.training and i >= 1 and self.ss_prob > 0.0: # otherwiste no need to sample
|
146 |
+
sample_prob = fc_feats.new(batch_size*seq_per_img).uniform_(0, 1)
|
147 |
+
sample_mask = sample_prob < self.ss_prob
|
148 |
+
if sample_mask.sum() == 0:
|
149 |
+
it = seq[:, i].clone()
|
150 |
+
else:
|
151 |
+
sample_ind = sample_mask.nonzero().view(-1)
|
152 |
+
it = seq[:, i].data.clone()
|
153 |
+
prob_prev = torch.exp(outputs[:, i-1].detach()) # fetch prev distribution: shape Nx(M+1)
|
154 |
+
it.index_copy_(0, sample_ind, torch.multinomial(prob_prev, 1).view(-1).index_select(0, sample_ind))
|
155 |
+
else:
|
156 |
+
it = seq[:, i].clone()
|
157 |
+
# break if all the sequences end
|
158 |
+
if i >= 1 and seq[:, i].sum() == 0:
|
159 |
+
break
|
160 |
+
|
161 |
+
output, state = self.get_logprobs_state(it, p_fc_feats, p_att_feats, pp_att_feats, p_att_masks, state)
|
162 |
+
outputs[:, i] = output
|
163 |
+
|
164 |
+
return outputs
|
165 |
+
|
166 |
+
def get_logprobs_state(self, it, fc_feats, att_feats, p_att_feats, att_masks, state, output_logsoftmax=1):
|
167 |
+
# 'it' contains a word index
|
168 |
+
xt = self.embed(it)
|
169 |
+
|
170 |
+
output, state = self.core(xt, fc_feats, att_feats, p_att_feats, state, att_masks)
|
171 |
+
if output_logsoftmax:
|
172 |
+
logprobs = F.log_softmax(self.logit(output), dim=1)
|
173 |
+
else:
|
174 |
+
logprobs = self.logit(output)
|
175 |
+
|
176 |
+
return logprobs, state
|
177 |
+
|
178 |
+
def _old_sample_beam(self, fc_feats, att_feats, att_masks=None, opt={}):
|
179 |
+
beam_size = opt.get('beam_size', 10)
|
180 |
+
group_size = opt.get('group_size', 1)
|
181 |
+
sample_n = opt.get('sample_n', 10)
|
182 |
+
# when sample_n == beam_size then each beam is a sample.
|
183 |
+
assert sample_n == 1 or sample_n == beam_size // group_size, 'when beam search, sample_n == 1 or beam search'
|
184 |
+
batch_size = fc_feats.size(0)
|
185 |
+
|
186 |
+
p_fc_feats, p_att_feats, pp_att_feats, p_att_masks = self._prepare_feature(fc_feats, att_feats, att_masks)
|
187 |
+
|
188 |
+
assert beam_size <= self.vocab_size + 1, 'lets assume this for now, otherwise this corner case causes a few headaches down the road. can be dealt with in future if needed'
|
189 |
+
seq = fc_feats.new_full((batch_size*sample_n, self.seq_length), self.pad_idx, dtype=torch.long)
|
190 |
+
seqLogprobs = fc_feats.new_zeros(batch_size*sample_n, self.seq_length, self.vocab_size + 1)
|
191 |
+
# lets process every image independently for now, for simplicity
|
192 |
+
|
193 |
+
self.done_beams = [[] for _ in range(batch_size)]
|
194 |
+
for k in range(batch_size):
|
195 |
+
state = self.init_hidden(beam_size)
|
196 |
+
tmp_fc_feats, tmp_att_feats, tmp_p_att_feats, tmp_att_masks = utils.repeat_tensors(beam_size,
|
197 |
+
[p_fc_feats[k:k+1], p_att_feats[k:k+1], pp_att_feats[k:k+1], p_att_masks[k:k+1] if att_masks is not None else None]
|
198 |
+
)
|
199 |
+
|
200 |
+
for t in range(1):
|
201 |
+
if t == 0: # input <bos>
|
202 |
+
it = fc_feats.new_full([beam_size], self.bos_idx, dtype=torch.long)
|
203 |
+
|
204 |
+
logprobs, state = self.get_logprobs_state(it, tmp_fc_feats, tmp_att_feats, tmp_p_att_feats, tmp_att_masks, state)
|
205 |
+
|
206 |
+
self.done_beams[k] = self.old_beam_search(state, logprobs, tmp_fc_feats, tmp_att_feats, tmp_p_att_feats, tmp_att_masks, opt=opt)
|
207 |
+
if sample_n == beam_size:
|
208 |
+
for _n in range(sample_n):
|
209 |
+
seq[k*sample_n+_n, :] = self.done_beams[k][_n]['seq']
|
210 |
+
seqLogprobs[k*sample_n+_n, :] = self.done_beams[k][_n]['logps']
|
211 |
+
else:
|
212 |
+
seq[k, :] = self.done_beams[k][0]['seq'] # the first beam has highest cumulative score
|
213 |
+
seqLogprobs[k, :] = self.done_beams[k][0]['logps']
|
214 |
+
# return the samples and their log likelihoods
|
215 |
+
return seq, seqLogprobs
|
216 |
+
|
217 |
+
|
218 |
+
def _sample_beam(self, fc_feats, att_feats, att_masks=None, opt={}):
|
219 |
+
beam_size = opt.get('beam_size', 10)
|
220 |
+
group_size = opt.get('group_size', 1)
|
221 |
+
sample_n = opt.get('sample_n', 10)
|
222 |
+
# when sample_n == beam_size then each beam is a sample.
|
223 |
+
assert sample_n == 1 or sample_n == beam_size // group_size, 'when beam search, sample_n == 1 or beam search'
|
224 |
+
batch_size = fc_feats.size(0)
|
225 |
+
|
226 |
+
p_fc_feats, p_att_feats, pp_att_feats, p_att_masks = self._prepare_feature(fc_feats, att_feats, att_masks)
|
227 |
+
|
228 |
+
assert beam_size <= self.vocab_size + 1, 'lets assume this for now, otherwise this corner case causes a few headaches down the road. can be dealt with in future if needed'
|
229 |
+
seq = fc_feats.new_full((batch_size*sample_n, self.seq_length), self.pad_idx, dtype=torch.long)
|
230 |
+
seqLogprobs = fc_feats.new_zeros(batch_size*sample_n, self.seq_length, self.vocab_size + 1)
|
231 |
+
# lets process every image independently for now, for simplicity
|
232 |
+
|
233 |
+
self.done_beams = [[] for _ in range(batch_size)]
|
234 |
+
|
235 |
+
state = self.init_hidden(batch_size)
|
236 |
+
|
237 |
+
# first step, feed bos
|
238 |
+
it = fc_feats.new_full([batch_size], self.bos_idx, dtype=torch.long)
|
239 |
+
logprobs, state = self.get_logprobs_state(it, p_fc_feats, p_att_feats, pp_att_feats, p_att_masks, state)
|
240 |
+
|
241 |
+
p_fc_feats, p_att_feats, pp_att_feats, p_att_masks = utils.repeat_tensors(beam_size,
|
242 |
+
[p_fc_feats, p_att_feats, pp_att_feats, p_att_masks]
|
243 |
+
)
|
244 |
+
self.done_beams = self.beam_search(state, logprobs, p_fc_feats, p_att_feats, pp_att_feats, p_att_masks, opt=opt)
|
245 |
+
|
246 |
+
for k in range(batch_size):
|
247 |
+
if sample_n == beam_size:
|
248 |
+
for _n in range(sample_n):
|
249 |
+
seq_len = self.done_beams[k][_n]['seq'].shape[0]
|
250 |
+
seq[k*sample_n+_n, :seq_len] = self.done_beams[k][_n]['seq']
|
251 |
+
seqLogprobs[k*sample_n+_n, :seq_len] = self.done_beams[k][_n]['logps']
|
252 |
+
else:
|
253 |
+
seq_len = self.done_beams[k][0]['seq'].shape[0]
|
254 |
+
seq[k, :seq_len] = self.done_beams[k][0]['seq'] # the first beam has highest cumulative score
|
255 |
+
seqLogprobs[k, :seq_len] = self.done_beams[k][0]['logps']
|
256 |
+
# return the samples and their log likelihoods
|
257 |
+
return seq, seqLogprobs
|
258 |
+
|
259 |
+
def _sample(self, fc_feats, att_feats, att_masks=None, opt={}):
|
260 |
+
|
261 |
+
sample_method = opt.get('sample_method', 'greedy')
|
262 |
+
beam_size = opt.get('beam_size', 1)
|
263 |
+
temperature = opt.get('temperature', 1.0)
|
264 |
+
sample_n = int(opt.get('sample_n', 1))
|
265 |
+
group_size = opt.get('group_size', 1)
|
266 |
+
output_logsoftmax = opt.get('output_logsoftmax', 1)
|
267 |
+
decoding_constraint = opt.get('decoding_constraint', 0)
|
268 |
+
block_trigrams = opt.get('block_trigrams', 0)
|
269 |
+
remove_bad_endings = opt.get('remove_bad_endings', 1)
|
270 |
+
suppress_UNK = opt.get('suppress_UNK', 1)
|
271 |
+
|
272 |
+
if beam_size > 1 and sample_method in ['greedy', 'beam_search']:
|
273 |
+
return self._sample_beam(fc_feats, att_feats, att_masks, opt)
|
274 |
+
if group_size > 1:
|
275 |
+
return self._diverse_sample(fc_feats, att_feats, att_masks, opt)
|
276 |
+
|
277 |
+
batch_size = fc_feats.size(0)
|
278 |
+
state = self.init_hidden(batch_size*sample_n)
|
279 |
+
|
280 |
+
p_fc_feats, p_att_feats, pp_att_feats, p_att_masks = self._prepare_feature(fc_feats, att_feats, att_masks)
|
281 |
+
|
282 |
+
if sample_n > 1:
|
283 |
+
p_fc_feats, p_att_feats, pp_att_feats, p_att_masks = utils.repeat_tensors(sample_n,
|
284 |
+
[p_fc_feats, p_att_feats, pp_att_feats, p_att_masks]
|
285 |
+
)
|
286 |
+
|
287 |
+
trigrams = [] # will be a list of batch_size dictionaries
|
288 |
+
|
289 |
+
seq = fc_feats.new_full((batch_size*sample_n, self.seq_length), self.pad_idx, dtype=torch.long)
|
290 |
+
seqLogprobs = fc_feats.new_zeros(batch_size*sample_n, self.seq_length, self.vocab_size + 1)
|
291 |
+
for t in range(self.seq_length + 1):
|
292 |
+
if t == 0: # input <bos>
|
293 |
+
it = fc_feats.new_full([batch_size*sample_n], self.bos_idx, dtype=torch.long)
|
294 |
+
|
295 |
+
logprobs, state = self.get_logprobs_state(it, p_fc_feats, p_att_feats, pp_att_feats, p_att_masks, state, output_logsoftmax=output_logsoftmax)
|
296 |
+
|
297 |
+
if decoding_constraint and t > 0:
|
298 |
+
tmp = logprobs.new_zeros(logprobs.size())
|
299 |
+
tmp.scatter_(1, seq[:,t-1].data.unsqueeze(1), float('-inf'))
|
300 |
+
logprobs = logprobs + tmp
|
301 |
+
|
302 |
+
if remove_bad_endings and t > 0:
|
303 |
+
logprobs[torch.from_numpy(np.isin(seq[:,t-1].data.cpu().numpy(), self.bad_endings_ix)), 0] = float('-inf')
|
304 |
+
# suppress UNK tokens in the decoding
|
305 |
+
if suppress_UNK and hasattr(self, 'vocab') and self.vocab[str(logprobs.size(1)-1)] == 'UNK':
|
306 |
+
logprobs[:,logprobs.size(1)-1] = logprobs[:, logprobs.size(1)-1] - 1000
|
307 |
+
|
308 |
+
# if remove_bad_endings and t > 0:
|
309 |
+
# tmp = logprobs.new_zeros(logprobs.size())
|
310 |
+
# prev_bad = np.isin(seq[:,t-1].data.cpu().numpy(), self.bad_endings_ix)
|
311 |
+
# # Make it impossible to generate bad_endings
|
312 |
+
# tmp[torch.from_numpy(prev_bad.astype('uint8')), 0] = float('-inf')
|
313 |
+
# logprobs = logprobs + tmp
|
314 |
+
|
315 |
+
# Mess with trigrams
|
316 |
+
# Copy from https://github.com/lukemelas/image-paragraph-captioning
|
317 |
+
if block_trigrams and t >= 3:
|
318 |
+
# Store trigram generated at last step
|
319 |
+
prev_two_batch = seq[:,t-3:t-1]
|
320 |
+
for i in range(batch_size): # = seq.size(0)
|
321 |
+
prev_two = (prev_two_batch[i][0].item(), prev_two_batch[i][1].item())
|
322 |
+
current = seq[i][t-1]
|
323 |
+
if t == 3: # initialize
|
324 |
+
trigrams.append({prev_two: [current]}) # {LongTensor: list containing 1 int}
|
325 |
+
elif t > 3:
|
326 |
+
if prev_two in trigrams[i]: # add to list
|
327 |
+
trigrams[i][prev_two].append(current)
|
328 |
+
else: # create list
|
329 |
+
trigrams[i][prev_two] = [current]
|
330 |
+
# Block used trigrams at next step
|
331 |
+
prev_two_batch = seq[:,t-2:t]
|
332 |
+
mask = torch.zeros(logprobs.size(), requires_grad=False).to(logprobs.device) # batch_size x vocab_size
|
333 |
+
for i in range(batch_size):
|
334 |
+
prev_two = (prev_two_batch[i][0].item(), prev_two_batch[i][1].item())
|
335 |
+
if prev_two in trigrams[i]:
|
336 |
+
for j in trigrams[i][prev_two]:
|
337 |
+
mask[i,j] += 1
|
338 |
+
# Apply mask to log probs
|
339 |
+
#logprobs = logprobs - (mask * 1e9)
|
340 |
+
alpha = 2.0 # = 4
|
341 |
+
logprobs = logprobs + (mask * -0.693 * alpha) # ln(1/2) * alpha (alpha -> infty works best)
|
342 |
+
|
343 |
+
# sample the next word
|
344 |
+
if t == self.seq_length: # skip if we achieve maximum length
|
345 |
+
break
|
346 |
+
it, sampleLogprobs = self.sample_next_word(logprobs, sample_method, temperature)
|
347 |
+
|
348 |
+
# stop when all finished
|
349 |
+
if t == 0:
|
350 |
+
unfinished = it != self.eos_idx
|
351 |
+
else:
|
352 |
+
it[~unfinished] = self.pad_idx # This allows eos_idx not being overwritten to 0
|
353 |
+
logprobs = logprobs * unfinished.unsqueeze(1).to(logprobs)
|
354 |
+
unfinished = unfinished & (it != self.eos_idx)
|
355 |
+
seq[:,t] = it
|
356 |
+
seqLogprobs[:,t] = logprobs
|
357 |
+
# quit loop if all sequences have finished
|
358 |
+
if unfinished.sum() == 0:
|
359 |
+
break
|
360 |
+
return seq, seqLogprobs
|
361 |
+
|
362 |
+
def _diverse_sample(self, fc_feats, att_feats, att_masks=None, opt={}):
|
363 |
+
|
364 |
+
sample_method = opt.get('sample_method', 'greedy')
|
365 |
+
beam_size = opt.get('beam_size', 1)
|
366 |
+
temperature = opt.get('temperature', 1.0)
|
367 |
+
group_size = opt.get('group_size', 1)
|
368 |
+
diversity_lambda = opt.get('diversity_lambda', 0.5)
|
369 |
+
decoding_constraint = opt.get('decoding_constraint', 0)
|
370 |
+
block_trigrams = opt.get('block_trigrams', 0)
|
371 |
+
remove_bad_endings = opt.get('remove_bad_endings', 1)
|
372 |
+
|
373 |
+
batch_size = fc_feats.size(0)
|
374 |
+
state = self.init_hidden(batch_size)
|
375 |
+
|
376 |
+
p_fc_feats, p_att_feats, pp_att_feats, p_att_masks = self._prepare_feature(fc_feats, att_feats, att_masks)
|
377 |
+
|
378 |
+
trigrams_table = [[] for _ in range(group_size)] # will be a list of batch_size dictionaries
|
379 |
+
|
380 |
+
seq_table = [fc_feats.new_full((batch_size, self.seq_length), self.pad_idx, dtype=torch.long) for _ in range(group_size)]
|
381 |
+
seqLogprobs_table = [fc_feats.new_zeros(batch_size, self.seq_length) for _ in range(group_size)]
|
382 |
+
state_table = [self.init_hidden(batch_size) for _ in range(group_size)]
|
383 |
+
|
384 |
+
for tt in range(self.seq_length + group_size):
|
385 |
+
for divm in range(group_size):
|
386 |
+
t = tt - divm
|
387 |
+
seq = seq_table[divm]
|
388 |
+
seqLogprobs = seqLogprobs_table[divm]
|
389 |
+
trigrams = trigrams_table[divm]
|
390 |
+
if t >= 0 and t <= self.seq_length-1:
|
391 |
+
if t == 0: # input <bos>
|
392 |
+
it = fc_feats.new_full([batch_size], self.bos_idx, dtype=torch.long)
|
393 |
+
else:
|
394 |
+
it = seq[:, t-1] # changed
|
395 |
+
|
396 |
+
logprobs, state_table[divm] = self.get_logprobs_state(it, p_fc_feats, p_att_feats, pp_att_feats, p_att_masks, state_table[divm]) # changed
|
397 |
+
logprobs = F.log_softmax(logprobs / temperature, dim=-1)
|
398 |
+
|
399 |
+
# Add diversity
|
400 |
+
if divm > 0:
|
401 |
+
unaug_logprobs = logprobs.clone()
|
402 |
+
for prev_choice in range(divm):
|
403 |
+
prev_decisions = seq_table[prev_choice][:, t]
|
404 |
+
logprobs[:, prev_decisions] = logprobs[:, prev_decisions] - diversity_lambda
|
405 |
+
|
406 |
+
if decoding_constraint and t > 0:
|
407 |
+
tmp = logprobs.new_zeros(logprobs.size())
|
408 |
+
tmp.scatter_(1, seq[:,t-1].data.unsqueeze(1), float('-inf'))
|
409 |
+
logprobs = logprobs + tmp
|
410 |
+
|
411 |
+
if remove_bad_endings and t > 0:
|
412 |
+
tmp = logprobs.new_zeros(logprobs.size())
|
413 |
+
prev_bad = np.isin(seq[:,t-1].data.cpu().numpy(), self.bad_endings_ix)
|
414 |
+
# Impossible to generate remove_bad_endings
|
415 |
+
tmp[torch.from_numpy(prev_bad.astype('uint8')), 0] = float('-inf')
|
416 |
+
logprobs = logprobs + tmp
|
417 |
+
|
418 |
+
# Mess with trigrams
|
419 |
+
if block_trigrams and t >= 3:
|
420 |
+
# Store trigram generated at last step
|
421 |
+
prev_two_batch = seq[:,t-3:t-1]
|
422 |
+
for i in range(batch_size): # = seq.size(0)
|
423 |
+
prev_two = (prev_two_batch[i][0].item(), prev_two_batch[i][1].item())
|
424 |
+
current = seq[i][t-1]
|
425 |
+
if t == 3: # initialize
|
426 |
+
trigrams.append({prev_two: [current]}) # {LongTensor: list containing 1 int}
|
427 |
+
elif t > 3:
|
428 |
+
if prev_two in trigrams[i]: # add to list
|
429 |
+
trigrams[i][prev_two].append(current)
|
430 |
+
else: # create list
|
431 |
+
trigrams[i][prev_two] = [current]
|
432 |
+
# Block used trigrams at next step
|
433 |
+
prev_two_batch = seq[:,t-2:t]
|
434 |
+
mask = torch.zeros(logprobs.size(), requires_grad=False).cuda() # batch_size x vocab_size
|
435 |
+
for i in range(batch_size):
|
436 |
+
prev_two = (prev_two_batch[i][0].item(), prev_two_batch[i][1].item())
|
437 |
+
if prev_two in trigrams[i]:
|
438 |
+
for j in trigrams[i][prev_two]:
|
439 |
+
mask[i,j] += 1
|
440 |
+
# Apply mask to log probs
|
441 |
+
#logprobs = logprobs - (mask * 1e9)
|
442 |
+
alpha = 2.0 # = 4
|
443 |
+
logprobs = logprobs + (mask * -0.693 * alpha) # ln(1/2) * alpha (alpha -> infty works best)
|
444 |
+
|
445 |
+
it, sampleLogprobs = self.sample_next_word(logprobs, sample_method, 1)
|
446 |
+
|
447 |
+
# stop when all finished
|
448 |
+
if t == 0:
|
449 |
+
unfinished = it != self.eos_idx
|
450 |
+
else:
|
451 |
+
unfinished = (seq[:,t-1] != self.pad_idx) & (seq[:,t-1] != self.eos_idx)
|
452 |
+
it[~unfinished] = self.pad_idx
|
453 |
+
unfinished = unfinished & (it != self.eos_idx) # changed
|
454 |
+
seq[:,t] = it
|
455 |
+
seqLogprobs[:,t] = sampleLogprobs.view(-1)
|
456 |
+
|
457 |
+
return torch.stack(seq_table, 1).reshape(batch_size * group_size, -1), torch.stack(seqLogprobs_table, 1).reshape(batch_size * group_size, -1)
|
458 |
+
|
459 |
+
class AdaAtt_lstm(nn.Module):
|
460 |
+
def __init__(self, opt, use_maxout=True):
|
461 |
+
super(AdaAtt_lstm, self).__init__()
|
462 |
+
self.input_encoding_size = opt.input_encoding_size
|
463 |
+
#self.rnn_type = opt.rnn_type
|
464 |
+
self.rnn_size = opt.rnn_size
|
465 |
+
self.num_layers = opt.num_layers
|
466 |
+
self.drop_prob_lm = opt.drop_prob_lm
|
467 |
+
self.fc_feat_size = opt.fc_feat_size
|
468 |
+
self.att_feat_size = opt.att_feat_size
|
469 |
+
self.att_hid_size = opt.att_hid_size
|
470 |
+
|
471 |
+
self.use_maxout = use_maxout
|
472 |
+
|
473 |
+
# Build a LSTM
|
474 |
+
self.w2h = nn.Linear(self.input_encoding_size, (4+(use_maxout==True)) * self.rnn_size)
|
475 |
+
self.v2h = nn.Linear(self.rnn_size, (4+(use_maxout==True)) * self.rnn_size)
|
476 |
+
|
477 |
+
self.i2h = nn.ModuleList([nn.Linear(self.rnn_size, (4+(use_maxout==True)) * self.rnn_size) for _ in range(self.num_layers - 1)])
|
478 |
+
self.h2h = nn.ModuleList([nn.Linear(self.rnn_size, (4+(use_maxout==True)) * self.rnn_size) for _ in range(self.num_layers)])
|
479 |
+
|
480 |
+
# Layers for getting the fake region
|
481 |
+
if self.num_layers == 1:
|
482 |
+
self.r_w2h = nn.Linear(self.input_encoding_size, self.rnn_size)
|
483 |
+
self.r_v2h = nn.Linear(self.rnn_size, self.rnn_size)
|
484 |
+
else:
|
485 |
+
self.r_i2h = nn.Linear(self.rnn_size, self.rnn_size)
|
486 |
+
self.r_h2h = nn.Linear(self.rnn_size, self.rnn_size)
|
487 |
+
|
488 |
+
|
489 |
+
def forward(self, xt, img_fc, state):
|
490 |
+
|
491 |
+
hs = []
|
492 |
+
cs = []
|
493 |
+
for L in range(self.num_layers):
|
494 |
+
# c,h from previous timesteps
|
495 |
+
prev_h = state[0][L]
|
496 |
+
prev_c = state[1][L]
|
497 |
+
# the input to this layer
|
498 |
+
if L == 0:
|
499 |
+
x = xt
|
500 |
+
i2h = self.w2h(x) + self.v2h(img_fc)
|
501 |
+
else:
|
502 |
+
x = hs[-1]
|
503 |
+
x = F.dropout(x, self.drop_prob_lm, self.training)
|
504 |
+
i2h = self.i2h[L-1](x)
|
505 |
+
|
506 |
+
all_input_sums = i2h+self.h2h[L](prev_h)
|
507 |
+
|
508 |
+
sigmoid_chunk = all_input_sums.narrow(1, 0, 3 * self.rnn_size)
|
509 |
+
sigmoid_chunk = torch.sigmoid(sigmoid_chunk)
|
510 |
+
# decode the gates
|
511 |
+
in_gate = sigmoid_chunk.narrow(1, 0, self.rnn_size)
|
512 |
+
forget_gate = sigmoid_chunk.narrow(1, self.rnn_size, self.rnn_size)
|
513 |
+
out_gate = sigmoid_chunk.narrow(1, self.rnn_size * 2, self.rnn_size)
|
514 |
+
# decode the write inputs
|
515 |
+
if not self.use_maxout:
|
516 |
+
in_transform = torch.tanh(all_input_sums.narrow(1, 3 * self.rnn_size, self.rnn_size))
|
517 |
+
else:
|
518 |
+
in_transform = all_input_sums.narrow(1, 3 * self.rnn_size, 2 * self.rnn_size)
|
519 |
+
in_transform = torch.max(\
|
520 |
+
in_transform.narrow(1, 0, self.rnn_size),
|
521 |
+
in_transform.narrow(1, self.rnn_size, self.rnn_size))
|
522 |
+
# perform the LSTM update
|
523 |
+
next_c = forget_gate * prev_c + in_gate * in_transform
|
524 |
+
# gated cells form the output
|
525 |
+
tanh_nex_c = torch.tanh(next_c)
|
526 |
+
next_h = out_gate * tanh_nex_c
|
527 |
+
if L == self.num_layers-1:
|
528 |
+
if L == 0:
|
529 |
+
i2h = self.r_w2h(x) + self.r_v2h(img_fc)
|
530 |
+
else:
|
531 |
+
i2h = self.r_i2h(x)
|
532 |
+
n5 = i2h+self.r_h2h(prev_h)
|
533 |
+
fake_region = torch.sigmoid(n5) * tanh_nex_c
|
534 |
+
|
535 |
+
cs.append(next_c)
|
536 |
+
hs.append(next_h)
|
537 |
+
|
538 |
+
# set up the decoder
|
539 |
+
top_h = hs[-1]
|
540 |
+
top_h = F.dropout(top_h, self.drop_prob_lm, self.training)
|
541 |
+
fake_region = F.dropout(fake_region, self.drop_prob_lm, self.training)
|
542 |
+
|
543 |
+
state = (torch.cat([_.unsqueeze(0) for _ in hs], 0),
|
544 |
+
torch.cat([_.unsqueeze(0) for _ in cs], 0))
|
545 |
+
return top_h, fake_region, state
|
546 |
+
|
547 |
+
class AdaAtt_attention(nn.Module):
|
548 |
+
def __init__(self, opt):
|
549 |
+
super(AdaAtt_attention, self).__init__()
|
550 |
+
self.input_encoding_size = opt.input_encoding_size
|
551 |
+
#self.rnn_type = opt.rnn_type
|
552 |
+
self.rnn_size = opt.rnn_size
|
553 |
+
self.drop_prob_lm = opt.drop_prob_lm
|
554 |
+
self.att_hid_size = opt.att_hid_size
|
555 |
+
|
556 |
+
# fake region embed
|
557 |
+
self.fr_linear = nn.Sequential(
|
558 |
+
nn.Linear(self.rnn_size, self.input_encoding_size),
|
559 |
+
nn.ReLU(),
|
560 |
+
nn.Dropout(self.drop_prob_lm))
|
561 |
+
self.fr_embed = nn.Linear(self.input_encoding_size, self.att_hid_size)
|
562 |
+
|
563 |
+
# h out embed
|
564 |
+
self.ho_linear = nn.Sequential(
|
565 |
+
nn.Linear(self.rnn_size, self.input_encoding_size),
|
566 |
+
nn.Tanh(),
|
567 |
+
nn.Dropout(self.drop_prob_lm))
|
568 |
+
self.ho_embed = nn.Linear(self.input_encoding_size, self.att_hid_size)
|
569 |
+
|
570 |
+
self.alpha_net = nn.Linear(self.att_hid_size, 1)
|
571 |
+
self.att2h = nn.Linear(self.rnn_size, self.rnn_size)
|
572 |
+
|
573 |
+
def forward(self, h_out, fake_region, conv_feat, conv_feat_embed, att_masks=None):
|
574 |
+
|
575 |
+
# View into three dimensions
|
576 |
+
att_size = conv_feat.numel() // conv_feat.size(0) // self.rnn_size
|
577 |
+
conv_feat = conv_feat.view(-1, att_size, self.rnn_size)
|
578 |
+
conv_feat_embed = conv_feat_embed.view(-1, att_size, self.att_hid_size)
|
579 |
+
|
580 |
+
# view neighbor from bach_size * neighbor_num x rnn_size to bach_size x rnn_size * neighbor_num
|
581 |
+
fake_region = self.fr_linear(fake_region)
|
582 |
+
fake_region_embed = self.fr_embed(fake_region)
|
583 |
+
|
584 |
+
h_out_linear = self.ho_linear(h_out)
|
585 |
+
h_out_embed = self.ho_embed(h_out_linear)
|
586 |
+
|
587 |
+
txt_replicate = h_out_embed.unsqueeze(1).expand(h_out_embed.size(0), att_size + 1, h_out_embed.size(1))
|
588 |
+
|
589 |
+
img_all = torch.cat([fake_region.view(-1,1,self.input_encoding_size), conv_feat], 1)
|
590 |
+
img_all_embed = torch.cat([fake_region_embed.view(-1,1,self.input_encoding_size), conv_feat_embed], 1)
|
591 |
+
|
592 |
+
hA = torch.tanh(img_all_embed + txt_replicate)
|
593 |
+
hA = F.dropout(hA,self.drop_prob_lm, self.training)
|
594 |
+
|
595 |
+
hAflat = self.alpha_net(hA.view(-1, self.att_hid_size))
|
596 |
+
PI = F.softmax(hAflat.view(-1, att_size + 1), dim=1)
|
597 |
+
|
598 |
+
if att_masks is not None:
|
599 |
+
att_masks = att_masks.view(-1, att_size)
|
600 |
+
PI = PI * torch.cat([att_masks[:,:1], att_masks], 1) # assume one one at the first time step.
|
601 |
+
PI = PI / PI.sum(1, keepdim=True)
|
602 |
+
|
603 |
+
visAtt = torch.bmm(PI.unsqueeze(1), img_all)
|
604 |
+
visAttdim = visAtt.squeeze(1)
|
605 |
+
|
606 |
+
atten_out = visAttdim + h_out_linear
|
607 |
+
|
608 |
+
h = torch.tanh(self.att2h(atten_out))
|
609 |
+
h = F.dropout(h, self.drop_prob_lm, self.training)
|
610 |
+
return h
|
611 |
+
|
612 |
+
class AdaAttCore(nn.Module):
|
613 |
+
def __init__(self, opt, use_maxout=False):
|
614 |
+
super(AdaAttCore, self).__init__()
|
615 |
+
self.lstm = AdaAtt_lstm(opt, use_maxout)
|
616 |
+
self.attention = AdaAtt_attention(opt)
|
617 |
+
|
618 |
+
def forward(self, xt, fc_feats, att_feats, p_att_feats, state, att_masks=None):
|
619 |
+
h_out, p_out, state = self.lstm(xt, fc_feats, state)
|
620 |
+
atten_out = self.attention(h_out, p_out, att_feats, p_att_feats, att_masks)
|
621 |
+
return atten_out, state
|
622 |
+
|
623 |
+
class UpDownCore(nn.Module):
|
624 |
+
def __init__(self, opt, use_maxout=False):
|
625 |
+
super(UpDownCore, self).__init__()
|
626 |
+
self.drop_prob_lm = opt.drop_prob_lm
|
627 |
+
|
628 |
+
self.att_lstm = nn.LSTMCell(opt.input_encoding_size + opt.rnn_size * 2, opt.rnn_size) # we, fc, h^2_t-1
|
629 |
+
self.lang_lstm = nn.LSTMCell(opt.rnn_size * 2, opt.rnn_size) # h^1_t, \hat v
|
630 |
+
self.attention = Attention(opt)
|
631 |
+
|
632 |
+
def forward(self, xt, fc_feats, att_feats, p_att_feats, state, att_masks=None):
|
633 |
+
prev_h = state[0][-1]
|
634 |
+
att_lstm_input = torch.cat([prev_h, fc_feats, xt], 1)
|
635 |
+
|
636 |
+
h_att, c_att = self.att_lstm(att_lstm_input, (state[0][0], state[1][0]))
|
637 |
+
|
638 |
+
att = self.attention(h_att, att_feats, p_att_feats, att_masks)
|
639 |
+
|
640 |
+
lang_lstm_input = torch.cat([att, h_att], 1)
|
641 |
+
# lang_lstm_input = torch.cat([att, F.dropout(h_att, self.drop_prob_lm, self.training)], 1) ?????
|
642 |
+
|
643 |
+
h_lang, c_lang = self.lang_lstm(lang_lstm_input, (state[0][1], state[1][1]))
|
644 |
+
|
645 |
+
output = F.dropout(h_lang, self.drop_prob_lm, self.training)
|
646 |
+
state = (torch.stack([h_att, h_lang]), torch.stack([c_att, c_lang]))
|
647 |
+
|
648 |
+
return output, state
|
649 |
+
|
650 |
+
|
651 |
+
############################################################################
|
652 |
+
# Notice:
|
653 |
+
# StackAtt and DenseAtt are models that I randomly designed.
|
654 |
+
# They are not related to any paper.
|
655 |
+
############################################################################
|
656 |
+
|
657 |
+
from .FCModel import LSTMCore
|
658 |
+
class StackAttCore(nn.Module):
|
659 |
+
def __init__(self, opt, use_maxout=False):
|
660 |
+
super(StackAttCore, self).__init__()
|
661 |
+
self.drop_prob_lm = opt.drop_prob_lm
|
662 |
+
|
663 |
+
# self.att0 = Attention(opt)
|
664 |
+
self.att1 = Attention(opt)
|
665 |
+
self.att2 = Attention(opt)
|
666 |
+
|
667 |
+
opt_input_encoding_size = opt.input_encoding_size
|
668 |
+
opt.input_encoding_size = opt.input_encoding_size + opt.rnn_size
|
669 |
+
self.lstm0 = LSTMCore(opt) # att_feat + word_embedding
|
670 |
+
opt.input_encoding_size = opt.rnn_size * 2
|
671 |
+
self.lstm1 = LSTMCore(opt)
|
672 |
+
self.lstm2 = LSTMCore(opt)
|
673 |
+
opt.input_encoding_size = opt_input_encoding_size
|
674 |
+
|
675 |
+
# self.emb1 = nn.Linear(opt.rnn_size, opt.rnn_size)
|
676 |
+
self.emb2 = nn.Linear(opt.rnn_size, opt.rnn_size)
|
677 |
+
|
678 |
+
def forward(self, xt, fc_feats, att_feats, p_att_feats, state, att_masks=None):
|
679 |
+
# att_res_0 = self.att0(state[0][-1], att_feats, p_att_feats, att_masks)
|
680 |
+
h_0, state_0 = self.lstm0(torch.cat([xt,fc_feats],1), [state[0][0:1], state[1][0:1]])
|
681 |
+
att_res_1 = self.att1(h_0, att_feats, p_att_feats, att_masks)
|
682 |
+
h_1, state_1 = self.lstm1(torch.cat([h_0,att_res_1],1), [state[0][1:2], state[1][1:2]])
|
683 |
+
att_res_2 = self.att2(h_1 + self.emb2(att_res_1), att_feats, p_att_feats, att_masks)
|
684 |
+
h_2, state_2 = self.lstm2(torch.cat([h_1,att_res_2],1), [state[0][2:3], state[1][2:3]])
|
685 |
+
|
686 |
+
return h_2, [torch.cat(_, 0) for _ in zip(state_0, state_1, state_2)]
|
687 |
+
|
688 |
+
class DenseAttCore(nn.Module):
|
689 |
+
def __init__(self, opt, use_maxout=False):
|
690 |
+
super(DenseAttCore, self).__init__()
|
691 |
+
self.drop_prob_lm = opt.drop_prob_lm
|
692 |
+
|
693 |
+
# self.att0 = Attention(opt)
|
694 |
+
self.att1 = Attention(opt)
|
695 |
+
self.att2 = Attention(opt)
|
696 |
+
|
697 |
+
opt_input_encoding_size = opt.input_encoding_size
|
698 |
+
opt.input_encoding_size = opt.input_encoding_size + opt.rnn_size
|
699 |
+
self.lstm0 = LSTMCore(opt) # att_feat + word_embedding
|
700 |
+
opt.input_encoding_size = opt.rnn_size * 2
|
701 |
+
self.lstm1 = LSTMCore(opt)
|
702 |
+
self.lstm2 = LSTMCore(opt)
|
703 |
+
opt.input_encoding_size = opt_input_encoding_size
|
704 |
+
|
705 |
+
# self.emb1 = nn.Linear(opt.rnn_size, opt.rnn_size)
|
706 |
+
self.emb2 = nn.Linear(opt.rnn_size, opt.rnn_size)
|
707 |
+
|
708 |
+
# fuse h_0 and h_1
|
709 |
+
self.fusion1 = nn.Sequential(nn.Linear(opt.rnn_size*2, opt.rnn_size),
|
710 |
+
nn.ReLU(),
|
711 |
+
nn.Dropout(opt.drop_prob_lm))
|
712 |
+
# fuse h_0, h_1 and h_2
|
713 |
+
self.fusion2 = nn.Sequential(nn.Linear(opt.rnn_size*3, opt.rnn_size),
|
714 |
+
nn.ReLU(),
|
715 |
+
nn.Dropout(opt.drop_prob_lm))
|
716 |
+
|
717 |
+
def forward(self, xt, fc_feats, att_feats, p_att_feats, state, att_masks=None):
|
718 |
+
# att_res_0 = self.att0(state[0][-1], att_feats, p_att_feats, att_masks)
|
719 |
+
h_0, state_0 = self.lstm0(torch.cat([xt,fc_feats],1), [state[0][0:1], state[1][0:1]])
|
720 |
+
att_res_1 = self.att1(h_0, att_feats, p_att_feats, att_masks)
|
721 |
+
h_1, state_1 = self.lstm1(torch.cat([h_0,att_res_1],1), [state[0][1:2], state[1][1:2]])
|
722 |
+
att_res_2 = self.att2(h_1 + self.emb2(att_res_1), att_feats, p_att_feats, att_masks)
|
723 |
+
h_2, state_2 = self.lstm2(torch.cat([self.fusion1(torch.cat([h_0, h_1], 1)),att_res_2],1), [state[0][2:3], state[1][2:3]])
|
724 |
+
|
725 |
+
return self.fusion2(torch.cat([h_0, h_1, h_2], 1)), [torch.cat(_, 0) for _ in zip(state_0, state_1, state_2)]
|
726 |
+
|
727 |
+
class Attention(nn.Module):
|
728 |
+
def __init__(self, opt):
|
729 |
+
super(Attention, self).__init__()
|
730 |
+
self.rnn_size = opt.rnn_size
|
731 |
+
self.att_hid_size = opt.att_hid_size
|
732 |
+
|
733 |
+
self.h2att = nn.Linear(self.rnn_size, self.att_hid_size)
|
734 |
+
self.alpha_net = nn.Linear(self.att_hid_size, 1)
|
735 |
+
|
736 |
+
def forward(self, h, att_feats, p_att_feats, att_masks=None):
|
737 |
+
# The p_att_feats here is already projected
|
738 |
+
att_size = att_feats.numel() // att_feats.size(0) // att_feats.size(-1)
|
739 |
+
att = p_att_feats.view(-1, att_size, self.att_hid_size)
|
740 |
+
|
741 |
+
att_h = self.h2att(h) # batch * att_hid_size
|
742 |
+
att_h = att_h.unsqueeze(1).expand_as(att) # batch * att_size * att_hid_size
|
743 |
+
dot = att + att_h # batch * att_size * att_hid_size
|
744 |
+
dot = torch.tanh(dot) # batch * att_size * att_hid_size
|
745 |
+
dot = dot.view(-1, self.att_hid_size) # (batch * att_size) * att_hid_size
|
746 |
+
dot = self.alpha_net(dot) # (batch * att_size) * 1
|
747 |
+
dot = dot.view(-1, att_size) # batch * att_size
|
748 |
+
|
749 |
+
weight = F.softmax(dot, dim=1) # batch * att_size
|
750 |
+
if att_masks is not None:
|
751 |
+
weight = weight * att_masks.view(-1, att_size).to(weight)
|
752 |
+
weight = weight / weight.sum(1, keepdim=True) # normalize to 1
|
753 |
+
att_feats_ = att_feats.view(-1, att_size, att_feats.size(-1)) # batch * att_size * att_feat_size
|
754 |
+
att_res = torch.bmm(weight.unsqueeze(1), att_feats_).squeeze(1) # batch * att_feat_size
|
755 |
+
|
756 |
+
return att_res
|
757 |
+
|
758 |
+
class Att2in2Core(nn.Module):
|
759 |
+
def __init__(self, opt):
|
760 |
+
super(Att2in2Core, self).__init__()
|
761 |
+
self.input_encoding_size = opt.input_encoding_size
|
762 |
+
#self.rnn_type = opt.rnn_type
|
763 |
+
self.rnn_size = opt.rnn_size
|
764 |
+
#self.num_layers = opt.num_layers
|
765 |
+
self.drop_prob_lm = opt.drop_prob_lm
|
766 |
+
self.fc_feat_size = opt.fc_feat_size
|
767 |
+
self.att_feat_size = opt.att_feat_size
|
768 |
+
self.att_hid_size = opt.att_hid_size
|
769 |
+
|
770 |
+
# Build a LSTM
|
771 |
+
self.a2c = nn.Linear(self.rnn_size, 2 * self.rnn_size)
|
772 |
+
self.i2h = nn.Linear(self.input_encoding_size, 5 * self.rnn_size)
|
773 |
+
self.h2h = nn.Linear(self.rnn_size, 5 * self.rnn_size)
|
774 |
+
self.dropout = nn.Dropout(self.drop_prob_lm)
|
775 |
+
|
776 |
+
self.attention = Attention(opt)
|
777 |
+
|
778 |
+
def forward(self, xt, fc_feats, att_feats, p_att_feats, state, att_masks=None):
|
779 |
+
att_res = self.attention(state[0][-1], att_feats, p_att_feats, att_masks)
|
780 |
+
|
781 |
+
all_input_sums = self.i2h(xt) + self.h2h(state[0][-1])
|
782 |
+
sigmoid_chunk = all_input_sums.narrow(1, 0, 3 * self.rnn_size)
|
783 |
+
sigmoid_chunk = torch.sigmoid(sigmoid_chunk)
|
784 |
+
in_gate = sigmoid_chunk.narrow(1, 0, self.rnn_size)
|
785 |
+
forget_gate = sigmoid_chunk.narrow(1, self.rnn_size, self.rnn_size)
|
786 |
+
out_gate = sigmoid_chunk.narrow(1, self.rnn_size * 2, self.rnn_size)
|
787 |
+
|
788 |
+
in_transform = all_input_sums.narrow(1, 3 * self.rnn_size, 2 * self.rnn_size) + \
|
789 |
+
self.a2c(att_res)
|
790 |
+
in_transform = torch.max(\
|
791 |
+
in_transform.narrow(1, 0, self.rnn_size),
|
792 |
+
in_transform.narrow(1, self.rnn_size, self.rnn_size))
|
793 |
+
next_c = forget_gate * state[1][-1] + in_gate * in_transform
|
794 |
+
next_h = out_gate * torch.tanh(next_c)
|
795 |
+
|
796 |
+
output = self.dropout(next_h)
|
797 |
+
state = (next_h.unsqueeze(0), next_c.unsqueeze(0))
|
798 |
+
return output, state
|
799 |
+
|
800 |
+
class Att2inCore(Att2in2Core):
|
801 |
+
def __init__(self, opt):
|
802 |
+
super(Att2inCore, self).__init__(opt)
|
803 |
+
del self.a2c
|
804 |
+
self.a2c = nn.Linear(self.att_feat_size, 2 * self.rnn_size)
|
805 |
+
|
806 |
+
"""
|
807 |
+
Note this is my attempt to replicate att2all model in self-critical paper.
|
808 |
+
However, this is not a correct replication actually. Will fix it.
|
809 |
+
"""
|
810 |
+
class Att2all2Core(nn.Module):
|
811 |
+
def __init__(self, opt):
|
812 |
+
super(Att2all2Core, self).__init__()
|
813 |
+
self.input_encoding_size = opt.input_encoding_size
|
814 |
+
#self.rnn_type = opt.rnn_type
|
815 |
+
self.rnn_size = opt.rnn_size
|
816 |
+
#self.num_layers = opt.num_layers
|
817 |
+
self.drop_prob_lm = opt.drop_prob_lm
|
818 |
+
self.fc_feat_size = opt.fc_feat_size
|
819 |
+
self.att_feat_size = opt.att_feat_size
|
820 |
+
self.att_hid_size = opt.att_hid_size
|
821 |
+
|
822 |
+
# Build a LSTM
|
823 |
+
self.a2h = nn.Linear(self.rnn_size, 5 * self.rnn_size)
|
824 |
+
self.i2h = nn.Linear(self.input_encoding_size, 5 * self.rnn_size)
|
825 |
+
self.h2h = nn.Linear(self.rnn_size, 5 * self.rnn_size)
|
826 |
+
self.dropout = nn.Dropout(self.drop_prob_lm)
|
827 |
+
|
828 |
+
self.attention = Attention(opt)
|
829 |
+
|
830 |
+
def forward(self, xt, fc_feats, att_feats, p_att_feats, state, att_masks=None):
|
831 |
+
att_res = self.attention(state[0][-1], att_feats, p_att_feats, att_masks)
|
832 |
+
|
833 |
+
all_input_sums = self.i2h(xt) + self.h2h(state[0][-1]) + self.a2h(att_res)
|
834 |
+
sigmoid_chunk = all_input_sums.narrow(1, 0, 3 * self.rnn_size)
|
835 |
+
sigmoid_chunk = torch.sigmoid(sigmoid_chunk)
|
836 |
+
in_gate = sigmoid_chunk.narrow(1, 0, self.rnn_size)
|
837 |
+
forget_gate = sigmoid_chunk.narrow(1, self.rnn_size, self.rnn_size)
|
838 |
+
out_gate = sigmoid_chunk.narrow(1, self.rnn_size * 2, self.rnn_size)
|
839 |
+
|
840 |
+
in_transform = all_input_sums.narrow(1, 3 * self.rnn_size, 2 * self.rnn_size)
|
841 |
+
in_transform = torch.max(\
|
842 |
+
in_transform.narrow(1, 0, self.rnn_size),
|
843 |
+
in_transform.narrow(1, self.rnn_size, self.rnn_size))
|
844 |
+
next_c = forget_gate * state[1][-1] + in_gate * in_transform
|
845 |
+
next_h = out_gate * torch.tanh(next_c)
|
846 |
+
|
847 |
+
output = self.dropout(next_h)
|
848 |
+
state = (next_h.unsqueeze(0), next_c.unsqueeze(0))
|
849 |
+
return output, state
|
850 |
+
|
851 |
+
class AdaAttModel(AttModel):
|
852 |
+
def __init__(self, opt):
|
853 |
+
super(AdaAttModel, self).__init__(opt)
|
854 |
+
self.core = AdaAttCore(opt)
|
855 |
+
|
856 |
+
# AdaAtt with maxout lstm
|
857 |
+
class AdaAttMOModel(AttModel):
|
858 |
+
def __init__(self, opt):
|
859 |
+
super(AdaAttMOModel, self).__init__(opt)
|
860 |
+
self.core = AdaAttCore(opt, True)
|
861 |
+
|
862 |
+
class Att2in2Model(AttModel):
|
863 |
+
def __init__(self, opt):
|
864 |
+
super(Att2in2Model, self).__init__(opt)
|
865 |
+
self.core = Att2in2Core(opt)
|
866 |
+
delattr(self, 'fc_embed')
|
867 |
+
self.fc_embed = lambda x : x
|
868 |
+
|
869 |
+
class Att2all2Model(AttModel):
|
870 |
+
def __init__(self, opt):
|
871 |
+
super(Att2all2Model, self).__init__(opt)
|
872 |
+
self.core = Att2all2Core(opt)
|
873 |
+
delattr(self, 'fc_embed')
|
874 |
+
self.fc_embed = lambda x : x
|
875 |
+
|
876 |
+
class UpDownModel(AttModel):
|
877 |
+
def __init__(self, opt):
|
878 |
+
super(UpDownModel, self).__init__(opt)
|
879 |
+
self.num_layers = 2
|
880 |
+
self.core = UpDownCore(opt)
|
881 |
+
|
882 |
+
class StackAttModel(AttModel):
|
883 |
+
def __init__(self, opt):
|
884 |
+
super(StackAttModel, self).__init__(opt)
|
885 |
+
self.num_layers = 3
|
886 |
+
self.core = StackAttCore(opt)
|
887 |
+
|
888 |
+
class DenseAttModel(AttModel):
|
889 |
+
def __init__(self, opt):
|
890 |
+
super(DenseAttModel, self).__init__(opt)
|
891 |
+
self.num_layers = 3
|
892 |
+
self.core = DenseAttCore(opt)
|
893 |
+
|
894 |
+
class Att2inModel(AttModel):
|
895 |
+
def __init__(self, opt):
|
896 |
+
super(Att2inModel, self).__init__(opt)
|
897 |
+
del self.embed, self.fc_embed, self.att_embed
|
898 |
+
self.embed = nn.Embedding(self.vocab_size + 1, self.input_encoding_size)
|
899 |
+
self.fc_embed = self.att_embed = lambda x: x
|
900 |
+
del self.ctx2att
|
901 |
+
self.ctx2att = nn.Linear(self.att_feat_size, self.att_hid_size)
|
902 |
+
self.core = Att2inCore(opt)
|
903 |
+
self.init_weights()
|
904 |
+
|
905 |
+
def init_weights(self):
|
906 |
+
initrange = 0.1
|
907 |
+
self.embed.weight.data.uniform_(-initrange, initrange)
|
908 |
+
self.logit.bias.data.fill_(0)
|
909 |
+
self.logit.weight.data.uniform_(-initrange, initrange)
|
910 |
+
|
911 |
+
|
912 |
+
class NewFCModel(AttModel):
|
913 |
+
def __init__(self, opt):
|
914 |
+
super(NewFCModel, self).__init__(opt)
|
915 |
+
self.fc_embed = nn.Linear(self.fc_feat_size, self.input_encoding_size)
|
916 |
+
self.embed = nn.Embedding(self.vocab_size + 1, self.input_encoding_size)
|
917 |
+
self._core = LSTMCore(opt)
|
918 |
+
delattr(self, 'att_embed')
|
919 |
+
self.att_embed = lambda x : x
|
920 |
+
delattr(self, 'ctx2att')
|
921 |
+
self.ctx2att = lambda x: x
|
922 |
+
|
923 |
+
def core(self, xt, fc_feats, att_feats, p_att_feats, state, att_masks):
|
924 |
+
# Step 0, feed the input image
|
925 |
+
# if (self.training and state[0].is_leaf) or \
|
926 |
+
# (not self.training and state[0].sum() == 0):
|
927 |
+
# _, state = self._core(fc_feats, state)
|
928 |
+
# three cases
|
929 |
+
# normal mle training
|
930 |
+
# Sample
|
931 |
+
# beam search (diverse beam search)
|
932 |
+
# fixed captioning module.
|
933 |
+
is_first_step = (state[0]==0).all(2).all(0) # size: B
|
934 |
+
if is_first_step.all():
|
935 |
+
_, state = self._core(fc_feats, state)
|
936 |
+
elif is_first_step.any():
|
937 |
+
# This is mostly for diverse beam search I think
|
938 |
+
new_state = [torch.zeros_like(_) for _ in state]
|
939 |
+
new_state[0][:, ~is_first_step] = state[0][:, ~is_first_step]
|
940 |
+
new_state[1][:, ~is_first_step] = state[1][:, ~is_first_step]
|
941 |
+
_, state = self._core(fc_feats, state)
|
942 |
+
new_state[0][:, is_first_step] = state[0][:, is_first_step]
|
943 |
+
new_state[1][:, is_first_step] = state[1][:, is_first_step]
|
944 |
+
state = new_state
|
945 |
+
# if (state[0]==0).all():
|
946 |
+
# # Let's forget about diverse beam search first
|
947 |
+
# _, state = self._core(fc_feats, state)
|
948 |
+
return self._core(xt, state)
|
949 |
+
|
950 |
+
def _prepare_feature(self, fc_feats, att_feats, att_masks):
|
951 |
+
fc_feats = self.fc_embed(fc_feats)
|
952 |
+
|
953 |
+
return fc_feats, att_feats, att_feats, att_masks
|
954 |
+
|
955 |
+
|
956 |
+
class LMModel(AttModel):
|
957 |
+
def __init__(self, opt):
|
958 |
+
super(LMModel, self).__init__(opt)
|
959 |
+
delattr(self, 'fc_embed')
|
960 |
+
self.fc_embed = lambda x: x.new_zeros(x.shape[0], self.input_encoding_size)
|
961 |
+
self.embed = nn.Embedding(self.vocab_size + 1, self.input_encoding_size)
|
962 |
+
self._core = LSTMCore(opt)
|
963 |
+
delattr(self, 'att_embed')
|
964 |
+
self.att_embed = lambda x : x
|
965 |
+
delattr(self, 'ctx2att')
|
966 |
+
self.ctx2att = lambda x: x
|
967 |
+
|
968 |
+
def core(self, xt, fc_feats, att_feats, p_att_feats, state, att_masks):
|
969 |
+
if (state[0]==0).all():
|
970 |
+
# Let's forget about diverse beam search first
|
971 |
+
_, state = self._core(fc_feats, state)
|
972 |
+
return self._core(xt, state)
|
973 |
+
|
974 |
+
def _prepare_feature(self, fc_feats, att_feats, att_masks):
|
975 |
+
fc_feats = self.fc_embed(fc_feats)
|
976 |
+
|
977 |
+
return fc_feats, None, None, None
|
captioning/models/BertCapModel.py
ADDED
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
BertCapModel is using huggingface transformer bert model as seq2seq model.
|
3 |
+
The result is not as goog as original transformer.
|
4 |
+
"""
|
5 |
+
|
6 |
+
from __future__ import absolute_import
|
7 |
+
from __future__ import division
|
8 |
+
from __future__ import print_function
|
9 |
+
|
10 |
+
import torch
|
11 |
+
import torch.nn as nn
|
12 |
+
import torch.nn.functional as F
|
13 |
+
|
14 |
+
import copy
|
15 |
+
import math
|
16 |
+
import numpy as np
|
17 |
+
|
18 |
+
from .CaptionModel import CaptionModel
|
19 |
+
from .AttModel import sort_pack_padded_sequence, pad_unsort_packed_sequence, pack_wrapper, AttModel
|
20 |
+
try:
|
21 |
+
from transformers import BertModel, BertConfig
|
22 |
+
except:
|
23 |
+
print('Hugginface transformers not installed; please visit https://github.com/huggingface/transformers')
|
24 |
+
from .TransformerModel import subsequent_mask, TransformerModel, Generator
|
25 |
+
|
26 |
+
class EncoderDecoder(nn.Module):
|
27 |
+
"""
|
28 |
+
A standard Encoder-Decoder architecture. Base for this and many
|
29 |
+
other models.
|
30 |
+
"""
|
31 |
+
def __init__(self, encoder, decoder, generator):
|
32 |
+
super(EncoderDecoder, self).__init__()
|
33 |
+
self.encoder = encoder
|
34 |
+
self.decoder = decoder
|
35 |
+
self.generator = generator
|
36 |
+
|
37 |
+
def forward(self, src, tgt, src_mask, tgt_mask):
|
38 |
+
"Take in and process masked src and target sequences."
|
39 |
+
return self.decode(self.encode(src, src_mask), src_mask,
|
40 |
+
tgt, tgt_mask)
|
41 |
+
|
42 |
+
def encode(self, src, src_mask):
|
43 |
+
return self.encoder(inputs_embeds=src,
|
44 |
+
attention_mask=src_mask)[0]
|
45 |
+
|
46 |
+
def decode(self, memory, src_mask, tgt, tgt_mask):
|
47 |
+
return self.decoder(input_ids=tgt,
|
48 |
+
attention_mask=tgt_mask,
|
49 |
+
encoder_hidden_states=memory,
|
50 |
+
encoder_attention_mask=src_mask)[0]
|
51 |
+
|
52 |
+
|
53 |
+
class BertCapModel(TransformerModel):
|
54 |
+
|
55 |
+
def make_model(self, src_vocab, tgt_vocab, N_enc=6, N_dec=6,
|
56 |
+
d_model=512, d_ff=2048, h=8, dropout=0.1):
|
57 |
+
"Helper: Construct a model from hyperparameters."
|
58 |
+
enc_config = BertConfig(vocab_size=1,
|
59 |
+
hidden_size=d_model,
|
60 |
+
num_hidden_layers=N_enc,
|
61 |
+
num_attention_heads=h,
|
62 |
+
intermediate_size=d_ff,
|
63 |
+
hidden_dropout_prob=dropout,
|
64 |
+
attention_probs_dropout_prob=dropout,
|
65 |
+
max_position_embeddings=1,
|
66 |
+
type_vocab_size=1)
|
67 |
+
dec_config = BertConfig(vocab_size=tgt_vocab,
|
68 |
+
hidden_size=d_model,
|
69 |
+
num_hidden_layers=N_dec,
|
70 |
+
num_attention_heads=h,
|
71 |
+
intermediate_size=d_ff,
|
72 |
+
hidden_dropout_prob=dropout,
|
73 |
+
attention_probs_dropout_prob=dropout,
|
74 |
+
max_position_embeddings=17,
|
75 |
+
type_vocab_size=1,
|
76 |
+
is_decoder=True)
|
77 |
+
encoder = BertModel(enc_config)
|
78 |
+
def return_embeds(*args, **kwargs):
|
79 |
+
return kwargs['inputs_embeds']
|
80 |
+
del encoder.embeddings; encoder.embeddings = return_embeds
|
81 |
+
decoder = BertModel(dec_config)
|
82 |
+
model = EncoderDecoder(
|
83 |
+
encoder,
|
84 |
+
decoder,
|
85 |
+
Generator(d_model, tgt_vocab))
|
86 |
+
return model
|
87 |
+
|
88 |
+
def __init__(self, opt):
|
89 |
+
super(BertCapModel, self).__init__(opt)
|
90 |
+
|
91 |
+
def core(self, it, fc_feats_ph, att_feats_ph, memory, state, mask):
|
92 |
+
"""
|
93 |
+
state = [ys.unsqueeze(0)]
|
94 |
+
"""
|
95 |
+
if len(state) == 0:
|
96 |
+
ys = it.unsqueeze(1)
|
97 |
+
else:
|
98 |
+
ys = torch.cat([state[0][0], it.unsqueeze(1)], dim=1)
|
99 |
+
out = self.model.decode(memory, mask,
|
100 |
+
ys,
|
101 |
+
subsequent_mask(ys.size(1))
|
102 |
+
.to(memory.device))
|
103 |
+
return out[:, -1], [ys.unsqueeze(0)]
|
captioning/models/CaptionModel.py
ADDED
@@ -0,0 +1,411 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# This file contains ShowAttendTell and AllImg model
|
2 |
+
|
3 |
+
# ShowAttendTell is from Show, Attend and Tell: Neural Image Caption Generation with Visual Attention
|
4 |
+
# https://arxiv.org/abs/1502.03044
|
5 |
+
|
6 |
+
# AllImg is a model where
|
7 |
+
# img feature is concatenated with word embedding at every time step as the input of lstm
|
8 |
+
from __future__ import absolute_import
|
9 |
+
from __future__ import division
|
10 |
+
from __future__ import print_function
|
11 |
+
|
12 |
+
import numpy as np
|
13 |
+
import torch
|
14 |
+
import torch.nn as nn
|
15 |
+
import torch.nn.functional as F
|
16 |
+
from torch.autograd import *
|
17 |
+
# from ..utils import misc as utils
|
18 |
+
from captioning.utils import misc as utils
|
19 |
+
from . import utils as model_utils
|
20 |
+
|
21 |
+
# torch.manual_seed(42)
|
22 |
+
# if torch.cuda.is_available():
|
23 |
+
# torch.cuda.manual_seed(42)
|
24 |
+
|
25 |
+
class CaptionModel(nn.Module):
|
26 |
+
def __init__(self):
|
27 |
+
super(CaptionModel, self).__init__()
|
28 |
+
|
29 |
+
# implements beam search
|
30 |
+
# calls beam_step and returns the final set of beams
|
31 |
+
# augments log-probabilities with diversity terms when number of groups > 1
|
32 |
+
|
33 |
+
def forward(self, *args, **kwargs):
|
34 |
+
mode = kwargs.get('mode', 'forward')
|
35 |
+
if 'mode' in kwargs:
|
36 |
+
del kwargs['mode']
|
37 |
+
return getattr(self, '_'+mode)(*args, **kwargs)
|
38 |
+
|
39 |
+
def beam_search(self, init_state, init_logprobs, *args, **kwargs):
|
40 |
+
|
41 |
+
# function computes the similarity score to be augmented
|
42 |
+
def add_diversity(beam_seq_table, logprobs, t, divm, diversity_lambda, bdash):
|
43 |
+
local_time = t - divm
|
44 |
+
unaug_logprobs = logprobs.clone()
|
45 |
+
batch_size = beam_seq_table[0].shape[0]
|
46 |
+
|
47 |
+
if divm > 0:
|
48 |
+
change = logprobs.new_zeros(batch_size, logprobs.shape[-1])
|
49 |
+
for prev_choice in range(divm):
|
50 |
+
prev_decisions = beam_seq_table[prev_choice][:, :, local_time] # Nxb
|
51 |
+
for prev_labels in range(bdash):
|
52 |
+
change.scatter_add_(1, prev_decisions[:, prev_labels].unsqueeze(-1), change.new_ones(batch_size, 1))
|
53 |
+
|
54 |
+
if local_time == 0:
|
55 |
+
logprobs = logprobs - change * diversity_lambda
|
56 |
+
else:
|
57 |
+
logprobs = logprobs - self.repeat_tensor(bdash, change) * diversity_lambda
|
58 |
+
|
59 |
+
return logprobs, unaug_logprobs
|
60 |
+
|
61 |
+
|
62 |
+
# does one step of classical beam search
|
63 |
+
|
64 |
+
def beam_step(logprobs, unaug_logprobs, beam_size, t, beam_seq, beam_seq_logprobs, beam_logprobs_sum, state):
|
65 |
+
#INPUTS:
|
66 |
+
#logprobs: probabilities augmented after diversity N*bxV
|
67 |
+
#beam_size: obvious
|
68 |
+
#t : time instant
|
69 |
+
#beam_seq : tensor contanining the beams
|
70 |
+
#beam_seq_logprobs: tensor contanining the beam logprobs
|
71 |
+
#beam_logprobs_sum: tensor contanining joint logprobs
|
72 |
+
#OUPUTS:
|
73 |
+
#beam_seq : tensor containing the word indices of the decoded captions Nxbxl
|
74 |
+
#beam_seq_logprobs : log-probability of each decision made, NxbxlxV
|
75 |
+
#beam_logprobs_sum : joint log-probability of each beam Nxb
|
76 |
+
|
77 |
+
batch_size = beam_logprobs_sum.shape[0]
|
78 |
+
vocab_size = logprobs.shape[-1]
|
79 |
+
logprobs = logprobs.reshape(batch_size, -1, vocab_size) # NxbxV
|
80 |
+
if t == 0:
|
81 |
+
assert logprobs.shape[1] == 1
|
82 |
+
beam_logprobs_sum = beam_logprobs_sum[:, :1]
|
83 |
+
candidate_logprobs = beam_logprobs_sum.unsqueeze(-1) + logprobs # beam_logprobs_sum Nxb logprobs is NxbxV
|
84 |
+
ys, ix = torch.sort(candidate_logprobs.reshape(candidate_logprobs.shape[0], -1), -1, True)
|
85 |
+
ys, ix = ys[:,:beam_size], ix[:,:beam_size]
|
86 |
+
beam_ix = ix // vocab_size # Nxb which beam
|
87 |
+
selected_ix = ix % vocab_size # Nxb # which world
|
88 |
+
state_ix = (beam_ix + torch.arange(batch_size).type_as(beam_ix).unsqueeze(-1) * logprobs.shape[1]).reshape(-1) # N*b which in Nxb beams
|
89 |
+
|
90 |
+
|
91 |
+
if t > 0:
|
92 |
+
# gather according to beam_ix
|
93 |
+
assert (beam_seq.gather(1, beam_ix.unsqueeze(-1).expand_as(beam_seq)) == beam_seq.reshape(-1, beam_seq.shape[-1])[state_ix].view_as(beam_seq)).all()
|
94 |
+
beam_seq = beam_seq.gather(1, beam_ix.unsqueeze(-1).expand_as(beam_seq))
|
95 |
+
|
96 |
+
beam_seq_logprobs = beam_seq_logprobs.gather(1, beam_ix.unsqueeze(-1).unsqueeze(-1).expand_as(beam_seq_logprobs))
|
97 |
+
|
98 |
+
beam_seq = torch.cat([beam_seq, selected_ix.unsqueeze(-1)], -1) # beam_seq Nxbxl
|
99 |
+
beam_logprobs_sum = beam_logprobs_sum.gather(1, beam_ix) + \
|
100 |
+
logprobs.reshape(batch_size, -1).gather(1, ix)
|
101 |
+
assert (beam_logprobs_sum == ys).all()
|
102 |
+
_tmp_beam_logprobs = unaug_logprobs[state_ix].reshape(batch_size, -1, vocab_size)
|
103 |
+
beam_logprobs = unaug_logprobs.reshape(batch_size, -1, vocab_size).gather(1, beam_ix.unsqueeze(-1).expand(-1, -1, vocab_size)) # NxbxV
|
104 |
+
assert (_tmp_beam_logprobs == beam_logprobs).all()
|
105 |
+
beam_seq_logprobs = torch.cat([
|
106 |
+
beam_seq_logprobs,
|
107 |
+
beam_logprobs.reshape(batch_size, -1, 1, vocab_size)], 2)
|
108 |
+
|
109 |
+
new_state = [None for _ in state]
|
110 |
+
for _ix in range(len(new_state)):
|
111 |
+
# copy over state in previous beam q to new beam at vix
|
112 |
+
new_state[_ix] = state[_ix][:, state_ix]
|
113 |
+
state = new_state
|
114 |
+
return beam_seq,beam_seq_logprobs,beam_logprobs_sum,state
|
115 |
+
|
116 |
+
# Start diverse_beam_search
|
117 |
+
opt = kwargs['opt']
|
118 |
+
temperature = opt.get('temperature', 1) # This should not affect beam search, but will affect dbs
|
119 |
+
beam_size = opt.get('beam_size', 10)
|
120 |
+
group_size = opt.get('group_size', 1)
|
121 |
+
diversity_lambda = opt.get('diversity_lambda', 0.5)
|
122 |
+
decoding_constraint = opt.get('decoding_constraint', 0)
|
123 |
+
remove_bad_endings = opt.get('remove_bad_endings', 1)
|
124 |
+
suppress_UNK = opt.get('suppress_UNK', 1)
|
125 |
+
length_penalty = utils.penalty_builder(opt.get('length_penalty', ''))
|
126 |
+
bdash = beam_size // group_size # beam per group
|
127 |
+
|
128 |
+
batch_size = init_logprobs.shape[0]
|
129 |
+
device = init_logprobs.device
|
130 |
+
# INITIALIZATIONS
|
131 |
+
beam_seq_table = [torch.LongTensor(batch_size, bdash, 0).to(device) for _ in range(group_size)]
|
132 |
+
beam_seq_logprobs_table = [torch.FloatTensor(batch_size, bdash, 0, self.vocab_size + 1).to(device) for _ in range(group_size)]
|
133 |
+
beam_logprobs_sum_table = [torch.zeros(batch_size, bdash).to(device) for _ in range(group_size)]
|
134 |
+
|
135 |
+
# logprobs # logprobs predicted in last time step, shape (beam_size, vocab_size+1)
|
136 |
+
done_beams_table = [[[] for __ in range(group_size)] for _ in range(batch_size)]
|
137 |
+
# state_table = [list(torch.unbind(_)) for _ in torch.stack(init_state).chunk(group_size, 2)]
|
138 |
+
# state_table = list(zip(*[_.reshape(-1, batch_size * bdash, group_size, *_.shape[2:]).chunk(group_size, 2) for _ in init_state]))
|
139 |
+
state_table = [[_.clone() for _ in init_state] for _ in range(group_size)]
|
140 |
+
# logprobs_table = list(init_logprobs.reshape(batch_size * bdash, group_size, -1).chunk(group_size, 0))
|
141 |
+
logprobs_table = [init_logprobs.clone() for _ in range(group_size)]
|
142 |
+
# END INIT
|
143 |
+
|
144 |
+
# Chunk elements in the args
|
145 |
+
args = list(args)
|
146 |
+
args = model_utils.split_tensors(group_size, args) # For each arg, turn (Bbg)x... to (Bb)x(g)x...
|
147 |
+
if self.__class__.__name__ == 'AttEnsemble':
|
148 |
+
args = [[[args[j][i][k] for i in range(len(self.models))] for j in range(len(args))] for k in range(group_size)] # group_name, arg_name, model_name
|
149 |
+
else:
|
150 |
+
args = [[args[i][j] for i in range(len(args))] for j in range(group_size)]
|
151 |
+
|
152 |
+
for t in range(self.seq_length + group_size - 1):
|
153 |
+
for divm in range(group_size):
|
154 |
+
if t >= divm and t <= self.seq_length + divm - 1:
|
155 |
+
# add diversity
|
156 |
+
logprobs = logprobs_table[divm]
|
157 |
+
# suppress previous word
|
158 |
+
if decoding_constraint and t-divm > 0:
|
159 |
+
logprobs.scatter_(1, beam_seq_table[divm][:, :, t-divm-1].reshape(-1, 1).to(device), float('-inf'))
|
160 |
+
if remove_bad_endings and t-divm > 0:
|
161 |
+
logprobs[torch.from_numpy(np.isin(beam_seq_table[divm][:, :, t-divm-1].cpu().numpy(), self.bad_endings_ix)).reshape(-1), 0] = float('-inf')
|
162 |
+
# suppress UNK tokens in the decoding
|
163 |
+
if suppress_UNK and hasattr(self, 'vocab') and self.vocab[str(logprobs.size(1)-1)] == 'UNK':
|
164 |
+
logprobs[:,logprobs.size(1)-1] = logprobs[:, logprobs.size(1)-1] - 1000
|
165 |
+
# diversity is added here
|
166 |
+
# the function directly modifies the logprobs values and hence, we need to return
|
167 |
+
# the unaugmented ones for sorting the candidates in the end. # for historical
|
168 |
+
# reasons :-)
|
169 |
+
logprobs, unaug_logprobs = add_diversity(beam_seq_table,logprobs,t,divm,diversity_lambda,bdash)
|
170 |
+
|
171 |
+
# infer new beams
|
172 |
+
beam_seq_table[divm],\
|
173 |
+
beam_seq_logprobs_table[divm],\
|
174 |
+
beam_logprobs_sum_table[divm],\
|
175 |
+
state_table[divm] = beam_step(logprobs,
|
176 |
+
unaug_logprobs,
|
177 |
+
bdash,
|
178 |
+
t-divm,
|
179 |
+
beam_seq_table[divm],
|
180 |
+
beam_seq_logprobs_table[divm],
|
181 |
+
beam_logprobs_sum_table[divm],
|
182 |
+
state_table[divm])
|
183 |
+
|
184 |
+
# if time's up... or if end token is reached then copy beams
|
185 |
+
for b in range(batch_size):
|
186 |
+
is_end = beam_seq_table[divm][b, :, t-divm] == self.eos_idx
|
187 |
+
assert beam_seq_table[divm].shape[-1] == t-divm+1
|
188 |
+
if t == self.seq_length + divm - 1:
|
189 |
+
is_end.fill_(1)
|
190 |
+
for vix in range(bdash):
|
191 |
+
if is_end[vix]:
|
192 |
+
final_beam = {
|
193 |
+
'seq': beam_seq_table[divm][b, vix].clone(),
|
194 |
+
'logps': beam_seq_logprobs_table[divm][b, vix].clone(),
|
195 |
+
'unaug_p': beam_seq_logprobs_table[divm][b, vix].sum().item(),
|
196 |
+
'p': beam_logprobs_sum_table[divm][b, vix].item()
|
197 |
+
}
|
198 |
+
final_beam['p'] = length_penalty(t-divm+1, final_beam['p'])
|
199 |
+
done_beams_table[b][divm].append(final_beam)
|
200 |
+
beam_logprobs_sum_table[divm][b, is_end] -= 1000
|
201 |
+
|
202 |
+
# move the current group one step forward in time
|
203 |
+
|
204 |
+
it = beam_seq_table[divm][:, :, t-divm].reshape(-1).to(logprobs.device)
|
205 |
+
logprobs_table[divm], state_table[divm] = self.get_logprobs_state(it, *(args[divm] + [state_table[divm]]))
|
206 |
+
logprobs_table[divm] = F.log_softmax(logprobs_table[divm] / temperature, dim=-1)
|
207 |
+
|
208 |
+
# all beams are sorted by their log-probabilities
|
209 |
+
done_beams_table = [[sorted(done_beams_table[b][i], key=lambda x: -x['p'])[:bdash] for i in range(group_size)] for b in range(batch_size)]
|
210 |
+
done_beams = [sum(_, []) for _ in done_beams_table]
|
211 |
+
return done_beams
|
212 |
+
|
213 |
+
def old_beam_search(self, init_state, init_logprobs, *args, **kwargs):
|
214 |
+
|
215 |
+
# function computes the similarity score to be augmented
|
216 |
+
def add_diversity(beam_seq_table, logprobsf, t, divm, diversity_lambda, bdash):
|
217 |
+
local_time = t - divm
|
218 |
+
unaug_logprobsf = logprobsf.clone()
|
219 |
+
for prev_choice in range(divm):
|
220 |
+
prev_decisions = beam_seq_table[prev_choice][local_time]
|
221 |
+
for sub_beam in range(bdash):
|
222 |
+
for prev_labels in range(bdash):
|
223 |
+
logprobsf[sub_beam][prev_decisions[prev_labels]] = logprobsf[sub_beam][prev_decisions[prev_labels]] - diversity_lambda
|
224 |
+
return unaug_logprobsf
|
225 |
+
|
226 |
+
# does one step of classical beam search
|
227 |
+
|
228 |
+
def beam_step(logprobsf, unaug_logprobsf, beam_size, t, beam_seq, beam_seq_logprobs, beam_logprobs_sum, state):
|
229 |
+
#INPUTS:
|
230 |
+
#logprobsf: probabilities augmented after diversity
|
231 |
+
#beam_size: obvious
|
232 |
+
#t : time instant
|
233 |
+
#beam_seq : tensor contanining the beams
|
234 |
+
#beam_seq_logprobs: tensor contanining the beam logprobs
|
235 |
+
#beam_logprobs_sum: tensor contanining joint logprobs
|
236 |
+
#OUPUTS:
|
237 |
+
#beam_seq : tensor containing the word indices of the decoded captions
|
238 |
+
#beam_seq_logprobs : log-probability of each decision made, same size as beam_seq
|
239 |
+
#beam_logprobs_sum : joint log-probability of each beam
|
240 |
+
|
241 |
+
ys,ix = torch.sort(logprobsf,1,True)
|
242 |
+
candidates = []
|
243 |
+
cols = min(beam_size, ys.size(1))
|
244 |
+
rows = beam_size
|
245 |
+
if t == 0:
|
246 |
+
rows = 1
|
247 |
+
for c in range(cols): # for each column (word, essentially)
|
248 |
+
for q in range(rows): # for each beam expansion
|
249 |
+
#compute logprob of expanding beam q with word in (sorted) position c
|
250 |
+
local_logprob = ys[q,c].item()
|
251 |
+
candidate_logprob = beam_logprobs_sum[q] + local_logprob
|
252 |
+
# local_unaug_logprob = unaug_logprobsf[q,ix[q,c]]
|
253 |
+
candidates.append({'c':ix[q,c], 'q':q, 'p':candidate_logprob, 'r':unaug_logprobsf[q]})
|
254 |
+
candidates = sorted(candidates, key=lambda x: -x['p'])
|
255 |
+
|
256 |
+
new_state = [_.clone() for _ in state]
|
257 |
+
#beam_seq_prev, beam_seq_logprobs_prev
|
258 |
+
if t >= 1:
|
259 |
+
#we''ll need these as reference when we fork beams around
|
260 |
+
beam_seq_prev = beam_seq[:t].clone()
|
261 |
+
beam_seq_logprobs_prev = beam_seq_logprobs[:t].clone()
|
262 |
+
for vix in range(beam_size):
|
263 |
+
v = candidates[vix]
|
264 |
+
#fork beam index q into index vix
|
265 |
+
if t >= 1:
|
266 |
+
beam_seq[:t, vix] = beam_seq_prev[:, v['q']]
|
267 |
+
beam_seq_logprobs[:t, vix] = beam_seq_logprobs_prev[:, v['q']]
|
268 |
+
#rearrange recurrent states
|
269 |
+
for state_ix in range(len(new_state)):
|
270 |
+
# copy over state in previous beam q to new beam at vix
|
271 |
+
new_state[state_ix][:, vix] = state[state_ix][:, v['q']] # dimension one is time step
|
272 |
+
#append new end terminal at the end of this beam
|
273 |
+
beam_seq[t, vix] = v['c'] # c'th word is the continuation
|
274 |
+
beam_seq_logprobs[t, vix] = v['r'] # the raw logprob here
|
275 |
+
beam_logprobs_sum[vix] = v['p'] # the new (sum) logprob along this beam
|
276 |
+
state = new_state
|
277 |
+
return beam_seq,beam_seq_logprobs,beam_logprobs_sum,state,candidates
|
278 |
+
|
279 |
+
# Start diverse_beam_search
|
280 |
+
opt = kwargs['opt']
|
281 |
+
temperature = opt.get('temperature', 1) # This should not affect beam search, but will affect dbs
|
282 |
+
beam_size = opt.get('beam_size', 10)
|
283 |
+
group_size = opt.get('group_size', 1)
|
284 |
+
diversity_lambda = opt.get('diversity_lambda', 0.5)
|
285 |
+
decoding_constraint = opt.get('decoding_constraint', 0)
|
286 |
+
remove_bad_endings = opt.get('remove_bad_endings', 1)
|
287 |
+
suppress_UNK = opt.get('suppress_UNK', 1)
|
288 |
+
length_penalty = utils.penalty_builder(opt.get('length_penalty', ''))
|
289 |
+
bdash = beam_size // group_size # beam per group
|
290 |
+
|
291 |
+
# INITIALIZATIONS
|
292 |
+
beam_seq_table = [torch.LongTensor(self.seq_length, bdash).zero_() for _ in range(group_size)]
|
293 |
+
beam_seq_logprobs_table = [torch.FloatTensor(self.seq_length, bdash, self.vocab_size + 1).zero_() for _ in range(group_size)]
|
294 |
+
beam_logprobs_sum_table = [torch.zeros(bdash) for _ in range(group_size)]
|
295 |
+
|
296 |
+
# logprobs # logprobs predicted in last time step, shape (beam_size, vocab_size+1)
|
297 |
+
done_beams_table = [[] for _ in range(group_size)]
|
298 |
+
# state_table = [list(torch.unbind(_)) for _ in torch.stack(init_state).chunk(group_size, 2)]
|
299 |
+
state_table = list(zip(*[_.chunk(group_size, 1) for _ in init_state]))
|
300 |
+
logprobs_table = list(init_logprobs.chunk(group_size, 0))
|
301 |
+
# END INIT
|
302 |
+
|
303 |
+
# Chunk elements in the args
|
304 |
+
args = list(args)
|
305 |
+
if self.__class__.__name__ == 'AttEnsemble':
|
306 |
+
args = [[_.chunk(group_size) if _ is not None else [None]*group_size for _ in args_] for args_ in args] # arg_name, model_name, group_name
|
307 |
+
args = [[[args[j][i][k] for i in range(len(self.models))] for j in range(len(args))] for k in range(group_size)] # group_name, arg_name, model_name
|
308 |
+
else:
|
309 |
+
args = [_.chunk(group_size) if _ is not None else [None]*group_size for _ in args]
|
310 |
+
args = [[args[i][j] for i in range(len(args))] for j in range(group_size)]
|
311 |
+
|
312 |
+
for t in range(self.seq_length + group_size - 1):
|
313 |
+
for divm in range(group_size):
|
314 |
+
if t >= divm and t <= self.seq_length + divm - 1:
|
315 |
+
# add diversity
|
316 |
+
logprobsf = logprobs_table[divm]
|
317 |
+
# suppress previous word
|
318 |
+
if decoding_constraint and t-divm > 0:
|
319 |
+
logprobsf.scatter_(1, beam_seq_table[divm][t-divm-1].unsqueeze(1).to(logprobsf.device), float('-inf'))
|
320 |
+
if remove_bad_endings and t-divm > 0:
|
321 |
+
logprobsf[torch.from_numpy(np.isin(beam_seq_table[divm][t-divm-1].cpu().numpy(), self.bad_endings_ix)), 0] = float('-inf')
|
322 |
+
# suppress UNK tokens in the decoding
|
323 |
+
if suppress_UNK and hasattr(self, 'vocab') and self.vocab[str(logprobsf.size(1)-1)] == 'UNK':
|
324 |
+
logprobsf[:,logprobsf.size(1)-1] = logprobsf[:, logprobsf.size(1)-1] - 1000
|
325 |
+
# diversity is added here
|
326 |
+
# the function directly modifies the logprobsf values and hence, we need to return
|
327 |
+
# the unaugmented ones for sorting the candidates in the end. # for historical
|
328 |
+
# reasons :-)
|
329 |
+
unaug_logprobsf = add_diversity(beam_seq_table,logprobsf,t,divm,diversity_lambda,bdash)
|
330 |
+
|
331 |
+
# infer new beams
|
332 |
+
beam_seq_table[divm],\
|
333 |
+
beam_seq_logprobs_table[divm],\
|
334 |
+
beam_logprobs_sum_table[divm],\
|
335 |
+
state_table[divm],\
|
336 |
+
candidates_divm = beam_step(logprobsf,
|
337 |
+
unaug_logprobsf,
|
338 |
+
bdash,
|
339 |
+
t-divm,
|
340 |
+
beam_seq_table[divm],
|
341 |
+
beam_seq_logprobs_table[divm],
|
342 |
+
beam_logprobs_sum_table[divm],
|
343 |
+
state_table[divm])
|
344 |
+
|
345 |
+
# if time's up... or if end token is reached then copy beams
|
346 |
+
for vix in range(bdash):
|
347 |
+
if beam_seq_table[divm][t-divm,vix] == self.eos_idx or t == self.seq_length + divm - 1:
|
348 |
+
final_beam = {
|
349 |
+
'seq': beam_seq_table[divm][:, vix].clone(),
|
350 |
+
'logps': beam_seq_logprobs_table[divm][:, vix].clone(),
|
351 |
+
'unaug_p': beam_seq_logprobs_table[divm][:, vix].sum().item(),
|
352 |
+
'p': beam_logprobs_sum_table[divm][vix].item()
|
353 |
+
}
|
354 |
+
final_beam['p'] = length_penalty(t-divm+1, final_beam['p'])
|
355 |
+
done_beams_table[divm].append(final_beam)
|
356 |
+
# don't continue beams from finished sequences
|
357 |
+
beam_logprobs_sum_table[divm][vix] = -1000
|
358 |
+
|
359 |
+
# move the current group one step forward in time
|
360 |
+
|
361 |
+
it = beam_seq_table[divm][t-divm].to(logprobsf.device)
|
362 |
+
logprobs_table[divm], state_table[divm] = self.get_logprobs_state(it, *(args[divm] + [state_table[divm]]))
|
363 |
+
logprobs_table[divm] = F.log_softmax(logprobs_table[divm] / temperature, dim=-1)
|
364 |
+
|
365 |
+
# all beams are sorted by their log-probabilities
|
366 |
+
done_beams_table = [sorted(done_beams_table[i], key=lambda x: -x['p'])[:bdash] for i in range(group_size)]
|
367 |
+
done_beams = sum(done_beams_table, [])
|
368 |
+
return done_beams
|
369 |
+
|
370 |
+
def sample_next_word(self, logprobs, sample_method, temperature):
|
371 |
+
if sample_method == 'greedy':
|
372 |
+
sampleLogprobs, it = torch.max(logprobs.data, 1)
|
373 |
+
it = it.view(-1).long()
|
374 |
+
elif sample_method == 'gumbel': # gumbel softmax
|
375 |
+
# ref: https://gist.github.com/yzh119/fd2146d2aeb329d067568a493b20172f
|
376 |
+
def sample_gumbel(shape, eps=1e-20):
|
377 |
+
U = torch.rand(shape).to(logprobs.device)
|
378 |
+
return -torch.log(-torch.log(U + eps) + eps)
|
379 |
+
def gumbel_softmax_sample(logits, temperature):
|
380 |
+
y = logits + sample_gumbel(logits.size())
|
381 |
+
return F.log_softmax(y / temperature, dim=-1)
|
382 |
+
_logprobs = gumbel_softmax_sample(logprobs, temperature)
|
383 |
+
_, it = torch.max(_logprobs.data, 1)
|
384 |
+
sampleLogprobs = logprobs.gather(1, it.unsqueeze(1)) # gather the logprobs at sampled positions
|
385 |
+
else:
|
386 |
+
logprobs = logprobs / temperature
|
387 |
+
if sample_method.startswith('top'): # topk sampling
|
388 |
+
top_num = float(sample_method[3:])
|
389 |
+
if 0 < top_num < 1:
|
390 |
+
# nucleus sampling from # The Curious Case of Neural Text Degeneration
|
391 |
+
probs = F.softmax(logprobs, dim=1)
|
392 |
+
sorted_probs, sorted_indices = torch.sort(probs, descending=True, dim=1)
|
393 |
+
_cumsum = sorted_probs.cumsum(1)
|
394 |
+
mask = _cumsum < top_num
|
395 |
+
mask = torch.cat([torch.ones_like(mask[:,:1]), mask[:,:-1]], 1)
|
396 |
+
sorted_probs = sorted_probs * mask.to(sorted_probs)
|
397 |
+
sorted_probs = sorted_probs / sorted_probs.sum(1, keepdim=True)
|
398 |
+
logprobs.scatter_(1, sorted_indices, sorted_probs.log())
|
399 |
+
else:
|
400 |
+
the_k = int(top_num)
|
401 |
+
tmp = torch.empty_like(logprobs).fill_(float('-inf'))
|
402 |
+
topk, indices = torch.topk(logprobs, the_k, dim=1)
|
403 |
+
tmp = tmp.scatter(1, indices, topk)
|
404 |
+
logprobs = tmp
|
405 |
+
it = torch.distributions.Categorical(logits=logprobs.detach()).sample()
|
406 |
+
sampleLogprobs = logprobs.gather(1, it.unsqueeze(1)) # gather the logprobs at sampled positions
|
407 |
+
return it, sampleLogprobs
|
408 |
+
|
409 |
+
|
410 |
+
def decode_sequence(self, seq):
|
411 |
+
return utils.decode_sequence(self.vocab, seq)
|
captioning/models/FCModel.py
ADDED
@@ -0,0 +1,204 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import absolute_import
|
2 |
+
from __future__ import division
|
3 |
+
from __future__ import print_function
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
import torch.nn.functional as F
|
8 |
+
from torch.autograd import *
|
9 |
+
from . import utils
|
10 |
+
|
11 |
+
from .CaptionModel import CaptionModel
|
12 |
+
|
13 |
+
class LSTMCore(nn.Module):
|
14 |
+
def __init__(self, opt):
|
15 |
+
super(LSTMCore, self).__init__()
|
16 |
+
self.input_encoding_size = opt.input_encoding_size
|
17 |
+
self.rnn_size = opt.rnn_size
|
18 |
+
self.drop_prob_lm = opt.drop_prob_lm
|
19 |
+
|
20 |
+
# Build a LSTM
|
21 |
+
self.i2h = nn.Linear(self.input_encoding_size, 5 * self.rnn_size)
|
22 |
+
self.h2h = nn.Linear(self.rnn_size, 5 * self.rnn_size)
|
23 |
+
self.dropout = nn.Dropout(self.drop_prob_lm)
|
24 |
+
|
25 |
+
def forward(self, xt, state):
|
26 |
+
|
27 |
+
all_input_sums = self.i2h(xt) + self.h2h(state[0][-1])
|
28 |
+
sigmoid_chunk = all_input_sums.narrow(1, 0, 3 * self.rnn_size)
|
29 |
+
sigmoid_chunk = torch.sigmoid(sigmoid_chunk)
|
30 |
+
in_gate = sigmoid_chunk.narrow(1, 0, self.rnn_size)
|
31 |
+
forget_gate = sigmoid_chunk.narrow(1, self.rnn_size, self.rnn_size)
|
32 |
+
out_gate = sigmoid_chunk.narrow(1, self.rnn_size * 2, self.rnn_size)
|
33 |
+
|
34 |
+
in_transform = torch.max(\
|
35 |
+
all_input_sums.narrow(1, 3 * self.rnn_size, self.rnn_size),
|
36 |
+
all_input_sums.narrow(1, 4 * self.rnn_size, self.rnn_size))
|
37 |
+
next_c = forget_gate * state[1][-1] + in_gate * in_transform
|
38 |
+
next_h = out_gate * torch.tanh(next_c)
|
39 |
+
|
40 |
+
output = self.dropout(next_h)
|
41 |
+
state = (next_h.unsqueeze(0), next_c.unsqueeze(0))
|
42 |
+
return output, state
|
43 |
+
|
44 |
+
class FCModel(CaptionModel):
|
45 |
+
def __init__(self, opt):
|
46 |
+
super(FCModel, self).__init__()
|
47 |
+
self.vocab_size = opt.vocab_size
|
48 |
+
self.input_encoding_size = opt.input_encoding_size
|
49 |
+
self.rnn_type = opt.rnn_type
|
50 |
+
self.rnn_size = opt.rnn_size
|
51 |
+
self.num_layers = opt.num_layers
|
52 |
+
self.drop_prob_lm = opt.drop_prob_lm
|
53 |
+
self.seq_length = opt.seq_length
|
54 |
+
self.fc_feat_size = opt.fc_feat_size
|
55 |
+
|
56 |
+
self.ss_prob = 0.0 # Schedule sampling probability
|
57 |
+
|
58 |
+
self.img_embed = nn.Linear(self.fc_feat_size, self.input_encoding_size)
|
59 |
+
self.core = LSTMCore(opt)
|
60 |
+
self.embed = nn.Embedding(self.vocab_size + 1, self.input_encoding_size)
|
61 |
+
self.logit = nn.Linear(self.rnn_size, self.vocab_size + 1)
|
62 |
+
|
63 |
+
self.init_weights()
|
64 |
+
|
65 |
+
def init_weights(self):
|
66 |
+
initrange = 0.1
|
67 |
+
self.embed.weight.data.uniform_(-initrange, initrange)
|
68 |
+
self.logit.bias.data.fill_(0)
|
69 |
+
self.logit.weight.data.uniform_(-initrange, initrange)
|
70 |
+
|
71 |
+
def init_hidden(self, bsz):
|
72 |
+
weight = self.logit.weight
|
73 |
+
if self.rnn_type == 'lstm':
|
74 |
+
return (weight.new_zeros(self.num_layers, bsz, self.rnn_size),
|
75 |
+
weight.new_zeros(self.num_layers, bsz, self.rnn_size))
|
76 |
+
else:
|
77 |
+
return weight.new_zeros(self.num_layers, bsz, self.rnn_size)
|
78 |
+
|
79 |
+
def _forward(self, fc_feats, att_feats, seq, att_masks=None):
|
80 |
+
batch_size = fc_feats.size(0)
|
81 |
+
seq_per_img = seq.shape[0] // batch_size
|
82 |
+
state = self.init_hidden(batch_size*seq_per_img)
|
83 |
+
outputs = []
|
84 |
+
|
85 |
+
if seq_per_img > 1:
|
86 |
+
fc_feats = utils.repeat_tensors(seq_per_img, fc_feats)
|
87 |
+
|
88 |
+
for i in range(seq.size(1) + 1):
|
89 |
+
if i == 0:
|
90 |
+
xt = self.img_embed(fc_feats)
|
91 |
+
else:
|
92 |
+
if self.training and i >= 2 and self.ss_prob > 0.0: # otherwiste no need to sample
|
93 |
+
sample_prob = fc_feats.data.new(batch_size*seq_per_img).uniform_(0, 1)
|
94 |
+
sample_mask = sample_prob < self.ss_prob
|
95 |
+
if sample_mask.sum() == 0:
|
96 |
+
it = seq[:, i-1].clone()
|
97 |
+
else:
|
98 |
+
sample_ind = sample_mask.nonzero().view(-1)
|
99 |
+
it = seq[:, i-1].data.clone()
|
100 |
+
#prob_prev = torch.exp(outputs[-1].data.index_select(0, sample_ind)) # fetch prev distribution: shape Nx(M+1)
|
101 |
+
#it.index_copy_(0, sample_ind, torch.multinomial(prob_prev, 1).view(-1))
|
102 |
+
prob_prev = torch.exp(outputs[-1].data) # fetch prev distribution: shape Nx(M+1)
|
103 |
+
it.index_copy_(0, sample_ind, torch.multinomial(prob_prev, 1).view(-1).index_select(0, sample_ind))
|
104 |
+
else:
|
105 |
+
it = seq[:, i-1].clone()
|
106 |
+
# break if all the sequences end
|
107 |
+
if i >= 2 and seq[:, i-1].sum() == 0:
|
108 |
+
break
|
109 |
+
xt = self.embed(it)
|
110 |
+
|
111 |
+
output, state = self.core(xt, state)
|
112 |
+
output = F.log_softmax(self.logit(output), dim=1)
|
113 |
+
outputs.append(output)
|
114 |
+
|
115 |
+
return torch.cat([_.unsqueeze(1) for _ in outputs[1:]], 1).contiguous()
|
116 |
+
|
117 |
+
def get_logprobs_state(self, it, state):
|
118 |
+
# 'it' is contains a word index
|
119 |
+
xt = self.embed(it)
|
120 |
+
|
121 |
+
output, state = self.core(xt, state)
|
122 |
+
logprobs = F.log_softmax(self.logit(output), dim=1)
|
123 |
+
|
124 |
+
return logprobs, state
|
125 |
+
|
126 |
+
def _sample_beam(self, fc_feats, att_feats, att_masks=None, opt={}):
|
127 |
+
beam_size = opt.get('beam_size', 10)
|
128 |
+
batch_size = fc_feats.size(0)
|
129 |
+
|
130 |
+
assert beam_size <= self.vocab_size + 1, 'lets assume this for now, otherwise this corner case causes a few headaches down the road. can be dealt with in future if needed'
|
131 |
+
seq = torch.LongTensor(self.seq_length, batch_size).zero_()
|
132 |
+
seqLogprobs = torch.FloatTensor(self.seq_length, batch_size, self.vocab_size + 1)
|
133 |
+
# lets process every image independently for now, for simplicity
|
134 |
+
|
135 |
+
self.done_beams = [[] for _ in range(batch_size)]
|
136 |
+
for k in range(batch_size):
|
137 |
+
state = self.init_hidden(beam_size)
|
138 |
+
for t in range(2):
|
139 |
+
if t == 0:
|
140 |
+
xt = self.img_embed(fc_feats[k:k+1]).expand(beam_size, self.input_encoding_size)
|
141 |
+
elif t == 1: # input <bos>
|
142 |
+
it = fc_feats.data.new(beam_size).long().zero_()
|
143 |
+
xt = self.embed(it)
|
144 |
+
|
145 |
+
output, state = self.core(xt, state)
|
146 |
+
logprobs = F.log_softmax(self.logit(output), dim=1)
|
147 |
+
|
148 |
+
self.done_beams[k] = self.beam_search(state, logprobs, opt=opt)
|
149 |
+
seq[:, k] = self.done_beams[k][0]['seq'] # the first beam has highest cumulative score
|
150 |
+
seqLogprobs[:, k] = self.done_beams[k][0]['logps']
|
151 |
+
# return the samples and their log likelihoods
|
152 |
+
return seq.transpose(0, 1), seqLogprobs.transpose(0, 1)
|
153 |
+
|
154 |
+
def _sample(self, fc_feats, att_feats, att_masks=None, opt={}):
|
155 |
+
sample_method = opt.get('sample_method', 'greedy')
|
156 |
+
beam_size = opt.get('beam_size', 1)
|
157 |
+
temperature = opt.get('temperature', 1.0)
|
158 |
+
if beam_size > 1 and sample_method in ['greedy', 'beam_search']:
|
159 |
+
return self._sample_beam(fc_feats, att_feats, opt)
|
160 |
+
|
161 |
+
batch_size = fc_feats.size(0)
|
162 |
+
state = self.init_hidden(batch_size)
|
163 |
+
seq = fc_feats.new_zeros(batch_size, self.seq_length, dtype=torch.long)
|
164 |
+
seqLogprobs = fc_feats.new_zeros(batch_size, self.seq_length, self.vocab_size + 1)
|
165 |
+
for t in range(self.seq_length + 2):
|
166 |
+
if t == 0:
|
167 |
+
xt = self.img_embed(fc_feats)
|
168 |
+
else:
|
169 |
+
if t == 1: # input <bos>
|
170 |
+
it = fc_feats.data.new(batch_size).long().zero_()
|
171 |
+
xt = self.embed(it)
|
172 |
+
|
173 |
+
output, state = self.core(xt, state)
|
174 |
+
logprobs = F.log_softmax(self.logit(output), dim=1)
|
175 |
+
|
176 |
+
# sample the next_word
|
177 |
+
if t == self.seq_length + 1: # skip if we achieve maximum length
|
178 |
+
break
|
179 |
+
if sample_method == 'greedy':
|
180 |
+
sampleLogprobs, it = torch.max(logprobs.data, 1)
|
181 |
+
it = it.view(-1).long()
|
182 |
+
else:
|
183 |
+
if temperature == 1.0:
|
184 |
+
prob_prev = torch.exp(logprobs.data).cpu() # fetch prev distribution: shape Nx(M+1)
|
185 |
+
else:
|
186 |
+
# scale logprobs by temperature
|
187 |
+
prob_prev = torch.exp(torch.div(logprobs.data, temperature)).cpu()
|
188 |
+
it = torch.multinomial(prob_prev, 1).to(logprobs.device)
|
189 |
+
sampleLogprobs = logprobs.gather(1, it) # gather the logprobs at sampled positions
|
190 |
+
it = it.view(-1).long() # and flatten indices for downstream processing
|
191 |
+
|
192 |
+
if t >= 1:
|
193 |
+
# stop when all finished
|
194 |
+
if t == 1:
|
195 |
+
unfinished = it > 0
|
196 |
+
else:
|
197 |
+
unfinished = unfinished & (it > 0)
|
198 |
+
it = it * unfinished.type_as(it)
|
199 |
+
seq[:,t-1] = it #seq[t] the input of t+2 time step
|
200 |
+
seqLogprobs[:,t-1] = sampleLogprobs.view(-1)
|
201 |
+
if unfinished.sum() == 0:
|
202 |
+
break
|
203 |
+
|
204 |
+
return seq, seqLogprobs
|
captioning/models/M2Transformer.py
ADDED
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Instruction to use meshed_memory_transformer (https://arxiv.org/abs/1912.08226)
|
3 |
+
|
4 |
+
pip install git+https://github.com/ruotianluo/meshed-memory-transformer.git
|
5 |
+
|
6 |
+
Note:
|
7 |
+
Currently m2transformer is not performing as well as original transformer. Not sure why? Still investigating.
|
8 |
+
"""
|
9 |
+
|
10 |
+
from __future__ import absolute_import
|
11 |
+
from __future__ import division
|
12 |
+
from __future__ import print_function
|
13 |
+
|
14 |
+
import torch
|
15 |
+
import torch.nn as nn
|
16 |
+
import torch.nn.functional as F
|
17 |
+
|
18 |
+
import copy
|
19 |
+
import math
|
20 |
+
import numpy as np
|
21 |
+
|
22 |
+
from .CaptionModel import CaptionModel
|
23 |
+
from .AttModel import sort_pack_padded_sequence, pad_unsort_packed_sequence, pack_wrapper, AttModel
|
24 |
+
|
25 |
+
try:
|
26 |
+
from m2transformer.models.transformer import Transformer, MemoryAugmentedEncoder, MeshedDecoder, ScaledDotProductAttentionMemory
|
27 |
+
except:
|
28 |
+
print('meshed-memory-transformer not installed; please run `pip install git+https://github.com/ruotianluo/meshed-memory-transformer.git`')
|
29 |
+
from .TransformerModel import subsequent_mask, TransformerModel
|
30 |
+
|
31 |
+
|
32 |
+
class M2TransformerModel(TransformerModel):
|
33 |
+
|
34 |
+
def make_model(self, src_vocab, tgt_vocab, N_enc=6, N_dec=6,
|
35 |
+
d_model=512, d_ff=2048, h=8, dropout=0.1):
|
36 |
+
"Helper: Construct a model from hyperparameters."
|
37 |
+
encoder = MemoryAugmentedEncoder(N_enc, 0, attention_module=ScaledDotProductAttentionMemory,
|
38 |
+
attention_module_kwargs={'m': 40})
|
39 |
+
# Another implementation is to use MultiLevelEncoder + att_embed
|
40 |
+
decoder = MeshedDecoder(tgt_vocab, 54, N_dec, -1) # -1 is padding;
|
41 |
+
model = Transformer(0, encoder, decoder) # 0 is bos
|
42 |
+
return model
|
43 |
+
|
44 |
+
def __init__(self, opt):
|
45 |
+
super(M2TransformerModel, self).__init__(opt)
|
46 |
+
delattr(self, 'att_embed')
|
47 |
+
self.att_embed = lambda x: x # The visual embed is in the MAEncoder
|
48 |
+
# Notes: The dropout in MAEncoder is different from my att_embed, mine is 0.5?
|
49 |
+
# Also the attention mask seems wrong in MAEncoder too...intersting
|
50 |
+
|
51 |
+
def logit(self, x): # unsafe way
|
52 |
+
return x # M2transformer always output logsoftmax
|
53 |
+
|
54 |
+
def _prepare_feature(self, fc_feats, att_feats, att_masks):
|
55 |
+
|
56 |
+
att_feats, seq, att_masks, seq_mask = self._prepare_feature_forward(att_feats, att_masks)
|
57 |
+
memory, att_masks = self.model.encoder(att_feats)
|
58 |
+
|
59 |
+
return fc_feats[...,:0], att_feats[...,:0], memory, att_masks
|
60 |
+
|
61 |
+
def _forward(self, fc_feats, att_feats, seq, att_masks=None):
|
62 |
+
if seq.ndim == 3: # B * seq_per_img * seq_len
|
63 |
+
seq = seq.reshape(-1, seq.shape[2])
|
64 |
+
att_feats, seq, att_masks, seq_mask = self._prepare_feature_forward(att_feats, att_masks, seq)
|
65 |
+
|
66 |
+
seq = seq.clone()
|
67 |
+
seq[~seq_mask.any(-2)] = -1 # Make padding to be -1 (my dataloader uses 0 as padding)
|
68 |
+
outputs = self.model(att_feats, seq)
|
69 |
+
|
70 |
+
# out = self.model(att_feats, seq, att_masks, seq_mask)
|
71 |
+
|
72 |
+
# outputs = self.model.generator(out)
|
73 |
+
|
74 |
+
return outputs
|
75 |
+
|
76 |
+
def core(self, it, fc_feats_ph, att_feats_ph, memory, state, mask):
|
77 |
+
"""
|
78 |
+
state = [ys.unsqueeze(0)]
|
79 |
+
"""
|
80 |
+
if len(state) == 0:
|
81 |
+
ys = it.unsqueeze(1)
|
82 |
+
else:
|
83 |
+
ys = torch.cat([state[0][0], it.unsqueeze(1)], dim=1)
|
84 |
+
out = self.model.decoder(ys, memory, mask)
|
85 |
+
return out[:, -1], [ys.unsqueeze(0)]
|
86 |
+
|
87 |
+
def _sample_beam(self, fc_feats, att_feats, att_masks=None, opt={}):
|
88 |
+
beam_size = opt.get('beam_size', 10)
|
89 |
+
group_size = opt.get('group_size', 1)
|
90 |
+
sample_n = opt.get('sample_n', 10)
|
91 |
+
assert sample_n == 1 or sample_n == beam_size // group_size, 'when beam search, sample_n == 1 or beam search'
|
92 |
+
|
93 |
+
att_feats, _, __, ___ = self._prepare_feature_forward(att_feats, att_masks)
|
94 |
+
seq, logprobs, seqLogprobs = self.model.beam_search(att_feats, self.seq_length, 0,
|
95 |
+
beam_size, return_probs=True, out_size=beam_size)
|
96 |
+
seq = seq.reshape(-1, *seq.shape[2:])
|
97 |
+
seqLogprobs = seqLogprobs.reshape(-1, *seqLogprobs.shape[2:])
|
98 |
+
|
99 |
+
# if not (seqLogprobs.gather(-1, seq.unsqueeze(-1)).squeeze(-1) == logprobs.reshape(-1, logprobs.shape[-1])).all():
|
100 |
+
# import pudb;pu.db
|
101 |
+
# seqLogprobs = logprobs.reshape(-1, logprobs.shape[-1]).unsqueeze(-1).expand(-1,-1,seqLogprobs.shape[-1])
|
102 |
+
return seq, seqLogprobs
|
captioning/models/OldModel.py
ADDED
@@ -0,0 +1,265 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# This file contains ShowAttendTell and AllImg model
|
2 |
+
|
3 |
+
# ShowAttendTell is from Show, Attend and Tell: Neural Image Caption Generation with Visual Attention
|
4 |
+
# https://arxiv.org/abs/1502.03044
|
5 |
+
|
6 |
+
# AllImg is a model where
|
7 |
+
# img feature is concatenated with word embedding at every time step as the input of lstm
|
8 |
+
from __future__ import absolute_import
|
9 |
+
from __future__ import division
|
10 |
+
from __future__ import print_function
|
11 |
+
|
12 |
+
import torch
|
13 |
+
import torch.nn as nn
|
14 |
+
import torch.nn.functional as F
|
15 |
+
from torch.autograd import *
|
16 |
+
# import misc.utils as utils
|
17 |
+
# import utils as utils
|
18 |
+
from . import utils
|
19 |
+
|
20 |
+
from .CaptionModel import CaptionModel
|
21 |
+
|
22 |
+
|
23 |
+
class OldModel(CaptionModel):
|
24 |
+
def __init__(self, opt):
|
25 |
+
super(OldModel, self).__init__()
|
26 |
+
self.vocab_size = opt.vocab_size
|
27 |
+
self.input_encoding_size = opt.input_encoding_size
|
28 |
+
self.rnn_type = opt.rnn_type
|
29 |
+
self.rnn_size = opt.rnn_size
|
30 |
+
self.num_layers = opt.num_layers
|
31 |
+
self.drop_prob_lm = opt.drop_prob_lm
|
32 |
+
self.seq_length = opt.seq_length
|
33 |
+
self.fc_feat_size = opt.fc_feat_size
|
34 |
+
self.att_feat_size = opt.att_feat_size
|
35 |
+
|
36 |
+
self.ss_prob = 0.0 # Schedule sampling probability
|
37 |
+
|
38 |
+
self.linear = nn.Linear(self.fc_feat_size, self.num_layers * self.rnn_size) # feature to rnn_size
|
39 |
+
self.embed = nn.Embedding(self.vocab_size + 1, self.input_encoding_size)
|
40 |
+
self.logit = nn.Linear(self.rnn_size, self.vocab_size + 1)
|
41 |
+
self.dropout = nn.Dropout(self.drop_prob_lm)
|
42 |
+
|
43 |
+
self.init_weights()
|
44 |
+
|
45 |
+
def init_weights(self):
|
46 |
+
initrange = 0.1
|
47 |
+
self.embed.weight.data.uniform_(-initrange, initrange)
|
48 |
+
self.logit.bias.data.fill_(0)
|
49 |
+
self.logit.weight.data.uniform_(-initrange, initrange)
|
50 |
+
|
51 |
+
def init_hidden(self, fc_feats):
|
52 |
+
image_map = self.linear(fc_feats).view(-1, self.num_layers, self.rnn_size).transpose(0, 1)
|
53 |
+
if self.rnn_type == 'lstm':
|
54 |
+
return (image_map, image_map)
|
55 |
+
else:
|
56 |
+
return image_map
|
57 |
+
|
58 |
+
def forward(self, fc_feats, att_feats, seq):
|
59 |
+
batch_size = fc_feats.size(0)
|
60 |
+
state = self.init_hidden(fc_feats)
|
61 |
+
|
62 |
+
outputs = []
|
63 |
+
|
64 |
+
for i in range(seq.size(1) - 1):
|
65 |
+
if self.training and i >= 1 and self.ss_prob > 0.0: # otherwiste no need to sample
|
66 |
+
sample_prob = fc_feats.data.new(batch_size).uniform_(0, 1)
|
67 |
+
sample_mask = sample_prob < self.ss_prob
|
68 |
+
if sample_mask.sum() == 0:
|
69 |
+
it = seq[:, i].clone()
|
70 |
+
else:
|
71 |
+
sample_ind = sample_mask.nonzero().view(-1)
|
72 |
+
it = seq[:, i].data.clone()
|
73 |
+
# prob_prev = torch.exp(outputs[-1].data.index_select(0, sample_ind)) # fetch prev distribution: shape Nx(M+1)
|
74 |
+
# it.index_copy_(0, sample_ind, torch.multinomial(prob_prev, 1).view(-1))
|
75 |
+
prob_prev = torch.exp(outputs[-1].data) # fetch prev distribution: shape Nx(M+1)
|
76 |
+
it.index_copy_(0, sample_ind, torch.multinomial(prob_prev, 1).view(-1).index_select(0, sample_ind))
|
77 |
+
it = Variable(it, requires_grad=False)
|
78 |
+
else:
|
79 |
+
it = seq[:, i].clone()
|
80 |
+
# break if all the sequences end
|
81 |
+
if i >= 1 and seq[:, i].data.sum() == 0:
|
82 |
+
break
|
83 |
+
|
84 |
+
xt = self.embed(it)
|
85 |
+
|
86 |
+
output, state = self.core(xt, fc_feats, att_feats, state)
|
87 |
+
output = F.log_softmax(self.logit(self.dropout(output)))
|
88 |
+
outputs.append(output)
|
89 |
+
|
90 |
+
return torch.cat([_.unsqueeze(1) for _ in outputs], 1)
|
91 |
+
|
92 |
+
def get_logprobs_state(self, it, tmp_fc_feats, tmp_att_feats, state):
|
93 |
+
# 'it' is Variable contraining a word index
|
94 |
+
xt = self.embed(it)
|
95 |
+
|
96 |
+
output, state = self.core(xt, tmp_fc_feats, tmp_att_feats, state)
|
97 |
+
logprobs = F.log_softmax(self.logit(self.dropout(output)))
|
98 |
+
|
99 |
+
return logprobs, state
|
100 |
+
|
101 |
+
def sample_beam(self, fc_feats, att_feats, opt={}):
|
102 |
+
beam_size = opt.get('beam_size', 10)
|
103 |
+
batch_size = fc_feats.size(0)
|
104 |
+
|
105 |
+
assert beam_size <= self.vocab_size + 1, 'lets assume this for now, otherwise this corner case causes a few headaches down the road. can be dealt with in future if needed'
|
106 |
+
seq = torch.LongTensor(self.seq_length, batch_size).zero_()
|
107 |
+
seqLogprobs = torch.FloatTensor(self.seq_length, batch_size)
|
108 |
+
# lets process every image independently for now, for simplicity
|
109 |
+
|
110 |
+
self.done_beams = [[] for _ in range(batch_size)]
|
111 |
+
for k in range(batch_size):
|
112 |
+
tmp_fc_feats = fc_feats[k:k + 1].expand(beam_size, self.fc_feat_size)
|
113 |
+
tmp_att_feats = att_feats[k:k + 1].expand(*((beam_size,) + att_feats.size()[1:])).contiguous()
|
114 |
+
|
115 |
+
state = self.init_hidden(tmp_fc_feats)
|
116 |
+
|
117 |
+
beam_seq = torch.LongTensor(self.seq_length, beam_size).zero_()
|
118 |
+
beam_seq_logprobs = torch.FloatTensor(self.seq_length, beam_size).zero_()
|
119 |
+
beam_logprobs_sum = torch.zeros(beam_size) # running sum of logprobs for each beam
|
120 |
+
done_beams = []
|
121 |
+
for t in range(1):
|
122 |
+
if t == 0: # input <bos>
|
123 |
+
it = fc_feats.data.new(beam_size).long().zero_()
|
124 |
+
xt = self.embed(Variable(it, requires_grad=False))
|
125 |
+
|
126 |
+
output, state = self.core(xt, tmp_fc_feats, tmp_att_feats, state)
|
127 |
+
logprobs = F.log_softmax(self.logit(self.dropout(output)))
|
128 |
+
|
129 |
+
self.done_beams[k] = self.beam_search(state, logprobs, tmp_fc_feats, tmp_att_feats, opt=opt)
|
130 |
+
seq[:, k] = self.done_beams[k][0]['seq'] # the first beam has highest cumulative score
|
131 |
+
seqLogprobs[:, k] = self.done_beams[k][0]['logps']
|
132 |
+
# return the samples and their log likelihoods
|
133 |
+
return seq.transpose(0, 1), seqLogprobs.transpose(0, 1)
|
134 |
+
|
135 |
+
def sample(self, fc_feats, att_feats, opt={}):
|
136 |
+
sample_max = opt.get('sample_max', 1)
|
137 |
+
beam_size = opt.get('beam_size', 1)
|
138 |
+
temperature = opt.get('temperature', 1.0)
|
139 |
+
if beam_size > 1:
|
140 |
+
return self.sample_beam(fc_feats, att_feats, opt)
|
141 |
+
|
142 |
+
batch_size = fc_feats.size(0)
|
143 |
+
state = self.init_hidden(fc_feats)
|
144 |
+
|
145 |
+
seq = []
|
146 |
+
seqLogprobs = []
|
147 |
+
for t in range(self.seq_length + 1):
|
148 |
+
if t == 0: # input <bos>
|
149 |
+
it = fc_feats.data.new(batch_size).long().zero_()
|
150 |
+
elif sample_max:
|
151 |
+
sampleLogprobs, it = torch.max(logprobs.data, 1)
|
152 |
+
it = it.view(-1).long()
|
153 |
+
else:
|
154 |
+
if temperature == 1.0:
|
155 |
+
prob_prev = torch.exp(logprobs.data).cpu() # fetch prev distribution: shape Nx(M+1)
|
156 |
+
else:
|
157 |
+
# scale logprobs by temperature
|
158 |
+
prob_prev = torch.exp(torch.div(logprobs.data, temperature)).cpu()
|
159 |
+
it = torch.multinomial(prob_prev, 1).cuda()
|
160 |
+
sampleLogprobs = logprobs.gather(1, Variable(it,
|
161 |
+
requires_grad=False)) # gather the logprobs at sampled positions
|
162 |
+
it = it.view(-1).long() # and flatten indices for downstream processing
|
163 |
+
|
164 |
+
xt = self.embed(Variable(it, requires_grad=False))
|
165 |
+
|
166 |
+
if t >= 1:
|
167 |
+
# stop when all finished
|
168 |
+
if t == 1:
|
169 |
+
unfinished = it > 0
|
170 |
+
else:
|
171 |
+
unfinished = unfinished * (it > 0)
|
172 |
+
if unfinished.sum() == 0:
|
173 |
+
break
|
174 |
+
it = it * unfinished.type_as(it)
|
175 |
+
seq.append(it) # seq[t] the input of t+2 time step
|
176 |
+
seqLogprobs.append(sampleLogprobs.view(-1))
|
177 |
+
|
178 |
+
output, state = self.core(xt, fc_feats, att_feats, state)
|
179 |
+
logprobs = F.log_softmax(self.logit(self.dropout(output)), -1)
|
180 |
+
|
181 |
+
return torch.cat([_.unsqueeze(1) for _ in seq], 1), torch.cat([_.unsqueeze(1) for _ in seqLogprobs], 1)
|
182 |
+
|
183 |
+
|
184 |
+
class ShowAttendTellCore(nn.Module):
|
185 |
+
def __init__(self, opt):
|
186 |
+
super(ShowAttendTellCore, self).__init__()
|
187 |
+
self.input_encoding_size = opt.input_encoding_size
|
188 |
+
self.rnn_type = opt.rnn_type
|
189 |
+
self.rnn_size = opt.rnn_size
|
190 |
+
self.num_layers = opt.num_layers
|
191 |
+
self.drop_prob_lm = opt.drop_prob_lm
|
192 |
+
self.fc_feat_size = opt.fc_feat_size
|
193 |
+
self.att_feat_size = opt.att_feat_size
|
194 |
+
self.att_hid_size = opt.att_hid_size
|
195 |
+
|
196 |
+
self.rnn = getattr(nn, self.rnn_type.upper())(self.input_encoding_size + self.att_feat_size,
|
197 |
+
self.rnn_size, self.num_layers, bias=False,
|
198 |
+
dropout=self.drop_prob_lm)
|
199 |
+
|
200 |
+
if self.att_hid_size > 0:
|
201 |
+
self.ctx2att = nn.Linear(self.att_feat_size, self.att_hid_size)
|
202 |
+
self.h2att = nn.Linear(self.rnn_size, self.att_hid_size)
|
203 |
+
self.alpha_net = nn.Linear(self.att_hid_size, 1)
|
204 |
+
else:
|
205 |
+
self.ctx2att = nn.Linear(self.att_feat_size, 1)
|
206 |
+
self.h2att = nn.Linear(self.rnn_size, 1)
|
207 |
+
|
208 |
+
def forward(self, xt, fc_feats, att_feats, state):
|
209 |
+
att_size = att_feats.numel() // att_feats.size(0) // self.att_feat_size
|
210 |
+
att = att_feats.view(-1, self.att_feat_size)
|
211 |
+
if self.att_hid_size > 0:
|
212 |
+
att = self.ctx2att(att) # (batch * att_size) * att_hid_size
|
213 |
+
att = att.view(-1, att_size, self.att_hid_size) # batch * att_size * att_hid_size
|
214 |
+
att_h = self.h2att(state[0][-1]) # batch * att_hid_size
|
215 |
+
att_h = att_h.unsqueeze(1).expand_as(att) # batch * att_size * att_hid_size
|
216 |
+
dot = att + att_h # batch * att_size * att_hid_size
|
217 |
+
dot = torch.tanh(dot) # batch * att_size * att_hid_size
|
218 |
+
dot = dot.view(-1, self.att_hid_size) # (batch * att_size) * att_hid_size
|
219 |
+
dot = self.alpha_net(dot) # (batch * att_size) * 1
|
220 |
+
dot = dot.view(-1, att_size) # batch * att_size
|
221 |
+
else:
|
222 |
+
att = self.ctx2att(att)(att) # (batch * att_size) * 1
|
223 |
+
att = att.view(-1, att_size) # batch * att_size
|
224 |
+
att_h = self.h2att(state[0][-1]) # batch * 1
|
225 |
+
att_h = att_h.expand_as(att) # batch * att_size
|
226 |
+
dot = att_h + att # batch * att_size
|
227 |
+
|
228 |
+
weight = F.softmax(dot, -1)
|
229 |
+
att_feats_ = att_feats.view(-1, att_size, self.att_feat_size) # batch * att_size * att_feat_size
|
230 |
+
att_res = torch.bmm(weight.unsqueeze(1), att_feats_).squeeze(1) # batch * att_feat_size
|
231 |
+
|
232 |
+
output, state = self.rnn(torch.cat([xt, att_res], 1).unsqueeze(0), state)
|
233 |
+
return output.squeeze(0), state
|
234 |
+
|
235 |
+
|
236 |
+
class AllImgCore(nn.Module):
|
237 |
+
def __init__(self, opt):
|
238 |
+
super(AllImgCore, self).__init__()
|
239 |
+
self.input_encoding_size = opt.input_encoding_size
|
240 |
+
self.rnn_type = opt.rnn_type
|
241 |
+
self.rnn_size = opt.rnn_size
|
242 |
+
self.num_layers = opt.num_layers
|
243 |
+
self.drop_prob_lm = opt.drop_prob_lm
|
244 |
+
self.fc_feat_size = opt.fc_feat_size
|
245 |
+
|
246 |
+
self.rnn = getattr(nn, self.rnn_type.upper())(self.input_encoding_size + self.fc_feat_size,
|
247 |
+
self.rnn_size, self.num_layers, bias=False,
|
248 |
+
dropout=self.drop_prob_lm)
|
249 |
+
|
250 |
+
def forward(self, xt, fc_feats, att_feats, state):
|
251 |
+
output, state = self.rnn(torch.cat([xt, fc_feats], 1).unsqueeze(0), state)
|
252 |
+
return output.squeeze(0), state
|
253 |
+
|
254 |
+
|
255 |
+
class ShowAttendTellModel(OldModel):
|
256 |
+
def __init__(self, opt):
|
257 |
+
super(ShowAttendTellModel, self).__init__(opt)
|
258 |
+
self.core = ShowAttendTellCore(opt)
|
259 |
+
|
260 |
+
|
261 |
+
class AllImgModel(OldModel):
|
262 |
+
def __init__(self, opt):
|
263 |
+
super(AllImgModel, self).__init__(opt)
|
264 |
+
self.core = AllImgCore(opt)
|
265 |
+
|
captioning/models/ShowTellModel.py
ADDED
@@ -0,0 +1,368 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import absolute_import
|
2 |
+
from __future__ import division
|
3 |
+
from __future__ import print_function
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
import torch.nn.functional as F
|
9 |
+
from torch.autograd import *
|
10 |
+
from . import utils
|
11 |
+
|
12 |
+
from .CaptionModel import CaptionModel
|
13 |
+
|
14 |
+
bad_endings = ['a','an','the','in','for','at','of','with','before','after','on','upon','near','to','is','are','am']
|
15 |
+
bad_endings += ['UNK', 'has', 'and', 'more']
|
16 |
+
|
17 |
+
# torch.manual_seed(42)
|
18 |
+
# if torch.cuda.is_available():
|
19 |
+
# torch.cuda.manual_seed(42)
|
20 |
+
|
21 |
+
class ShowTellModel(CaptionModel):
|
22 |
+
def __init__(self, opt):
|
23 |
+
super(ShowTellModel, self).__init__()
|
24 |
+
self.vocab_size = opt.vocab_size
|
25 |
+
self.input_encoding_size = opt.input_encoding_size
|
26 |
+
self.rnn_type = opt.rnn_type
|
27 |
+
self.rnn_size = opt.rnn_size
|
28 |
+
self.num_layers = opt.num_layers
|
29 |
+
self.drop_prob_lm = opt.drop_prob_lm
|
30 |
+
self.seq_length = opt.seq_length
|
31 |
+
self.fc_feat_size = opt.fc_feat_size
|
32 |
+
|
33 |
+
self.eos_idx = getattr(opt, 'eos_idx', 0)
|
34 |
+
self.pad_idx = getattr(opt, 'pad_idx', 0)
|
35 |
+
|
36 |
+
self.ss_prob = 0.0 # Schedule sampling probability
|
37 |
+
|
38 |
+
self.img_embed = nn.Linear(self.fc_feat_size, self.input_encoding_size)
|
39 |
+
self.core = getattr(nn, self.rnn_type.upper())(self.input_encoding_size, self.rnn_size, self.num_layers, bias=False, dropout=self.drop_prob_lm)
|
40 |
+
self.embed = nn.Embedding(self.vocab_size + 1, self.input_encoding_size)
|
41 |
+
self.logit = nn.Linear(self.rnn_size, self.vocab_size + 1)
|
42 |
+
self.dropout = nn.Dropout(self.drop_prob_lm)
|
43 |
+
|
44 |
+
# For remove bad endding
|
45 |
+
self.vocab = opt.vocab
|
46 |
+
self.bad_endings_ix = [int(k) for k,v in self.vocab.items() if v in bad_endings]
|
47 |
+
|
48 |
+
self.init_weights()
|
49 |
+
|
50 |
+
def init_weights(self):
|
51 |
+
initrange = 0.1
|
52 |
+
self.embed.weight.data.uniform_(-initrange, initrange)
|
53 |
+
self.logit.bias.data.fill_(0)
|
54 |
+
self.logit.weight.data.uniform_(-initrange, initrange)
|
55 |
+
|
56 |
+
def init_hidden(self, bsz):
|
57 |
+
weight = self.logit.weight
|
58 |
+
if self.rnn_type == 'lstm':
|
59 |
+
return (weight.new_zeros(self.num_layers, bsz, self.rnn_size),
|
60 |
+
weight.new_zeros(self.num_layers, bsz, self.rnn_size))
|
61 |
+
else:
|
62 |
+
return weight.new_zeros(self.num_layers, bsz, self.rnn_size)
|
63 |
+
|
64 |
+
def _forward(self, fc_feats, att_feats, seq, att_masks=None):
|
65 |
+
|
66 |
+
batch_size = fc_feats.size(0)
|
67 |
+
if seq.ndim == 3: # B * seq_per_img * seq_len
|
68 |
+
seq = seq.reshape(-1, seq.shape[2])
|
69 |
+
seq_per_img = seq.shape[0] // batch_size
|
70 |
+
state = self.init_hidden(batch_size*seq_per_img)
|
71 |
+
outputs = []
|
72 |
+
|
73 |
+
if seq_per_img > 1:
|
74 |
+
fc_feats = utils.repeat_tensors(seq_per_img, fc_feats)
|
75 |
+
|
76 |
+
for i in range(seq.size(1)+1):
|
77 |
+
if i == 0:
|
78 |
+
xt = self.img_embed(fc_feats)
|
79 |
+
else:
|
80 |
+
if self.training and i >= 2 and self.ss_prob > 0.0: # otherwiste no need to sample
|
81 |
+
sample_prob = fc_feats.data.new(batch_size*seq_per_img).uniform_(0, 1)
|
82 |
+
sample_mask = sample_prob < self.ss_prob
|
83 |
+
if sample_mask.sum() == 0:
|
84 |
+
it = seq[:, i-1].clone()
|
85 |
+
else:
|
86 |
+
sample_ind = sample_mask.nonzero().view(-1)
|
87 |
+
it = seq[:, i-1].data.clone()
|
88 |
+
#prob_prev = torch.exp(outputs[-1].data.index_select(0, sample_ind)) # fetch prev distribution: shape Nx(M+1)
|
89 |
+
#it.index_copy_(0, sample_ind, torch.multinomial(prob_prev, 1).view(-1))
|
90 |
+
prob_prev = torch.exp(outputs[-1].data) # fetch prev distribution: shape Nx(M+1)
|
91 |
+
it.index_copy_(0, sample_ind, torch.multinomial(prob_prev, 1).view(-1).index_select(0, sample_ind))
|
92 |
+
else:
|
93 |
+
it = seq[:, i-1].clone()
|
94 |
+
# break if all the sequences end
|
95 |
+
if i >= 2 and seq[:, i-1].data.sum() == 0:
|
96 |
+
break
|
97 |
+
xt = self.embed(it)
|
98 |
+
|
99 |
+
output, state = self.core(xt.unsqueeze(0), state)
|
100 |
+
|
101 |
+
output = F.log_softmax(self.logit(self.dropout(output.squeeze(0))), dim=1)
|
102 |
+
outputs.append(output)
|
103 |
+
|
104 |
+
return torch.cat([_.unsqueeze(1) for _ in outputs[1:]], 1).contiguous()
|
105 |
+
|
106 |
+
def get_logprobs_state(self, it, state):
|
107 |
+
# 'it' contains a word index
|
108 |
+
xt = self.embed(it)
|
109 |
+
|
110 |
+
output, state = self.core(xt.unsqueeze(0), state)
|
111 |
+
logprobs = F.log_softmax(self.logit(self.dropout(output.squeeze(0))), dim=1)
|
112 |
+
|
113 |
+
return logprobs, state
|
114 |
+
|
115 |
+
def _sample_beam(self, fc_feats, att_feats, att_masks=None, opt={}):
|
116 |
+
# beam_size = opt.get('beam_size', 10)
|
117 |
+
# batch_size = fc_feats.size(0)
|
118 |
+
|
119 |
+
# assert beam_size <= self.vocab_size + 1, 'lets assume this for now, otherwise this corner case causes a few headaches down the road. can be dealt with in future if needed'
|
120 |
+
# seq = torch.LongTensor(self.seq_length, batch_size).zero_()
|
121 |
+
# seqLogprobs = torch.FloatTensor(self.seq_length, batch_size)
|
122 |
+
# # lets process every image independently for now, for simplicity
|
123 |
+
|
124 |
+
|
125 |
+
beam_size = opt.get('beam_size', 10)
|
126 |
+
group_size = opt.get('group_size', 1)
|
127 |
+
sample_n = opt.get('sample_n', 10)
|
128 |
+
# when sample_n == beam_size then each beam is a sample.
|
129 |
+
assert sample_n == 1 or sample_n == beam_size // group_size, 'when beam search, sample_n == 1 or beam search'
|
130 |
+
batch_size = fc_feats.size(0)
|
131 |
+
|
132 |
+
assert beam_size <= self.vocab_size + 1, 'lets assume this for now, otherwise this corner case causes a few headaches down the road. can be dealt with in future if needed'
|
133 |
+
seq = fc_feats.new_full((batch_size*sample_n, self.seq_length), self.pad_idx, dtype=torch.long)
|
134 |
+
seqLogprobs = fc_feats.new_zeros(batch_size*sample_n, self.seq_length, self.vocab_size + 1)
|
135 |
+
|
136 |
+
self.done_beams = [[] for _ in range(batch_size)]
|
137 |
+
for k in range(batch_size):
|
138 |
+
state = self.init_hidden(beam_size)
|
139 |
+
for t in range(2):
|
140 |
+
if t == 0:
|
141 |
+
xt = self.img_embed(fc_feats[k:k+1]).expand(beam_size, self.input_encoding_size)
|
142 |
+
elif t == 1: # input <bos>
|
143 |
+
it = fc_feats.data.new(beam_size).long().zero_()
|
144 |
+
xt = self.embed(it)
|
145 |
+
|
146 |
+
output, state = self.core(xt.unsqueeze(0), state)
|
147 |
+
logprobs = F.log_softmax(self.logit(self.dropout(output.squeeze(0))), dim=1)
|
148 |
+
|
149 |
+
self.done_beams[k] = self.old_beam_search(state, logprobs, opt=opt)
|
150 |
+
if sample_n == beam_size:
|
151 |
+
for _n in range(sample_n):
|
152 |
+
seq[k*sample_n+_n, :] = self.done_beams[k][_n]['seq']
|
153 |
+
seqLogprobs[k*sample_n+_n, :] = self.done_beams[k][_n]['logps']
|
154 |
+
else:
|
155 |
+
seq[k, :] = self.done_beams[k][0]['seq'] # the first beam has highest cumulative score
|
156 |
+
seqLogprobs[k, :] = self.done_beams[k][0]['logps']
|
157 |
+
# return the samples and their log likelihoods
|
158 |
+
return seq, seqLogprobs
|
159 |
+
|
160 |
+
# seq[:, k] = self.done_beams[k][0]['seq'] # the first beam has highest cumulative score
|
161 |
+
# seqLogprobs[:, k] = self.done_beams[k][0]['logps']
|
162 |
+
# # return the samples and their log likelihoods
|
163 |
+
# return seq.transpose(0, 1), seqLogprobs.transpose(0, 1)
|
164 |
+
|
165 |
+
def _new_sample_beam(self, fc_feats, att_feats, att_masks=None, opt={}):
|
166 |
+
|
167 |
+
beam_size = opt.get('beam_size', 10)
|
168 |
+
group_size = opt.get('group_size', 1)
|
169 |
+
sample_n = opt.get('sample_n', 10)
|
170 |
+
# when sample_n == beam_size then each beam is a sample.
|
171 |
+
assert sample_n == 1 or sample_n == beam_size // group_size, 'when beam search, sample_n == 1 or beam search'
|
172 |
+
batch_size = fc_feats.size(0)
|
173 |
+
|
174 |
+
assert beam_size <= self.vocab_size + 1, 'lets assume this for now, otherwise this corner case causes a few headaches down the road. can be dealt with in future if needed'
|
175 |
+
seq = fc_feats.new_full((batch_size*sample_n, self.seq_length), self.pad_idx, dtype=torch.long)
|
176 |
+
seqLogprobs = fc_feats.new_zeros(batch_size*sample_n, self.seq_length, self.vocab_size + 1)
|
177 |
+
|
178 |
+
self.done_beams = [[] for _ in range(batch_size)]
|
179 |
+
|
180 |
+
state = self.init_hidden(batch_size)
|
181 |
+
|
182 |
+
it = fc_feats.data.new(batch_size).long().zero_()
|
183 |
+
xt = self.embed(it)
|
184 |
+
|
185 |
+
output, state = self.core(xt.unsqueeze(0), state)
|
186 |
+
logprobs = F.log_softmax(self.logit(self.dropout(output.squeeze(0))), dim=1)
|
187 |
+
|
188 |
+
self.done_beams = self.beam_search(state, logprobs, opt=opt)
|
189 |
+
|
190 |
+
for k in range(batch_size):
|
191 |
+
if sample_n == beam_size:
|
192 |
+
for _n in range(sample_n):
|
193 |
+
seq_len = self.done_beams[k][_n]['seq'].shape[0]
|
194 |
+
seq[k*sample_n+_n, :seq_len] = self.done_beams[k][_n]['seq']
|
195 |
+
seqLogprobs[k*sample_n+_n, :seq_len] = self.done_beams[k][_n]['logps']
|
196 |
+
else:
|
197 |
+
seq_len = self.done_beams[k][0]['seq'].shape[0]
|
198 |
+
seq[k, :seq_len] = self.done_beams[k][0]['seq'] # the first beam has highest cumulative score
|
199 |
+
seqLogprobs[k, :seq_len] = self.done_beams[k][0]['logps']
|
200 |
+
# return the samples and their log likelihoods
|
201 |
+
return seq, seqLogprobs
|
202 |
+
|
203 |
+
def _old_sample(self, fc_feats, att_feats, att_masks=None, opt={}):
|
204 |
+
sample_method = opt.get('sample_method', 'greedy')
|
205 |
+
beam_size = opt.get('beam_size', 1)
|
206 |
+
temperature = opt.get('temperature', 1.0)
|
207 |
+
if beam_size > 1 and sample_method in ['greedy', 'beam_search']:
|
208 |
+
return self._sample_beam(fc_feats, att_feats, opt)
|
209 |
+
|
210 |
+
batch_size = fc_feats.size(0)
|
211 |
+
state = self.init_hidden(batch_size)
|
212 |
+
seq = fc_feats.new_zeros(batch_size, self.seq_length, dtype=torch.long)
|
213 |
+
seqLogprobs = fc_feats.new_zeros(batch_size, self.seq_length)
|
214 |
+
for t in range(self.seq_length + 2):
|
215 |
+
if t == 0:
|
216 |
+
xt = self.img_embed(fc_feats)
|
217 |
+
else:
|
218 |
+
if t == 1: # input <bos>
|
219 |
+
it = fc_feats.data.new(batch_size).long().zero_()
|
220 |
+
xt = self.embed(it)
|
221 |
+
|
222 |
+
output, state = self.core(xt.unsqueeze(0), state)
|
223 |
+
logprobs = F.log_softmax(self.logit(self.dropout(output.squeeze(0))), dim=1)
|
224 |
+
|
225 |
+
# sample the next word
|
226 |
+
if t == self.seq_length + 1: # skip if we achieve maximum length
|
227 |
+
break
|
228 |
+
if sample_method == 'greedy':
|
229 |
+
sampleLogprobs, it = torch.max(logprobs.data, 1)
|
230 |
+
it = it.view(-1).long()
|
231 |
+
else:
|
232 |
+
if temperature == 1.0:
|
233 |
+
prob_prev = torch.exp(logprobs.data).cpu() # fetch prev distribution: shape Nx(M+1)
|
234 |
+
else:
|
235 |
+
# scale logprobs by temperature
|
236 |
+
prob_prev = torch.exp(torch.div(logprobs.data, temperature)).cpu()
|
237 |
+
it = torch.multinomial(prob_prev, 1).to(logprobs.device)
|
238 |
+
sampleLogprobs = logprobs.gather(1, it) # gather the logprobs at sampled positions
|
239 |
+
it = it.view(-1).long() # and flatten indices for downstream processing
|
240 |
+
|
241 |
+
if t >= 1:
|
242 |
+
# stop when all finished
|
243 |
+
if t == 1:
|
244 |
+
unfinished = it > 0
|
245 |
+
else:
|
246 |
+
unfinished = unfinished & (it > 0)
|
247 |
+
it = it * unfinished.type_as(it)
|
248 |
+
seq[:,t-1] = it #seq[t] the input of t+2 time step
|
249 |
+
seqLogprobs[:,t-1] = sampleLogprobs.view(-1)
|
250 |
+
if unfinished.sum() == 0:
|
251 |
+
break
|
252 |
+
return seq, seqLogprobs
|
253 |
+
|
254 |
+
|
255 |
+
# remove bad endings and UNK
|
256 |
+
def _sample(self, fc_feats, att_feats, att_masks=None, opt={}):
|
257 |
+
sample_method = opt.get('sample_method', 'greedy')
|
258 |
+
beam_size = opt.get('beam_size', 1)
|
259 |
+
temperature = opt.get('temperature', 1.0)
|
260 |
+
|
261 |
+
sample_n = int(opt.get('sample_n', 1))
|
262 |
+
sample_n = 1
|
263 |
+
group_size = opt.get('group_size', 1)
|
264 |
+
output_logsoftmax = opt.get('output_logsoftmax', 1)
|
265 |
+
decoding_constraint = opt.get('decoding_constraint', 0)
|
266 |
+
block_trigrams = opt.get('block_trigrams', 0)
|
267 |
+
remove_bad_endings = opt.get('remove_bad_endings', 1)
|
268 |
+
suppress_UNK = opt.get('suppress_UNK', 1)
|
269 |
+
|
270 |
+
if beam_size > 1 and sample_method in ['greedy', 'beam_search']:
|
271 |
+
return self._sample_beam(fc_feats, att_feats, opt=opt)
|
272 |
+
|
273 |
+
batch_size = fc_feats.size(0)
|
274 |
+
state = self.init_hidden(batch_size)
|
275 |
+
|
276 |
+
trigrams = [] # will be a list of batch_size dictionaries
|
277 |
+
|
278 |
+
# seq = fc_feats.new_zeros(batch_size, self.seq_length, dtype=torch.long)
|
279 |
+
# seqLogprobs = fc_feats.new_zeros(batch_size, self.seq_length)
|
280 |
+
|
281 |
+
seq = fc_feats.new_full((batch_size*sample_n, self.seq_length), self.pad_idx, dtype=torch.long)
|
282 |
+
seqLogprobs = fc_feats.new_zeros(batch_size*sample_n, self.seq_length, self.vocab_size + 1)
|
283 |
+
for t in range(self.seq_length + 1):
|
284 |
+
if t == 0:
|
285 |
+
xt = self.img_embed(fc_feats)
|
286 |
+
else:
|
287 |
+
if t == 1: # input <bos>
|
288 |
+
it = fc_feats.data.new(batch_size).long().zero_()
|
289 |
+
xt = self.embed(it)
|
290 |
+
|
291 |
+
output, state = self.core(xt.unsqueeze(0), state)
|
292 |
+
logprobs = F.log_softmax(self.logit(self.dropout(output.squeeze(0))), dim=1)
|
293 |
+
|
294 |
+
if decoding_constraint and t > 0:
|
295 |
+
tmp = logprobs.new_zeros(logprobs.size())
|
296 |
+
tmp.scatter_(1, seq[:,t-1].data.unsqueeze(1), float('-inf'))
|
297 |
+
logprobs = logprobs + tmp
|
298 |
+
|
299 |
+
# print('seq', seq)
|
300 |
+
# print('self.seq_length',self.seq_length)
|
301 |
+
# print('seq shape', seq.shape)
|
302 |
+
if remove_bad_endings and t > 0:
|
303 |
+
logprobs[torch.from_numpy(np.isin(seq[:,t-1].data.cpu().numpy(), self.bad_endings_ix)), 0] = float('-inf')
|
304 |
+
|
305 |
+
# suppress UNK tokens in the decoding
|
306 |
+
if suppress_UNK and hasattr(self, 'vocab') and self.vocab[str(logprobs.size(1)-1)] == 'UNK':
|
307 |
+
logprobs[:,logprobs.size(1)-1] = logprobs[:, logprobs.size(1)-1] - 1000
|
308 |
+
|
309 |
+
# if remove_bad_endings and t > 0:
|
310 |
+
# tmp = logprobs.new_zeros(logprobs.size())
|
311 |
+
# prev_bad = np.isin(seq[:,t-1].data.cpu().numpy(), self.bad_endings_ix)
|
312 |
+
# # Make it impossible to generate bad_endings
|
313 |
+
# tmp[torch.from_numpy(prev_bad.astype('uint8')), 0] = float('-inf')
|
314 |
+
# # tmp[torch.from_numpy(prev_bad.bool()), 0] = float('-inf')
|
315 |
+
# logprobs = logprobs + tmp
|
316 |
+
|
317 |
+
# Mess with trigrams
|
318 |
+
# Copy from https://github.com/lukemelas/image-paragraph-captioning
|
319 |
+
if block_trigrams and t >= 3:
|
320 |
+
# Store trigram generated at last step
|
321 |
+
prev_two_batch = seq[:,t-3:t-1]
|
322 |
+
for i in range(batch_size): # = seq.size(0)
|
323 |
+
prev_two = (prev_two_batch[i][0].item(), prev_two_batch[i][1].item())
|
324 |
+
current = seq[i][t-1]
|
325 |
+
if t == 3: # initialize
|
326 |
+
trigrams.append({prev_two: [current]}) # {LongTensor: list containing 1 int}
|
327 |
+
elif t > 3:
|
328 |
+
if prev_two in trigrams[i]: # add to list
|
329 |
+
trigrams[i][prev_two].append(current)
|
330 |
+
else: # create list
|
331 |
+
trigrams[i][prev_two] = [current]
|
332 |
+
# Block used trigrams at next step
|
333 |
+
prev_two_batch = seq[:,t-2:t]
|
334 |
+
mask = torch.zeros(logprobs.size(), requires_grad=False).to(logprobs.device) # batch_size x vocab_size
|
335 |
+
for i in range(batch_size):
|
336 |
+
prev_two = (prev_two_batch[i][0].item(), prev_two_batch[i][1].item())
|
337 |
+
if prev_two in trigrams[i]:
|
338 |
+
for j in trigrams[i][prev_two]:
|
339 |
+
mask[i,j] += 1
|
340 |
+
# Apply mask to log probs
|
341 |
+
#logprobs = logprobs - (mask * 1e9)
|
342 |
+
alpha = 2.0 # = 4
|
343 |
+
logprobs = logprobs + (mask * -0.693 * alpha) # ln(1/2) * alpha (alpha -> infty works best)
|
344 |
+
|
345 |
+
# sample the next word
|
346 |
+
if t == self.seq_length+1: # skip if we achieve maximum length
|
347 |
+
break
|
348 |
+
it, sampleLogprobs = self.sample_next_word(logprobs, sample_method, temperature)
|
349 |
+
|
350 |
+
# stop when all finished
|
351 |
+
if t == 0:
|
352 |
+
unfinished = it != self.eos_idx
|
353 |
+
else:
|
354 |
+
it[~unfinished] = self.pad_idx # This allows eos_idx not being overwritten to 0
|
355 |
+
logprobs = logprobs * unfinished.unsqueeze(1).to(logprobs)
|
356 |
+
unfinished = unfinished & (it != self.eos_idx)
|
357 |
+
|
358 |
+
# print('-------logprobs shape:',logprobs.shape)
|
359 |
+
# print('-------it shape:',it.shape)
|
360 |
+
|
361 |
+
seq[:,t-1] = it
|
362 |
+
seqLogprobs[:,t-1] = logprobs
|
363 |
+
# quit loop if all sequences have finished
|
364 |
+
if unfinished.sum() == 0:
|
365 |
+
break
|
366 |
+
# print('-------seqLogprobs shape:',seqLogprobs.shape)
|
367 |
+
# print('-------seq shape:',seq.shape)
|
368 |
+
return seq, seqLogprobs
|
captioning/models/TransformerModel.py
ADDED
@@ -0,0 +1,367 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# This file contains Transformer network
|
2 |
+
# Most of the code is copied from http://nlp.seas.harvard.edu/2018/04/03/attention.html
|
3 |
+
|
4 |
+
# The cfg name correspondance:
|
5 |
+
# N=num_layers
|
6 |
+
# d_model=input_encoding_size
|
7 |
+
# d_ff=rnn_size
|
8 |
+
# h is always 8
|
9 |
+
|
10 |
+
from __future__ import absolute_import
|
11 |
+
from __future__ import division
|
12 |
+
from __future__ import print_function
|
13 |
+
|
14 |
+
import torch
|
15 |
+
import torch.nn as nn
|
16 |
+
import torch.nn.functional as F
|
17 |
+
from . import utils
|
18 |
+
|
19 |
+
import copy
|
20 |
+
import math
|
21 |
+
import numpy as np
|
22 |
+
|
23 |
+
from .CaptionModel import CaptionModel
|
24 |
+
from .AttModel import sort_pack_padded_sequence, pad_unsort_packed_sequence, pack_wrapper, AttModel
|
25 |
+
|
26 |
+
# torch.manual_seed(42)
|
27 |
+
# if torch.cuda.is_available():
|
28 |
+
# torch.cuda.manual_seed(42)
|
29 |
+
|
30 |
+
class EncoderDecoder(nn.Module):
|
31 |
+
"""
|
32 |
+
A standard Encoder-Decoder architecture. Base for this and many
|
33 |
+
other models.
|
34 |
+
"""
|
35 |
+
def __init__(self, encoder, decoder, src_embed, tgt_embed, generator):
|
36 |
+
super(EncoderDecoder, self).__init__()
|
37 |
+
self.encoder = encoder
|
38 |
+
self.decoder = decoder
|
39 |
+
self.src_embed = src_embed
|
40 |
+
self.tgt_embed = tgt_embed
|
41 |
+
self.generator = generator
|
42 |
+
|
43 |
+
def forward(self, src, tgt, src_mask, tgt_mask):
|
44 |
+
"Take in and process masked src and target sequences."
|
45 |
+
return self.decode(self.encode(src, src_mask), src_mask,
|
46 |
+
tgt, tgt_mask)
|
47 |
+
|
48 |
+
def encode(self, src, src_mask):
|
49 |
+
return self.encoder(self.src_embed(src), src_mask)
|
50 |
+
|
51 |
+
def decode(self, memory, src_mask, tgt, tgt_mask):
|
52 |
+
return self.decoder(self.tgt_embed(tgt), memory, src_mask, tgt_mask)
|
53 |
+
|
54 |
+
class Generator(nn.Module):
|
55 |
+
"Define standard linear + softmax generation step."
|
56 |
+
def __init__(self, d_model, vocab):
|
57 |
+
super(Generator, self).__init__()
|
58 |
+
self.proj = nn.Linear(d_model, vocab)
|
59 |
+
|
60 |
+
def forward(self, x):
|
61 |
+
return F.log_softmax(self.proj(x), dim=-1)
|
62 |
+
|
63 |
+
def clones(module, N):
|
64 |
+
"Produce N identical layers."
|
65 |
+
return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])
|
66 |
+
|
67 |
+
class Encoder(nn.Module):
|
68 |
+
"Core encoder is a stack of N layers"
|
69 |
+
def __init__(self, layer, N):
|
70 |
+
super(Encoder, self).__init__()
|
71 |
+
self.layers = clones(layer, N)
|
72 |
+
self.norm = LayerNorm(layer.size)
|
73 |
+
|
74 |
+
def forward(self, x, mask):
|
75 |
+
"Pass the input (and mask) through each layer in turn."
|
76 |
+
for layer in self.layers:
|
77 |
+
x = layer(x, mask)
|
78 |
+
return self.norm(x)
|
79 |
+
|
80 |
+
class LayerNorm(nn.Module):
|
81 |
+
"Construct a layernorm module (See citation for details)."
|
82 |
+
def __init__(self, features, eps=1e-6):
|
83 |
+
super(LayerNorm, self).__init__()
|
84 |
+
self.a_2 = nn.Parameter(torch.ones(features))
|
85 |
+
self.b_2 = nn.Parameter(torch.zeros(features))
|
86 |
+
self.eps = eps
|
87 |
+
|
88 |
+
def forward(self, x):
|
89 |
+
mean = x.mean(-1, keepdim=True)
|
90 |
+
std = x.std(-1, keepdim=True)
|
91 |
+
return self.a_2 * (x - mean) / (std + self.eps) + self.b_2
|
92 |
+
|
93 |
+
class SublayerConnection(nn.Module):
|
94 |
+
"""
|
95 |
+
A residual connection followed by a layer norm.
|
96 |
+
Note for code simplicity the norm is first as opposed to last.
|
97 |
+
"""
|
98 |
+
def __init__(self, size, dropout):
|
99 |
+
super(SublayerConnection, self).__init__()
|
100 |
+
self.norm = LayerNorm(size)
|
101 |
+
self.dropout = nn.Dropout(dropout)
|
102 |
+
|
103 |
+
def forward(self, x, sublayer):
|
104 |
+
"Apply residual connection to any sublayer with the same size."
|
105 |
+
return x + self.dropout(sublayer(self.norm(x)))
|
106 |
+
|
107 |
+
class EncoderLayer(nn.Module):
|
108 |
+
"Encoder is made up of self-attn and feed forward (defined below)"
|
109 |
+
def __init__(self, size, self_attn, feed_forward, dropout):
|
110 |
+
super(EncoderLayer, self).__init__()
|
111 |
+
self.self_attn = self_attn
|
112 |
+
self.feed_forward = feed_forward
|
113 |
+
self.sublayer = clones(SublayerConnection(size, dropout), 2)
|
114 |
+
self.size = size
|
115 |
+
|
116 |
+
def forward(self, x, mask):
|
117 |
+
"Follow Figure 1 (left) for connections."
|
118 |
+
x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, mask))
|
119 |
+
return self.sublayer[1](x, self.feed_forward)
|
120 |
+
|
121 |
+
class Decoder(nn.Module):
|
122 |
+
"Generic N layer decoder with masking."
|
123 |
+
def __init__(self, layer, N):
|
124 |
+
super(Decoder, self).__init__()
|
125 |
+
self.layers = clones(layer, N)
|
126 |
+
self.norm = LayerNorm(layer.size)
|
127 |
+
|
128 |
+
def forward(self, x, memory, src_mask, tgt_mask):
|
129 |
+
for layer in self.layers:
|
130 |
+
x = layer(x, memory, src_mask, tgt_mask)
|
131 |
+
return self.norm(x)
|
132 |
+
|
133 |
+
class DecoderLayer(nn.Module):
|
134 |
+
"Decoder is made of self-attn, src-attn, and feed forward (defined below)"
|
135 |
+
def __init__(self, size, self_attn, src_attn, feed_forward, dropout):
|
136 |
+
super(DecoderLayer, self).__init__()
|
137 |
+
self.size = size
|
138 |
+
self.self_attn = self_attn
|
139 |
+
self.src_attn = src_attn
|
140 |
+
self.feed_forward = feed_forward
|
141 |
+
self.sublayer = clones(SublayerConnection(size, dropout), 3)
|
142 |
+
|
143 |
+
def forward(self, x, memory, src_mask, tgt_mask):
|
144 |
+
"Follow Figure 1 (right) for connections."
|
145 |
+
m = memory
|
146 |
+
x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, tgt_mask))
|
147 |
+
x = self.sublayer[1](x, lambda x: self.src_attn(x, m, m, src_mask))
|
148 |
+
return self.sublayer[2](x, self.feed_forward)
|
149 |
+
|
150 |
+
def subsequent_mask(size):
|
151 |
+
"Mask out subsequent positions."
|
152 |
+
attn_shape = (1, size, size)
|
153 |
+
subsequent_mask = np.triu(np.ones(attn_shape), k=1).astype('uint8')
|
154 |
+
return torch.from_numpy(subsequent_mask) == 0
|
155 |
+
|
156 |
+
def attention(query, key, value, mask=None, dropout=None):
|
157 |
+
"Compute 'Scaled Dot Product Attention'"
|
158 |
+
d_k = query.size(-1)
|
159 |
+
scores = torch.matmul(query, key.transpose(-2, -1)) \
|
160 |
+
/ math.sqrt(d_k)
|
161 |
+
if mask is not None:
|
162 |
+
scores = scores.masked_fill(mask == 0, float('-inf'))
|
163 |
+
p_attn = F.softmax(scores, dim = -1)
|
164 |
+
if dropout is not None:
|
165 |
+
p_attn = dropout(p_attn)
|
166 |
+
return torch.matmul(p_attn, value), p_attn
|
167 |
+
|
168 |
+
class MultiHeadedAttention(nn.Module):
|
169 |
+
def __init__(self, h, d_model, dropout=0.1):
|
170 |
+
"Take in model size and number of heads."
|
171 |
+
super(MultiHeadedAttention, self).__init__()
|
172 |
+
assert d_model % h == 0
|
173 |
+
# We assume d_v always equals d_k
|
174 |
+
self.d_k = d_model // h
|
175 |
+
self.h = h
|
176 |
+
self.linears = clones(nn.Linear(d_model, d_model), 4)
|
177 |
+
self.attn = None
|
178 |
+
self.dropout = nn.Dropout(p=dropout)
|
179 |
+
|
180 |
+
def forward(self, query, key, value, mask=None):
|
181 |
+
"Implements Figure 2"
|
182 |
+
if mask is not None:
|
183 |
+
# Same mask applied to all h heads.
|
184 |
+
mask = mask.unsqueeze(1)
|
185 |
+
nbatches = query.size(0)
|
186 |
+
|
187 |
+
# 1) Do all the linear projections in batch from d_model => h x d_k
|
188 |
+
query, key, value = \
|
189 |
+
[l(x).view(nbatches, -1, self.h, self.d_k).transpose(1, 2)
|
190 |
+
for l, x in zip(self.linears, (query, key, value))]
|
191 |
+
|
192 |
+
# 2) Apply attention on all the projected vectors in batch.
|
193 |
+
x, self.attn = attention(query, key, value, mask=mask,
|
194 |
+
dropout=self.dropout)
|
195 |
+
|
196 |
+
# 3) "Concat" using a view and apply a final linear.
|
197 |
+
x = x.transpose(1, 2).contiguous() \
|
198 |
+
.view(nbatches, -1, self.h * self.d_k)
|
199 |
+
return self.linears[-1](x)
|
200 |
+
|
201 |
+
class PositionwiseFeedForward(nn.Module):
|
202 |
+
"Implements FFN equation."
|
203 |
+
def __init__(self, d_model, d_ff, dropout=0.1):
|
204 |
+
super(PositionwiseFeedForward, self).__init__()
|
205 |
+
self.w_1 = nn.Linear(d_model, d_ff)
|
206 |
+
self.w_2 = nn.Linear(d_ff, d_model)
|
207 |
+
self.dropout = nn.Dropout(dropout)
|
208 |
+
|
209 |
+
def forward(self, x):
|
210 |
+
return self.w_2(self.dropout(F.relu(self.w_1(x))))
|
211 |
+
|
212 |
+
class Embeddings(nn.Module):
|
213 |
+
def __init__(self, d_model, vocab):
|
214 |
+
super(Embeddings, self).__init__()
|
215 |
+
self.lut = nn.Embedding(vocab, d_model)
|
216 |
+
self.d_model = d_model
|
217 |
+
|
218 |
+
def forward(self, x):
|
219 |
+
return self.lut(x) * math.sqrt(self.d_model)
|
220 |
+
|
221 |
+
class PositionalEncoding(nn.Module):
|
222 |
+
"Implement the PE function."
|
223 |
+
def __init__(self, d_model, dropout, max_len=5000):
|
224 |
+
super(PositionalEncoding, self).__init__()
|
225 |
+
self.dropout = nn.Dropout(p=dropout)
|
226 |
+
|
227 |
+
# Compute the positional encodings once in log space.
|
228 |
+
pe = torch.zeros(max_len, d_model)
|
229 |
+
position = torch.arange(0, max_len).unsqueeze(1).float()
|
230 |
+
div_term = torch.exp(torch.arange(0, d_model, 2).float() *
|
231 |
+
-(math.log(10000.0) / d_model))
|
232 |
+
pe[:, 0::2] = torch.sin(position * div_term)
|
233 |
+
pe[:, 1::2] = torch.cos(position * div_term)
|
234 |
+
pe = pe.unsqueeze(0)
|
235 |
+
self.register_buffer('pe', pe)
|
236 |
+
|
237 |
+
def forward(self, x):
|
238 |
+
x = x + self.pe[:, :x.size(1)]
|
239 |
+
return self.dropout(x)
|
240 |
+
|
241 |
+
class TransformerModel(AttModel):
|
242 |
+
|
243 |
+
def make_model(self, src_vocab, tgt_vocab, N_enc=6, N_dec=6,
|
244 |
+
d_model=512, d_ff=2048, h=8, dropout=0.1):
|
245 |
+
"Helper: Construct a model from hyperparameters."
|
246 |
+
c = copy.deepcopy
|
247 |
+
attn = MultiHeadedAttention(h, d_model, dropout)
|
248 |
+
ff = PositionwiseFeedForward(d_model, d_ff, dropout)
|
249 |
+
position = PositionalEncoding(d_model, dropout)
|
250 |
+
model = EncoderDecoder(
|
251 |
+
Encoder(EncoderLayer(d_model, c(attn), c(ff), dropout), N_enc),
|
252 |
+
Decoder(DecoderLayer(d_model, c(attn), c(attn),
|
253 |
+
c(ff), dropout), N_dec),
|
254 |
+
lambda x:x, # nn.Sequential(Embeddings(d_model, src_vocab), c(position)),
|
255 |
+
nn.Sequential(Embeddings(d_model, tgt_vocab), c(position)),
|
256 |
+
Generator(d_model, tgt_vocab))
|
257 |
+
|
258 |
+
# This was important from their code.
|
259 |
+
# Initialize parameters with Glorot / fan_avg.
|
260 |
+
for p in model.parameters():
|
261 |
+
if p.dim() > 1:
|
262 |
+
nn.init.xavier_uniform_(p)
|
263 |
+
return model
|
264 |
+
|
265 |
+
def __init__(self, opt):
|
266 |
+
super(TransformerModel, self).__init__(opt)
|
267 |
+
self.opt = opt
|
268 |
+
# self.config = yaml.load(open(opt.config_file))
|
269 |
+
|
270 |
+
self.N_enc = getattr(opt, 'N_enc', opt.num_layers)
|
271 |
+
self.N_dec = getattr(opt, 'N_dec', opt.num_layers)
|
272 |
+
self.d_model = getattr(opt, 'd_model', opt.input_encoding_size)
|
273 |
+
self.d_ff = getattr(opt, 'd_ff', opt.rnn_size)
|
274 |
+
self.h = getattr(opt, 'num_att_heads', 8)
|
275 |
+
self.dropout = getattr(opt, 'dropout', 0.1)
|
276 |
+
|
277 |
+
delattr(self, 'att_embed')
|
278 |
+
self.att_embed = nn.Sequential(*(
|
279 |
+
((nn.BatchNorm1d(self.att_feat_size),) if self.use_bn else ())+
|
280 |
+
(nn.Linear(self.att_feat_size, self.d_model),
|
281 |
+
nn.ReLU(),
|
282 |
+
nn.Dropout(self.drop_prob_lm))+
|
283 |
+
((nn.BatchNorm1d(self.d_model),) if self.use_bn==2 else ())))
|
284 |
+
|
285 |
+
delattr(self, 'embed')
|
286 |
+
self.embed = lambda x : x
|
287 |
+
delattr(self, 'fc_embed')
|
288 |
+
self.fc_embed = lambda x : x
|
289 |
+
delattr(self, 'logit')
|
290 |
+
del self.ctx2att
|
291 |
+
|
292 |
+
tgt_vocab = self.vocab_size + 1
|
293 |
+
|
294 |
+
|
295 |
+
self.model = self.make_model(0, tgt_vocab,
|
296 |
+
N_enc=self.N_enc,
|
297 |
+
N_dec=self.N_dec,
|
298 |
+
d_model=self.d_model,
|
299 |
+
d_ff=self.d_ff,
|
300 |
+
h=self.h,
|
301 |
+
dropout=self.dropout)
|
302 |
+
|
303 |
+
def logit(self, x): # unsafe way
|
304 |
+
return self.model.generator.proj(x)
|
305 |
+
|
306 |
+
def init_hidden(self, bsz):
|
307 |
+
return []
|
308 |
+
|
309 |
+
def _prepare_feature(self, fc_feats, att_feats, att_masks):
|
310 |
+
|
311 |
+
att_feats, seq, att_masks, seq_mask = self._prepare_feature_forward(att_feats, att_masks)
|
312 |
+
memory = self.model.encode(att_feats, att_masks)
|
313 |
+
|
314 |
+
return fc_feats[...,:0], att_feats[...,:0], memory, att_masks
|
315 |
+
|
316 |
+
def _prepare_feature_forward(self, att_feats, att_masks=None, seq=None):
|
317 |
+
att_feats, att_masks = self.clip_att(att_feats, att_masks)
|
318 |
+
|
319 |
+
att_feats = pack_wrapper(self.att_embed, att_feats, att_masks)
|
320 |
+
|
321 |
+
if att_masks is None:
|
322 |
+
att_masks = att_feats.new_ones(att_feats.shape[:2], dtype=torch.long)
|
323 |
+
att_masks = att_masks.unsqueeze(-2)
|
324 |
+
|
325 |
+
if seq is not None:
|
326 |
+
# crop the last one
|
327 |
+
# seq = seq[:,:-1]
|
328 |
+
seq_mask = (seq.data != self.eos_idx) & (seq.data != self.pad_idx)
|
329 |
+
seq_mask[:,0] = 1 # bos
|
330 |
+
|
331 |
+
seq_mask = seq_mask.unsqueeze(-2)
|
332 |
+
seq_mask = seq_mask & subsequent_mask(seq.size(-1)).to(seq_mask)
|
333 |
+
|
334 |
+
seq_per_img = seq.shape[0] // att_feats.shape[0]
|
335 |
+
if seq_per_img > 1:
|
336 |
+
att_feats, att_masks = utils.repeat_tensors(seq_per_img,
|
337 |
+
[att_feats, att_masks]
|
338 |
+
)
|
339 |
+
else:
|
340 |
+
seq_mask = None
|
341 |
+
|
342 |
+
return att_feats, seq, att_masks, seq_mask
|
343 |
+
|
344 |
+
def _forward(self, fc_feats, att_feats, seq, att_masks=None):
|
345 |
+
if seq.ndim == 3: # B * seq_per_img * seq_len
|
346 |
+
seq = seq.reshape(-1, seq.shape[2])
|
347 |
+
att_feats, seq, att_masks, seq_mask = self._prepare_feature_forward(att_feats, att_masks, seq)
|
348 |
+
|
349 |
+
out = self.model(att_feats, seq, att_masks, seq_mask)
|
350 |
+
|
351 |
+
outputs = self.model.generator(out)
|
352 |
+
return outputs
|
353 |
+
# return torch.cat([_.unsqueeze(1) for _ in outputs], 1)
|
354 |
+
|
355 |
+
def core(self, it, fc_feats_ph, att_feats_ph, memory, state, mask):
|
356 |
+
"""
|
357 |
+
state = [ys.unsqueeze(0)]
|
358 |
+
"""
|
359 |
+
if len(state) == 0:
|
360 |
+
ys = it.unsqueeze(1)
|
361 |
+
else:
|
362 |
+
ys = torch.cat([state[0][0], it.unsqueeze(1)], dim=1)
|
363 |
+
out = self.model.decode(memory, mask,
|
364 |
+
ys,
|
365 |
+
subsequent_mask(ys.size(1))
|
366 |
+
.to(memory.device))
|
367 |
+
return out[:, -1], [ys.unsqueeze(0)]
|
captioning/models/__init__.py
ADDED
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import absolute_import
|
2 |
+
from __future__ import division
|
3 |
+
from __future__ import print_function
|
4 |
+
|
5 |
+
import os
|
6 |
+
import copy
|
7 |
+
|
8 |
+
import numpy as np
|
9 |
+
import torch
|
10 |
+
|
11 |
+
from .ShowTellModel import ShowTellModel
|
12 |
+
from .FCModel import FCModel
|
13 |
+
from .AttModel import *
|
14 |
+
from .TransformerModel import TransformerModel
|
15 |
+
from .cachedTransformer import TransformerModel as cachedTransformer
|
16 |
+
from .BertCapModel import BertCapModel
|
17 |
+
from .M2Transformer import M2TransformerModel
|
18 |
+
from .AoAModel import AoAModel
|
19 |
+
from .OldModel import ShowAttendTellModel
|
20 |
+
|
21 |
+
def setup(opt):
|
22 |
+
if opt.caption_model in ['fc', 'show_tell']:
|
23 |
+
print('Warning: %s model is mostly deprecated; many new features are not supported.' %opt.caption_model)
|
24 |
+
if opt.caption_model == 'fc':
|
25 |
+
print('Use newfc instead of fc')
|
26 |
+
if opt.caption_model == 'fc':
|
27 |
+
model = FCModel(opt)
|
28 |
+
elif opt.caption_model == 'language_model':
|
29 |
+
model = LMModel(opt)
|
30 |
+
elif opt.caption_model == 'newfc':
|
31 |
+
model = NewFCModel(opt)
|
32 |
+
elif opt.caption_model == 'show_tell':
|
33 |
+
model = ShowTellModel(opt)
|
34 |
+
elif opt.caption_model == 'show_attend_tell':
|
35 |
+
model = ShowAttendTellModel(opt)
|
36 |
+
# Att2in model in self-critical
|
37 |
+
elif opt.caption_model == 'att2in':
|
38 |
+
model = Att2inModel(opt)
|
39 |
+
# Att2in model with two-layer MLP img embedding and word embedding
|
40 |
+
elif opt.caption_model == 'att2in2':
|
41 |
+
model = Att2in2Model(opt)
|
42 |
+
elif opt.caption_model == 'att2all2':
|
43 |
+
print('Warning: this is not a correct implementation of the att2all model in the original paper.')
|
44 |
+
model = Att2all2Model(opt)
|
45 |
+
# Adaptive Attention model from Knowing when to look
|
46 |
+
elif opt.caption_model == 'adaatt':
|
47 |
+
model = AdaAttModel(opt)
|
48 |
+
# Adaptive Attention with maxout lstm
|
49 |
+
elif opt.caption_model == 'adaattmo':
|
50 |
+
model = AdaAttMOModel(opt)
|
51 |
+
# Top-down attention model
|
52 |
+
elif opt.caption_model in ['topdown', 'updown']:
|
53 |
+
model = UpDownModel(opt)
|
54 |
+
# StackAtt
|
55 |
+
elif opt.caption_model == 'stackatt':
|
56 |
+
model = StackAttModel(opt)
|
57 |
+
# DenseAtt
|
58 |
+
elif opt.caption_model == 'denseatt':
|
59 |
+
model = DenseAttModel(opt)
|
60 |
+
# Transformer
|
61 |
+
elif opt.caption_model == 'transformer':
|
62 |
+
if getattr(opt, 'cached_transformer', False):
|
63 |
+
model = cachedTransformer(opt)
|
64 |
+
else:
|
65 |
+
model = TransformerModel(opt)
|
66 |
+
# AoANet
|
67 |
+
elif opt.caption_model == 'aoa':
|
68 |
+
model = AoAModel(opt)
|
69 |
+
elif opt.caption_model == 'bert':
|
70 |
+
model = BertCapModel(opt)
|
71 |
+
elif opt.caption_model == 'm2transformer':
|
72 |
+
model = M2TransformerModel(opt)
|
73 |
+
else:
|
74 |
+
raise Exception("Caption model not supported: {}".format(opt.caption_model))
|
75 |
+
|
76 |
+
return model
|
captioning/models/cachedTransformer.py
ADDED
@@ -0,0 +1,420 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# This file contains Transformer network
|
2 |
+
# Most of the code is copied from http://nlp.seas.harvard.edu/2018/04/03/attention.html
|
3 |
+
|
4 |
+
# The cfg name correspondance:
|
5 |
+
# N=num_layers
|
6 |
+
# d_model=input_encoding_size
|
7 |
+
# d_ff=rnn_size
|
8 |
+
# h is always 8
|
9 |
+
|
10 |
+
from __future__ import absolute_import
|
11 |
+
from __future__ import division
|
12 |
+
from __future__ import print_function
|
13 |
+
|
14 |
+
import torch
|
15 |
+
import torch.nn as nn
|
16 |
+
import torch.nn.functional as F
|
17 |
+
from . import utils
|
18 |
+
|
19 |
+
import copy
|
20 |
+
import math
|
21 |
+
import numpy as np
|
22 |
+
|
23 |
+
from .CaptionModel import CaptionModel
|
24 |
+
from .AttModel import sort_pack_padded_sequence, pad_unsort_packed_sequence, pack_wrapper, AttModel
|
25 |
+
|
26 |
+
class EncoderDecoder(nn.Module):
|
27 |
+
"""
|
28 |
+
A standard Encoder-Decoder architecture. Base for this and many
|
29 |
+
other models.
|
30 |
+
"""
|
31 |
+
def __init__(self, encoder, decoder, src_embed, tgt_embed, generator):
|
32 |
+
super(EncoderDecoder, self).__init__()
|
33 |
+
self.encoder = encoder
|
34 |
+
self.decoder = decoder
|
35 |
+
self.src_embed = src_embed
|
36 |
+
self.tgt_embed = tgt_embed
|
37 |
+
self.generator = generator
|
38 |
+
|
39 |
+
def forward(self, src, tgt, src_mask, tgt_mask):
|
40 |
+
"Take in and process masked src and target sequences."
|
41 |
+
return self.decode(self.encode(src, src_mask), src_mask,
|
42 |
+
tgt, tgt_mask)
|
43 |
+
|
44 |
+
def encode(self, src, src_mask):
|
45 |
+
return self.encoder(self.src_embed(src), src_mask)
|
46 |
+
|
47 |
+
def decode(self, memory, src_mask, tgt, tgt_mask, past=None):
|
48 |
+
return self.decoder(self.tgt_embed(tgt), memory, src_mask, tgt_mask, past=past)
|
49 |
+
|
50 |
+
class Generator(nn.Module):
|
51 |
+
"Define standard linear + softmax generation step."
|
52 |
+
def __init__(self, d_model, vocab):
|
53 |
+
super(Generator, self).__init__()
|
54 |
+
self.proj = nn.Linear(d_model, vocab)
|
55 |
+
|
56 |
+
def forward(self, x):
|
57 |
+
return F.log_softmax(self.proj(x), dim=-1)
|
58 |
+
|
59 |
+
def clones(module, N):
|
60 |
+
"Produce N identical layers."
|
61 |
+
return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])
|
62 |
+
|
63 |
+
class Encoder(nn.Module):
|
64 |
+
"Core encoder is a stack of N layers"
|
65 |
+
def __init__(self, layer, N):
|
66 |
+
super(Encoder, self).__init__()
|
67 |
+
self.layers = clones(layer, N)
|
68 |
+
self.norm = LayerNorm(layer.size)
|
69 |
+
|
70 |
+
def forward(self, x, mask):
|
71 |
+
"Pass the input (and mask) through each layer in turn."
|
72 |
+
for layer in self.layers:
|
73 |
+
x = layer(x, mask)
|
74 |
+
return self.norm(x)
|
75 |
+
|
76 |
+
class LayerNorm(nn.Module):
|
77 |
+
"Construct a layernorm module (See citation for details)."
|
78 |
+
def __init__(self, features, eps=1e-6):
|
79 |
+
super(LayerNorm, self).__init__()
|
80 |
+
self.a_2 = nn.Parameter(torch.ones(features))
|
81 |
+
self.b_2 = nn.Parameter(torch.zeros(features))
|
82 |
+
self.eps = eps
|
83 |
+
|
84 |
+
def forward(self, x):
|
85 |
+
mean = x.mean(-1, keepdim=True)
|
86 |
+
std = x.std(-1, keepdim=True)
|
87 |
+
return self.a_2 * (x - mean) / (std + self.eps) + self.b_2
|
88 |
+
|
89 |
+
class SublayerConnection(nn.Module):
|
90 |
+
"""
|
91 |
+
A residual connection followed by a layer norm.
|
92 |
+
Note for code simplicity the norm is first as opposed to last.
|
93 |
+
"""
|
94 |
+
def __init__(self, size, dropout):
|
95 |
+
super(SublayerConnection, self).__init__()
|
96 |
+
self.norm = LayerNorm(size)
|
97 |
+
self.dropout = nn.Dropout(dropout)
|
98 |
+
|
99 |
+
def forward(self, x, sublayer):
|
100 |
+
"Apply residual connection to any sublayer with the same size."
|
101 |
+
_x = sublayer(self.norm(x))
|
102 |
+
if type(_x) is tuple: # for multi-head attention that returns past
|
103 |
+
return x + self.dropout(_x[0]), _x[1]
|
104 |
+
return x + self.dropout(_x)
|
105 |
+
|
106 |
+
class EncoderLayer(nn.Module):
|
107 |
+
"Encoder is made up of self-attn and feed forward (defined below)"
|
108 |
+
def __init__(self, size, self_attn, feed_forward, dropout):
|
109 |
+
super(EncoderLayer, self).__init__()
|
110 |
+
self.self_attn = self_attn
|
111 |
+
self.feed_forward = feed_forward
|
112 |
+
self.sublayer = clones(SublayerConnection(size, dropout), 2)
|
113 |
+
self.size = size
|
114 |
+
|
115 |
+
def forward(self, x, mask):
|
116 |
+
"Follow Figure 1 (left) for connections."
|
117 |
+
x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, mask))
|
118 |
+
return self.sublayer[1](x, self.feed_forward)
|
119 |
+
|
120 |
+
class Decoder(nn.Module):
|
121 |
+
"Generic N layer decoder with masking."
|
122 |
+
def __init__(self, layer, N):
|
123 |
+
super(Decoder, self).__init__()
|
124 |
+
self.layers = clones(layer, N)
|
125 |
+
self.norm = LayerNorm(layer.size)
|
126 |
+
|
127 |
+
def forward(self, x, memory, src_mask, tgt_mask, past=None):
|
128 |
+
if past is not None:
|
129 |
+
present = [[], []]
|
130 |
+
x = x[:, -1:]
|
131 |
+
tgt_mask = tgt_mask[:, -1:] if tgt_mask is not None else None
|
132 |
+
past = list(zip(past[0].split(2, dim=0), past[1].split(2, dim=0)))
|
133 |
+
else:
|
134 |
+
past = [None] * len(self.layers)
|
135 |
+
for i, (layer, layer_past) in enumerate(zip(self.layers, past)):
|
136 |
+
x = layer(x, memory, src_mask, tgt_mask,
|
137 |
+
layer_past)
|
138 |
+
if layer_past is not None:
|
139 |
+
present[0].append(x[1][0])
|
140 |
+
present[1].append(x[1][1])
|
141 |
+
x = x[0]
|
142 |
+
if past[0] is None:
|
143 |
+
return self.norm(x)
|
144 |
+
else:
|
145 |
+
return self.norm(x), [torch.cat(present[0], 0), torch.cat(present[1], 0)]
|
146 |
+
|
147 |
+
|
148 |
+
class DecoderLayer(nn.Module):
|
149 |
+
"Decoder is made of self-attn, src-attn, and feed forward (defined below)"
|
150 |
+
def __init__(self, size, self_attn, src_attn, feed_forward, dropout):
|
151 |
+
super(DecoderLayer, self).__init__()
|
152 |
+
self.size = size
|
153 |
+
self.self_attn = self_attn
|
154 |
+
self.src_attn = src_attn
|
155 |
+
self.feed_forward = feed_forward
|
156 |
+
self.sublayer = clones(SublayerConnection(size, dropout), 3)
|
157 |
+
|
158 |
+
def forward(self, x, memory, src_mask, tgt_mask, layer_past=None):
|
159 |
+
"Follow Figure 1 (right) for connections."
|
160 |
+
m = memory
|
161 |
+
if layer_past is None:
|
162 |
+
x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, tgt_mask))
|
163 |
+
x = self.sublayer[1](x, lambda x: self.src_attn(x, m, m, src_mask))
|
164 |
+
return self.sublayer[2](x, self.feed_forward)
|
165 |
+
else:
|
166 |
+
present = [None, None]
|
167 |
+
x, present[0] = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, tgt_mask, layer_past[0]))
|
168 |
+
x, present[1] = self.sublayer[1](x, lambda x: self.src_attn(x, m, m, src_mask, layer_past[1]))
|
169 |
+
return self.sublayer[2](x, self.feed_forward), present
|
170 |
+
|
171 |
+
def subsequent_mask(size):
|
172 |
+
"Mask out subsequent positions."
|
173 |
+
attn_shape = (1, size, size)
|
174 |
+
subsequent_mask = np.triu(np.ones(attn_shape), k=1).astype('uint8')
|
175 |
+
return torch.from_numpy(subsequent_mask) == 0
|
176 |
+
|
177 |
+
def attention(query, key, value, mask=None, dropout=None):
|
178 |
+
"Compute 'Scaled Dot Product Attention'"
|
179 |
+
d_k = query.size(-1)
|
180 |
+
scores = torch.matmul(query, key.transpose(-2, -1)) \
|
181 |
+
/ math.sqrt(d_k)
|
182 |
+
if mask is not None:
|
183 |
+
scores = scores.masked_fill(mask == 0, float('-inf'))
|
184 |
+
p_attn = F.softmax(scores, dim = -1)
|
185 |
+
if dropout is not None:
|
186 |
+
p_attn = dropout(p_attn)
|
187 |
+
return torch.matmul(p_attn, value), p_attn
|
188 |
+
|
189 |
+
class MultiHeadedAttention(nn.Module):
|
190 |
+
def __init__(self, h, d_model, dropout=0.1):
|
191 |
+
"Take in model size and number of heads."
|
192 |
+
super(MultiHeadedAttention, self).__init__()
|
193 |
+
assert d_model % h == 0
|
194 |
+
# We assume d_v always equals d_k
|
195 |
+
self.d_k = d_model // h
|
196 |
+
self.h = h
|
197 |
+
self.linears = clones(nn.Linear(d_model, d_model), 4)
|
198 |
+
self.attn = None
|
199 |
+
self.dropout = nn.Dropout(p=dropout)
|
200 |
+
|
201 |
+
def forward(self, query, key, value, mask=None, layer_past=None):
|
202 |
+
"Implements Figure 2"
|
203 |
+
if mask is not None:
|
204 |
+
# Same mask applied to all h heads.
|
205 |
+
mask = mask.unsqueeze(1)
|
206 |
+
nbatches = query.size(0)
|
207 |
+
|
208 |
+
# The past works differently here. For self attn, the query and key be updated incrementailly
|
209 |
+
# For src_attn the past is fixed.
|
210 |
+
|
211 |
+
# For src_attn, when the layer past is ready
|
212 |
+
if layer_past is not None and layer_past.shape[2] == key.shape[1] > 1: # suppose memory size always greater than 1
|
213 |
+
query = self.linears[0](query)
|
214 |
+
key, value = layer_past[0], layer_past[1]
|
215 |
+
present = torch.stack([key, value])
|
216 |
+
else:
|
217 |
+
# 1) Do all the linear projections in batch from d_model => h x d_k
|
218 |
+
query, key, value = \
|
219 |
+
[l(x) for l, x in zip(self.linears, (query, key, value))]
|
220 |
+
|
221 |
+
# self attn + past OR the first time step of src attn
|
222 |
+
if layer_past is not None and not (layer_past.shape[2] == key.shape[1] > 1):
|
223 |
+
past_key, past_value = layer_past[0], layer_past[1]
|
224 |
+
key = torch.cat((past_key, key), dim=1)
|
225 |
+
value = torch.cat((past_value, value), dim=1)
|
226 |
+
present = torch.stack([key, value])
|
227 |
+
|
228 |
+
query, key, value = \
|
229 |
+
[x.view(nbatches, -1, self.h, self.d_k).transpose(1, 2)
|
230 |
+
for x in [query, key, value]]
|
231 |
+
|
232 |
+
# 2) Apply attention on all the projected vectors in batch.
|
233 |
+
x, self.attn = attention(query, key, value, mask=mask,
|
234 |
+
dropout=self.dropout)
|
235 |
+
|
236 |
+
# 3) "Concat" using a view and apply a final linear.
|
237 |
+
x = x.transpose(1, 2).contiguous() \
|
238 |
+
.view(nbatches, -1, self.h * self.d_k)
|
239 |
+
if layer_past is not None:
|
240 |
+
return self.linears[-1](x), present
|
241 |
+
else:
|
242 |
+
return self.linears[-1](x)
|
243 |
+
|
244 |
+
class PositionwiseFeedForward(nn.Module):
|
245 |
+
"Implements FFN equation."
|
246 |
+
def __init__(self, d_model, d_ff, dropout=0.1):
|
247 |
+
super(PositionwiseFeedForward, self).__init__()
|
248 |
+
self.w_1 = nn.Linear(d_model, d_ff)
|
249 |
+
self.w_2 = nn.Linear(d_ff, d_model)
|
250 |
+
self.dropout = nn.Dropout(dropout)
|
251 |
+
|
252 |
+
def forward(self, x):
|
253 |
+
return self.w_2(self.dropout(F.relu(self.w_1(x))))
|
254 |
+
|
255 |
+
class Embeddings(nn.Module):
|
256 |
+
def __init__(self, d_model, vocab):
|
257 |
+
super(Embeddings, self).__init__()
|
258 |
+
self.lut = nn.Embedding(vocab, d_model)
|
259 |
+
self.d_model = d_model
|
260 |
+
|
261 |
+
def forward(self, x):
|
262 |
+
return self.lut(x) * math.sqrt(self.d_model)
|
263 |
+
|
264 |
+
class PositionalEncoding(nn.Module):
|
265 |
+
"Implement the PE function."
|
266 |
+
def __init__(self, d_model, dropout, max_len=5000):
|
267 |
+
super(PositionalEncoding, self).__init__()
|
268 |
+
self.dropout = nn.Dropout(p=dropout)
|
269 |
+
|
270 |
+
# Compute the positional encodings once in log space.
|
271 |
+
pe = torch.zeros(max_len, d_model)
|
272 |
+
position = torch.arange(0, max_len).unsqueeze(1).float()
|
273 |
+
div_term = torch.exp(torch.arange(0, d_model, 2).float() *
|
274 |
+
-(math.log(10000.0) / d_model))
|
275 |
+
pe[:, 0::2] = torch.sin(position * div_term)
|
276 |
+
pe[:, 1::2] = torch.cos(position * div_term)
|
277 |
+
pe = pe.unsqueeze(0)
|
278 |
+
self.register_buffer('pe', pe)
|
279 |
+
|
280 |
+
def forward(self, x):
|
281 |
+
x = x + self.pe[:, :x.size(1)]
|
282 |
+
return self.dropout(x)
|
283 |
+
|
284 |
+
class TransformerModel(AttModel):
|
285 |
+
|
286 |
+
def make_model(self, src_vocab, tgt_vocab, N_enc=6, N_dec=6,
|
287 |
+
d_model=512, d_ff=2048, h=8, dropout=0.1):
|
288 |
+
"Helper: Construct a model from hyperparameters."
|
289 |
+
c = copy.deepcopy
|
290 |
+
attn = MultiHeadedAttention(h, d_model, dropout)
|
291 |
+
ff = PositionwiseFeedForward(d_model, d_ff, dropout)
|
292 |
+
position = PositionalEncoding(d_model, dropout)
|
293 |
+
model = EncoderDecoder(
|
294 |
+
Encoder(EncoderLayer(d_model, c(attn), c(ff), dropout), N_enc),
|
295 |
+
Decoder(DecoderLayer(d_model, c(attn), c(attn),
|
296 |
+
c(ff), dropout), N_dec),
|
297 |
+
lambda x:x, # nn.Sequential(Embeddings(d_model, src_vocab), c(position)),
|
298 |
+
nn.Sequential(Embeddings(d_model, tgt_vocab), c(position)),
|
299 |
+
Generator(d_model, tgt_vocab))
|
300 |
+
|
301 |
+
# This was important from their code.
|
302 |
+
# Initialize parameters with Glorot / fan_avg.
|
303 |
+
for p in model.parameters():
|
304 |
+
if p.dim() > 1:
|
305 |
+
nn.init.xavier_uniform_(p)
|
306 |
+
return model
|
307 |
+
|
308 |
+
def __init__(self, opt):
|
309 |
+
super(TransformerModel, self).__init__(opt)
|
310 |
+
self.opt = opt
|
311 |
+
# self.config = yaml.load(open(opt.config_file))
|
312 |
+
|
313 |
+
self.N_enc = getattr(opt, 'N_enc', opt.num_layers)
|
314 |
+
self.N_dec = getattr(opt, 'N_dec', opt.num_layers)
|
315 |
+
self.d_model = getattr(opt, 'd_model', opt.input_encoding_size)
|
316 |
+
self.d_ff = getattr(opt, 'd_ff', opt.rnn_size)
|
317 |
+
self.h = getattr(opt, 'num_att_heads', 8)
|
318 |
+
self.dropout = getattr(opt, 'dropout', 0.1)
|
319 |
+
|
320 |
+
delattr(self, 'att_embed')
|
321 |
+
self.att_embed = nn.Sequential(*(
|
322 |
+
((nn.BatchNorm1d(self.att_feat_size),) if self.use_bn else ())+
|
323 |
+
(nn.Linear(self.att_feat_size, self.d_model),
|
324 |
+
nn.ReLU(),
|
325 |
+
nn.Dropout(self.drop_prob_lm))+
|
326 |
+
((nn.BatchNorm1d(self.d_model),) if self.use_bn==2 else ())))
|
327 |
+
|
328 |
+
delattr(self, 'embed')
|
329 |
+
self.embed = lambda x : x
|
330 |
+
delattr(self, 'fc_embed')
|
331 |
+
self.fc_embed = lambda x : x
|
332 |
+
delattr(self, 'logit')
|
333 |
+
del self.ctx2att
|
334 |
+
|
335 |
+
tgt_vocab = self.vocab_size + 1
|
336 |
+
|
337 |
+
|
338 |
+
self.model = self.make_model(0, tgt_vocab,
|
339 |
+
N_enc=self.N_enc,
|
340 |
+
N_dec=self.N_dec,
|
341 |
+
d_model=self.d_model,
|
342 |
+
d_ff=self.d_ff,
|
343 |
+
h=self.h,
|
344 |
+
dropout=self.dropout)
|
345 |
+
|
346 |
+
def logit(self, x): # unsafe way
|
347 |
+
return self.model.generator.proj(x)
|
348 |
+
|
349 |
+
def init_hidden(self, bsz):
|
350 |
+
return []
|
351 |
+
|
352 |
+
def _prepare_feature(self, fc_feats, att_feats, att_masks):
|
353 |
+
|
354 |
+
att_feats, seq, att_masks, seq_mask = self._prepare_feature_forward(att_feats, att_masks)
|
355 |
+
memory = self.model.encode(att_feats, att_masks)
|
356 |
+
|
357 |
+
return fc_feats[...,:0], att_feats[...,:0], memory, att_masks
|
358 |
+
|
359 |
+
def _prepare_feature_forward(self, att_feats, att_masks=None, seq=None):
|
360 |
+
att_feats, att_masks = self.clip_att(att_feats, att_masks)
|
361 |
+
|
362 |
+
att_feats = pack_wrapper(self.att_embed, att_feats, att_masks)
|
363 |
+
|
364 |
+
if att_masks is None:
|
365 |
+
att_masks = att_feats.new_ones(att_feats.shape[:2], dtype=torch.long)
|
366 |
+
att_masks = att_masks.unsqueeze(-2)
|
367 |
+
|
368 |
+
if seq is not None:
|
369 |
+
# crop the last one
|
370 |
+
# seq = seq[:,:-1]
|
371 |
+
seq_mask = (seq.data != self.eos_idx) & (seq.data != self.pad_idx)
|
372 |
+
seq_mask[:,0] = 1 # bos
|
373 |
+
|
374 |
+
seq_mask = seq_mask.unsqueeze(-2)
|
375 |
+
seq_mask = seq_mask & subsequent_mask(seq.size(-1)).to(seq_mask)
|
376 |
+
|
377 |
+
seq_per_img = seq.shape[0] // att_feats.shape[0]
|
378 |
+
if seq_per_img > 1:
|
379 |
+
att_feats, att_masks = utils.repeat_tensors(seq_per_img,
|
380 |
+
[att_feats, att_masks]
|
381 |
+
)
|
382 |
+
else:
|
383 |
+
seq_mask = None
|
384 |
+
|
385 |
+
return att_feats, seq, att_masks, seq_mask
|
386 |
+
|
387 |
+
def _forward(self, fc_feats, att_feats, seq, att_masks=None):
|
388 |
+
if seq.ndim == 3: # B * seq_per_img * seq_len
|
389 |
+
seq = seq.reshape(-1, seq.shape[2])
|
390 |
+
att_feats, seq, att_masks, seq_mask = self._prepare_feature_forward(att_feats, att_masks, seq)
|
391 |
+
|
392 |
+
out = self.model(att_feats, seq, att_masks, seq_mask)
|
393 |
+
|
394 |
+
outputs = self.model.generator(out)
|
395 |
+
return outputs
|
396 |
+
# return torch.cat([_.unsqueeze(1) for _ in outputs], 1)
|
397 |
+
|
398 |
+
def core(self, it, fc_feats_ph, att_feats_ph, memory, state, mask):
|
399 |
+
"""
|
400 |
+
state is the precomputed key/value. N_dec x seq_len x d_model
|
401 |
+
Note: due to the layer norm, it's not equivalant to stateless,
|
402 |
+
but it seems behaving similar
|
403 |
+
"""
|
404 |
+
# state is tokens + past
|
405 |
+
if len(state) == 0:
|
406 |
+
ys = it.unsqueeze(1)
|
407 |
+
# basically empty state, just to let it know to return past
|
408 |
+
# The second dim has to be batch_size, for beam search purpose
|
409 |
+
past = [fc_feats_ph.new_zeros(self.N_dec * 2, fc_feats_ph.shape[0], 0, self.d_model), # self
|
410 |
+
fc_feats_ph.new_zeros(self.N_dec * 2, fc_feats_ph.shape[0], 0, self.d_model)] # src
|
411 |
+
# 2 for self attn, 2 for src attn
|
412 |
+
else:
|
413 |
+
ys = torch.cat([state[0][0], it.unsqueeze(1)], dim=1)
|
414 |
+
past = state[1:]
|
415 |
+
out, past = self.model.decode(memory, mask,
|
416 |
+
ys, # We still feed the full past words, because we need it for position embedding to know the position id
|
417 |
+
subsequent_mask(ys.size(1))
|
418 |
+
.to(memory.device),
|
419 |
+
past=past)
|
420 |
+
return out[:, -1], [ys.unsqueeze(0)] + past
|
captioning/models/utils.py
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
def repeat_tensors(n, x):
|
4 |
+
"""
|
5 |
+
For a tensor of size Bx..., we repeat it n times, and make it Bnx...
|
6 |
+
For collections, do nested repeat
|
7 |
+
"""
|
8 |
+
if torch.is_tensor(x):
|
9 |
+
x = x.unsqueeze(1) # Bx1x...
|
10 |
+
x = x.expand(-1, n, *([-1]*len(x.shape[2:]))) # Bxnx...
|
11 |
+
x = x.reshape(x.shape[0]*n, *x.shape[2:]) # Bnx...
|
12 |
+
elif type(x) is list or type(x) is tuple:
|
13 |
+
x = [repeat_tensors(n, _) for _ in x]
|
14 |
+
return x
|
15 |
+
|
16 |
+
|
17 |
+
def split_tensors(n, x):
|
18 |
+
if torch.is_tensor(x):
|
19 |
+
assert x.shape[0] % n == 0
|
20 |
+
x = x.reshape(x.shape[0] // n, n, *x.shape[1:]).unbind(1)
|
21 |
+
elif type(x) is list or type(x) is tuple:
|
22 |
+
x = [split_tensors(n, _) for _ in x]
|
23 |
+
elif x is None:
|
24 |
+
x = [None] * n
|
25 |
+
return x
|
captioning/modules/.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
captioning/modules/loss_wrapper.py
ADDED
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from . import losses
|
3 |
+
from ..utils.rewards import init_scorer, get_self_critical_reward
|
4 |
+
|
5 |
+
class LossWrapper(torch.nn.Module):
|
6 |
+
def __init__(self, model, opt):
|
7 |
+
super(LossWrapper, self).__init__()
|
8 |
+
self.opt = opt
|
9 |
+
self.model = model
|
10 |
+
if opt.label_smoothing > 0:
|
11 |
+
self.crit = losses.LabelSmoothing(smoothing=opt.label_smoothing)
|
12 |
+
else:
|
13 |
+
self.crit = losses.LanguageModelCriterion()
|
14 |
+
self.rl_crit = losses.RewardCriterion()
|
15 |
+
self.struc_crit = losses.StructureLosses(opt)
|
16 |
+
|
17 |
+
def forward(self, fc_feats, att_feats, labels, masks, att_masks, gts, gt_indices,
|
18 |
+
sc_flag, struc_flag):
|
19 |
+
opt = self.opt
|
20 |
+
|
21 |
+
out = {}
|
22 |
+
if struc_flag:
|
23 |
+
if opt.structure_loss_weight < 1:
|
24 |
+
lm_loss = self.crit(self.model(fc_feats, att_feats, labels[..., :-1], att_masks), labels[..., 1:], masks[..., 1:])
|
25 |
+
else:
|
26 |
+
lm_loss = torch.tensor(0).type_as(fc_feats)
|
27 |
+
if opt.structure_loss_weight > 0:
|
28 |
+
gen_result, sample_logprobs = self.model(fc_feats, att_feats, att_masks,
|
29 |
+
opt={'sample_method':opt.train_sample_method,
|
30 |
+
'beam_size':opt.train_beam_size,
|
31 |
+
'output_logsoftmax': opt.struc_use_logsoftmax or opt.structure_loss_type == 'softmax_margin'\
|
32 |
+
or not 'margin' in opt.structure_loss_type,
|
33 |
+
'sample_n': opt.train_sample_n},
|
34 |
+
mode='sample')
|
35 |
+
gts = [gts[_] for _ in gt_indices.tolist()]
|
36 |
+
struc_loss = self.struc_crit(sample_logprobs, gen_result, gts)
|
37 |
+
else:
|
38 |
+
struc_loss = {'loss': torch.tensor(0).type_as(fc_feats),
|
39 |
+
'reward': torch.tensor(0).type_as(fc_feats)}
|
40 |
+
loss = (1-opt.structure_loss_weight) * lm_loss + opt.structure_loss_weight * struc_loss['loss']
|
41 |
+
out['lm_loss'] = lm_loss
|
42 |
+
out['struc_loss'] = struc_loss['loss']
|
43 |
+
out['reward'] = struc_loss['reward']
|
44 |
+
elif not sc_flag:
|
45 |
+
loss = self.crit(self.model(fc_feats, att_feats, labels[..., :-1], att_masks), labels[..., 1:], masks[..., 1:])
|
46 |
+
else:
|
47 |
+
self.model.eval()
|
48 |
+
with torch.no_grad():
|
49 |
+
greedy_res, _ = self.model(fc_feats, att_feats, att_masks,
|
50 |
+
mode='sample',
|
51 |
+
opt={'sample_method': opt.sc_sample_method,
|
52 |
+
'beam_size': opt.sc_beam_size})
|
53 |
+
self.model.train()
|
54 |
+
gen_result, sample_logprobs = self.model(fc_feats, att_feats, att_masks,
|
55 |
+
opt={'sample_method':opt.train_sample_method,
|
56 |
+
'beam_size':opt.train_beam_size,
|
57 |
+
'sample_n': opt.train_sample_n},
|
58 |
+
mode='sample')
|
59 |
+
gts = [gts[_] for _ in gt_indices.tolist()]
|
60 |
+
reward = get_self_critical_reward(greedy_res, gts, gen_result, self.opt)
|
61 |
+
reward = torch.from_numpy(reward).to(sample_logprobs)
|
62 |
+
loss = self.rl_crit(sample_logprobs, gen_result.data, reward)
|
63 |
+
out['reward'] = reward[:,0].mean()
|
64 |
+
out['loss'] = loss
|
65 |
+
return out
|
captioning/modules/losses.py
ADDED
@@ -0,0 +1,218 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from ..utils.rewards import get_scores, get_self_cider_scores
|
4 |
+
|
5 |
+
class RewardCriterion(nn.Module):
|
6 |
+
def __init__(self):
|
7 |
+
super(RewardCriterion, self).__init__()
|
8 |
+
|
9 |
+
def forward(self, input, seq, reward):
|
10 |
+
input = input.gather(2, seq.unsqueeze(2)).squeeze(2)
|
11 |
+
|
12 |
+
input = input.reshape(-1)
|
13 |
+
reward = reward.reshape(-1)
|
14 |
+
mask = (seq>0).to(input)
|
15 |
+
mask = torch.cat([mask.new(mask.size(0), 1).fill_(1), mask[:, :-1]], 1).reshape(-1)
|
16 |
+
output = - input * reward * mask
|
17 |
+
output = torch.sum(output) / torch.sum(mask)
|
18 |
+
|
19 |
+
return output
|
20 |
+
|
21 |
+
class StructureLosses(nn.Module):
|
22 |
+
"""
|
23 |
+
This loss is inspired by Classical Structured Prediction Losses for Sequence to Sequence Learning (Edunov et al., 2018).
|
24 |
+
"""
|
25 |
+
def __init__(self, opt):
|
26 |
+
super(StructureLosses, self).__init__()
|
27 |
+
self.opt = opt
|
28 |
+
self.loss_type = opt.structure_loss_type
|
29 |
+
|
30 |
+
def forward(self, input, seq, data_gts):
|
31 |
+
"""
|
32 |
+
Input is either logits or log softmax
|
33 |
+
"""
|
34 |
+
out = {}
|
35 |
+
|
36 |
+
batch_size = input.size(0)# batch_size = sample_size * seq_per_img
|
37 |
+
seq_per_img = batch_size // len(data_gts)
|
38 |
+
|
39 |
+
assert seq_per_img == self.opt.train_sample_n, seq_per_img
|
40 |
+
|
41 |
+
mask = (seq>0).to(input)
|
42 |
+
mask = torch.cat([mask.new_full((mask.size(0), 1), 1), mask[:, :-1]], 1)
|
43 |
+
|
44 |
+
scores = get_scores(data_gts, seq, self.opt)
|
45 |
+
scores = torch.from_numpy(scores).type_as(input).view(-1, seq_per_img)
|
46 |
+
out['reward'] = scores #.mean()
|
47 |
+
if self.opt.entropy_reward_weight > 0:
|
48 |
+
entropy = - (F.softmax(input, dim=2) * F.log_softmax(input, dim=2)).sum(2).data
|
49 |
+
entropy = (entropy * mask).sum(1) / mask.sum(1)
|
50 |
+
print('entropy', entropy.mean().item())
|
51 |
+
scores = scores + self.opt.entropy_reward_weight * entropy.view(-1, seq_per_img)
|
52 |
+
# rescale cost to [0,1]
|
53 |
+
costs = - scores
|
54 |
+
if self.loss_type == 'risk' or self.loss_type == 'softmax_margin':
|
55 |
+
costs = costs - costs.min(1, keepdim=True)[0]
|
56 |
+
costs = costs / costs.max(1, keepdim=True)[0]
|
57 |
+
# in principle
|
58 |
+
# Only risk need such rescale
|
59 |
+
# margin should be alright; Let's try.
|
60 |
+
|
61 |
+
# Gather input: BxTxD -> BxT
|
62 |
+
input = input.gather(2, seq.unsqueeze(2)).squeeze(2)
|
63 |
+
|
64 |
+
if self.loss_type == 'seqnll':
|
65 |
+
# input is logsoftmax
|
66 |
+
input = input * mask
|
67 |
+
input = input.sum(1) / mask.sum(1)
|
68 |
+
input = input.view(-1, seq_per_img)
|
69 |
+
|
70 |
+
target = costs.min(1)[1]
|
71 |
+
output = F.cross_entropy(input, target)
|
72 |
+
elif self.loss_type == 'risk':
|
73 |
+
# input is logsoftmax
|
74 |
+
input = input * mask
|
75 |
+
input = input.sum(1)
|
76 |
+
input = input.view(-1, seq_per_img)
|
77 |
+
|
78 |
+
output = (F.softmax(input.exp()) * costs).sum(1).mean()
|
79 |
+
|
80 |
+
# test
|
81 |
+
# avg_scores = input
|
82 |
+
# probs = F.softmax(avg_scores.exp_())
|
83 |
+
# loss = (probs * costs.type_as(probs)).sum() / input.size(0)
|
84 |
+
# print(output.item(), loss.item())
|
85 |
+
|
86 |
+
elif self.loss_type == 'max_margin':
|
87 |
+
# input is logits
|
88 |
+
input = input * mask
|
89 |
+
input = input.sum(1) / mask.sum(1)
|
90 |
+
input = input.view(-1, seq_per_img)
|
91 |
+
_, __ = costs.min(1, keepdim=True)
|
92 |
+
costs_star = _
|
93 |
+
input_star = input.gather(1, __)
|
94 |
+
output = F.relu(costs - costs_star - input_star + input).max(1)[0] / 2
|
95 |
+
output = output.mean()
|
96 |
+
|
97 |
+
# sanity test
|
98 |
+
# avg_scores = input + costs
|
99 |
+
# scores_with_high_target = avg_scores.clone()
|
100 |
+
# scores_with_high_target.scatter_(1, costs.min(1)[1].view(-1, 1), 1e10)
|
101 |
+
|
102 |
+
# target_and_offender_index = scores_with_high_target.sort(1, True)[1][:, 0:2]
|
103 |
+
# avg_scores = avg_scores.gather(1, target_and_offender_index)
|
104 |
+
# target_index = avg_scores.new_zeros(avg_scores.size(0), dtype=torch.long)
|
105 |
+
# loss = F.multi_margin_loss(avg_scores, target_index, size_average=True, margin=0)
|
106 |
+
# print(loss.item() * 2, output.item())
|
107 |
+
|
108 |
+
elif self.loss_type == 'multi_margin':
|
109 |
+
# input is logits
|
110 |
+
input = input * mask
|
111 |
+
input = input.sum(1) / mask.sum(1)
|
112 |
+
input = input.view(-1, seq_per_img)
|
113 |
+
_, __ = costs.min(1, keepdim=True)
|
114 |
+
costs_star = _
|
115 |
+
input_star = input.gather(1, __)
|
116 |
+
output = F.relu(costs - costs_star - input_star + input)
|
117 |
+
output = output.mean()
|
118 |
+
|
119 |
+
# sanity test
|
120 |
+
# avg_scores = input + costs
|
121 |
+
# loss = F.multi_margin_loss(avg_scores, costs.min(1)[1], margin=0)
|
122 |
+
# print(output, loss)
|
123 |
+
|
124 |
+
elif self.loss_type == 'softmax_margin':
|
125 |
+
# input is logsoftmax
|
126 |
+
input = input * mask
|
127 |
+
input = input.sum(1) / mask.sum(1)
|
128 |
+
input = input.view(-1, seq_per_img)
|
129 |
+
|
130 |
+
input = input + costs
|
131 |
+
target = costs.min(1)[1]
|
132 |
+
output = F.cross_entropy(input, target)
|
133 |
+
|
134 |
+
elif self.loss_type == 'real_softmax_margin':
|
135 |
+
# input is logits
|
136 |
+
# This is what originally defined in Kevin's paper
|
137 |
+
# The result should be equivalent to softmax_margin
|
138 |
+
input = input * mask
|
139 |
+
input = input.sum(1) / mask.sum(1)
|
140 |
+
input = input.view(-1, seq_per_img)
|
141 |
+
|
142 |
+
input = input + costs
|
143 |
+
target = costs.min(1)[1]
|
144 |
+
output = F.cross_entropy(input, target)
|
145 |
+
|
146 |
+
elif self.loss_type == 'new_self_critical':
|
147 |
+
"""
|
148 |
+
A different self critical
|
149 |
+
Self critical uses greedy decoding score as baseline;
|
150 |
+
This setting uses the average score of the rest samples as baseline
|
151 |
+
(suppose c1...cn n samples, reward1 = score1 - 1/(n-1)(score2+..+scoren) )
|
152 |
+
"""
|
153 |
+
baseline = (scores.sum(1, keepdim=True) - scores) / (scores.shape[1] - 1)
|
154 |
+
scores = scores - baseline
|
155 |
+
# self cider used as reward to promote diversity (not working that much in this way)
|
156 |
+
if getattr(self.opt, 'self_cider_reward_weight', 0) > 0:
|
157 |
+
_scores = get_self_cider_scores(data_gts, seq, self.opt)
|
158 |
+
_scores = torch.from_numpy(_scores).type_as(scores).view(-1, 1)
|
159 |
+
_scores = _scores.expand_as(scores - 1)
|
160 |
+
scores += self.opt.self_cider_reward_weight * _scores
|
161 |
+
output = - input * mask * scores.view(-1, 1)
|
162 |
+
output = torch.sum(output) / torch.sum(mask)
|
163 |
+
|
164 |
+
out['loss'] = output
|
165 |
+
return out
|
166 |
+
|
167 |
+
class LanguageModelCriterion(nn.Module):
|
168 |
+
def __init__(self):
|
169 |
+
super(LanguageModelCriterion, self).__init__()
|
170 |
+
|
171 |
+
def forward(self, input, target, mask):
|
172 |
+
if target.ndim == 3:
|
173 |
+
target = target.reshape(-1, target.shape[2])
|
174 |
+
mask = mask.reshape(-1, mask.shape[2])
|
175 |
+
# truncate to the same size
|
176 |
+
target = target[:, :input.size(1)]
|
177 |
+
mask = mask[:, :input.size(1)].to(input)
|
178 |
+
|
179 |
+
output = -input.gather(2, target.unsqueeze(2)).squeeze(2) * mask
|
180 |
+
# Average over each token
|
181 |
+
output = torch.sum(output) / torch.sum(mask)
|
182 |
+
|
183 |
+
return output
|
184 |
+
|
185 |
+
class LabelSmoothing(nn.Module):
|
186 |
+
"Implement label smoothing."
|
187 |
+
def __init__(self, size=0, padding_idx=0, smoothing=0.0):
|
188 |
+
super(LabelSmoothing, self).__init__()
|
189 |
+
self.criterion = nn.KLDivLoss(size_average=False, reduce=False)
|
190 |
+
# self.padding_idx = padding_idx
|
191 |
+
self.confidence = 1.0 - smoothing
|
192 |
+
self.smoothing = smoothing
|
193 |
+
# self.size = size
|
194 |
+
self.true_dist = None
|
195 |
+
|
196 |
+
def forward(self, input, target, mask):
|
197 |
+
if target.ndim == 3:
|
198 |
+
target = target.reshape(-1, target.shape[2])
|
199 |
+
mask = mask.reshape(-1, mask.shape[2])
|
200 |
+
# truncate to the same size
|
201 |
+
target = target[:, :input.size(1)]
|
202 |
+
mask = mask[:, :input.size(1)]
|
203 |
+
|
204 |
+
input = input.reshape(-1, input.size(-1))
|
205 |
+
target = target.reshape(-1)
|
206 |
+
mask = mask.reshape(-1).to(input)
|
207 |
+
|
208 |
+
# assert x.size(1) == self.size
|
209 |
+
self.size = input.size(1)
|
210 |
+
# true_dist = x.data.clone()
|
211 |
+
true_dist = input.data.clone()
|
212 |
+
# true_dist.fill_(self.smoothing / (self.size - 2))
|
213 |
+
true_dist.fill_(self.smoothing / (self.size - 1))
|
214 |
+
true_dist.scatter_(1, target.data.unsqueeze(1), self.confidence)
|
215 |
+
# true_dist[:, self.padding_idx] = 0
|
216 |
+
# mask = torch.nonzero(target.data == self.padding_idx)
|
217 |
+
# self.true_dist = true_dist
|
218 |
+
return (self.criterion(input, true_dist).sum(1) * mask).sum() / mask.sum()
|
captioning/utils/.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
captioning/utils/__init__.py
ADDED
File without changes
|
captioning/utils/config.py
ADDED
@@ -0,0 +1,153 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
|
2 |
+
# Copy from fvcore
|
3 |
+
|
4 |
+
import logging
|
5 |
+
import os
|
6 |
+
from typing import Any
|
7 |
+
import yaml
|
8 |
+
from yacs.config import CfgNode as _CfgNode
|
9 |
+
|
10 |
+
import io as PathManager
|
11 |
+
|
12 |
+
BASE_KEY = "_BASE_"
|
13 |
+
|
14 |
+
|
15 |
+
class CfgNode(_CfgNode):
|
16 |
+
"""
|
17 |
+
Our own extended version of :class:`yacs.config.CfgNode`.
|
18 |
+
It contains the following extra features:
|
19 |
+
|
20 |
+
1. The :meth:`merge_from_file` method supports the "_BASE_" key,
|
21 |
+
which allows the new CfgNode to inherit all the attributes from the
|
22 |
+
base configuration file.
|
23 |
+
2. Keys that start with "COMPUTED_" are treated as insertion-only
|
24 |
+
"computed" attributes. They can be inserted regardless of whether
|
25 |
+
the CfgNode is frozen or not.
|
26 |
+
3. With "allow_unsafe=True", it supports pyyaml tags that evaluate
|
27 |
+
expressions in config. See examples in
|
28 |
+
https://pyyaml.org/wiki/PyYAMLDocumentation#yaml-tags-and-python-types
|
29 |
+
Note that this may lead to arbitrary code execution: you must not
|
30 |
+
load a config file from untrusted sources before manually inspecting
|
31 |
+
the content of the file.
|
32 |
+
"""
|
33 |
+
|
34 |
+
@staticmethod
|
35 |
+
def load_yaml_with_base(filename, allow_unsafe = False):
|
36 |
+
"""
|
37 |
+
Just like `yaml.load(open(filename))`, but inherit attributes from its
|
38 |
+
`_BASE_`.
|
39 |
+
|
40 |
+
Args:
|
41 |
+
filename (str): the file name of the current config. Will be used to
|
42 |
+
find the base config file.
|
43 |
+
allow_unsafe (bool): whether to allow loading the config file with
|
44 |
+
`yaml.unsafe_load`.
|
45 |
+
|
46 |
+
Returns:
|
47 |
+
(dict): the loaded yaml
|
48 |
+
"""
|
49 |
+
with PathManager.open(filename, "r") as f:
|
50 |
+
try:
|
51 |
+
cfg = yaml.safe_load(f)
|
52 |
+
except yaml.constructor.ConstructorError:
|
53 |
+
if not allow_unsafe:
|
54 |
+
raise
|
55 |
+
logger = logging.getLogger(__name__)
|
56 |
+
logger.warning(
|
57 |
+
"Loading config {} with yaml.unsafe_load. Your machine may "
|
58 |
+
"be at risk if the file contains malicious content.".format(
|
59 |
+
filename
|
60 |
+
)
|
61 |
+
)
|
62 |
+
f.close()
|
63 |
+
with open(filename, "r") as f:
|
64 |
+
cfg = yaml.unsafe_load(f)
|
65 |
+
|
66 |
+
def merge_a_into_b(a, b):
|
67 |
+
# merge dict a into dict b. values in a will overwrite b.
|
68 |
+
for k, v in a.items():
|
69 |
+
if isinstance(v, dict) and k in b:
|
70 |
+
assert isinstance(
|
71 |
+
b[k], dict
|
72 |
+
), "Cannot inherit key '{}' from base!".format(k)
|
73 |
+
merge_a_into_b(v, b[k])
|
74 |
+
else:
|
75 |
+
b[k] = v
|
76 |
+
|
77 |
+
if BASE_KEY in cfg:
|
78 |
+
base_cfg_file = cfg[BASE_KEY]
|
79 |
+
if base_cfg_file.startswith("~"):
|
80 |
+
base_cfg_file = os.path.expanduser(base_cfg_file)
|
81 |
+
if not any(
|
82 |
+
map(base_cfg_file.startswith, ["/", "https://", "http://"])
|
83 |
+
):
|
84 |
+
# the path to base cfg is relative to the config file itself.
|
85 |
+
base_cfg_file = os.path.join(
|
86 |
+
os.path.dirname(filename), base_cfg_file
|
87 |
+
)
|
88 |
+
base_cfg = CfgNode.load_yaml_with_base(
|
89 |
+
base_cfg_file, allow_unsafe=allow_unsafe
|
90 |
+
)
|
91 |
+
del cfg[BASE_KEY]
|
92 |
+
|
93 |
+
merge_a_into_b(cfg, base_cfg)
|
94 |
+
return base_cfg
|
95 |
+
return cfg
|
96 |
+
|
97 |
+
def merge_from_file(self, cfg_filename, allow_unsafe = False):
|
98 |
+
"""
|
99 |
+
Merge configs from a given yaml file.
|
100 |
+
|
101 |
+
Args:
|
102 |
+
cfg_filename: the file name of the yaml config.
|
103 |
+
allow_unsafe: whether to allow loading the config file with
|
104 |
+
`yaml.unsafe_load`.
|
105 |
+
"""
|
106 |
+
loaded_cfg = CfgNode.load_yaml_with_base(
|
107 |
+
cfg_filename, allow_unsafe=allow_unsafe
|
108 |
+
)
|
109 |
+
loaded_cfg = type(self)(loaded_cfg)
|
110 |
+
self.merge_from_other_cfg(loaded_cfg)
|
111 |
+
|
112 |
+
# Forward the following calls to base, but with a check on the BASE_KEY.
|
113 |
+
def merge_from_other_cfg(self, cfg_other):
|
114 |
+
"""
|
115 |
+
Args:
|
116 |
+
cfg_other (CfgNode): configs to merge from.
|
117 |
+
"""
|
118 |
+
assert (
|
119 |
+
BASE_KEY not in cfg_other
|
120 |
+
), "The reserved key '{}' can only be used in files!".format(BASE_KEY)
|
121 |
+
return super().merge_from_other_cfg(cfg_other)
|
122 |
+
|
123 |
+
def merge_from_list(self, cfg_list):
|
124 |
+
"""
|
125 |
+
Args:
|
126 |
+
cfg_list (list): list of configs to merge from.
|
127 |
+
"""
|
128 |
+
keys = set(cfg_list[0::2])
|
129 |
+
assert (
|
130 |
+
BASE_KEY not in keys
|
131 |
+
), "The reserved key '{}' can only be used in files!".format(BASE_KEY)
|
132 |
+
return super().merge_from_list(cfg_list)
|
133 |
+
|
134 |
+
def __setattr__(self, name, val):
|
135 |
+
if name.startswith("COMPUTED_"):
|
136 |
+
if name in self:
|
137 |
+
old_val = self[name]
|
138 |
+
if old_val == val:
|
139 |
+
return
|
140 |
+
raise KeyError(
|
141 |
+
"Computed attributed '{}' already exists "
|
142 |
+
"with a different value! old={}, new={}.".format(
|
143 |
+
name, old_val, val
|
144 |
+
)
|
145 |
+
)
|
146 |
+
self[name] = val
|
147 |
+
else:
|
148 |
+
super().__setattr__(name, val)
|
149 |
+
|
150 |
+
|
151 |
+
if __name__ == '__main__':
|
152 |
+
cfg = CfgNode.load_yaml_with_base('configs/updown_long.yml')
|
153 |
+
print(cfg)
|
captioning/utils/div_utils.py
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from random import uniform
|
2 |
+
import numpy as np
|
3 |
+
from collections import OrderedDict, defaultdict
|
4 |
+
from itertools import tee
|
5 |
+
import time
|
6 |
+
|
7 |
+
# -----------------------------------------------
|
8 |
+
def find_ngrams(input_list, n):
|
9 |
+
return zip(*[input_list[i:] for i in range(n)])
|
10 |
+
|
11 |
+
def compute_div_n(caps,n=1):
|
12 |
+
aggr_div = []
|
13 |
+
for k in caps:
|
14 |
+
all_ngrams = set()
|
15 |
+
lenT = 0.
|
16 |
+
for c in caps[k]:
|
17 |
+
tkns = c.split()
|
18 |
+
lenT += len(tkns)
|
19 |
+
ng = find_ngrams(tkns, n)
|
20 |
+
all_ngrams.update(ng)
|
21 |
+
aggr_div.append(float(len(all_ngrams))/ (1e-6 + float(lenT)))
|
22 |
+
return np.array(aggr_div).mean(), np.array(aggr_div)
|
23 |
+
|
24 |
+
def compute_global_div_n(caps,n=1):
|
25 |
+
aggr_div = []
|
26 |
+
all_ngrams = set()
|
27 |
+
lenT = 0.
|
28 |
+
for k in caps:
|
29 |
+
for c in caps[k]:
|
30 |
+
tkns = c.split()
|
31 |
+
lenT += len(tkns)
|
32 |
+
ng = find_ngrams(tkns, n)
|
33 |
+
all_ngrams.update(ng)
|
34 |
+
if n == 1:
|
35 |
+
aggr_div.append(float(len(all_ngrams)))
|
36 |
+
else:
|
37 |
+
aggr_div.append(float(len(all_ngrams))/ (1e-6 + float(lenT)))
|
38 |
+
return aggr_div[0], np.repeat(np.array(aggr_div),len(caps))
|
captioning/utils/eval_multi.py
ADDED
@@ -0,0 +1,218 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import absolute_import
|
2 |
+
from __future__ import division
|
3 |
+
from __future__ import print_function
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
|
8 |
+
import numpy as np
|
9 |
+
import json
|
10 |
+
from json import encoder
|
11 |
+
import random
|
12 |
+
import string
|
13 |
+
import time
|
14 |
+
import os
|
15 |
+
import sys
|
16 |
+
from . import misc as utils
|
17 |
+
from eval_utils import getCOCO
|
18 |
+
|
19 |
+
from .div_utils import compute_div_n, compute_global_div_n
|
20 |
+
|
21 |
+
import sys
|
22 |
+
try:
|
23 |
+
sys.path.append("coco-caption")
|
24 |
+
annFile = 'coco-caption/annotations/captions_val2014.json'
|
25 |
+
from pycocotools.coco import COCO
|
26 |
+
from pycocoevalcap.eval import COCOEvalCap
|
27 |
+
from pycocoevalcap.eval_spice import COCOEvalCapSpice
|
28 |
+
from pycocoevalcap.tokenizer.ptbtokenizer import PTBTokenizer
|
29 |
+
from pycocoevalcap.bleu.bleu import Bleu
|
30 |
+
sys.path.append("cider")
|
31 |
+
from pyciderevalcap.cider.cider import Cider
|
32 |
+
except:
|
33 |
+
print('Warning: requirements for eval_multi not satisfied')
|
34 |
+
|
35 |
+
|
36 |
+
def eval_allspice(dataset, preds_n, model_id, split):
|
37 |
+
coco = getCOCO(dataset)
|
38 |
+
valids = coco.getImgIds()
|
39 |
+
|
40 |
+
capsById = {}
|
41 |
+
for d in preds_n:
|
42 |
+
capsById[d['image_id']] = capsById.get(d['image_id'], []) + [d]
|
43 |
+
|
44 |
+
# filter results to only those in MSCOCO validation set (will be about a third)
|
45 |
+
preds_filt_n = [p for p in preds_n if p['image_id'] in valids]
|
46 |
+
print('using %d/%d predictions_n' % (len(preds_filt_n), len(preds_n)))
|
47 |
+
cache_path_n = os.path.join('eval_results/', model_id + '_' + split + '_n.json')
|
48 |
+
json.dump(preds_filt_n, open(cache_path_n, 'w')) # serialize to temporary json file. Sigh, COCO API...
|
49 |
+
|
50 |
+
# Eval AllSPICE
|
51 |
+
cocoRes_n = coco.loadRes(cache_path_n)
|
52 |
+
cocoEvalAllSPICE = COCOEvalCapSpice(coco, cocoRes_n)
|
53 |
+
cocoEvalAllSPICE.params['image_id'] = cocoRes_n.getImgIds()
|
54 |
+
cocoEvalAllSPICE.evaluate()
|
55 |
+
|
56 |
+
out = {}
|
57 |
+
for metric, score in cocoEvalAllSPICE.eval.items():
|
58 |
+
out['All'+metric] = score
|
59 |
+
|
60 |
+
imgToEvalAllSPICE = cocoEvalAllSPICE.imgToEval
|
61 |
+
# collect SPICE_sub_score
|
62 |
+
for k in list(imgToEvalAllSPICE.values())[0]['SPICE'].keys():
|
63 |
+
if k != 'All':
|
64 |
+
out['AllSPICE_'+k] = np.array([v['SPICE'][k]['f'] for v in imgToEvalAllSPICE.values()])
|
65 |
+
out['AllSPICE_'+k] = (out['AllSPICE_'+k][out['AllSPICE_'+k]==out['AllSPICE_'+k]]).mean()
|
66 |
+
for p in preds_filt_n:
|
67 |
+
image_id, caption = p['image_id'], p['caption']
|
68 |
+
imgToEvalAllSPICE[image_id]['caption'] = capsById[image_id]
|
69 |
+
return {'overall': out, 'imgToEvalAllSPICE': imgToEvalAllSPICE}
|
70 |
+
|
71 |
+
def eval_oracle(dataset, preds_n, model_id, split):
|
72 |
+
cache_path = os.path.join('eval_results/', model_id + '_' + split + '_n.json')
|
73 |
+
|
74 |
+
coco = getCOCO(dataset)
|
75 |
+
valids = coco.getImgIds()
|
76 |
+
|
77 |
+
capsById = {}
|
78 |
+
for d in preds_n:
|
79 |
+
capsById[d['image_id']] = capsById.get(d['image_id'], []) + [d]
|
80 |
+
|
81 |
+
sample_n = capsById[list(capsById.keys())[0]]
|
82 |
+
for i in range(len(capsById[list(capsById.keys())[0]])):
|
83 |
+
preds = [_[i] for _ in capsById.values()]
|
84 |
+
|
85 |
+
json.dump(preds, open(cache_path, 'w')) # serialize to temporary json file. Sigh, COCO API...
|
86 |
+
|
87 |
+
cocoRes = coco.loadRes(cache_path)
|
88 |
+
cocoEval = COCOEvalCap(coco, cocoRes)
|
89 |
+
cocoEval.params['image_id'] = cocoRes.getImgIds()
|
90 |
+
cocoEval.evaluate()
|
91 |
+
|
92 |
+
imgToEval = cocoEval.imgToEval
|
93 |
+
for img_id in capsById.keys():
|
94 |
+
tmp = imgToEval[img_id]
|
95 |
+
for k in tmp['SPICE'].keys():
|
96 |
+
if k != 'All':
|
97 |
+
tmp['SPICE_'+k] = tmp['SPICE'][k]['f']
|
98 |
+
if tmp['SPICE_'+k] != tmp['SPICE_'+k]: # nan
|
99 |
+
tmp['SPICE_'+k] = -100
|
100 |
+
tmp['SPICE'] = tmp['SPICE']['All']['f']
|
101 |
+
if tmp['SPICE'] != tmp['SPICE']: tmp['SPICE'] = -100
|
102 |
+
capsById[img_id][i]['scores'] = imgToEval[img_id]
|
103 |
+
|
104 |
+
out = {'overall': {}, 'ImgToEval': {}}
|
105 |
+
for img_id in capsById.keys():
|
106 |
+
out['ImgToEval'][img_id] = {}
|
107 |
+
for metric in capsById[img_id][0]['scores'].keys():
|
108 |
+
if metric == 'image_id': continue
|
109 |
+
out['ImgToEval'][img_id]['oracle_'+metric] = max([_['scores'][metric] for _ in capsById[img_id]])
|
110 |
+
out['ImgToEval'][img_id]['avg_'+metric] = sum([_['scores'][metric] for _ in capsById[img_id]]) / len(capsById[img_id])
|
111 |
+
out['ImgToEval'][img_id]['captions'] = capsById[img_id]
|
112 |
+
for metric in list(out['ImgToEval'].values())[0].keys():
|
113 |
+
if metric == 'captions':
|
114 |
+
continue
|
115 |
+
tmp = np.array([_[metric] for _ in out['ImgToEval'].values()])
|
116 |
+
tmp = tmp[tmp!=-100]
|
117 |
+
out['overall'][metric] = tmp.mean()
|
118 |
+
|
119 |
+
return out
|
120 |
+
|
121 |
+
def eval_div_stats(dataset, preds_n, model_id, split):
|
122 |
+
tokenizer = PTBTokenizer()
|
123 |
+
|
124 |
+
capsById = {}
|
125 |
+
for i, d in enumerate(preds_n):
|
126 |
+
d['id'] = i
|
127 |
+
capsById[d['image_id']] = capsById.get(d['image_id'], []) + [d]
|
128 |
+
|
129 |
+
n_caps_perimg = len(capsById[list(capsById.keys())[0]])
|
130 |
+
print(n_caps_perimg)
|
131 |
+
_capsById = capsById # save the untokenized version
|
132 |
+
capsById = tokenizer.tokenize(capsById)
|
133 |
+
|
134 |
+
div_1, adiv_1 = compute_div_n(capsById,1)
|
135 |
+
div_2, adiv_2 = compute_div_n(capsById,2)
|
136 |
+
|
137 |
+
globdiv_1, _= compute_global_div_n(capsById,1)
|
138 |
+
|
139 |
+
print('Diversity Statistics are as follows: \n Div1: %.2f, Div2: %.2f, gDiv1: %d\n'%(div_1,div_2, globdiv_1))
|
140 |
+
|
141 |
+
# compute mbleu
|
142 |
+
scorer = Bleu(4)
|
143 |
+
all_scrs = []
|
144 |
+
scrperimg = np.zeros((n_caps_perimg, len(capsById)))
|
145 |
+
|
146 |
+
for i in range(n_caps_perimg):
|
147 |
+
tempRefsById = {}
|
148 |
+
candsById = {}
|
149 |
+
for k in capsById:
|
150 |
+
tempRefsById[k] = capsById[k][:i] + capsById[k][i+1:]
|
151 |
+
candsById[k] = [capsById[k][i]]
|
152 |
+
|
153 |
+
score, scores = scorer.compute_score(tempRefsById, candsById)
|
154 |
+
all_scrs.append(score)
|
155 |
+
scrperimg[i,:] = scores[1]
|
156 |
+
|
157 |
+
all_scrs = np.array(all_scrs)
|
158 |
+
|
159 |
+
out = {}
|
160 |
+
out['overall'] = {'Div1': div_1, 'Div2': div_2, 'gDiv1': globdiv_1}
|
161 |
+
for k, score in zip(range(4), all_scrs.mean(axis=0).tolist()):
|
162 |
+
out['overall'].update({'mBLeu_%d'%(k+1): score})
|
163 |
+
imgToEval = {}
|
164 |
+
for i,imgid in enumerate(capsById.keys()):
|
165 |
+
imgToEval[imgid] = {'mBleu_2' : scrperimg[:,i].mean()}
|
166 |
+
imgToEval[imgid]['individuals'] = []
|
167 |
+
for j, d in enumerate(_capsById[imgid]):
|
168 |
+
imgToEval[imgid]['individuals'].append(preds_n[d['id']])
|
169 |
+
imgToEval[imgid]['individuals'][-1]['mBleu_2'] = scrperimg[j,i]
|
170 |
+
out['ImgToEval'] = imgToEval
|
171 |
+
|
172 |
+
print('Mean mutual Bleu scores on this set is:\nmBLeu_1, mBLeu_2, mBLeu_3, mBLeu_4')
|
173 |
+
print(all_scrs.mean(axis=0))
|
174 |
+
|
175 |
+
return out
|
176 |
+
|
177 |
+
def eval_self_cider(dataset, preds_n, model_id, split):
|
178 |
+
cache_path = os.path.join('eval_results/', model_id + '_' + split + '_n.json')
|
179 |
+
|
180 |
+
coco = getCOCO(dataset)
|
181 |
+
valids = coco.getImgIds()
|
182 |
+
|
183 |
+
# Get Cider_scorer
|
184 |
+
Cider_scorer = Cider(df='corpus')
|
185 |
+
|
186 |
+
tokenizer = PTBTokenizer()
|
187 |
+
gts = {}
|
188 |
+
for imgId in valids:
|
189 |
+
gts[imgId] = coco.imgToAnns[imgId]
|
190 |
+
gts = tokenizer.tokenize(gts)
|
191 |
+
|
192 |
+
for imgId in valids:
|
193 |
+
Cider_scorer.cider_scorer += (None, gts[imgId])
|
194 |
+
Cider_scorer.cider_scorer.compute_doc_freq()
|
195 |
+
Cider_scorer.cider_scorer.ref_len = np.log(float(len(Cider_scorer.cider_scorer.crefs)))
|
196 |
+
|
197 |
+
# Prepare captions
|
198 |
+
capsById = {}
|
199 |
+
for d in preds_n:
|
200 |
+
capsById[d['image_id']] = capsById.get(d['image_id'], []) + [d]
|
201 |
+
|
202 |
+
capsById = tokenizer.tokenize(capsById)
|
203 |
+
imgIds = list(capsById.keys())
|
204 |
+
scores = Cider_scorer.my_self_cider([capsById[_] for _ in imgIds])
|
205 |
+
|
206 |
+
def get_div(eigvals):
|
207 |
+
eigvals = np.clip(eigvals, 0, None)
|
208 |
+
return -np.log(np.sqrt(eigvals[-1]) / (np.sqrt(eigvals).sum())) / np.log(len(eigvals))
|
209 |
+
sc_scores = [get_div(np.linalg.eigvalsh(_/10)) for _ in scores]
|
210 |
+
score = np.mean(np.array(sc_scores))
|
211 |
+
|
212 |
+
imgToEval = {}
|
213 |
+
for i, image_id in enumerate(imgIds):
|
214 |
+
imgToEval[image_id] = {'self_cider': sc_scores[i], 'self_cider_mat': scores[i].tolist()}
|
215 |
+
return {'overall': {'self_cider': score}, 'imgToEval': imgToEval}
|
216 |
+
|
217 |
+
|
218 |
+
return score
|
captioning/utils/eval_utils.py
ADDED
@@ -0,0 +1,323 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import absolute_import
|
2 |
+
from __future__ import division
|
3 |
+
from __future__ import print_function
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
import torch.nn.functional as F
|
8 |
+
|
9 |
+
import numpy as np
|
10 |
+
import json
|
11 |
+
from json import encoder
|
12 |
+
import random
|
13 |
+
import string
|
14 |
+
import time
|
15 |
+
import os
|
16 |
+
import sys
|
17 |
+
from . import misc as utils
|
18 |
+
|
19 |
+
# sys.path.insert(0, os.getcwd())
|
20 |
+
|
21 |
+
# sys.path.append("coco-caption")
|
22 |
+
|
23 |
+
# load coco-caption if available
|
24 |
+
|
25 |
+
from coco_caption.pycocotools.coco import COCO
|
26 |
+
from coco_caption.pycocoevalcap.eval import COCOEvalCap
|
27 |
+
|
28 |
+
# try:
|
29 |
+
# # sys.path.append("coco-caption")
|
30 |
+
# # from pycocotools.coco import COCO
|
31 |
+
# # from pycocoevalcap.eval import COCOEvalCap
|
32 |
+
# from coco_caption.pycocotools.coco import COCO
|
33 |
+
# from coco_caption.pycocoevalcap.eval import COCOEvalCap
|
34 |
+
# except:
|
35 |
+
# print('Warning: coco-caption not available')
|
36 |
+
|
37 |
+
bad_endings = ['a','an','the','in','for','at','of','with','before','after','on','upon','near','to','is','are','am']
|
38 |
+
bad_endings += ['UNK', 'has', 'and', 'more']
|
39 |
+
|
40 |
+
|
41 |
+
def count_bad(sen):
|
42 |
+
sen = sen.split(' ')
|
43 |
+
if sen[-1] in bad_endings:
|
44 |
+
return 1
|
45 |
+
else:
|
46 |
+
return 0
|
47 |
+
|
48 |
+
|
49 |
+
def getCOCO(dataset):
|
50 |
+
if 'coco' in dataset:
|
51 |
+
annFile = 'coco-caption/annotations/captions_val2014.json'
|
52 |
+
elif 'flickr30k' in dataset or 'f30k' in dataset:
|
53 |
+
annFile = 'data/f30k_captions4eval.json'
|
54 |
+
# elif 'relative' in dataset:
|
55 |
+
# annFile = 'data/dress/features_simulator/caption_relative.json'
|
56 |
+
elif 'dress' in dataset:
|
57 |
+
annFile = 'data/dress/features_simulator/caption_relative.json'
|
58 |
+
elif 'shirt' in dataset:
|
59 |
+
annFile = 'data/shirt/features_simulator/caption_relative.json'
|
60 |
+
elif 'toptee' in dataset:
|
61 |
+
annFile = 'data/toptee/features_simulator/caption_relative.json'
|
62 |
+
elif 'fashion-gen' in dataset:
|
63 |
+
annFile = 'data/fashion-gen/features_simulator/caption_direct.json'
|
64 |
+
elif 'shoe' in dataset:
|
65 |
+
annFile = 'data/shoe/features_simulator/caption_relative.json'
|
66 |
+
return COCO(annFile)
|
67 |
+
|
68 |
+
|
69 |
+
def language_eval(dataset, preds, preds_n, eval_kwargs, split):
|
70 |
+
model_id = eval_kwargs['id']
|
71 |
+
eval_oracle = eval_kwargs.get('eval_oracle', 0)
|
72 |
+
|
73 |
+
# create output dictionary
|
74 |
+
out = {}
|
75 |
+
|
76 |
+
if len(preds_n) > 0:
|
77 |
+
# vocab size and novel sentences
|
78 |
+
if 'coco' in dataset:
|
79 |
+
dataset_file = 'data/dataset_coco.json'
|
80 |
+
elif 'flickr30k' in dataset or 'f30k' in dataset:
|
81 |
+
dataset_file = 'data/dataset_flickr30k.json'
|
82 |
+
# elif 'relative' in dataset:
|
83 |
+
# dataset_file = 'data/dress/features_simulator/caption_relative.json'
|
84 |
+
elif 'dress' in dataset:
|
85 |
+
annFile = 'data/dress/features_simulator/caption_relative.json'
|
86 |
+
elif 'shirt' in dataset:
|
87 |
+
annFile = 'data/shirt/features_simulator/caption_relative.json'
|
88 |
+
elif 'toptee' in dataset:
|
89 |
+
annFile = 'data/toptee/features_simulator/caption_relative.json'
|
90 |
+
elif 'fashion-gen' in dataset:
|
91 |
+
annFile = 'data/fashion-gen/features_simulator/caption_direct.json'
|
92 |
+
elif 'shoe' in dataset:
|
93 |
+
annFile = 'data/shoe/features_simulator/caption_relative.json'
|
94 |
+
training_sentences = set([' '.join(__['tokens']) for _ in json.load(open(dataset_file))['images'] if not _['split'] in ['val', 'test'] for __ in _['sentences']])
|
95 |
+
generated_sentences = set([_['caption'] for _ in preds_n])
|
96 |
+
novels = generated_sentences - training_sentences
|
97 |
+
out['novel_sentences'] = float(len(novels)) / len(preds_n)
|
98 |
+
tmp = [_.split() for _ in generated_sentences]
|
99 |
+
words = []
|
100 |
+
for _ in tmp:
|
101 |
+
words += _
|
102 |
+
out['vocab_size'] = len(set(words))
|
103 |
+
|
104 |
+
# encoder.FLOAT_REPR = lambda o: format(o, '.3f')
|
105 |
+
|
106 |
+
# cache_path = os.path.join('eval_results/', '.cache_'+ model_id + '_' + split + '.json')\
|
107 |
+
cache_path = os.path.join('results/log_'+eval_kwargs['topic']+'_'+model_id+'/eval_results_'+eval_kwargs['topic']+'/', '.cache_'+ model_id + '_' + split + '.json')
|
108 |
+
|
109 |
+
coco = getCOCO(dataset)
|
110 |
+
valids = coco.getImgIds()
|
111 |
+
|
112 |
+
# filter results to only those in MSCOCO validation set
|
113 |
+
preds_filt = [p for p in preds if p['image_id'] in valids]
|
114 |
+
mean_perplexity = sum([_['perplexity'] for _ in preds_filt]) / len(preds_filt)
|
115 |
+
mean_entropy = sum([_['entropy'] for _ in preds_filt]) / len(preds_filt)
|
116 |
+
print('using %d/%d predictions' % (len(preds_filt), len(preds)))
|
117 |
+
json.dump(preds_filt, open(cache_path, 'w')) # serialize to temporary json file. Sigh, COCO API...
|
118 |
+
|
119 |
+
cocoRes = coco.loadRes(cache_path)
|
120 |
+
cocoEval = COCOEvalCap(coco, cocoRes)
|
121 |
+
cocoEval.params['image_id'] = cocoRes.getImgIds()
|
122 |
+
cocoEval.evaluate()
|
123 |
+
|
124 |
+
for metric, score in cocoEval.eval.items():
|
125 |
+
out[metric] = score
|
126 |
+
# Add mean perplexity
|
127 |
+
out['perplexity'] = mean_perplexity
|
128 |
+
out['entropy'] = mean_entropy
|
129 |
+
|
130 |
+
imgToEval = cocoEval.imgToEval
|
131 |
+
for k in list(imgToEval.values())[0]['SPICE'].keys():
|
132 |
+
if k != 'All':
|
133 |
+
out['SPICE_'+k] = np.array([v['SPICE'][k]['f'] for v in imgToEval.values()])
|
134 |
+
out['SPICE_'+k] = (out['SPICE_'+k][out['SPICE_'+k]==out['SPICE_'+k]]).mean()
|
135 |
+
for p in preds_filt:
|
136 |
+
image_id, caption = p['image_id'], p['caption']
|
137 |
+
imgToEval[image_id]['caption'] = caption
|
138 |
+
|
139 |
+
if len(preds_n) > 0:
|
140 |
+
from . import eval_multi
|
141 |
+
# cache_path_n = os.path.join('eval_results/', '.cache_'+ model_id + '_' + split + '_n.json')
|
142 |
+
cache_path_n = os.path.join('results/log_'+eval_kwargs['topic']+'_'+model_id+'/eval_results_'+eval_kwargs['topic']+'/', '.cache_'+ model_id + '_' + split + '_n.json')
|
143 |
+
allspice = eval_multi.eval_allspice(dataset, preds_n, model_id, split)
|
144 |
+
out.update(allspice['overall'])
|
145 |
+
div_stats = eval_multi.eval_div_stats(dataset, preds_n, model_id, split)
|
146 |
+
out.update(div_stats['overall'])
|
147 |
+
if eval_oracle:
|
148 |
+
oracle = eval_multi.eval_oracle(dataset, preds_n, model_id, split)
|
149 |
+
out.update(oracle['overall'])
|
150 |
+
else:
|
151 |
+
oracle = None
|
152 |
+
self_cider = eval_multi.eval_self_cider(dataset, preds_n, model_id, split)
|
153 |
+
out.update(self_cider['overall'])
|
154 |
+
with open(cache_path_n, 'w') as outfile:
|
155 |
+
json.dump({'allspice': allspice, 'div_stats': div_stats, 'oracle': oracle, 'self_cider': self_cider}, outfile)
|
156 |
+
|
157 |
+
out['bad_count_rate'] = sum([count_bad(_['caption']) for _ in preds_filt]) / float(len(preds_filt))
|
158 |
+
# outfile_path = os.path.join('eval_results/', model_id + '_' + split + '.json')
|
159 |
+
outfile_path = os.path.join('results/log_'+eval_kwargs['topic']+'_'+model_id+'/eval_results_'+eval_kwargs['topic']+'/', model_id + '_' + split + '.json')
|
160 |
+
with open(outfile_path, 'w') as outfile:
|
161 |
+
json.dump({'overall': out, 'imgToEval': imgToEval}, outfile)
|
162 |
+
|
163 |
+
return out
|
164 |
+
|
165 |
+
def eval_split(model, crit, loader, eval_kwargs={}):
|
166 |
+
verbose = eval_kwargs.get('verbose', True)
|
167 |
+
verbose_beam = eval_kwargs.get('verbose_beam', 0)
|
168 |
+
verbose_loss = eval_kwargs.get('verbose_loss', 1)
|
169 |
+
num_images = eval_kwargs.get('num_images', eval_kwargs.get('val_images_use', -1))
|
170 |
+
split = eval_kwargs.get('split', 'val')
|
171 |
+
lang_eval = eval_kwargs.get('language_eval', 0)
|
172 |
+
dataset = eval_kwargs.get('dataset', 'coco')
|
173 |
+
beam_size = eval_kwargs.get('beam_size', 1)
|
174 |
+
sample_n = eval_kwargs.get('sample_n', 1)
|
175 |
+
remove_bad_endings = eval_kwargs.get('remove_bad_endings', 1)
|
176 |
+
os.environ["REMOVE_BAD_ENDINGS"] = str(remove_bad_endings) # Use this nasty way to make other code clean since it's a global configuration
|
177 |
+
device = eval_kwargs.get('device', 'cuda')
|
178 |
+
|
179 |
+
# Make sure in the evaluation mode
|
180 |
+
model.eval()
|
181 |
+
|
182 |
+
loader.reset_iterator(split)
|
183 |
+
|
184 |
+
n = 0
|
185 |
+
loss = 0
|
186 |
+
loss_sum = 0
|
187 |
+
loss_evals = 1e-8
|
188 |
+
predictions = []
|
189 |
+
n_predictions = [] # when sample_n > 1
|
190 |
+
while True:
|
191 |
+
data = loader.get_batch(split)
|
192 |
+
n = n + len(data['infos'])
|
193 |
+
|
194 |
+
tmp = [data['fc_feats'], data['att_feats'], data['labels'], data['masks'], data['att_masks']]
|
195 |
+
tmp = [_.to(device) if _ is not None else _ for _ in tmp]
|
196 |
+
fc_feats, att_feats, labels, masks, att_masks = tmp
|
197 |
+
|
198 |
+
if labels is not None and verbose_loss:
|
199 |
+
# forward the model to get loss
|
200 |
+
with torch.no_grad():
|
201 |
+
loss = crit(model(fc_feats, att_feats, labels[..., :-1], att_masks), labels[..., 1:], masks[..., 1:]).item()
|
202 |
+
loss_sum = loss_sum + loss
|
203 |
+
loss_evals = loss_evals + 1
|
204 |
+
|
205 |
+
# forward the model to also get generated samples for each image
|
206 |
+
with torch.no_grad():
|
207 |
+
tmp_eval_kwargs = eval_kwargs.copy()
|
208 |
+
tmp_eval_kwargs.update({'sample_n': 1})
|
209 |
+
seq, seq_logprobs = model(fc_feats, att_feats, att_masks, opt=tmp_eval_kwargs, mode='sample')
|
210 |
+
seq = seq.data
|
211 |
+
|
212 |
+
entropy = - (F.softmax(seq_logprobs, dim=2) * seq_logprobs).sum(2).sum(1) / ((seq>0).to(seq_logprobs).sum(1)+1)
|
213 |
+
perplexity = - seq_logprobs.gather(2, seq.unsqueeze(2)).squeeze(2).sum(1) / ((seq>0).to(seq_logprobs).sum(1)+1)
|
214 |
+
|
215 |
+
# Print beam search
|
216 |
+
if beam_size > 1 and verbose_beam:
|
217 |
+
for i in range(fc_feats.shape[0]):
|
218 |
+
print('\n'.join([utils.decode_sequence(model.vocab, _['seq'].unsqueeze(0))[0] for _ in model.done_beams[i]]))
|
219 |
+
print('--' * 10)
|
220 |
+
sents = utils.decode_sequence(model.vocab, seq)
|
221 |
+
|
222 |
+
for k, sent in enumerate(sents):
|
223 |
+
entry = {'image_id': data['infos'][k]['id'], 'caption': sent, 'perplexity': perplexity[k].item(), 'entropy': entropy[k].item()}
|
224 |
+
if eval_kwargs.get('dump_path', 0) == 1:
|
225 |
+
entry['file_name'] = data['infos'][k]['file_path']
|
226 |
+
predictions.append(entry)
|
227 |
+
if eval_kwargs.get('dump_images', 0) == 1:
|
228 |
+
# dump the raw image to vis/ folder
|
229 |
+
cmd = 'cp "' + os.path.join(eval_kwargs['image_root'], data['infos'][k]['file_path']) + '" vis/imgs/img' + str(len(predictions)) + '.jpg' # bit gross
|
230 |
+
print(cmd)
|
231 |
+
os.system(cmd)
|
232 |
+
|
233 |
+
if verbose:
|
234 |
+
print('image %s: %s' %(entry['image_id'], entry['caption']))
|
235 |
+
|
236 |
+
if sample_n > 1:
|
237 |
+
eval_split_n(model, n_predictions, [fc_feats, att_feats, att_masks, data], eval_kwargs)
|
238 |
+
|
239 |
+
# ix0 = data['bounds']['it_pos_now']
|
240 |
+
ix1 = data['bounds']['it_max']
|
241 |
+
if num_images != -1:
|
242 |
+
ix1 = min(ix1, num_images)
|
243 |
+
else:
|
244 |
+
num_images = ix1
|
245 |
+
for i in range(n - ix1):
|
246 |
+
predictions.pop()
|
247 |
+
|
248 |
+
if verbose:
|
249 |
+
print('evaluating validation preformance... %d/%d (%f)' %(n, ix1, loss))
|
250 |
+
|
251 |
+
if num_images >= 0 and n >= num_images:
|
252 |
+
break
|
253 |
+
|
254 |
+
lang_stats = None
|
255 |
+
if len(n_predictions) > 0 and 'perplexity' in n_predictions[0]:
|
256 |
+
n_predictions = sorted(n_predictions, key=lambda x: x['perplexity'])
|
257 |
+
# if not os.path.isdir('eval_results'):
|
258 |
+
# os.mkdir('eval_results')
|
259 |
+
if not os.path.isdir('results/log_'+eval_kwargs['topic']+'_'+eval_kwargs['id']+'/eval_results_'+eval_kwargs['topic']):
|
260 |
+
os.mkdir('results/log_'+eval_kwargs['topic']+'_'+eval_kwargs['id']+'/eval_results_'+eval_kwargs['topic'])
|
261 |
+
# torch.save((predictions, n_predictions), os.path.join('eval_results/', '.saved_pred_'+ eval_kwargs['id'] + '_' + split + '.pth'))
|
262 |
+
torch.save((predictions, n_predictions), os.path.join('results/log_'+eval_kwargs['topic']+'_'+eval_kwargs['id']+'/eval_results_'+eval_kwargs['topic']+'/', '.saved_pred_'+ eval_kwargs['id'] + '_' + split + '.pth'))
|
263 |
+
if lang_eval == 1:
|
264 |
+
lang_stats = language_eval(dataset, predictions, n_predictions, eval_kwargs, split)
|
265 |
+
|
266 |
+
# Switch back to training mode
|
267 |
+
model.train()
|
268 |
+
return loss_sum/loss_evals, predictions, lang_stats
|
269 |
+
|
270 |
+
|
271 |
+
# Only run when sample_n > 0
|
272 |
+
def eval_split_n(model, n_predictions, input_data, eval_kwargs={}):
|
273 |
+
verbose = eval_kwargs.get('verbose', True)
|
274 |
+
beam_size = eval_kwargs.get('beam_size', 1)
|
275 |
+
sample_n = eval_kwargs.get('sample_n', 1)
|
276 |
+
sample_n_method = eval_kwargs.get('sample_n_method', 'sample')
|
277 |
+
|
278 |
+
fc_feats, att_feats, att_masks, data = input_data
|
279 |
+
|
280 |
+
tmp_eval_kwargs = eval_kwargs.copy()
|
281 |
+
if sample_n_method == 'bs':
|
282 |
+
# case 1 sample_n == beam size
|
283 |
+
tmp_eval_kwargs.update({'sample_n': 1, 'beam_size': sample_n, 'group_size': 1}) # randomness from softmax
|
284 |
+
with torch.no_grad():
|
285 |
+
model(fc_feats, att_feats, att_masks, opt=tmp_eval_kwargs, mode='sample')
|
286 |
+
for k in range(fc_feats.shape[0]):
|
287 |
+
_sents = utils.decode_sequence(model.vocab, torch.stack([model.done_beams[k][_]['seq'] for _ in range(sample_n)]))
|
288 |
+
for sent in _sents:
|
289 |
+
entry = {'image_id': data['infos'][k]['id'], 'caption': sent}
|
290 |
+
n_predictions.append(entry)
|
291 |
+
# case 2 sample / gumbel / topk sampling/ nucleus sampling
|
292 |
+
elif sample_n_method == 'sample' or \
|
293 |
+
sample_n_method == 'gumbel' or \
|
294 |
+
sample_n_method.startswith('top'):
|
295 |
+
tmp_eval_kwargs.update({'sample_n': sample_n, 'sample_method': sample_n_method, 'beam_size': 1}) # randomness from sample
|
296 |
+
with torch.no_grad():
|
297 |
+
_seq, _sampleLogprobs = model(fc_feats, att_feats, att_masks, opt=tmp_eval_kwargs, mode='sample')
|
298 |
+
_sents = utils.decode_sequence(model.vocab, _seq)
|
299 |
+
_perplexity = - _sampleLogprobs.gather(2, _seq.unsqueeze(2)).squeeze(2).sum(1) / ((_seq>0).to(_sampleLogprobs).sum(1)+1)
|
300 |
+
for k, sent in enumerate(_sents):
|
301 |
+
entry = {'image_id': data['infos'][k // sample_n]['id'], 'caption': sent, 'perplexity': _perplexity[k].item()}
|
302 |
+
n_predictions.append(entry)
|
303 |
+
elif sample_n_method == 'dbs':
|
304 |
+
# Use diverse beam search
|
305 |
+
tmp_eval_kwargs.update({'beam_size': sample_n * beam_size, 'group_size': sample_n}) # randomness from softmax
|
306 |
+
with torch.no_grad():
|
307 |
+
model(fc_feats, att_feats, att_masks, opt=tmp_eval_kwargs, mode='sample')
|
308 |
+
for k in range(loader.batch_size):
|
309 |
+
_sents = utils.decode_sequence(model.vocab, torch.stack([model.done_beams[k][_]['seq'] for _ in range(0, sample_n*beam_size, beam_size)]))
|
310 |
+
for sent in _sents:
|
311 |
+
entry = {'image_id': data['infos'][k]['id'], 'caption': sent}
|
312 |
+
n_predictions.append(entry)
|
313 |
+
else:
|
314 |
+
tmp_eval_kwargs.update({'sample_method': sample_n_method[1:], 'group_size': sample_n, 'beam_size':1}) # randomness from softmax
|
315 |
+
with torch.no_grad():
|
316 |
+
_seq, _sampleLogprobs = model(fc_feats, att_feats, att_masks, opt=tmp_eval_kwargs, mode='sample')
|
317 |
+
_sents = utils.decode_sequence(model.vocab, _seq)
|
318 |
+
for k, sent in enumerate(_sents):
|
319 |
+
entry = {'image_id': data['infos'][k // sample_n]['id'], 'caption': sent}
|
320 |
+
n_predictions.append(entry)
|
321 |
+
if verbose:
|
322 |
+
for entry in sorted(n_predictions[-fc_feats.shape[0] * sample_n:], key=lambda x: x['image_id']):
|
323 |
+
print('image %s: %s' %(entry['image_id'], entry['caption']))
|
captioning/utils/misc.py
ADDED
@@ -0,0 +1,249 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import absolute_import
|
2 |
+
from __future__ import division
|
3 |
+
from __future__ import print_function
|
4 |
+
|
5 |
+
import collections
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
import numpy as np
|
9 |
+
import torch.optim as optim
|
10 |
+
import os
|
11 |
+
|
12 |
+
import torch.nn.functional as F
|
13 |
+
|
14 |
+
import six
|
15 |
+
from six.moves import cPickle
|
16 |
+
|
17 |
+
bad_endings = ['a','an','the','in','for','at','of','with','before','after','on','upon','near','to','is','are','am']
|
18 |
+
bad_endings += ['UNK', 'has', 'and', 'more']
|
19 |
+
|
20 |
+
def pickle_load(f):
|
21 |
+
""" Load a pickle.
|
22 |
+
Parameters
|
23 |
+
----------
|
24 |
+
f: file-like object
|
25 |
+
"""
|
26 |
+
if six.PY3:
|
27 |
+
return cPickle.load(f, encoding='latin-1')
|
28 |
+
else:
|
29 |
+
return cPickle.load(f)
|
30 |
+
|
31 |
+
|
32 |
+
def pickle_dump(obj, f):
|
33 |
+
""" Dump a pickle.
|
34 |
+
Parameters
|
35 |
+
----------
|
36 |
+
obj: pickled object
|
37 |
+
f: file-like object
|
38 |
+
"""
|
39 |
+
if six.PY3:
|
40 |
+
return cPickle.dump(obj, f, protocol=2)
|
41 |
+
else:
|
42 |
+
return cPickle.dump(obj, f)
|
43 |
+
|
44 |
+
|
45 |
+
# modified from https://github.com/facebookresearch/detectron2/blob/master/detectron2/utils/comm.py
|
46 |
+
def serialize_to_tensor(data):
|
47 |
+
device = torch.device("cpu")
|
48 |
+
|
49 |
+
buffer = cPickle.dumps(data)
|
50 |
+
storage = torch.ByteStorage.from_buffer(buffer)
|
51 |
+
tensor = torch.ByteTensor(storage).to(device=device)
|
52 |
+
return tensor
|
53 |
+
|
54 |
+
|
55 |
+
def deserialize(tensor):
|
56 |
+
buffer = tensor.cpu().numpy().tobytes()
|
57 |
+
return cPickle.loads(buffer)
|
58 |
+
|
59 |
+
|
60 |
+
# Input: seq, N*D numpy array, with element 0 .. vocab_size. 0 is END token.
|
61 |
+
def decode_sequence(ix_to_word, seq):
|
62 |
+
N, D = seq.size()
|
63 |
+
out = []
|
64 |
+
for i in range(N):
|
65 |
+
txt = ''
|
66 |
+
for j in range(D):
|
67 |
+
ix = seq[i,j]
|
68 |
+
if ix > 0 :
|
69 |
+
if j >= 1:
|
70 |
+
txt = txt + ' '
|
71 |
+
txt = txt + ix_to_word[str(ix.item())]
|
72 |
+
else:
|
73 |
+
break
|
74 |
+
if int(os.getenv('REMOVE_BAD_ENDINGS', '0')):
|
75 |
+
flag = 0
|
76 |
+
words = txt.split(' ')
|
77 |
+
for j in range(len(words)):
|
78 |
+
if words[-j-1] not in bad_endings:
|
79 |
+
flag = -j
|
80 |
+
break
|
81 |
+
txt = ' '.join(words[0:len(words)+flag])
|
82 |
+
out.append(txt.replace('@@ ', ''))
|
83 |
+
return out
|
84 |
+
|
85 |
+
|
86 |
+
def save_checkpoint(opt, model, infos, optimizer, histories=None, append=''):
|
87 |
+
if len(append) > 0:
|
88 |
+
append = '_' + append
|
89 |
+
# if checkpoint_path doesn't exist
|
90 |
+
if not os.path.isdir(opt.checkpoint_path):
|
91 |
+
os.makedirs(opt.checkpoint_path)
|
92 |
+
checkpoint_path = os.path.join(opt.checkpoint_path, 'model%s.pth' %(append))
|
93 |
+
torch.save(model.state_dict(), checkpoint_path)
|
94 |
+
print("model saved to {}".format(checkpoint_path))
|
95 |
+
optimizer_path = os.path.join(opt.checkpoint_path, 'optimizer%s.pth' %(append))
|
96 |
+
torch.save(optimizer.state_dict(), optimizer_path)
|
97 |
+
with open(os.path.join(opt.checkpoint_path, 'infos%s.pkl' %(append)), 'wb') as f:
|
98 |
+
pickle_dump(infos, f)
|
99 |
+
if histories:
|
100 |
+
with open(os.path.join(opt.checkpoint_path, 'histories%s.pkl' %(append)), 'wb') as f:
|
101 |
+
pickle_dump(histories, f)
|
102 |
+
|
103 |
+
|
104 |
+
def set_lr(optimizer, lr):
|
105 |
+
for group in optimizer.param_groups:
|
106 |
+
group['lr'] = lr
|
107 |
+
|
108 |
+
def get_lr(optimizer):
|
109 |
+
for group in optimizer.param_groups:
|
110 |
+
return group['lr']
|
111 |
+
|
112 |
+
|
113 |
+
def build_optimizer(params, opt):
|
114 |
+
if opt.optim == 'rmsprop':
|
115 |
+
return optim.RMSprop(params, opt.learning_rate, opt.optim_alpha, opt.optim_epsilon, weight_decay=opt.weight_decay)
|
116 |
+
elif opt.optim == 'adagrad':
|
117 |
+
return optim.Adagrad(params, opt.learning_rate, weight_decay=opt.weight_decay)
|
118 |
+
elif opt.optim == 'sgd':
|
119 |
+
return optim.SGD(params, opt.learning_rate, weight_decay=opt.weight_decay)
|
120 |
+
elif opt.optim == 'sgdm':
|
121 |
+
return optim.SGD(params, opt.learning_rate, opt.optim_alpha, weight_decay=opt.weight_decay)
|
122 |
+
elif opt.optim == 'sgdmom':
|
123 |
+
return optim.SGD(params, opt.learning_rate, opt.optim_alpha, weight_decay=opt.weight_decay, nesterov=True)
|
124 |
+
elif opt.optim == 'adam':
|
125 |
+
return optim.Adam(params, opt.learning_rate, (opt.optim_alpha, opt.optim_beta), opt.optim_epsilon, weight_decay=opt.weight_decay)
|
126 |
+
elif opt.optim == 'adamw':
|
127 |
+
return optim.AdamW(params, opt.learning_rate, (opt.optim_alpha, opt.optim_beta), opt.optim_epsilon, weight_decay=opt.weight_decay)
|
128 |
+
else:
|
129 |
+
raise Exception("bad option opt.optim: {}".format(opt.optim))
|
130 |
+
|
131 |
+
|
132 |
+
def penalty_builder(penalty_config):
|
133 |
+
if penalty_config == '':
|
134 |
+
return lambda x,y: y
|
135 |
+
pen_type, alpha = penalty_config.split('_')
|
136 |
+
alpha = float(alpha)
|
137 |
+
if pen_type == 'wu':
|
138 |
+
return lambda x,y: length_wu(x,y,alpha)
|
139 |
+
if pen_type == 'avg':
|
140 |
+
return lambda x,y: length_average(x,y,alpha)
|
141 |
+
|
142 |
+
def length_wu(length, logprobs, alpha=0.):
|
143 |
+
"""
|
144 |
+
NMT length re-ranking score from
|
145 |
+
"Google's Neural Machine Translation System" :cite:`wu2016google`.
|
146 |
+
"""
|
147 |
+
|
148 |
+
modifier = (((5 + length) ** alpha) /
|
149 |
+
((5 + 1) ** alpha))
|
150 |
+
return (logprobs / modifier)
|
151 |
+
|
152 |
+
def length_average(length, logprobs, alpha=0.):
|
153 |
+
"""
|
154 |
+
Returns the average probability of tokens in a sequence.
|
155 |
+
"""
|
156 |
+
return logprobs / length
|
157 |
+
|
158 |
+
|
159 |
+
class NoamOpt(object):
|
160 |
+
"Optim wrapper that implements rate."
|
161 |
+
def __init__(self, model_size, factor, warmup, optimizer):
|
162 |
+
self.optimizer = optimizer
|
163 |
+
self._step = 0
|
164 |
+
self.warmup = warmup
|
165 |
+
self.factor = factor
|
166 |
+
self.model_size = model_size
|
167 |
+
self._rate = 0
|
168 |
+
|
169 |
+
def step(self):
|
170 |
+
"Update parameters and rate"
|
171 |
+
self._step += 1
|
172 |
+
rate = self.rate()
|
173 |
+
for p in self.optimizer.param_groups:
|
174 |
+
p['lr'] = rate
|
175 |
+
self._rate = rate
|
176 |
+
self.optimizer.step()
|
177 |
+
|
178 |
+
def rate(self, step = None):
|
179 |
+
"Implement `lrate` above"
|
180 |
+
if step is None:
|
181 |
+
step = self._step
|
182 |
+
return self.factor * \
|
183 |
+
(self.model_size ** (-0.5) *
|
184 |
+
min(step ** (-0.5), step * self.warmup ** (-1.5)))
|
185 |
+
|
186 |
+
def __getattr__(self, name):
|
187 |
+
return getattr(self.optimizer, name)
|
188 |
+
|
189 |
+
def state_dict(self):
|
190 |
+
state_dict = self.optimizer.state_dict()
|
191 |
+
state_dict['_step'] = self._step
|
192 |
+
return state_dict
|
193 |
+
|
194 |
+
def load_state_dict(self, state_dict):
|
195 |
+
if '_step' in state_dict:
|
196 |
+
self._step = state_dict['_step']
|
197 |
+
del state_dict['_step']
|
198 |
+
self.optimizer.load_state_dict(state_dict)
|
199 |
+
|
200 |
+
class ReduceLROnPlateau(object):
|
201 |
+
"Optim wrapper that implements rate."
|
202 |
+
def __init__(self, optimizer, mode='min', factor=0.1, patience=10, verbose=False, threshold=0.0001, threshold_mode='rel', cooldown=0, min_lr=0, eps=1e-08):
|
203 |
+
self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode, factor, patience, verbose, threshold, threshold_mode, cooldown, min_lr, eps)
|
204 |
+
self.optimizer = optimizer
|
205 |
+
self.current_lr = get_lr(optimizer)
|
206 |
+
|
207 |
+
def step(self):
|
208 |
+
"Update parameters and rate"
|
209 |
+
self.optimizer.step()
|
210 |
+
|
211 |
+
def scheduler_step(self, val):
|
212 |
+
self.scheduler.step(val)
|
213 |
+
self.current_lr = get_lr(self.optimizer)
|
214 |
+
|
215 |
+
def state_dict(self):
|
216 |
+
return {'current_lr':self.current_lr,
|
217 |
+
'scheduler_state_dict': self.scheduler.state_dict(),
|
218 |
+
'optimizer_state_dict': self.optimizer.state_dict()}
|
219 |
+
|
220 |
+
def load_state_dict(self, state_dict):
|
221 |
+
if 'current_lr' not in state_dict:
|
222 |
+
# it's normal optimizer
|
223 |
+
self.optimizer.load_state_dict(state_dict)
|
224 |
+
set_lr(self.optimizer, self.current_lr) # use the lr fromt the option
|
225 |
+
else:
|
226 |
+
# it's a schduler
|
227 |
+
self.current_lr = state_dict['current_lr']
|
228 |
+
self.scheduler.load_state_dict(state_dict['scheduler_state_dict'])
|
229 |
+
self.optimizer.load_state_dict(state_dict['optimizer_state_dict'])
|
230 |
+
# current_lr is actually useless in this case
|
231 |
+
|
232 |
+
def rate(self, step = None):
|
233 |
+
"Implement `lrate` above"
|
234 |
+
if step is None:
|
235 |
+
step = self._step
|
236 |
+
return self.factor * \
|
237 |
+
(self.model_size ** (-0.5) *
|
238 |
+
min(step ** (-0.5), step * self.warmup ** (-1.5)))
|
239 |
+
|
240 |
+
def __getattr__(self, name):
|
241 |
+
return getattr(self.optimizer, name)
|
242 |
+
|
243 |
+
def get_std_opt(model, optim_func='adam', factor=1, warmup=2000):
|
244 |
+
# return NoamOpt(model.tgt_embed[0].d_model, 2, 4000,
|
245 |
+
# torch.optim.Adam(model.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9))
|
246 |
+
optim_func = dict(adam=torch.optim.Adam,
|
247 |
+
adamw=torch.optim.AdamW)[optim_func]
|
248 |
+
return NoamOpt(model.d_model, factor, warmup,
|
249 |
+
optim_func(model.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9))
|
captioning/utils/opts.py
ADDED
@@ -0,0 +1,365 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import print_function
|
2 |
+
import argparse
|
3 |
+
|
4 |
+
|
5 |
+
def if_use_feat(caption_model):
|
6 |
+
# Decide if load attention feature according to caption model
|
7 |
+
if caption_model in ['show_tell', 'all_img', 'fc', 'newfc']:
|
8 |
+
use_att, use_fc = False, True
|
9 |
+
elif caption_model == 'language_model':
|
10 |
+
use_att, use_fc = False, False
|
11 |
+
elif caption_model in ['updown', 'topdown']:
|
12 |
+
use_fc, use_att = True, True
|
13 |
+
else:
|
14 |
+
use_att, use_fc = True, False
|
15 |
+
return use_fc, use_att
|
16 |
+
|
17 |
+
|
18 |
+
def parse_opt():
|
19 |
+
parser = argparse.ArgumentParser()
|
20 |
+
# Data input settings
|
21 |
+
parser.add_argument('--input_json', type=str, default='data/coco.json',
|
22 |
+
help='path to the json file containing additional info and vocab')
|
23 |
+
parser.add_argument('--input_fc_dir', type=str, default='data/cocotalk_fc',
|
24 |
+
help='path to the directory containing the preprocessed fc feats')
|
25 |
+
parser.add_argument('--input_att_dir', type=str, default='data/cocotalk_att',
|
26 |
+
help='path to the directory containing the preprocessed att feats')
|
27 |
+
parser.add_argument('--input_box_dir', type=str, default='data/cocotalk_box',
|
28 |
+
help='path to the directory containing the boxes of att feats')
|
29 |
+
parser.add_argument('--input_label_h5', type=str, default='data/coco_label.h5',
|
30 |
+
help='path to the h5file containing the preprocessed dataset')
|
31 |
+
parser.add_argument('--data_in_memory', action='store_true',
|
32 |
+
help='True if we want to save the features in memory')
|
33 |
+
parser.add_argument('--start_from', type=str, default=None,
|
34 |
+
help="""continue training from saved model at this path. Path must contain files saved by previous training process:
|
35 |
+
'infos.pkl' : configuration;
|
36 |
+
'model.pth' : weights
|
37 |
+
""")
|
38 |
+
parser.add_argument('--cached_tokens', type=str, default='coco-train-idxs',
|
39 |
+
help='Cached token file for calculating cider score during self critical training.')
|
40 |
+
|
41 |
+
# Model settings
|
42 |
+
parser.add_argument('--caption_model', type=str, default="show_tell",
|
43 |
+
help='show_tell, show_attend_tell, all_img, fc, att2in, att2in2, att2all2, adaatt, adaattmo, updown, stackatt, denseatt, transformer')
|
44 |
+
parser.add_argument('--rnn_size', type=int, default=512,
|
45 |
+
help='size of the rnn in number of hidden nodes in each layer')
|
46 |
+
parser.add_argument('--num_layers', type=int, default=1,
|
47 |
+
help='number of layers in the RNN')
|
48 |
+
parser.add_argument('--rnn_type', type=str, default='lstm',
|
49 |
+
help='rnn, gru, or lstm')
|
50 |
+
parser.add_argument('--input_encoding_size', type=int, default=512,
|
51 |
+
help='the encoding size of each token in the vocabulary, and the image.')
|
52 |
+
parser.add_argument('--att_hid_size', type=int, default=512,
|
53 |
+
help='the hidden size of the attention MLP; only useful in show_attend_tell; 0 if not using hidden layer')
|
54 |
+
parser.add_argument('--fc_feat_size', type=int, default=2048,
|
55 |
+
help='2048 for resnet, 4096 for vgg')
|
56 |
+
parser.add_argument('--att_feat_size', type=int, default=2048,
|
57 |
+
help='2048 for resnet, 512 for vgg')
|
58 |
+
parser.add_argument('--logit_layers', type=int, default=1,
|
59 |
+
help='number of layers in the RNN')
|
60 |
+
|
61 |
+
|
62 |
+
parser.add_argument('--use_bn', type=int, default=0,
|
63 |
+
help='If 1, then do batch_normalization first in att_embed, if 2 then do bn both in the beginning and the end of att_embed')
|
64 |
+
|
65 |
+
# feature manipulation
|
66 |
+
parser.add_argument('--norm_att_feat', type=int, default=0,
|
67 |
+
help='If normalize attention features')
|
68 |
+
parser.add_argument('--use_box', type=int, default=0,
|
69 |
+
help='If use box features')
|
70 |
+
parser.add_argument('--norm_box_feat', type=int, default=0,
|
71 |
+
help='If use box, do we normalize box feature')
|
72 |
+
|
73 |
+
# Optimization: General
|
74 |
+
parser.add_argument('--max_epochs', type=int, default=-1,
|
75 |
+
help='number of epochs')
|
76 |
+
parser.add_argument('--batch_size', type=int, default=16,
|
77 |
+
help='minibatch size')
|
78 |
+
parser.add_argument('--grad_clip_mode', type=str, default='value',
|
79 |
+
help='value or norm')
|
80 |
+
parser.add_argument('--grad_clip_value', type=float, default=0.1,
|
81 |
+
help='clip gradients at this value/max_norm, 0 means no clipping')
|
82 |
+
parser.add_argument('--drop_prob_lm', type=float, default=0.5,
|
83 |
+
help='strength of dropout in the Language Model RNN')
|
84 |
+
parser.add_argument('--self_critical_after', type=int, default=-1,
|
85 |
+
help='After what epoch do we start finetuning the CNN? (-1 = disable; never finetune, 0 = finetune from start)')
|
86 |
+
parser.add_argument('--seq_per_img', type=int, default=5,
|
87 |
+
help='number of captions to sample for each image during training. Done for efficiency since CNN forward pass is expensive. E.g. coco has 5 sents/image')
|
88 |
+
|
89 |
+
# Sample related
|
90 |
+
add_eval_sample_opts(parser)
|
91 |
+
|
92 |
+
#Optimization: for the Language Model
|
93 |
+
parser.add_argument('--optim', type=str, default='adam',
|
94 |
+
help='what update to use? rmsprop|sgd|sgdmom|adagrad|adam|adamw')
|
95 |
+
parser.add_argument('--learning_rate', type=float, default=4e-4,
|
96 |
+
help='learning rate')
|
97 |
+
parser.add_argument('--learning_rate_decay_start', type=int, default=-1,
|
98 |
+
help='at what iteration to start decaying learning rate? (-1 = dont) (in epoch)')
|
99 |
+
parser.add_argument('--learning_rate_decay_every', type=int, default=3,
|
100 |
+
help='every how many iterations thereafter to drop LR?(in epoch)')
|
101 |
+
parser.add_argument('--learning_rate_decay_rate', type=float, default=0.8,
|
102 |
+
help='every how many iterations thereafter to drop LR?(in epoch)')
|
103 |
+
parser.add_argument('--optim_alpha', type=float, default=0.9,
|
104 |
+
help='alpha for adam')
|
105 |
+
parser.add_argument('--optim_beta', type=float, default=0.999,
|
106 |
+
help='beta used for adam')
|
107 |
+
parser.add_argument('--optim_epsilon', type=float, default=1e-8,
|
108 |
+
help='epsilon that goes into denominator for smoothing')
|
109 |
+
parser.add_argument('--weight_decay', type=float, default=0,
|
110 |
+
help='weight_decay')
|
111 |
+
# Transformer
|
112 |
+
parser.add_argument('--label_smoothing', type=float, default=0,
|
113 |
+
help='')
|
114 |
+
parser.add_argument('--noamopt', action='store_true',
|
115 |
+
help='')
|
116 |
+
parser.add_argument('--noamopt_warmup', type=int, default=2000,
|
117 |
+
help='')
|
118 |
+
parser.add_argument('--noamopt_factor', type=float, default=1,
|
119 |
+
help='')
|
120 |
+
parser.add_argument('--reduce_on_plateau', action='store_true',
|
121 |
+
help='')
|
122 |
+
parser.add_argument('--reduce_on_plateau_factor', type=float, default=0.5,
|
123 |
+
help='')
|
124 |
+
parser.add_argument('--reduce_on_plateau_patience', type=int, default=3,
|
125 |
+
help='')
|
126 |
+
parser.add_argument('--cached_transformer', action='store_true',
|
127 |
+
help='')
|
128 |
+
|
129 |
+
|
130 |
+
parser.add_argument('--use_warmup', action='store_true',
|
131 |
+
help='warm up the learing rate?')
|
132 |
+
|
133 |
+
parser.add_argument('--scheduled_sampling_start', type=int, default=-1,
|
134 |
+
help='at what iteration to start decay gt probability')
|
135 |
+
parser.add_argument('--scheduled_sampling_increase_every', type=int, default=5,
|
136 |
+
help='every how many iterations thereafter to gt probability')
|
137 |
+
parser.add_argument('--scheduled_sampling_increase_prob', type=float, default=0.05,
|
138 |
+
help='How much to update the prob')
|
139 |
+
parser.add_argument('--scheduled_sampling_max_prob', type=float, default=0.25,
|
140 |
+
help='Maximum scheduled sampling prob.')
|
141 |
+
|
142 |
+
|
143 |
+
# Evaluation/Checkpointing
|
144 |
+
parser.add_argument('--val_images_use', type=int, default=3200,
|
145 |
+
help='how many images to use when periodically evaluating the validation loss? (-1 = all)')
|
146 |
+
parser.add_argument('--save_checkpoint_every', type=int, default=2500,
|
147 |
+
help='how often to save a model checkpoint (in iterations)?')
|
148 |
+
parser.add_argument('--save_every_epoch', action='store_true',
|
149 |
+
help='Save checkpoint every epoch, will overwrite save_checkpoint_every')
|
150 |
+
parser.add_argument('--save_history_ckpt', type=int, default=0,
|
151 |
+
help='If save checkpoints at every save point')
|
152 |
+
parser.add_argument('--checkpoint_path', type=str, default=None,
|
153 |
+
help='directory to store checkpointed models')
|
154 |
+
parser.add_argument('--language_eval', type=int, default=0,
|
155 |
+
help='Evaluate language as well (1 = yes, 0 = no)? BLEU/CIDEr/METEOR/ROUGE_L? requires coco-caption code from Github.')
|
156 |
+
parser.add_argument('--losses_log_every', type=int, default=25,
|
157 |
+
help='How often do we snapshot losses, for inclusion in the progress dump? (0 = disable)')
|
158 |
+
parser.add_argument('--load_best_score', type=int, default=1,
|
159 |
+
help='Do we load previous best score when resuming training.')
|
160 |
+
|
161 |
+
# misc
|
162 |
+
parser.add_argument('--id', type=str, default='',
|
163 |
+
help='an id identifying this run/job. used in cross-val and appended when writing progress files')
|
164 |
+
parser.add_argument('--train_only', type=int, default=0,
|
165 |
+
help='if true then use 80k, else use 110k')
|
166 |
+
parser.add_argument('--topic', type=str, default='dress',
|
167 |
+
help='type of datasets, such as dress, shirt, toptee')
|
168 |
+
|
169 |
+
|
170 |
+
# Reward
|
171 |
+
parser.add_argument('--cider_reward_weight', type=float, default=1,
|
172 |
+
help='The reward weight from cider')
|
173 |
+
parser.add_argument('--bleu_reward_weight', type=float, default=0,
|
174 |
+
help='The reward weight from bleu4')
|
175 |
+
|
176 |
+
|
177 |
+
# Structure_loss
|
178 |
+
parser.add_argument('--structure_loss_weight', type=float, default=1,
|
179 |
+
help='')
|
180 |
+
parser.add_argument('--structure_after', type=int, default=-1,
|
181 |
+
help='T')
|
182 |
+
parser.add_argument('--structure_loss_type', type=str, default='seqnll',
|
183 |
+
help='')
|
184 |
+
parser.add_argument('--struc_use_logsoftmax', action='store_true', help='')
|
185 |
+
parser.add_argument('--entropy_reward_weight', type=float, default=0,
|
186 |
+
help='Entropy reward, seems very interesting')
|
187 |
+
parser.add_argument('--self_cider_reward_weight', type=float, default=0,
|
188 |
+
help='self cider reward')
|
189 |
+
|
190 |
+
# Used for self critical or structure. Used when sampling is need during training
|
191 |
+
parser.add_argument('--train_sample_n', type=int, default=1,
|
192 |
+
help='The reward weight from cider')
|
193 |
+
parser.add_argument('--train_sample_method', type=str, default='sample',
|
194 |
+
help='')
|
195 |
+
parser.add_argument('--train_beam_size', type=int, default=1,
|
196 |
+
help='')
|
197 |
+
|
198 |
+
# Used for self critical
|
199 |
+
parser.add_argument('--sc_sample_method', type=str, default='greedy',
|
200 |
+
help='')
|
201 |
+
parser.add_argument('--sc_beam_size', type=int, default=1,
|
202 |
+
help='')
|
203 |
+
|
204 |
+
parser.add_argument('--seed', type=int, default=42,
|
205 |
+
help='')
|
206 |
+
|
207 |
+
# For diversity evaluation during training
|
208 |
+
add_diversity_opts(parser)
|
209 |
+
|
210 |
+
|
211 |
+
# config
|
212 |
+
parser.add_argument('--cfg', type=str, default=None,
|
213 |
+
help='configuration; similar to what is used in detectron')
|
214 |
+
parser.add_argument(
|
215 |
+
'--set_cfgs', dest='set_cfgs',
|
216 |
+
help='Set config keys. Key value sequence seperate by whitespace.'
|
217 |
+
'e.g. [key] [value] [key] [value]\n This has higher priority'
|
218 |
+
'than cfg file but lower than other args. (You can only overwrite'
|
219 |
+
'arguments that have alerady been defined in config file.)',
|
220 |
+
default=[], nargs='+')
|
221 |
+
# How will config be used
|
222 |
+
# 1) read cfg argument, and load the cfg file if it's not None
|
223 |
+
# 2) Overwrite cfg argument with set_cfgs
|
224 |
+
# 3) parse config argument to args.
|
225 |
+
# 4) in the end, parse command line argument and overwrite args
|
226 |
+
|
227 |
+
# step 1: read cfg_fn
|
228 |
+
args = parser.parse_args()
|
229 |
+
if args.cfg is not None or args.set_cfgs is not None:
|
230 |
+
from .config import CfgNode
|
231 |
+
if args.cfg is not None:
|
232 |
+
cn = CfgNode(CfgNode.load_yaml_with_base(args.cfg))
|
233 |
+
else:
|
234 |
+
cn = CfgNode()
|
235 |
+
if args.set_cfgs is not None:
|
236 |
+
cn.merge_from_list(args.set_cfgs)
|
237 |
+
for k,v in cn.items():
|
238 |
+
if not hasattr(args, k):
|
239 |
+
print('Warning: key %s not in args' %k)
|
240 |
+
setattr(args, k, v)
|
241 |
+
args = parser.parse_args(namespace=args)
|
242 |
+
|
243 |
+
# Check if args are valid
|
244 |
+
assert args.rnn_size > 0, "rnn_size should be greater than 0"
|
245 |
+
assert args.num_layers > 0, "num_layers should be greater than 0"
|
246 |
+
assert args.input_encoding_size > 0, "input_encoding_size should be greater than 0"
|
247 |
+
assert args.batch_size > 0, "batch_size should be greater than 0"
|
248 |
+
assert args.drop_prob_lm >= 0 and args.drop_prob_lm < 1, "drop_prob_lm should be between 0 and 1"
|
249 |
+
assert args.seq_per_img > 0, "seq_per_img should be greater than 0"
|
250 |
+
assert args.beam_size > 0, "beam_size should be greater than 0"
|
251 |
+
assert args.save_checkpoint_every > 0, "save_checkpoint_every should be greater than 0"
|
252 |
+
assert args.losses_log_every > 0, "losses_log_every should be greater than 0"
|
253 |
+
assert args.language_eval == 0 or args.language_eval == 1, "language_eval should be 0 or 1"
|
254 |
+
assert args.load_best_score == 0 or args.load_best_score == 1, "language_eval should be 0 or 1"
|
255 |
+
assert args.train_only == 0 or args.train_only == 1, "language_eval should be 0 or 1"
|
256 |
+
|
257 |
+
# default value for start_from and checkpoint_path
|
258 |
+
# args.checkpoint_path = args.checkpoint_path or './log_%s' %args.id
|
259 |
+
args.checkpoint_path = args.checkpoint_path or './results/log_{}_{}'.format(args.topic, args.id)
|
260 |
+
args.start_from = args.start_from or args.checkpoint_path
|
261 |
+
|
262 |
+
# Deal with feature things before anything
|
263 |
+
args.use_fc, args.use_att = if_use_feat(args.caption_model)
|
264 |
+
if args.use_box: args.att_feat_size = args.att_feat_size + 5
|
265 |
+
|
266 |
+
return args
|
267 |
+
|
268 |
+
|
269 |
+
def add_eval_options(parser):
|
270 |
+
# Basic options
|
271 |
+
parser.add_argument('--batch_size', type=int, default=0,
|
272 |
+
help='if > 0 then overrule, otherwise load from checkpoint.')
|
273 |
+
parser.add_argument('--num_images', type=int, default=-1,
|
274 |
+
help='how many images to use when periodically evaluating the loss? (-1 = all)')
|
275 |
+
parser.add_argument('--language_eval', type=int, default=0,
|
276 |
+
help='Evaluate language as well (1 = yes, 0 = no)? BLEU/CIDEr/METEOR/ROUGE_L? requires coco-caption code from Github.')
|
277 |
+
parser.add_argument('--dump_images', type=int, default=1,
|
278 |
+
help='Dump images into vis/imgs folder for vis? (1=yes,0=no)')
|
279 |
+
parser.add_argument('--dump_json', type=int, default=1,
|
280 |
+
help='Dump json with predictions into vis folder? (1=yes,0=no)')
|
281 |
+
parser.add_argument('--dump_path', type=int, default=0,
|
282 |
+
help='Write image paths along with predictions into vis json? (1=yes,0=no)')
|
283 |
+
|
284 |
+
# Sampling options
|
285 |
+
add_eval_sample_opts(parser)
|
286 |
+
|
287 |
+
# For evaluation on a folder of images:
|
288 |
+
parser.add_argument('--image_folder', type=str, default='',
|
289 |
+
help='If this is nonempty then will predict on the images in this folder path')
|
290 |
+
parser.add_argument('--image_root', type=str, default='',
|
291 |
+
help='In case the image paths have to be preprended with a root path to an image folder')
|
292 |
+
# For evaluation on MSCOCO images from some split:
|
293 |
+
parser.add_argument('--input_fc_dir', type=str, default='',
|
294 |
+
help='path to the h5file containing the preprocessed dataset')
|
295 |
+
parser.add_argument('--input_att_dir', type=str, default='',
|
296 |
+
help='path to the h5file containing the preprocessed dataset')
|
297 |
+
parser.add_argument('--input_box_dir', type=str, default='',
|
298 |
+
help='path to the h5file containing the preprocessed dataset')
|
299 |
+
parser.add_argument('--input_label_h5', type=str, default='',
|
300 |
+
help='path to the h5file containing the preprocessed dataset')
|
301 |
+
parser.add_argument('--input_json', type=str, default='',
|
302 |
+
help='path to the json file containing additional info and vocab. empty = fetch from model checkpoint.')
|
303 |
+
parser.add_argument('--split', type=str, default='test',
|
304 |
+
help='if running on MSCOCO images, which split to use: val|test|train')
|
305 |
+
parser.add_argument('--coco_json', type=str, default='',
|
306 |
+
help='if nonempty then use this file in DataLoaderRaw (see docs there). Used only in MSCOCO test evaluation, where we have a specific json file of only test set images.')
|
307 |
+
# misc
|
308 |
+
parser.add_argument('--id', type=str, default='',
|
309 |
+
help='an id identifying this run/job. used only if language_eval = 1 for appending to intermediate files')
|
310 |
+
parser.add_argument('--verbose_beam', type=int, default=1,
|
311 |
+
help='if we need to print out all beam search beams.')
|
312 |
+
parser.add_argument('--verbose_loss', type=int, default=0,
|
313 |
+
help='If calculate loss using ground truth during evaluation')
|
314 |
+
|
315 |
+
parser.add_argument('--seed', type=int, default=42,
|
316 |
+
help='')
|
317 |
+
|
318 |
+
def add_diversity_opts(parser):
|
319 |
+
parser.add_argument('--sample_n', type=int, default=1,
|
320 |
+
help='Diverse sampling')
|
321 |
+
parser.add_argument('--sample_n_method', type=str, default='sample',
|
322 |
+
help='sample, bs, dbs, gumbel, topk, dgreedy, dsample, dtopk, dtopp')
|
323 |
+
parser.add_argument('--eval_oracle', type=int, default=1,
|
324 |
+
help='if we need to calculate loss.')
|
325 |
+
|
326 |
+
|
327 |
+
# Sampling related options
|
328 |
+
def add_eval_sample_opts(parser):
|
329 |
+
parser.add_argument('--sample_method', type=str, default='greedy',
|
330 |
+
help='greedy; sample; gumbel; top<int>, top<0-1>')
|
331 |
+
parser.add_argument('--beam_size', type=int, default=1,
|
332 |
+
help='used when sample_method = greedy, indicates number of beams in beam search. Usually 2 or 3 works well. More is not better. Set this to 1 for faster runtime but a bit worse performance.')
|
333 |
+
parser.add_argument('--max_length', type=int, default=8,
|
334 |
+
help='Maximum length during sampling')
|
335 |
+
parser.add_argument('--length_penalty', type=str, default='',
|
336 |
+
help='wu_X or avg_X, X is the alpha')
|
337 |
+
parser.add_argument('--group_size', type=int, default=1,
|
338 |
+
help='used for diverse beam search. if group_size is 1, then it\'s normal beam search')
|
339 |
+
parser.add_argument('--diversity_lambda', type=float, default=0.5,
|
340 |
+
help='used for diverse beam search. Usually from 0.2 to 0.8. Higher value of lambda produces a more diverse list')
|
341 |
+
parser.add_argument('--temperature', type=float, default=1.0,
|
342 |
+
help='temperature when sampling from distributions (i.e. when sample_method = sample). Lower = "safer" predictions.')
|
343 |
+
parser.add_argument('--decoding_constraint', type=int, default=0,
|
344 |
+
help='If 1, not allowing same word in a row')
|
345 |
+
parser.add_argument('--block_trigrams', type=int, default=0,
|
346 |
+
help='block repeated trigram.')
|
347 |
+
parser.add_argument('--remove_bad_endings', type=int, default=1,
|
348 |
+
help='Remove bad endings')
|
349 |
+
parser.add_argument('--suppress_UNK', type=int, default=1,
|
350 |
+
help='Not predicting UNK')
|
351 |
+
|
352 |
+
|
353 |
+
if __name__ == '__main__':
|
354 |
+
import sys
|
355 |
+
sys.argv = [sys.argv[0]]
|
356 |
+
args = parse_opt()
|
357 |
+
print(args)
|
358 |
+
print()
|
359 |
+
sys.argv = [sys.argv[0], '--cfg', 'configs/updown_long.yml']
|
360 |
+
args1 = parse_opt()
|
361 |
+
print(dict(set(vars(args1).items()) - set(vars(args).items())))
|
362 |
+
print()
|
363 |
+
sys.argv = [sys.argv[0], '--cfg', 'configs/updown_long.yml', '--caption_model', 'att2in2']
|
364 |
+
args2 = parse_opt()
|
365 |
+
print(dict(set(vars(args2).items()) - set(vars(args1).items())))
|
captioning/utils/resnet.py
ADDED
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torchvision.models.resnet
|
4 |
+
from torchvision.models.resnet import BasicBlock, Bottleneck
|
5 |
+
import torch.utils.model_zoo as model_zoo
|
6 |
+
|
7 |
+
__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
|
8 |
+
'resnet152']
|
9 |
+
|
10 |
+
model_urls = {
|
11 |
+
'resnet18': 'https://download.pytorch.org/models/resnet18-f37072fd.pth',
|
12 |
+
'resnet34': 'https://download.pytorch.org/models/resnet34-b627a593.pth',
|
13 |
+
'resnet50': 'https://download.pytorch.org/models/resnet50-11ad3fa6.pth',
|
14 |
+
'resnet101': 'https://download.pytorch.org/models/resnet101-cd907fc2.pth',
|
15 |
+
'resnet152': 'https://download.pytorch.org/models/resnet152-f82ba261.pth',
|
16 |
+
}
|
17 |
+
|
18 |
+
class ResNet(torchvision.models.resnet.ResNet):
|
19 |
+
def __init__(self, block, layers, num_classes=1000):
|
20 |
+
super(ResNet, self).__init__(block, layers, num_classes)
|
21 |
+
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=0, ceil_mode=True) # change
|
22 |
+
for i in range(2, 5):
|
23 |
+
getattr(self, 'layer%d'%i)[0].conv1.stride = (2,2)
|
24 |
+
getattr(self, 'layer%d'%i)[0].conv2.stride = (1,1)
|
25 |
+
|
26 |
+
def resnet18(pretrained=False):
|
27 |
+
"""Constructs a ResNet-18 model.
|
28 |
+
|
29 |
+
Args:
|
30 |
+
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
31 |
+
"""
|
32 |
+
model = ResNet(BasicBlock, [2, 2, 2, 2])
|
33 |
+
if pretrained:
|
34 |
+
model.load_state_dict(model_zoo.load_url(model_urls['resnet18']))
|
35 |
+
return model
|
36 |
+
|
37 |
+
|
38 |
+
def resnet34(pretrained=False):
|
39 |
+
"""Constructs a ResNet-34 model.
|
40 |
+
|
41 |
+
Args:
|
42 |
+
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
43 |
+
"""
|
44 |
+
model = ResNet(BasicBlock, [3, 4, 6, 3])
|
45 |
+
if pretrained:
|
46 |
+
model.load_state_dict(model_zoo.load_url(model_urls['resnet34']))
|
47 |
+
return model
|
48 |
+
|
49 |
+
|
50 |
+
def resnet50(pretrained=False):
|
51 |
+
"""Constructs a ResNet-50 model.
|
52 |
+
|
53 |
+
Args:
|
54 |
+
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
55 |
+
"""
|
56 |
+
model = ResNet(Bottleneck, [3, 4, 6, 3])
|
57 |
+
if pretrained:
|
58 |
+
model.load_state_dict(model_zoo.load_url(model_urls['resnet50']))
|
59 |
+
return model
|
60 |
+
|
61 |
+
|
62 |
+
def resnet101(pretrained=False):
|
63 |
+
"""Constructs a ResNet-101 model.
|
64 |
+
|
65 |
+
Args:
|
66 |
+
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
67 |
+
"""
|
68 |
+
|
69 |
+
model = ResNet(Bottleneck, [3, 4, 23, 3])
|
70 |
+
if pretrained:
|
71 |
+
model.load_state_dict(model_zoo.load_url(model_urls['resnet101']))
|
72 |
+
return model
|
73 |
+
|
74 |
+
|
75 |
+
def resnet152(pretrained=False):
|
76 |
+
"""Constructs a ResNet-152 model.
|
77 |
+
|
78 |
+
Args:
|
79 |
+
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
80 |
+
"""
|
81 |
+
model = ResNet(Bottleneck, [3, 8, 36, 3])
|
82 |
+
if pretrained:
|
83 |
+
model.load_state_dict(model_zoo.load_url(model_urls['resnet152']))
|
84 |
+
return model
|
captioning/utils/resnet_utils.py
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
class myResnet(nn.Module):
|
6 |
+
def __init__(self, resnet):
|
7 |
+
super(myResnet, self).__init__()
|
8 |
+
self.resnet = resnet
|
9 |
+
|
10 |
+
def forward(self, img, att_size=14):
|
11 |
+
x = img.unsqueeze(0)
|
12 |
+
|
13 |
+
x = self.resnet.conv1(x)
|
14 |
+
x = self.resnet.bn1(x)
|
15 |
+
x = self.resnet.relu(x)
|
16 |
+
x = self.resnet.maxpool(x)
|
17 |
+
|
18 |
+
x = self.resnet.layer1(x)
|
19 |
+
x = self.resnet.layer2(x)
|
20 |
+
x = self.resnet.layer3(x)
|
21 |
+
x = self.resnet.layer4(x)
|
22 |
+
|
23 |
+
fc = x.mean(3).mean(2).squeeze()
|
24 |
+
att = F.adaptive_avg_pool2d(x,[att_size,att_size]).squeeze().permute(1, 2, 0)
|
25 |
+
|
26 |
+
return fc, att
|
27 |
+
|
28 |
+
|
29 |
+
class ResNetBatch(nn.Module):
|
30 |
+
def __init__(self, resnet):
|
31 |
+
super(ResNetBatch, self).__init__()
|
32 |
+
self.resnet = resnet
|
33 |
+
|
34 |
+
def forward(self, x, att_size=14):
|
35 |
+
# size of x: nimages x nChannel x dim x dim
|
36 |
+
|
37 |
+
x = self.resnet.conv1(x)
|
38 |
+
x = self.resnet.bn1(x)
|
39 |
+
x = self.resnet.relu(x)
|
40 |
+
x = self.resnet.maxpool(x)
|
41 |
+
|
42 |
+
x = self.resnet.layer1(x)
|
43 |
+
x = self.resnet.layer2(x)
|
44 |
+
x = self.resnet.layer3(x)
|
45 |
+
x = self.resnet.layer4(x)
|
46 |
+
|
47 |
+
fc = x.mean(3).mean(2)
|
48 |
+
# att = F.adaptive_avg_pool2d(x, [att_size, att_size]).squeeze().permute(1, 2, 0)
|
49 |
+
att = F.adaptive_avg_pool2d(x, [att_size, att_size]).permute(0, 2, 3, 1)
|
50 |
+
|
51 |
+
return fc, att
|
captioning/utils/rewards.py
ADDED
@@ -0,0 +1,136 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import absolute_import
|
2 |
+
from __future__ import division
|
3 |
+
from __future__ import print_function
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
import time
|
7 |
+
from collections import OrderedDict
|
8 |
+
import torch
|
9 |
+
|
10 |
+
import sys
|
11 |
+
try:
|
12 |
+
sys.path.append("cider")
|
13 |
+
from pyciderevalcap.ciderD.ciderD import CiderD
|
14 |
+
from pyciderevalcap.cider.cider import Cider
|
15 |
+
sys.path.append("coco-caption")
|
16 |
+
from pycocoevalcap.bleu.bleu import Bleu
|
17 |
+
except:
|
18 |
+
print('cider or coco-caption missing')
|
19 |
+
|
20 |
+
CiderD_scorer = None
|
21 |
+
Cider_scorer = None
|
22 |
+
Bleu_scorer = None
|
23 |
+
#CiderD_scorer = CiderD(df='corpus')
|
24 |
+
|
25 |
+
def init_scorer(cached_tokens):
|
26 |
+
global CiderD_scorer
|
27 |
+
CiderD_scorer = CiderD_scorer or CiderD(df=cached_tokens)
|
28 |
+
global Cider_scorer
|
29 |
+
Cider_scorer = Cider_scorer or Cider(df=cached_tokens)
|
30 |
+
global Bleu_scorer
|
31 |
+
Bleu_scorer = Bleu_scorer or Bleu(4)
|
32 |
+
|
33 |
+
def array_to_str(arr):
|
34 |
+
out = ''
|
35 |
+
for i in range(len(arr)):
|
36 |
+
out += str(arr[i]) + ' '
|
37 |
+
if arr[i] == 0:
|
38 |
+
break
|
39 |
+
return out.strip()
|
40 |
+
|
41 |
+
def get_self_critical_reward(greedy_res, data_gts, gen_result, opt):
|
42 |
+
batch_size = len(data_gts)
|
43 |
+
gen_result_size = gen_result.shape[0]
|
44 |
+
seq_per_img = gen_result_size // len(data_gts) # gen_result_size = batch_size * seq_per_img
|
45 |
+
assert greedy_res.shape[0] == batch_size
|
46 |
+
|
47 |
+
res = OrderedDict()
|
48 |
+
gen_result = gen_result.data.cpu().numpy()
|
49 |
+
greedy_res = greedy_res.data.cpu().numpy()
|
50 |
+
for i in range(gen_result_size):
|
51 |
+
res[i] = [array_to_str(gen_result[i])]
|
52 |
+
for i in range(batch_size):
|
53 |
+
res[gen_result_size + i] = [array_to_str(greedy_res[i])]
|
54 |
+
|
55 |
+
gts = OrderedDict()
|
56 |
+
for i in range(len(data_gts)):
|
57 |
+
gts[i] = [array_to_str(data_gts[i][j]) for j in range(len(data_gts[i]))]
|
58 |
+
|
59 |
+
res_ = [{'image_id':i, 'caption': res[i]} for i in range(len(res))]
|
60 |
+
res__ = {i: res[i] for i in range(len(res_))}
|
61 |
+
gts_ = {i: gts[i // seq_per_img] for i in range(gen_result_size)}
|
62 |
+
gts_.update({i+gen_result_size: gts[i] for i in range(batch_size)})
|
63 |
+
if opt.cider_reward_weight > 0:
|
64 |
+
_, cider_scores = CiderD_scorer.compute_score(gts_, res_)
|
65 |
+
print('Cider scores:', _)
|
66 |
+
else:
|
67 |
+
cider_scores = 0
|
68 |
+
if opt.bleu_reward_weight > 0:
|
69 |
+
_, bleu_scores = Bleu_scorer.compute_score(gts_, res__)
|
70 |
+
bleu_scores = np.array(bleu_scores[3])
|
71 |
+
print('Bleu scores:', _[3])
|
72 |
+
else:
|
73 |
+
bleu_scores = 0
|
74 |
+
scores = opt.cider_reward_weight * cider_scores + opt.bleu_reward_weight * bleu_scores
|
75 |
+
|
76 |
+
scores = scores[:gen_result_size].reshape(batch_size, seq_per_img) - scores[-batch_size:][:, np.newaxis]
|
77 |
+
scores = scores.reshape(gen_result_size)
|
78 |
+
|
79 |
+
rewards = np.repeat(scores[:, np.newaxis], gen_result.shape[1], 1)
|
80 |
+
|
81 |
+
return rewards
|
82 |
+
|
83 |
+
def get_scores(data_gts, gen_result, opt):
|
84 |
+
batch_size = gen_result.size(0)# batch_size = sample_size * seq_per_img
|
85 |
+
seq_per_img = batch_size // len(data_gts)
|
86 |
+
|
87 |
+
res = OrderedDict()
|
88 |
+
|
89 |
+
gen_result = gen_result.data.cpu().numpy()
|
90 |
+
for i in range(batch_size):
|
91 |
+
res[i] = [array_to_str(gen_result[i])]
|
92 |
+
|
93 |
+
gts = OrderedDict()
|
94 |
+
for i in range(len(data_gts)):
|
95 |
+
gts[i] = [array_to_str(data_gts[i][j]) for j in range(len(data_gts[i]))]
|
96 |
+
|
97 |
+
res_ = [{'image_id':i, 'caption': res[i]} for i in range(batch_size)]
|
98 |
+
res__ = {i: res[i] for i in range(batch_size)}
|
99 |
+
gts = {i: gts[i // seq_per_img] for i in range(batch_size)}
|
100 |
+
if opt.cider_reward_weight > 0:
|
101 |
+
_, cider_scores = CiderD_scorer.compute_score(gts, res_)
|
102 |
+
print('Cider scores:', _)
|
103 |
+
else:
|
104 |
+
cider_scores = 0
|
105 |
+
if opt.bleu_reward_weight > 0:
|
106 |
+
_, bleu_scores = Bleu_scorer.compute_score(gts, res__)
|
107 |
+
bleu_scores = np.array(bleu_scores[3])
|
108 |
+
print('Bleu scores:', _[3])
|
109 |
+
else:
|
110 |
+
bleu_scores = 0
|
111 |
+
|
112 |
+
scores = opt.cider_reward_weight * cider_scores + opt.bleu_reward_weight * bleu_scores
|
113 |
+
|
114 |
+
return scores
|
115 |
+
|
116 |
+
def get_self_cider_scores(data_gts, gen_result, opt):
|
117 |
+
batch_size = gen_result.size(0)# batch_size = sample_size * seq_per_img
|
118 |
+
seq_per_img = batch_size // len(data_gts)
|
119 |
+
|
120 |
+
res = []
|
121 |
+
|
122 |
+
gen_result = gen_result.data.cpu().numpy()
|
123 |
+
for i in range(batch_size):
|
124 |
+
res.append(array_to_str(gen_result[i]))
|
125 |
+
|
126 |
+
scores = []
|
127 |
+
for i in range(len(data_gts)):
|
128 |
+
tmp = Cider_scorer.my_self_cider([res[i*seq_per_img:(i+1)*seq_per_img]])
|
129 |
+
def get_div(eigvals):
|
130 |
+
eigvals = np.clip(eigvals, 0, None)
|
131 |
+
return -np.log(np.sqrt(eigvals[-1]) / (np.sqrt(eigvals).sum())) / np.log(len(eigvals))
|
132 |
+
scores.append(get_div(np.linalg.eigvalsh(tmp[0]/10)))
|
133 |
+
|
134 |
+
scores = np.array(scores)
|
135 |
+
|
136 |
+
return scores
|