Spaces:
Sleeping
Sleeping
File size: 4,268 Bytes
9bf9e42 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 |
# This file is the implementation for ensemble evaluation.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import *
from .CaptionModel import CaptionModel
from .AttModel import pack_wrapper, AttModel
class AttEnsemble(AttModel):
def __init__(self, models, weights=None):
CaptionModel.__init__(self)
# super(AttEnsemble, self).__init__()
self.models = nn.ModuleList(models)
self.vocab_size = models[0].vocab_size
self.seq_length = models[0].seq_length
self.bad_endings_ix = models[0].bad_endings_ix
self.ss_prob = 0
weights = weights or [1.0] * len(self.models)
self.register_buffer('weights', torch.tensor(weights))
def init_hidden(self, batch_size):
state = [m.init_hidden(batch_size) for m in self.models]
return self.pack_state(state)
def pack_state(self, state):
self.state_lengths = [len(_) for _ in state]
return sum([list(_) for _ in state], [])
def unpack_state(self, state):
out = []
for l in self.state_lengths:
out.append(state[:l])
state = state[l:]
return out
def embed(self, it):
return [m.embed(it) for m in self.models]
def core(self, *args):
return zip(*[m.core(*_) for m, _ in zip(self.models, zip(*args))])
def get_logprobs_state(self, it, tmp_fc_feats, tmp_att_feats, tmp_p_att_feats, tmp_att_masks, state, output_logsoftmax=1):
# 'it' contains a word index
xt = self.embed(it)
state = self.unpack_state(state)
output, state = self.core(xt, tmp_fc_feats, tmp_att_feats, tmp_p_att_feats, state, tmp_att_masks)
logprobs = torch.stack([F.softmax(m.logit(output[i]), dim=1) for i,m in enumerate(self.models)], 2).mul(self.weights).div(self.weights.sum()).sum(-1).log()
return logprobs, self.pack_state(state)
def _prepare_feature(self, *args):
return tuple(zip(*[m._prepare_feature(*args) for m in self.models]))
def _old_sample_beam(self, fc_feats, att_feats, att_masks=None, opt={}):
beam_size = opt.get('beam_size', 10)
batch_size = fc_feats.size(0)
fc_feats, att_feats, p_att_feats, att_masks = self._prepare_feature(fc_feats, att_feats, att_masks)
assert beam_size <= self.vocab_size + 1, 'lets assume this for now, otherwise this corner case causes a few headaches down the road. can be dealt with in future if needed'
seq = torch.LongTensor(self.seq_length, batch_size).zero_()
seqLogprobs = torch.FloatTensor(self.seq_length, batch_size, self.vocab_size + 1)
# lets process every image independently for now, for simplicity
self.done_beams = [[] for _ in range(batch_size)]
for k in range(batch_size):
state = self.init_hidden(beam_size)
tmp_fc_feats = [fc_feats[i][k:k+1].expand(beam_size, fc_feats[i].size(1)) for i,m in enumerate(self.models)]
tmp_att_feats = [att_feats[i][k:k+1].expand(*((beam_size,)+att_feats[i].size()[1:])).contiguous() for i,m in enumerate(self.models)]
tmp_p_att_feats = [p_att_feats[i][k:k+1].expand(*((beam_size,)+p_att_feats[i].size()[1:])).contiguous() for i,m in enumerate(self.models)]
tmp_att_masks = [att_masks[i][k:k+1].expand(*((beam_size,)+att_masks[i].size()[1:])).contiguous() if att_masks[i] is not None else att_masks[i] for i,m in enumerate(self.models)]
it = fc_feats[0].data.new(beam_size).long().zero_()
logprobs, state = self.get_logprobs_state(it, tmp_fc_feats, tmp_att_feats, tmp_p_att_feats, tmp_att_masks, state)
self.done_beams[k] = self.old_beam_search(state, logprobs, tmp_fc_feats, tmp_att_feats, tmp_p_att_feats, tmp_att_masks, opt=opt)
seq[:, k] = self.done_beams[k][0]['seq'] # the first beam has highest cumulative score
seqLogprobs[:, k] = self.done_beams[k][0]['logps']
# return the samples and their log likelihoods
return seq.transpose(0, 1), seqLogprobs.transpose(0, 1)
# return the samples and their log likelihoods
|