|
|
|
|
|
|
|
|
|
|
|
import itertools |
|
import json |
|
import logging |
|
import os |
|
import torch |
|
from argparse import Namespace |
|
|
|
import numpy as np |
|
from fairseq import metrics, options, utils |
|
from fairseq.data import ( |
|
AppendTokenDataset, |
|
ConcatDataset, |
|
LanguagePairDataset, |
|
PrependTokenDataset, |
|
StripTokenDataset, |
|
TruncateDataset, |
|
data_utils, |
|
encoders, |
|
indexed_dataset, |
|
) |
|
from fairseq.tasks.translation import TranslationTask |
|
from fairseq.tasks import register_task, LegacyFairseqTask |
|
|
|
EVAL_BLEU_ORDER = 4 |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
def load_langpair_dataset( |
|
data_path, |
|
split, |
|
src, |
|
src_dict, |
|
tgt, |
|
tgt_dict, |
|
combine, |
|
dataset_impl, |
|
upsample_primary, |
|
left_pad_source, |
|
left_pad_target, |
|
max_source_positions, |
|
max_target_positions, |
|
prepend_bos=False, |
|
load_alignments=False, |
|
truncate_source=False, |
|
append_source_id=False, |
|
num_buckets=0, |
|
shuffle=True, |
|
pad_to_multiple=1, |
|
): |
|
def split_exists(split, src, tgt, lang, data_path): |
|
filename = os.path.join(data_path, "{}.{}-{}.{}".format(split, src, tgt, lang)) |
|
return os.path.exists(filename) |
|
|
|
src_datasets = [] |
|
tgt_datasets = [] |
|
|
|
for k in itertools.count(): |
|
split_k = split + (str(k) if k > 0 else "") |
|
|
|
|
|
if split_exists(split_k, src, tgt, src, data_path): |
|
prefix = os.path.join(data_path, "{}.{}-{}.".format(split_k, src, tgt)) |
|
elif split_exists(split_k, tgt, src, src, data_path): |
|
prefix = os.path.join(data_path, "{}.{}-{}.".format(split_k, tgt, src)) |
|
else: |
|
if k > 0: |
|
break |
|
else: |
|
raise FileNotFoundError( |
|
"Dataset not found: {} ({})".format(split, data_path) |
|
) |
|
|
|
src_dataset = data_utils.load_indexed_dataset( |
|
prefix + src, src_dict, dataset_impl |
|
) |
|
if truncate_source: |
|
src_dataset = AppendTokenDataset( |
|
TruncateDataset( |
|
StripTokenDataset(src_dataset, src_dict.eos()), |
|
max_source_positions - 1, |
|
), |
|
src_dict.eos(), |
|
) |
|
src_datasets.append(src_dataset) |
|
|
|
tgt_dataset = data_utils.load_indexed_dataset( |
|
prefix + tgt, tgt_dict, dataset_impl |
|
) |
|
if tgt_dataset is not None: |
|
tgt_datasets.append(tgt_dataset) |
|
|
|
logger.info( |
|
"{} {} {}-{} {} examples".format( |
|
data_path, split_k, src, tgt, len(src_datasets[-1]) |
|
) |
|
) |
|
|
|
if not combine: |
|
break |
|
|
|
assert len(src_datasets) == len(tgt_datasets) or len(tgt_datasets) == 0 |
|
|
|
if len(src_datasets) == 1: |
|
src_dataset = src_datasets[0] |
|
tgt_dataset = tgt_datasets[0] if len(tgt_datasets) > 0 else None |
|
else: |
|
sample_ratios = [1] * len(src_datasets) |
|
sample_ratios[0] = upsample_primary |
|
src_dataset = ConcatDataset(src_datasets, sample_ratios) |
|
if len(tgt_datasets) > 0: |
|
tgt_dataset = ConcatDataset(tgt_datasets, sample_ratios) |
|
else: |
|
tgt_dataset = None |
|
|
|
if prepend_bos: |
|
assert hasattr(src_dict, "bos_index") and hasattr(tgt_dict, "bos_index") |
|
src_dataset = PrependTokenDataset(src_dataset, src_dict.bos()) |
|
if tgt_dataset is not None: |
|
tgt_dataset = PrependTokenDataset(tgt_dataset, tgt_dict.bos()) |
|
|
|
eos = None |
|
if append_source_id: |
|
src_dataset = AppendTokenDataset( |
|
src_dataset, src_dict.index("[{}]".format(src)) |
|
) |
|
if tgt_dataset is not None: |
|
tgt_dataset = AppendTokenDataset( |
|
tgt_dataset, tgt_dict.index("[{}]".format(tgt)) |
|
) |
|
eos = tgt_dict.index("[{}]".format(tgt)) |
|
|
|
align_dataset = None |
|
if load_alignments: |
|
align_path = os.path.join(data_path, "{}.align.{}-{}".format(split, src, tgt)) |
|
if indexed_dataset.dataset_exists(align_path, impl=dataset_impl): |
|
align_dataset = data_utils.load_indexed_dataset( |
|
align_path, None, dataset_impl |
|
) |
|
|
|
tgt_dataset_sizes = tgt_dataset.sizes if tgt_dataset is not None else None |
|
return LanguagePairDataset( |
|
src_dataset, |
|
src_dataset.sizes, |
|
src_dict, |
|
tgt_dataset, |
|
tgt_dataset_sizes, |
|
tgt_dict, |
|
left_pad_source=left_pad_source, |
|
left_pad_target=left_pad_target, |
|
align_dataset=align_dataset, |
|
eos=eos, |
|
num_buckets=num_buckets, |
|
shuffle=shuffle, |
|
pad_to_multiple=pad_to_multiple, |
|
) |
|
|
|
|
|
@register_task("translation_w_langtok") |
|
class TranslationWithLangtokTask(LegacyFairseqTask): |
|
""" |
|
Translate from one (source) language to another (target) language. |
|
|
|
Args: |
|
src_dict (~fairseq.data.Dictionary): dictionary for the source language |
|
tgt_dict (~fairseq.data.Dictionary): dictionary for the target language |
|
|
|
.. note:: |
|
|
|
The translation task is compatible with :mod:`fairseq-train`, |
|
:mod:`fairseq-generate` and :mod:`fairseq-interactive`. |
|
|
|
The translation task provides the following additional command-line |
|
arguments: |
|
|
|
.. argparse:: |
|
:ref: fairseq.tasks.translation_parser |
|
:prog: |
|
""" |
|
|
|
@staticmethod |
|
def add_args(parser): |
|
"""Add task-specific arguments to the parser.""" |
|
|
|
parser.add_argument('data', help='colon separated path to data directories list, \ |
|
will be iterated upon during epochs in round-robin manner; \ |
|
however, valid and test data are always in the first directory to \ |
|
avoid the need for repeating them in all directories') |
|
parser.add_argument('-s', '--source-lang', default=None, metavar='SRC', |
|
help='source language') |
|
parser.add_argument('-t', '--target-lang', default=None, metavar='TARGET', |
|
help='target language') |
|
parser.add_argument('--load-alignments', action='store_true', |
|
help='load the binarized alignments') |
|
parser.add_argument('--left-pad-source', default='True', type=str, metavar='BOOL', |
|
help='pad the source on the left') |
|
parser.add_argument('--left-pad-target', default='False', type=str, metavar='BOOL', |
|
help='pad the target on the left') |
|
parser.add_argument('--max-source-positions', default=1024, type=int, metavar='N', |
|
help='max number of tokens in the source sequence') |
|
parser.add_argument('--max-target-positions', default=1024, type=int, metavar='N', |
|
help='max number of tokens in the target sequence') |
|
parser.add_argument('--upsample-primary', default=1, type=int, |
|
help='amount to upsample primary dataset') |
|
parser.add_argument('--truncate-source', action='store_true', default=False, |
|
help='truncate source to max-source-positions') |
|
parser.add_argument('--num-batch-buckets', default=0, type=int, metavar='N', |
|
help='if >0, then bucket source and target lengths into N ' |
|
'buckets and pad accordingly; this is useful on TPUs ' |
|
'to minimize the number of compilations') |
|
parser.add_argument('--lang-prefix-tok', default=None, type=str, help="starting token in decoder") |
|
|
|
|
|
parser.add_argument('--eval-bleu', action='store_true', |
|
help='evaluation with BLEU scores') |
|
parser.add_argument('--eval-bleu-detok', type=str, default="space", |
|
help='detokenize before computing BLEU (e.g., "moses"); ' |
|
'required if using --eval-bleu; use "space" to ' |
|
'disable detokenization; see fairseq.data.encoders ' |
|
'for other options') |
|
parser.add_argument('--eval-bleu-detok-args', type=str, metavar='JSON', |
|
help='args for building the tokenizer, if needed') |
|
parser.add_argument('--eval-tokenized-bleu', action='store_true', default=False, |
|
help='compute tokenized BLEU instead of sacrebleu') |
|
parser.add_argument('--eval-bleu-remove-bpe', nargs='?', const='@@ ', default=None, |
|
help='remove BPE before computing BLEU') |
|
parser.add_argument('--eval-bleu-args', type=str, metavar='JSON', |
|
help='generation args for BLUE scoring, ' |
|
'e.g., \'{"beam": 4, "lenpen": 0.6}\'') |
|
parser.add_argument('--eval-bleu-print-samples', action='store_true', |
|
help='print sample generations during validation') |
|
|
|
|
|
def __init__(self, args, src_dict, tgt_dict): |
|
super().__init__(args) |
|
self.src_dict = src_dict |
|
self.tgt_dict = tgt_dict |
|
|
|
@classmethod |
|
def setup_task(cls, args, **kwargs): |
|
"""Setup the task (e.g., load dictionaries). |
|
|
|
Args: |
|
args (argparse.Namespace): parsed command-line arguments |
|
""" |
|
args.left_pad_source = utils.eval_bool(args.left_pad_source) |
|
args.left_pad_target = utils.eval_bool(args.left_pad_target) |
|
|
|
paths = utils.split_paths(args.data) |
|
assert len(paths) > 0 |
|
|
|
if args.source_lang is None or args.target_lang is None: |
|
args.source_lang, args.target_lang = data_utils.infer_language_pair( |
|
paths[0] |
|
) |
|
if args.source_lang is None or args.target_lang is None: |
|
raise Exception( |
|
"Could not infer language pair, please provide it explicitly" |
|
) |
|
|
|
|
|
src_dict = cls.load_dictionary( |
|
os.path.join(paths[0], "dict.{}.txt".format(args.source_lang)) |
|
) |
|
tgt_dict = cls.load_dictionary( |
|
os.path.join(paths[0], "dict.{}.txt".format(args.target_lang)) |
|
) |
|
assert src_dict.pad() == tgt_dict.pad() |
|
assert src_dict.eos() == tgt_dict.eos() |
|
assert src_dict.unk() == tgt_dict.unk() |
|
logger.info("[{}] dictionary: {} types".format(args.source_lang, len(src_dict))) |
|
logger.info("[{}] dictionary: {} types".format(args.target_lang, len(tgt_dict))) |
|
|
|
return cls(args, src_dict, tgt_dict) |
|
|
|
def load_dataset(self, split, epoch=1, combine=False, **kwargs): |
|
"""Load a given dataset split. |
|
|
|
Args: |
|
split (str): name of the split (e.g., train, valid, test) |
|
""" |
|
paths = utils.split_paths(self.args.data) |
|
assert len(paths) > 0 |
|
if split != getattr(self.args, "train_subset", None): |
|
|
|
paths = paths[:1] |
|
data_path = paths[(epoch - 1) % len(paths)] |
|
|
|
|
|
src, tgt = self.args.source_lang, self.args.target_lang |
|
|
|
self.datasets[split] = load_langpair_dataset( |
|
data_path, |
|
split, |
|
src, |
|
self.src_dict, |
|
tgt, |
|
self.tgt_dict, |
|
combine=combine, |
|
dataset_impl=self.args.dataset_impl, |
|
upsample_primary=self.args.upsample_primary, |
|
left_pad_source=self.args.left_pad_source, |
|
left_pad_target=self.args.left_pad_target, |
|
max_source_positions=self.args.max_source_positions, |
|
max_target_positions=self.args.max_target_positions, |
|
load_alignments=self.args.load_alignments, |
|
truncate_source=self.args.truncate_source, |
|
num_buckets=self.args.num_batch_buckets, |
|
shuffle=(split != "test"), |
|
pad_to_multiple=self.args.required_seq_len_multiple, |
|
) |
|
|
|
def build_dataset_for_inference(self, src_tokens, src_lengths, constraints=None): |
|
return LanguagePairDataset( |
|
src_tokens, |
|
src_lengths, |
|
self.source_dictionary, |
|
tgt_dict=self.target_dictionary, |
|
constraints=constraints, |
|
) |
|
|
|
def build_model(self, args): |
|
model = super().build_model(args) |
|
if getattr(args, "eval_bleu", False): |
|
assert getattr(args, "eval_bleu_detok", None) is not None, ( |
|
"--eval-bleu-detok is required if using --eval-bleu; " |
|
"try --eval-bleu-detok=moses (or --eval-bleu-detok=space " |
|
"to disable detokenization, e.g., when using sentencepiece)" |
|
) |
|
detok_args = json.loads(getattr(args, "eval_bleu_detok_args", "{}") or "{}") |
|
self.tokenizer = encoders.build_tokenizer( |
|
Namespace( |
|
tokenizer=getattr(args, "eval_bleu_detok", None), **detok_args |
|
) |
|
) |
|
|
|
gen_args = json.loads(getattr(args, "eval_bleu_args", "{}") or "{}") |
|
self.sequence_generator = self.build_generator( |
|
[model], Namespace(**gen_args) |
|
) |
|
return model |
|
|
|
def valid_step(self, sample, model, criterion): |
|
loss, sample_size, logging_output = super().valid_step(sample, model, criterion) |
|
if self.args.eval_bleu: |
|
bleu = self._inference_with_bleu(self.sequence_generator, sample, model) |
|
logging_output["_bleu_sys_len"] = bleu.sys_len |
|
logging_output["_bleu_ref_len"] = bleu.ref_len |
|
|
|
|
|
assert len(bleu.counts) == EVAL_BLEU_ORDER |
|
for i in range(EVAL_BLEU_ORDER): |
|
logging_output["_bleu_counts_" + str(i)] = bleu.counts[i] |
|
logging_output["_bleu_totals_" + str(i)] = bleu.totals[i] |
|
return loss, sample_size, logging_output |
|
|
|
def inference_step( |
|
self, generator, models, sample, prefix_tokens=None, constraints=None |
|
): |
|
if self.args.lang_prefix_tok is None: |
|
prefix_tokens = None |
|
else: |
|
prefix_tokens = self.target_dictionary.index(self.args.lang_prefix_tok) |
|
assert prefix_tokens != self.target_dictionary.unk_index |
|
with torch.no_grad(): |
|
net_input = sample["net_input"] |
|
if "src_tokens" in net_input: |
|
src_tokens = net_input["src_tokens"] |
|
elif "source" in net_input: |
|
src_tokens = net_input["source"] |
|
else: |
|
raise Exception("expected src_tokens or source in net input") |
|
|
|
|
|
|
|
bsz, _ = src_tokens.size()[:2] |
|
if prefix_tokens is not None: |
|
if isinstance(prefix_tokens, int): |
|
prefix_tokens = torch.LongTensor([prefix_tokens]).unsqueeze(1) |
|
prefix_tokens = prefix_tokens.expand(bsz, -1) |
|
prefix_tokens = prefix_tokens.to(src_tokens.device) |
|
return generator.generate(models, sample, prefix_tokens=prefix_tokens) |
|
|
|
def reduce_metrics(self, logging_outputs, criterion): |
|
super().reduce_metrics(logging_outputs, criterion) |
|
if self.args.eval_bleu: |
|
|
|
def sum_logs(key): |
|
return sum(log.get(key, 0) for log in logging_outputs) |
|
|
|
counts, totals = [], [] |
|
for i in range(EVAL_BLEU_ORDER): |
|
counts.append(sum_logs("_bleu_counts_" + str(i))) |
|
totals.append(sum_logs("_bleu_totals_" + str(i))) |
|
|
|
if max(totals) > 0: |
|
|
|
metrics.log_scalar("_bleu_counts", np.array(counts)) |
|
metrics.log_scalar("_bleu_totals", np.array(totals)) |
|
metrics.log_scalar("_bleu_sys_len", sum_logs("_bleu_sys_len")) |
|
metrics.log_scalar("_bleu_ref_len", sum_logs("_bleu_ref_len")) |
|
|
|
def compute_bleu(meters): |
|
import inspect |
|
import sacrebleu |
|
|
|
fn_sig = inspect.getfullargspec(sacrebleu.compute_bleu)[0] |
|
if "smooth_method" in fn_sig: |
|
smooth = {"smooth_method": "exp"} |
|
else: |
|
smooth = {"smooth": "exp"} |
|
bleu = sacrebleu.compute_bleu( |
|
correct=meters["_bleu_counts"].sum, |
|
total=meters["_bleu_totals"].sum, |
|
sys_len=meters["_bleu_sys_len"].sum, |
|
ref_len=meters["_bleu_ref_len"].sum, |
|
**smooth |
|
) |
|
return round(bleu.score, 2) |
|
|
|
metrics.log_derived("bleu", compute_bleu) |
|
|
|
def max_positions(self): |
|
"""Return the max sentence length allowed by the task.""" |
|
return (self.args.max_source_positions, self.args.max_target_positions) |
|
|
|
@property |
|
def source_dictionary(self): |
|
"""Return the source :class:`~fairseq.data.Dictionary`.""" |
|
return self.src_dict |
|
|
|
@property |
|
def target_dictionary(self): |
|
"""Return the target :class:`~fairseq.data.Dictionary`.""" |
|
return self.tgt_dict |
|
|
|
def _inference_with_bleu(self, generator, sample, model): |
|
import sacrebleu |
|
|
|
def decode(toks, escape_unk=False): |
|
s = self.tgt_dict.string( |
|
toks.int().cpu(), |
|
self.args.eval_bleu_remove_bpe, |
|
|
|
|
|
|
|
|
|
|
|
unk_string=("UNKNOWNTOKENINREF" if escape_unk else "UNKNOWNTOKENINHYP"), |
|
) |
|
if self.tokenizer: |
|
s = self.tokenizer.decode(s) |
|
return s |
|
|
|
gen_out = self.inference_step(generator, [model], sample, prefix_tokens=None) |
|
hyps, refs = [], [] |
|
for i in range(len(gen_out)): |
|
hyps.append(decode(gen_out[i][0]["tokens"])) |
|
refs.append( |
|
decode( |
|
utils.strip_pad(sample["target"][i], self.tgt_dict.pad()), |
|
escape_unk=True, |
|
) |
|
) |
|
if self.args.eval_bleu_print_samples: |
|
logger.info("example hypothesis: " + hyps[0]) |
|
logger.info("example reference: " + refs[0]) |
|
if self.args.eval_tokenized_bleu: |
|
return sacrebleu.corpus_bleu(hyps, [refs], tokenize="none") |
|
else: |
|
return sacrebleu.corpus_bleu(hyps, [refs]) |
|
|