|
|
|
|
|
|
|
|
|
|
|
from fairseq.tasks.translation import TranslationTask |
|
from fairseq.tasks.language_modeling import LanguageModelingTask |
|
from fairseq import checkpoint_utils |
|
import argparse |
|
from fairseq.tasks import register_task |
|
import torch |
|
|
|
|
|
@register_task("noisy_channel_translation") |
|
class NoisyChannelTranslation(TranslationTask): |
|
""" |
|
Rescore the top k candidates from each beam using noisy channel modeling |
|
""" |
|
|
|
@staticmethod |
|
def add_args(parser): |
|
"""Add task-specific arguments to the parser.""" |
|
TranslationTask.add_args(parser) |
|
|
|
parser.add_argument('--channel-model', metavar='FILE', |
|
help='path to P(S|T) model. P(S|T) and P(T|S) must share source and target dictionaries.') |
|
parser.add_argument('--combine-method', default='lm_only', |
|
choices=['lm_only', 'noisy_channel'], |
|
help="""method for combining direct and channel model scores. |
|
lm_only: decode with P(T|S)P(T) |
|
noisy_channel: decode with 1/t P(T|S) + 1/s(P(S|T)P(T))""") |
|
parser.add_argument('--normalize-lm-scores-by-tgt-len', action='store_true', default=False, |
|
help='normalize lm score by target length instead of source length') |
|
parser.add_argument('--channel-scoring-type', default='log_norm', choices=['unnormalized', 'log_norm', 'k2_separate', 'src_vocab', 'src_vocab_batched'], |
|
help="Normalize bw scores with log softmax or return bw scores without log softmax") |
|
parser.add_argument('--top-k-vocab', default=0, type=int, |
|
help='top k vocab IDs to use with `src_vocab` in channel model scoring') |
|
parser.add_argument('--k2', default=50, type=int, |
|
help='the top k2 candidates to rescore with the noisy channel model for each beam') |
|
parser.add_argument('--ch-wt', default=1, type=float, |
|
help='weight for the channel model') |
|
parser.add_argument('--lm-model', metavar='FILE', |
|
help='path to lm model file, to model P(T). P(T) must share the same vocab as the direct model on the target side') |
|
parser.add_argument('--lm-data', metavar='FILE', |
|
help='path to lm model training data for target language, used to properly load LM with correct dictionary') |
|
parser.add_argument('--lm-wt', default=1, type=float, |
|
help='the weight of the lm in joint decoding') |
|
|
|
|
|
def build_generator( |
|
self, models, args, seq_gen_cls=None, extra_gen_cls_kwargs=None |
|
): |
|
if getattr(args, "score_reference", False): |
|
raise NotImplementedError() |
|
else: |
|
from .noisy_channel_sequence_generator import NoisyChannelSequenceGenerator |
|
use_cuda = torch.cuda.is_available() and not self.args.cpu |
|
assert self.args.lm_model is not None, '--lm-model required for noisy channel generation!' |
|
assert self.args.lm_data is not None, '--lm-data required for noisy channel generation to map between LM and bitext vocabs' |
|
if self.args.channel_model is not None: |
|
import copy |
|
ch_args_task = copy.deepcopy(self.args) |
|
tmp = ch_args_task.source_lang |
|
ch_args_task.source_lang = ch_args_task.target_lang |
|
ch_args_task.target_lang = tmp |
|
ch_args_task._name = 'translation' |
|
channel_task = TranslationTask.setup_task(ch_args_task) |
|
|
|
arg_dict = {} |
|
arg_dict['task'] = 'language_modeling' |
|
arg_dict['sample_break_mode'] = 'eos' |
|
arg_dict['data'] = self.args.lm_data |
|
arg_dict['output_dictionary_size'] = -1 |
|
lm_args = argparse.Namespace(**arg_dict) |
|
lm_task = LanguageModelingTask.setup_task(lm_args) |
|
lm_dict = lm_task.output_dictionary |
|
|
|
if self.args.channel_model is not None: |
|
channel_models, _ = checkpoint_utils.load_model_ensemble(self.args.channel_model.split(':'), task=channel_task) |
|
|
|
for model in channel_models: |
|
model.make_generation_fast_( |
|
beamable_mm_beam_size=None if args.no_beamable_mm else args.beam, |
|
need_attn=args.print_alignment, |
|
) |
|
if self.args.fp16: |
|
model.half() |
|
if use_cuda: |
|
model.cuda() |
|
else: |
|
channel_models = None |
|
|
|
lm_models, _ = checkpoint_utils.load_model_ensemble(self.args.lm_model.split(':'), task=lm_task) |
|
|
|
for model in lm_models: |
|
model.make_generation_fast_( |
|
beamable_mm_beam_size=None if args.no_beamable_mm else args.beam, |
|
need_attn=args.print_alignment, |
|
) |
|
if self.args.fp16: |
|
model.half() |
|
if use_cuda: |
|
model.cuda() |
|
return NoisyChannelSequenceGenerator( |
|
combine_method=self.args.combine_method, |
|
tgt_dict=self.target_dictionary, |
|
src_dict=self.source_dictionary, |
|
beam_size=getattr(args, 'beam', 5), |
|
max_len_a=getattr(args, 'max_len_a', 0), |
|
max_len_b=getattr(args, 'max_len_b', 200), |
|
min_len=getattr(args, 'min_len', 1), |
|
len_penalty=getattr(args, 'lenpen', 1), |
|
unk_penalty=getattr(args, 'unkpen', 0), |
|
temperature=getattr(args, 'temperature', 1.), |
|
match_source_len=getattr(args, 'match_source_len', False), |
|
no_repeat_ngram_size=getattr(args, 'no_repeat_ngram_size', 0), |
|
normalize_scores=(not getattr(args, 'unnormalized', False)), |
|
channel_models=channel_models, |
|
k2=getattr(self.args, 'k2', 50), |
|
ch_weight=getattr(self.args, 'ch_wt', 1), |
|
channel_scoring_type=self.args.channel_scoring_type, |
|
top_k_vocab=self.args.top_k_vocab, |
|
lm_models=lm_models, |
|
lm_dict=lm_dict, |
|
lm_weight=getattr(self.args, 'lm_wt', 1), |
|
normalize_lm_scores_by_tgt_len=getattr(self.args, 'normalize_lm_scores_by_tgt_len', False), |
|
) |
|
|