|
|
|
|
|
|
|
|
|
|
|
import torch |
|
from fairseq import utils |
|
from fairseq.data import LanguagePairDataset |
|
|
|
from . import register_task |
|
from .translation import TranslationTask, load_langpair_dataset |
|
|
|
|
|
@register_task("translation_from_pretrained_bart") |
|
class TranslationFromPretrainedBARTTask(TranslationTask): |
|
""" |
|
Translate from source language to target language with a model initialized with a multilingual pretrain. |
|
|
|
Args: |
|
src_dict (~fairseq.data.Dictionary): dictionary for the source language |
|
tgt_dict (~fairseq.data.Dictionary): dictionary for the target language |
|
|
|
.. note:: |
|
|
|
The translation task is compatible with :mod:`fairseq-train`, |
|
:mod:`fairseq-generate` and :mod:`fairseq-interactive`. |
|
|
|
The translation task provides the following additional command-line |
|
arguments: |
|
|
|
.. argparse:: |
|
:ref: fairseq.tasks.translation_parser |
|
:prog: |
|
""" |
|
|
|
@staticmethod |
|
def add_args(parser): |
|
"""Add task-specific arguments to the parser.""" |
|
|
|
TranslationTask.add_args(parser) |
|
parser.add_argument('--langs', type=str, metavar='LANG', |
|
help='comma-separated list of monolingual language, ' |
|
'for example, "en,de,fr". These should match the ' |
|
'langs from pretraining (and be in the same order). ' |
|
'You should always add all pretraining language idx ' |
|
'during finetuning.') |
|
parser.add_argument('--prepend-bos', action='store_true', |
|
help='prepend bos token to each sentence, which matches ' |
|
'mBART pretraining') |
|
|
|
|
|
def __init__(self, args, src_dict, tgt_dict): |
|
super().__init__(args, src_dict, tgt_dict) |
|
self.langs = args.langs.split(",") |
|
for d in [src_dict, tgt_dict]: |
|
for l in self.langs: |
|
d.add_symbol("[{}]".format(l)) |
|
d.add_symbol("<mask>") |
|
|
|
def load_dataset(self, split, epoch=1, combine=False, **kwargs): |
|
"""Load a given dataset split. |
|
|
|
Args: |
|
split (str): name of the split (e.g., train, valid, test) |
|
""" |
|
paths = utils.split_paths(self.args.data) |
|
assert len(paths) > 0 |
|
data_path = paths[(epoch - 1) % len(paths)] |
|
|
|
|
|
src, tgt = self.args.source_lang, self.args.target_lang |
|
|
|
self.datasets[split] = load_langpair_dataset( |
|
data_path, |
|
split, |
|
src, |
|
self.src_dict, |
|
tgt, |
|
self.tgt_dict, |
|
combine=combine, |
|
dataset_impl=self.args.dataset_impl, |
|
upsample_primary=self.args.upsample_primary, |
|
left_pad_source=self.args.left_pad_source, |
|
left_pad_target=self.args.left_pad_target, |
|
max_source_positions=getattr(self.args, "max_source_positions", 1024), |
|
max_target_positions=getattr(self.args, "max_target_positions", 1024), |
|
load_alignments=self.args.load_alignments, |
|
prepend_bos=getattr(self.args, "prepend_bos", False), |
|
append_source_id=True, |
|
) |
|
|
|
def build_generator(self, models, args, **unused): |
|
if getattr(args, "score_reference", False): |
|
from fairseq.sequence_scorer import SequenceScorer |
|
|
|
return SequenceScorer( |
|
self.target_dictionary, |
|
eos=self.tgt_dict.index("[{}]".format(self.args.target_lang)), |
|
) |
|
else: |
|
from fairseq.sequence_generator import SequenceGenerator |
|
|
|
return SequenceGenerator( |
|
models, |
|
self.target_dictionary, |
|
beam_size=getattr(args, "beam", 5), |
|
max_len_a=getattr(args, "max_len_a", 0), |
|
max_len_b=getattr(args, "max_len_b", 200), |
|
min_len=getattr(args, "min_len", 1), |
|
normalize_scores=(not getattr(args, "unnormalized", False)), |
|
len_penalty=getattr(args, "lenpen", 1), |
|
unk_penalty=getattr(args, "unkpen", 0), |
|
temperature=getattr(args, "temperature", 1.0), |
|
match_source_len=getattr(args, "match_source_len", False), |
|
no_repeat_ngram_size=getattr(args, "no_repeat_ngram_size", 0), |
|
eos=self.tgt_dict.index("[{}]".format(self.args.target_lang)), |
|
) |
|
|
|
def build_dataset_for_inference(self, src_tokens, src_lengths, constraints=None): |
|
src_lang_id = self.source_dictionary.index("[{}]".format(self.args.source_lang)) |
|
source_tokens = [] |
|
for s_t in src_tokens: |
|
s_t = torch.cat([s_t, s_t.new(1).fill_(src_lang_id)]) |
|
source_tokens.append(s_t) |
|
dataset = LanguagePairDataset( |
|
source_tokens, |
|
src_lengths, |
|
self.source_dictionary, |
|
tgt_dict=self.target_dictionary, |
|
constraints=constraints, |
|
) |
|
return dataset |
|
|