|
|
|
|
|
|
|
|
|
|
|
import logging |
|
import os |
|
from collections import OrderedDict |
|
|
|
from fairseq import utils |
|
from fairseq.data import ( |
|
BacktranslationDataset, |
|
IndexedCachedDataset, |
|
IndexedDataset, |
|
IndexedRawTextDataset, |
|
LanguagePairDataset, |
|
NoisingDataset, |
|
RoundRobinZipDatasets, |
|
data_utils, |
|
indexed_dataset, |
|
) |
|
from fairseq.models import FairseqMultiModel |
|
from fairseq.sequence_generator import SequenceGenerator |
|
|
|
from . import register_task |
|
from .multilingual_translation import MultilingualTranslationTask |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
def _get_bt_dataset_key(lang_pair): |
|
return "bt:" + lang_pair |
|
|
|
|
|
def _get_denoising_dataset_key(lang_pair): |
|
return "denoising:" + lang_pair |
|
|
|
|
|
|
|
def parse_lambda_config(x): |
|
""" |
|
Parse the configuration of lambda coefficient (for scheduling). |
|
x = "3" # lambda will be a constant equal to x |
|
x = "0:1,1000:0" # lambda will start from 1 and linearly decrease |
|
# to 0 during the first 1000 iterations |
|
x = "0:0,1000:0,2000:1" # lambda will be equal to 0 for the first 1000 |
|
# iterations, then will linearly increase to 1 until iteration 2000 |
|
""" |
|
split = x.split(",") |
|
if len(split) == 1: |
|
return float(x), None |
|
else: |
|
split = [s.split(os.pathsep) for s in split] |
|
assert all(len(s) == 2 for s in split) |
|
assert all(k.isdigit() for k, _ in split) |
|
assert all( |
|
int(split[i][0]) < int(split[i + 1][0]) for i in range(len(split) - 1) |
|
) |
|
return float(split[0][1]), [(int(k), float(v)) for k, v in split] |
|
|
|
|
|
@register_task("semisupervised_translation") |
|
class SemisupervisedTranslationTask(MultilingualTranslationTask): |
|
"""A task for training multiple translation models simultaneously. |
|
|
|
We iterate round-robin over batches from multiple language pairs, ordered |
|
according to the `--lang-pairs` argument. |
|
|
|
The training loop is roughly: |
|
|
|
for i in range(len(epoch)): |
|
for lang_pair in args.lang_pairs: |
|
batch = next_batch_for_lang_pair(lang_pair) |
|
loss = criterion(model_for_lang_pair(lang_pair), batch) |
|
loss.backward() |
|
optimizer.step() |
|
|
|
In practice, `next_batch_for_lang_pair` is abstracted in a FairseqDataset |
|
(e.g., `RoundRobinZipDatasets`) and `model_for_lang_pair` is a model that |
|
implements the `FairseqMultiModel` interface. |
|
|
|
During inference it is required to specify a single `--source-lang` and |
|
`--target-lang`, instead of `--lang-pairs`. |
|
""" |
|
|
|
@staticmethod |
|
def add_args(parser): |
|
"""Add task-specific arguments to the parser.""" |
|
|
|
MultilingualTranslationTask.add_args(parser) |
|
parser.add_argument('--lambda-parallel-config', default="1.0", type=str, metavar='CONFIG', |
|
help='cross-entropy reconstruction coefficient (parallel data). ' |
|
'use fixed weight during training if set to floating point number. ' |
|
'use piecewise linear function over number of updates to schedule the ' |
|
'weight with the format: w0:step0,w1:step1,...') |
|
parser.add_argument('--lambda-denoising-config', default="0.0", type=str, metavar='CONFIG', |
|
help='Cross-entropy reconstruction coefficient (denoising autoencoding)' |
|
'use fixed weight during training if set to floating point number. ' |
|
'use piecewise linear function over number of updates to schedule the ' |
|
'weight with the format: w0:step0,w1:step1,...') |
|
parser.add_argument('--lambda-otf-bt-config', default="0.0", type=str, metavar='CONFIG', |
|
help='cross-entropy reconstruction coefficient (on-the-fly back-translation parallel data)' |
|
'use fixed weight during training if set to floating point number. ' |
|
'use piecewise linear function over number of updates to schedule the ' |
|
'weight with the format: w0:step0,w1:step1,...') |
|
parser.add_argument('--bt-max-len-a', default=1.1, type=float, metavar='N', |
|
help='generate back-translated sequences of maximum length ax + b, where x is the ' |
|
'source length') |
|
parser.add_argument('--bt-max-len-b', default=10.0, type=float, metavar='N', |
|
help='generate back-translated sequences of maximum length ax + b, where x is the ' |
|
'source length') |
|
parser.add_argument('--bt-beam-size', default=1, type=int, metavar='N', |
|
help='beam size used in beam search of online back-translation') |
|
parser.add_argument('--max-word-shuffle-distance', default=3.0, type=float, metavar='N', |
|
help='maximum word shuffle distance for denoising autoencoding data generation') |
|
parser.add_argument('--word-dropout-prob', default=0.1, type=float, metavar='N', |
|
help='word dropout probability for denoising autoencoding data generation') |
|
parser.add_argument('--word-blanking-prob', default=0.2, type=float, metavar='N', |
|
help='word blanking probability for denoising autoencoding data generation') |
|
|
|
|
|
def __init__(self, args, dicts, training): |
|
super().__init__(args, dicts, training) |
|
self.lambda_parallel, self.lambda_parallel_steps = parse_lambda_config( |
|
args.lambda_parallel_config |
|
) |
|
self.lambda_otf_bt, self.lambda_otf_bt_steps = parse_lambda_config( |
|
args.lambda_otf_bt_config |
|
) |
|
self.lambda_denoising, self.lambda_denoising_steps = parse_lambda_config( |
|
args.lambda_denoising_config |
|
) |
|
if self.lambda_denoising > 0.0 or self.lambda_denoising_steps is not None: |
|
denoising_lang_pairs = [ |
|
"%s-%s" % (tgt, tgt) |
|
for tgt in {lang_pair.split("-")[1] for lang_pair in args.lang_pairs} |
|
] |
|
self.model_lang_pairs = self.model_lang_pairs + denoising_lang_pairs |
|
self.backtranslate_datasets = {} |
|
self.backtranslators = {} |
|
|
|
@classmethod |
|
def setup_task(cls, args, **kwargs): |
|
dicts, training = MultilingualTranslationTask.prepare(args, **kwargs) |
|
return cls(args, dicts, training) |
|
|
|
def load_dataset(self, split, epoch=1, **kwargs): |
|
"""Load a dataset split.""" |
|
paths = utils.split_paths(self.args.data) |
|
assert len(paths) > 0 |
|
data_path = paths[(epoch - 1) % len(paths)] |
|
|
|
def split_exists(split, src, tgt, lang): |
|
if src is not None: |
|
filename = os.path.join( |
|
data_path, "{}.{}-{}.{}".format(split, src, tgt, lang) |
|
) |
|
else: |
|
filename = os.path.join( |
|
data_path, "{}.{}-None.{}".format(split, src, tgt) |
|
) |
|
return indexed_dataset.dataset_exists(filename, impl=self.args.dataset_impl) |
|
|
|
def load_indexed_dataset(path, dictionary): |
|
return data_utils.load_indexed_dataset( |
|
path, dictionary, self.args.dataset_impl |
|
) |
|
|
|
|
|
src_datasets, tgt_datasets = {}, {} |
|
if ( |
|
self.lambda_parallel > 0.0 |
|
or self.lambda_parallel_steps is not None |
|
or not split.startswith("train") |
|
): |
|
for lang_pair in self.lang_pairs: |
|
src, tgt = lang_pair.split("-") |
|
if split_exists(split, src, tgt, src): |
|
prefix = os.path.join( |
|
data_path, "{}.{}-{}.".format(split, src, tgt) |
|
) |
|
elif split_exists(split, tgt, src, src): |
|
prefix = os.path.join( |
|
data_path, "{}.{}-{}.".format(split, tgt, src) |
|
) |
|
else: |
|
continue |
|
src_datasets[lang_pair] = load_indexed_dataset( |
|
prefix + src, self.dicts[src] |
|
) |
|
tgt_datasets[lang_pair] = load_indexed_dataset( |
|
prefix + tgt, self.dicts[tgt] |
|
) |
|
logger.info( |
|
"parallel-{} {} {} examples".format( |
|
data_path, split, len(src_datasets[lang_pair]) |
|
) |
|
) |
|
if len(src_datasets) == 0: |
|
raise FileNotFoundError( |
|
"Dataset not found: {} ({})".format(split, data_path) |
|
) |
|
|
|
|
|
backtranslate_datasets = {} |
|
if ( |
|
self.lambda_otf_bt > 0.0 or self.lambda_otf_bt_steps is not None |
|
) and split.startswith("train"): |
|
for lang_pair in self.lang_pairs: |
|
src, tgt = lang_pair.split("-") |
|
if not split_exists(split, tgt, None, tgt): |
|
raise FileNotFoundError( |
|
"Dataset not found: backtranslation {} ({})".format( |
|
split, data_path |
|
) |
|
) |
|
filename = os.path.join( |
|
data_path, "{}.{}-None.{}".format(split, tgt, tgt) |
|
) |
|
dataset = load_indexed_dataset(filename, self.dicts[tgt]) |
|
lang_pair_dataset_tgt = LanguagePairDataset( |
|
dataset, |
|
dataset.sizes, |
|
self.dicts[tgt], |
|
left_pad_source=self.args.left_pad_source, |
|
left_pad_target=self.args.left_pad_target, |
|
) |
|
lang_pair_dataset = LanguagePairDataset( |
|
dataset, |
|
dataset.sizes, |
|
src_dict=self.dicts[src], |
|
tgt=dataset, |
|
tgt_sizes=dataset.sizes, |
|
tgt_dict=self.dicts[tgt], |
|
left_pad_source=self.args.left_pad_source, |
|
left_pad_target=self.args.left_pad_target, |
|
) |
|
backtranslate_datasets[lang_pair] = BacktranslationDataset( |
|
tgt_dataset=self.alter_dataset_langtok( |
|
lang_pair_dataset_tgt, |
|
src_eos=self.dicts[tgt].eos(), |
|
src_lang=tgt, |
|
tgt_lang=src, |
|
), |
|
backtranslation_fn=self.backtranslators[lang_pair], |
|
src_dict=self.dicts[src], |
|
tgt_dict=self.dicts[tgt], |
|
output_collater=self.alter_dataset_langtok( |
|
lang_pair_dataset=lang_pair_dataset, |
|
src_eos=self.dicts[src].eos(), |
|
src_lang=src, |
|
tgt_eos=self.dicts[tgt].eos(), |
|
tgt_lang=tgt, |
|
).collater, |
|
) |
|
logger.info( |
|
"backtranslate-{}: {} {} {} examples".format( |
|
tgt, |
|
data_path, |
|
split, |
|
len(backtranslate_datasets[lang_pair]), |
|
) |
|
) |
|
self.backtranslate_datasets[lang_pair] = backtranslate_datasets[ |
|
lang_pair |
|
] |
|
|
|
|
|
noising_datasets = {} |
|
if ( |
|
self.lambda_denoising > 0.0 or self.lambda_denoising_steps is not None |
|
) and split.startswith("train"): |
|
for lang_pair in self.lang_pairs: |
|
_, tgt = lang_pair.split("-") |
|
if not split_exists(split, tgt, None, tgt): |
|
continue |
|
filename = os.path.join( |
|
data_path, "{}.{}-None.{}".format(split, tgt, tgt) |
|
) |
|
tgt_dataset1 = load_indexed_dataset(filename, self.dicts[tgt]) |
|
tgt_dataset2 = load_indexed_dataset(filename, self.dicts[tgt]) |
|
noising_dataset = NoisingDataset( |
|
tgt_dataset1, |
|
self.dicts[tgt], |
|
seed=1, |
|
max_word_shuffle_distance=self.args.max_word_shuffle_distance, |
|
word_dropout_prob=self.args.word_dropout_prob, |
|
word_blanking_prob=self.args.word_blanking_prob, |
|
) |
|
noising_datasets[lang_pair] = self.alter_dataset_langtok( |
|
LanguagePairDataset( |
|
noising_dataset, |
|
tgt_dataset1.sizes, |
|
self.dicts[tgt], |
|
tgt_dataset2, |
|
tgt_dataset2.sizes, |
|
self.dicts[tgt], |
|
left_pad_source=self.args.left_pad_source, |
|
left_pad_target=self.args.left_pad_target, |
|
), |
|
src_eos=self.dicts[tgt].eos(), |
|
src_lang=tgt, |
|
tgt_eos=self.dicts[tgt].eos(), |
|
tgt_lang=tgt, |
|
) |
|
logger.info( |
|
"denoising-{}: {} {} {} examples".format( |
|
tgt, |
|
data_path, |
|
split, |
|
len(noising_datasets[lang_pair]), |
|
) |
|
) |
|
|
|
def language_pair_dataset(lang_pair): |
|
src, tgt = lang_pair.split("-") |
|
src_dataset, tgt_dataset = src_datasets[lang_pair], tgt_datasets[lang_pair] |
|
return self.alter_dataset_langtok( |
|
LanguagePairDataset( |
|
src_dataset, |
|
src_dataset.sizes, |
|
self.dicts[src], |
|
tgt_dataset, |
|
tgt_dataset.sizes, |
|
self.dicts[tgt], |
|
left_pad_source=self.args.left_pad_source, |
|
left_pad_target=self.args.left_pad_target, |
|
), |
|
self.dicts[src].eos(), |
|
src, |
|
self.dicts[tgt].eos(), |
|
tgt, |
|
) |
|
|
|
self.datasets[split] = RoundRobinZipDatasets( |
|
OrderedDict( |
|
[ |
|
(lang_pair, language_pair_dataset(lang_pair)) |
|
for lang_pair in src_datasets.keys() |
|
] |
|
+ [ |
|
(_get_bt_dataset_key(lang_pair), dataset) |
|
for lang_pair, dataset in backtranslate_datasets.items() |
|
] |
|
+ [ |
|
(_get_denoising_dataset_key(lang_pair), dataset) |
|
for lang_pair, dataset in noising_datasets.items() |
|
] |
|
), |
|
eval_key=None |
|
if self.training |
|
else "%s-%s" % (self.args.source_lang, self.args.target_lang), |
|
) |
|
|
|
def build_model(self, args, from_checkpoint=False): |
|
from fairseq import models |
|
|
|
model = models.build_model(args, self, from_checkpoint) |
|
if not isinstance(model, FairseqMultiModel): |
|
raise ValueError( |
|
"SemisupervisedTranslationTask requires a FairseqMultiModel architecture" |
|
) |
|
|
|
|
|
self.sequence_generators = {} |
|
if ( |
|
self.lambda_otf_bt > 0.0 or self.lambda_otf_bt_steps is not None |
|
) and self.training: |
|
for lang_pair in self.lang_pairs: |
|
src, tgt = lang_pair.split("-") |
|
key = "{}-{}".format(tgt, src) |
|
self.sequence_generators[key] = SequenceGenerator( |
|
[model.models[key]], |
|
tgt_dict=self.dicts[src], |
|
beam_size=args.bt_beam_size, |
|
max_len_a=args.bt_max_len_a, |
|
max_len_b=args.bt_max_len_b, |
|
) |
|
decoder_lang_tok_idx = self.get_decoder_langtok(src) |
|
|
|
def backtranslate_fn( |
|
sample, |
|
model=model.models[key], |
|
bos_token=decoder_lang_tok_idx, |
|
sequence_generator=self.sequence_generators[key], |
|
): |
|
return sequence_generator.generate( |
|
[model], |
|
sample, |
|
bos_token=bos_token, |
|
) |
|
|
|
self.backtranslators[lang_pair] = backtranslate_fn |
|
|
|
return model |
|
|
|
def train_step( |
|
self, sample, model, criterion, optimizer, update_num, ignore_grad=False |
|
): |
|
model.train() |
|
|
|
if update_num > 0: |
|
self.update_step(update_num) |
|
|
|
agg_loss, agg_sample_size, agg_logging_output = 0.0, 0.0, {} |
|
|
|
def forward_backward(model, samples, logging_output_key, weight): |
|
nonlocal agg_loss, agg_sample_size, agg_logging_output |
|
if samples is None or len(samples) == 0: |
|
return |
|
loss, sample_size, logging_output = criterion(model, samples) |
|
if ignore_grad: |
|
loss *= 0 |
|
else: |
|
loss *= weight |
|
optimizer.backward(loss) |
|
agg_loss += loss.detach().item() |
|
|
|
agg_sample_size += sample_size |
|
for k in logging_output: |
|
agg_logging_output[k] += logging_output[k] |
|
agg_logging_output[logging_output_key] += logging_output[k] |
|
|
|
if self.lambda_parallel > 0.0: |
|
for lang_pair in self.lang_pairs: |
|
forward_backward( |
|
model.models[lang_pair], |
|
sample[lang_pair], |
|
lang_pair, |
|
self.lambda_parallel, |
|
) |
|
|
|
if self.lambda_otf_bt > 0.0: |
|
for lang_pair in self.lang_pairs: |
|
sample_key = _get_bt_dataset_key(lang_pair) |
|
forward_backward( |
|
model.models[lang_pair], |
|
sample[sample_key], |
|
sample_key, |
|
self.lambda_otf_bt, |
|
) |
|
|
|
if self.lambda_denoising > 0.0: |
|
for lang_pair in self.lang_pairs: |
|
_, tgt = lang_pair.split("-") |
|
sample_key = _get_denoising_dataset_key(lang_pair) |
|
forward_backward( |
|
model.models["{0}-{0}".format(tgt)], |
|
sample[sample_key], |
|
sample_key, |
|
self.lambda_denoising, |
|
) |
|
|
|
return agg_loss, agg_sample_size, agg_logging_output |
|
|
|
def update_step(self, num_updates): |
|
def lambda_step_func(config, n_iter): |
|
""" |
|
Update a lambda value according to its schedule configuration. |
|
""" |
|
ranges = [ |
|
i |
|
for i in range(len(config) - 1) |
|
if config[i][0] <= n_iter < config[i + 1][0] |
|
] |
|
if len(ranges) == 0: |
|
assert n_iter >= config[-1][0] |
|
return config[-1][1] |
|
assert len(ranges) == 1 |
|
i = ranges[0] |
|
x_a, y_a = config[i] |
|
x_b, y_b = config[i + 1] |
|
return y_a + (n_iter - x_a) * float(y_b - y_a) / float(x_b - x_a) |
|
|
|
if self.lambda_parallel_steps is not None: |
|
self.lambda_parallel = lambda_step_func( |
|
self.lambda_parallel_steps, num_updates |
|
) |
|
if self.lambda_denoising_steps is not None: |
|
self.lambda_denoising = lambda_step_func( |
|
self.lambda_denoising_steps, num_updates |
|
) |
|
if self.lambda_otf_bt_steps is not None: |
|
self.lambda_otf_bt = lambda_step_func(self.lambda_otf_bt_steps, num_updates) |
|
|