from __future__ import absolute_import from __future__ import division from __future__ import print_function import torch import torch.nn as nn import torch.nn.functional as F import numpy as np import json from json import encoder import random import string import time import os import sys from . import misc as utils # sys.path.insert(0, os.getcwd()) # sys.path.append("coco-caption") # load coco-caption if available from coco_caption.pycocotools.coco import COCO from coco_caption.pycocoevalcap.eval import COCOEvalCap # try: # # sys.path.append("coco-caption") # # from pycocotools.coco import COCO # # from pycocoevalcap.eval import COCOEvalCap # from coco_caption.pycocotools.coco import COCO # from coco_caption.pycocoevalcap.eval import COCOEvalCap # except: # print('Warning: coco-caption not available') bad_endings = ['a','an','the','in','for','at','of','with','before','after','on','upon','near','to','is','are','am'] bad_endings += ['UNK', 'has', 'and', 'more'] def count_bad(sen): sen = sen.split(' ') if sen[-1] in bad_endings: return 1 else: return 0 def getCOCO(dataset): if 'coco' in dataset: annFile = 'coco-caption/annotations/captions_val2014.json' elif 'flickr30k' in dataset or 'f30k' in dataset: annFile = 'data/f30k_captions4eval.json' # elif 'relative' in dataset: # annFile = 'data/dress/features_simulator/caption_relative.json' elif 'dress' in dataset: annFile = 'data/dress/features_simulator/caption_relative.json' elif 'shirt' in dataset: annFile = 'data/shirt/features_simulator/caption_relative.json' elif 'toptee' in dataset: annFile = 'data/toptee/features_simulator/caption_relative.json' elif 'fashion-gen' in dataset: annFile = 'data/fashion-gen/features_simulator/caption_direct.json' elif 'shoe' in dataset: annFile = 'data/shoe/features_simulator/caption_relative.json' return COCO(annFile) def language_eval(dataset, preds, preds_n, eval_kwargs, split): model_id = eval_kwargs['id'] eval_oracle = eval_kwargs.get('eval_oracle', 0) # create output dictionary out = {} if len(preds_n) > 0: # vocab size and novel sentences if 'coco' in dataset: dataset_file = 'data/dataset_coco.json' elif 'flickr30k' in dataset or 'f30k' in dataset: dataset_file = 'data/dataset_flickr30k.json' # elif 'relative' in dataset: # dataset_file = 'data/dress/features_simulator/caption_relative.json' elif 'dress' in dataset: annFile = 'data/dress/features_simulator/caption_relative.json' elif 'shirt' in dataset: annFile = 'data/shirt/features_simulator/caption_relative.json' elif 'toptee' in dataset: annFile = 'data/toptee/features_simulator/caption_relative.json' elif 'fashion-gen' in dataset: annFile = 'data/fashion-gen/features_simulator/caption_direct.json' elif 'shoe' in dataset: annFile = 'data/shoe/features_simulator/caption_relative.json' training_sentences = set([' '.join(__['tokens']) for _ in json.load(open(dataset_file))['images'] if not _['split'] in ['val', 'test'] for __ in _['sentences']]) generated_sentences = set([_['caption'] for _ in preds_n]) novels = generated_sentences - training_sentences out['novel_sentences'] = float(len(novels)) / len(preds_n) tmp = [_.split() for _ in generated_sentences] words = [] for _ in tmp: words += _ out['vocab_size'] = len(set(words)) # encoder.FLOAT_REPR = lambda o: format(o, '.3f') # cache_path = os.path.join('eval_results/', '.cache_'+ model_id + '_' + split + '.json')\ cache_path = os.path.join('results/log_'+eval_kwargs['topic']+'_'+model_id+'/eval_results_'+eval_kwargs['topic']+'/', '.cache_'+ model_id + '_' + split + '.json') coco = getCOCO(dataset) valids = coco.getImgIds() # filter results to only those in MSCOCO validation set preds_filt = [p for p in preds if p['image_id'] in valids] mean_perplexity = sum([_['perplexity'] for _ in preds_filt]) / len(preds_filt) mean_entropy = sum([_['entropy'] for _ in preds_filt]) / len(preds_filt) print('using %d/%d predictions' % (len(preds_filt), len(preds))) json.dump(preds_filt, open(cache_path, 'w')) # serialize to temporary json file. Sigh, COCO API... cocoRes = coco.loadRes(cache_path) cocoEval = COCOEvalCap(coco, cocoRes) cocoEval.params['image_id'] = cocoRes.getImgIds() cocoEval.evaluate() for metric, score in cocoEval.eval.items(): out[metric] = score # Add mean perplexity out['perplexity'] = mean_perplexity out['entropy'] = mean_entropy imgToEval = cocoEval.imgToEval for k in list(imgToEval.values())[0]['SPICE'].keys(): if k != 'All': out['SPICE_'+k] = np.array([v['SPICE'][k]['f'] for v in imgToEval.values()]) out['SPICE_'+k] = (out['SPICE_'+k][out['SPICE_'+k]==out['SPICE_'+k]]).mean() for p in preds_filt: image_id, caption = p['image_id'], p['caption'] imgToEval[image_id]['caption'] = caption if len(preds_n) > 0: from . import eval_multi # cache_path_n = os.path.join('eval_results/', '.cache_'+ model_id + '_' + split + '_n.json') cache_path_n = os.path.join('results/log_'+eval_kwargs['topic']+'_'+model_id+'/eval_results_'+eval_kwargs['topic']+'/', '.cache_'+ model_id + '_' + split + '_n.json') allspice = eval_multi.eval_allspice(dataset, preds_n, model_id, split) out.update(allspice['overall']) div_stats = eval_multi.eval_div_stats(dataset, preds_n, model_id, split) out.update(div_stats['overall']) if eval_oracle: oracle = eval_multi.eval_oracle(dataset, preds_n, model_id, split) out.update(oracle['overall']) else: oracle = None self_cider = eval_multi.eval_self_cider(dataset, preds_n, model_id, split) out.update(self_cider['overall']) with open(cache_path_n, 'w') as outfile: json.dump({'allspice': allspice, 'div_stats': div_stats, 'oracle': oracle, 'self_cider': self_cider}, outfile) out['bad_count_rate'] = sum([count_bad(_['caption']) for _ in preds_filt]) / float(len(preds_filt)) # outfile_path = os.path.join('eval_results/', model_id + '_' + split + '.json') outfile_path = os.path.join('results/log_'+eval_kwargs['topic']+'_'+model_id+'/eval_results_'+eval_kwargs['topic']+'/', model_id + '_' + split + '.json') with open(outfile_path, 'w') as outfile: json.dump({'overall': out, 'imgToEval': imgToEval}, outfile) return out def eval_split(model, crit, loader, eval_kwargs={}): verbose = eval_kwargs.get('verbose', True) verbose_beam = eval_kwargs.get('verbose_beam', 0) verbose_loss = eval_kwargs.get('verbose_loss', 1) num_images = eval_kwargs.get('num_images', eval_kwargs.get('val_images_use', -1)) split = eval_kwargs.get('split', 'val') lang_eval = eval_kwargs.get('language_eval', 0) dataset = eval_kwargs.get('dataset', 'coco') beam_size = eval_kwargs.get('beam_size', 1) sample_n = eval_kwargs.get('sample_n', 1) remove_bad_endings = eval_kwargs.get('remove_bad_endings', 1) os.environ["REMOVE_BAD_ENDINGS"] = str(remove_bad_endings) # Use this nasty way to make other code clean since it's a global configuration device = eval_kwargs.get('device', 'cuda') # Make sure in the evaluation mode model.eval() loader.reset_iterator(split) n = 0 loss = 0 loss_sum = 0 loss_evals = 1e-8 predictions = [] n_predictions = [] # when sample_n > 1 while True: data = loader.get_batch(split) n = n + len(data['infos']) tmp = [data['fc_feats'], data['att_feats'], data['labels'], data['masks'], data['att_masks']] tmp = [_.to(device) if _ is not None else _ for _ in tmp] fc_feats, att_feats, labels, masks, att_masks = tmp if labels is not None and verbose_loss: # forward the model to get loss with torch.no_grad(): loss = crit(model(fc_feats, att_feats, labels[..., :-1], att_masks), labels[..., 1:], masks[..., 1:]).item() loss_sum = loss_sum + loss loss_evals = loss_evals + 1 # forward the model to also get generated samples for each image with torch.no_grad(): tmp_eval_kwargs = eval_kwargs.copy() tmp_eval_kwargs.update({'sample_n': 1}) seq, seq_logprobs = model(fc_feats, att_feats, att_masks, opt=tmp_eval_kwargs, mode='sample') seq = seq.data entropy = - (F.softmax(seq_logprobs, dim=2) * seq_logprobs).sum(2).sum(1) / ((seq>0).to(seq_logprobs).sum(1)+1) perplexity = - seq_logprobs.gather(2, seq.unsqueeze(2)).squeeze(2).sum(1) / ((seq>0).to(seq_logprobs).sum(1)+1) # Print beam search if beam_size > 1 and verbose_beam: for i in range(fc_feats.shape[0]): print('\n'.join([utils.decode_sequence(model.vocab, _['seq'].unsqueeze(0))[0] for _ in model.done_beams[i]])) print('--' * 10) sents = utils.decode_sequence(model.vocab, seq) for k, sent in enumerate(sents): entry = {'image_id': data['infos'][k]['id'], 'caption': sent, 'perplexity': perplexity[k].item(), 'entropy': entropy[k].item()} if eval_kwargs.get('dump_path', 0) == 1: entry['file_name'] = data['infos'][k]['file_path'] predictions.append(entry) if eval_kwargs.get('dump_images', 0) == 1: # dump the raw image to vis/ folder cmd = 'cp "' + os.path.join(eval_kwargs['image_root'], data['infos'][k]['file_path']) + '" vis/imgs/img' + str(len(predictions)) + '.jpg' # bit gross print(cmd) os.system(cmd) if verbose: print('image %s: %s' %(entry['image_id'], entry['caption'])) if sample_n > 1: eval_split_n(model, n_predictions, [fc_feats, att_feats, att_masks, data], eval_kwargs) # ix0 = data['bounds']['it_pos_now'] ix1 = data['bounds']['it_max'] if num_images != -1: ix1 = min(ix1, num_images) else: num_images = ix1 for i in range(n - ix1): predictions.pop() if verbose: print('evaluating validation preformance... %d/%d (%f)' %(n, ix1, loss)) if num_images >= 0 and n >= num_images: break lang_stats = None if len(n_predictions) > 0 and 'perplexity' in n_predictions[0]: n_predictions = sorted(n_predictions, key=lambda x: x['perplexity']) # if not os.path.isdir('eval_results'): # os.mkdir('eval_results') if not os.path.isdir('results/log_'+eval_kwargs['topic']+'_'+eval_kwargs['id']+'/eval_results_'+eval_kwargs['topic']): os.mkdir('results/log_'+eval_kwargs['topic']+'_'+eval_kwargs['id']+'/eval_results_'+eval_kwargs['topic']) # torch.save((predictions, n_predictions), os.path.join('eval_results/', '.saved_pred_'+ eval_kwargs['id'] + '_' + split + '.pth')) torch.save((predictions, n_predictions), os.path.join('results/log_'+eval_kwargs['topic']+'_'+eval_kwargs['id']+'/eval_results_'+eval_kwargs['topic']+'/', '.saved_pred_'+ eval_kwargs['id'] + '_' + split + '.pth')) if lang_eval == 1: lang_stats = language_eval(dataset, predictions, n_predictions, eval_kwargs, split) # Switch back to training mode model.train() return loss_sum/loss_evals, predictions, lang_stats # Only run when sample_n > 0 def eval_split_n(model, n_predictions, input_data, eval_kwargs={}): verbose = eval_kwargs.get('verbose', True) beam_size = eval_kwargs.get('beam_size', 1) sample_n = eval_kwargs.get('sample_n', 1) sample_n_method = eval_kwargs.get('sample_n_method', 'sample') fc_feats, att_feats, att_masks, data = input_data tmp_eval_kwargs = eval_kwargs.copy() if sample_n_method == 'bs': # case 1 sample_n == beam size tmp_eval_kwargs.update({'sample_n': 1, 'beam_size': sample_n, 'group_size': 1}) # randomness from softmax with torch.no_grad(): model(fc_feats, att_feats, att_masks, opt=tmp_eval_kwargs, mode='sample') for k in range(fc_feats.shape[0]): _sents = utils.decode_sequence(model.vocab, torch.stack([model.done_beams[k][_]['seq'] for _ in range(sample_n)])) for sent in _sents: entry = {'image_id': data['infos'][k]['id'], 'caption': sent} n_predictions.append(entry) # case 2 sample / gumbel / topk sampling/ nucleus sampling elif sample_n_method == 'sample' or \ sample_n_method == 'gumbel' or \ sample_n_method.startswith('top'): tmp_eval_kwargs.update({'sample_n': sample_n, 'sample_method': sample_n_method, 'beam_size': 1}) # randomness from sample with torch.no_grad(): _seq, _sampleLogprobs = model(fc_feats, att_feats, att_masks, opt=tmp_eval_kwargs, mode='sample') _sents = utils.decode_sequence(model.vocab, _seq) _perplexity = - _sampleLogprobs.gather(2, _seq.unsqueeze(2)).squeeze(2).sum(1) / ((_seq>0).to(_sampleLogprobs).sum(1)+1) for k, sent in enumerate(_sents): entry = {'image_id': data['infos'][k // sample_n]['id'], 'caption': sent, 'perplexity': _perplexity[k].item()} n_predictions.append(entry) elif sample_n_method == 'dbs': # Use diverse beam search tmp_eval_kwargs.update({'beam_size': sample_n * beam_size, 'group_size': sample_n}) # randomness from softmax with torch.no_grad(): model(fc_feats, att_feats, att_masks, opt=tmp_eval_kwargs, mode='sample') for k in range(loader.batch_size): _sents = utils.decode_sequence(model.vocab, torch.stack([model.done_beams[k][_]['seq'] for _ in range(0, sample_n*beam_size, beam_size)])) for sent in _sents: entry = {'image_id': data['infos'][k]['id'], 'caption': sent} n_predictions.append(entry) else: tmp_eval_kwargs.update({'sample_method': sample_n_method[1:], 'group_size': sample_n, 'beam_size':1}) # randomness from softmax with torch.no_grad(): _seq, _sampleLogprobs = model(fc_feats, att_feats, att_masks, opt=tmp_eval_kwargs, mode='sample') _sents = utils.decode_sequence(model.vocab, _seq) for k, sent in enumerate(_sents): entry = {'image_id': data['infos'][k // sample_n]['id'], 'caption': sent} n_predictions.append(entry) if verbose: for entry in sorted(n_predictions[-fc_feats.shape[0] * sample_n:], key=lambda x: x['image_id']): print('image %s: %s' %(entry['image_id'], entry['caption']))