File size: 5,966 Bytes
158b61b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 |
"""Ensemble decoding.
Decodes using multiple models simultaneously,
combining their prediction distributions by averaging.
All models in the ensemble must share a target vocabulary.
"""
import torch
import torch.nn as nn
from onmt.encoders.encoder import EncoderBase
from onmt.decoders.decoder import DecoderBase
from onmt.models import NMTModel
import onmt.model_builder
class EnsembleDecoderOutput(object):
"""Wrapper around multiple decoder final hidden states."""
def __init__(self, model_dec_outs):
self.model_dec_outs = tuple(model_dec_outs)
def squeeze(self, dim=None):
"""Delegate squeeze to avoid modifying
:func:`onmt.translate.translator.Translator.translate_batch()`
"""
return EnsembleDecoderOutput([
x.squeeze(dim) for x in self.model_dec_outs])
def __getitem__(self, index):
return self.model_dec_outs[index]
class EnsembleEncoder(EncoderBase):
"""Dummy Encoder that delegates to individual real Encoders."""
def __init__(self, model_encoders):
super(EnsembleEncoder, self).__init__()
self.model_encoders = nn.ModuleList(model_encoders)
def forward(self, src, lengths=None):
enc_hidden, memory_bank, _ = zip(*[
model_encoder(src, lengths)
for model_encoder in self.model_encoders])
return enc_hidden, memory_bank, lengths
class EnsembleDecoder(DecoderBase):
"""Dummy Decoder that delegates to individual real Decoders."""
def __init__(self, model_decoders):
model_decoders = nn.ModuleList(model_decoders)
attentional = any([dec.attentional for dec in model_decoders])
super(EnsembleDecoder, self).__init__(attentional)
self.model_decoders = model_decoders
def forward(self, tgt, memory_bank, memory_lengths=None, step=None,
**kwargs):
"""See :func:`onmt.decoders.decoder.DecoderBase.forward()`."""
# Memory_lengths is a single tensor shared between all models.
# This assumption will not hold if Translator is modified
# to calculate memory_lengths as something other than the length
# of the input.
dec_outs, attns = zip(*[
model_decoder(
tgt, memory_bank[i],
memory_lengths=memory_lengths, step=step, **kwargs)
for i, model_decoder in enumerate(self.model_decoders)])
mean_attns = self.combine_attns(attns)
return EnsembleDecoderOutput(dec_outs), mean_attns
def combine_attns(self, attns):
result = {}
for key in attns[0].keys():
result[key] = torch.stack(
[attn[key] for attn in attns if attn[key] is not None]).mean(0)
return result
def init_state(self, src, memory_bank, enc_hidden):
""" See :obj:`RNNDecoderBase.init_state()` """
for i, model_decoder in enumerate(self.model_decoders):
model_decoder.init_state(src, memory_bank[i], enc_hidden[i])
def map_state(self, fn):
for model_decoder in self.model_decoders:
model_decoder.map_state(fn)
class EnsembleGenerator(nn.Module):
"""
Dummy Generator that delegates to individual real Generators,
and then averages the resulting target distributions.
"""
def __init__(self, model_generators, raw_probs=False):
super(EnsembleGenerator, self).__init__()
self.model_generators = nn.ModuleList(model_generators)
self._raw_probs = raw_probs
def forward(self, hidden, attn=None, src_map=None):
"""
Compute a distribution over the target dictionary
by averaging distributions from models in the ensemble.
All models in the ensemble must share a target vocabulary.
"""
distributions = torch.stack(
[mg(h) if attn is None else mg(h, attn, src_map)
for h, mg in zip(hidden, self.model_generators)]
)
if self._raw_probs:
return torch.log(torch.exp(distributions).mean(0))
else:
return distributions.mean(0)
class EnsembleModel(NMTModel):
"""Dummy NMTModel wrapping individual real NMTModels."""
def __init__(self, models, raw_probs=False):
encoder = EnsembleEncoder(model.encoder for model in models)
decoder = EnsembleDecoder(model.decoder for model in models)
super(EnsembleModel, self).__init__(encoder, decoder)
self.generator = EnsembleGenerator(
[model.generator for model in models], raw_probs)
self.models = nn.ModuleList(models)
def load_test_model(opt):
"""Read in multiple models for ensemble."""
shared_fields = None
shared_model_opt = None
models = []
for model_path in opt.models:
fields, model, model_opt = \
onmt.model_builder.load_test_model(opt, model_path=model_path)
if shared_fields is None:
shared_fields = fields
else:
for key, field in fields.items():
try:
f_iter = iter(field)
except TypeError:
f_iter = [(key, field)]
for sn, sf in f_iter:
if sf is not None and 'vocab' in sf.__dict__:
sh_field = shared_fields[key]
try:
sh_f_iter = iter(sh_field)
except TypeError:
sh_f_iter = [(key, sh_field)]
sh_f_dict = dict(sh_f_iter)
assert sf.vocab.stoi == sh_f_dict[sn].vocab.stoi, \
"Ensemble models must use the same " \
"preprocessed data"
models.append(model)
if shared_model_opt is None:
shared_model_opt = model_opt
ensemble_model = EnsembleModel(models, opt.avg_raw_probs)
return shared_fields, ensemble_model, shared_model_opt
|