aliabd
full working demo
d5175d3
raw
history blame
10.7 kB
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from collections import OrderedDict, defaultdict
import json
import os
import logging
from fairseq import options, models
from fairseq.data import (
data_utils,
Dictionary,
LanguagePairDataset,
IndexedDataset,
FairseqDataset,
)
from .multitask_data_utils import (
MultitaskDatasetWrapper,
MultidatasetEpochBatchIterator,
)
from fairseq.tasks import LegacyFairseqTask, register_task
logger = logging.getLogger(__name__)
@register_task("laser")
class LaserTask(LegacyFairseqTask):
@staticmethod
def add_args(parser):
"""Add task-specific arguments to the parser."""
parser.add_argument(
"configfile", metavar="PATH", help="dataset configuration file in json"
)
parser.add_argument(
"--weighting-alpha",
type=float,
default=None,
help="alpha for automatic weighting",
)
parser.add_argument(
"--raw-text", action="store_true", help="load raw text dataset"
)
parser.add_argument(
"--left-pad-source",
default="True",
type=str,
metavar="BOOL",
help="pad the source on the left (default: True)",
)
parser.add_argument(
"--left-pad-target",
default="False",
type=str,
metavar="BOOL",
help="pad the target on the left (default: False)",
)
parser.add_argument(
"--max-source-positions",
default=1024,
type=int,
metavar="N",
help="max number of tokens in the source sequence",
)
parser.add_argument(
"--max-target-positions",
default=1024,
type=int,
metavar="N",
help="max number of tokens in the target sequence",
)
def __init__(self, args, config, src_dictionary, tgt_dictionary, num_tasks):
super().__init__(args)
self.config = config
self.src_dictionary = src_dictionary
self.tgt_dictionary = tgt_dictionary
self.num_tasks = num_tasks
@classmethod
def setup_task(cls, args, **kwargs):
with open(args.configfile, "r") as f:
config = json.load(f)
num_tasks = max(dataset["id"] for dataset in config["train"]) + 1
args.left_pad_source = options.eval_bool(args.left_pad_source)
args.left_pad_target = options.eval_bool(args.left_pad_target)
src_dictionary = Dictionary.load(config["src_vocab"])
tgt_dictionary = Dictionary.load(config["tgt_vocab"])
logger.info(
"| src Dictionary {} : {} types".format(
config["src_vocab"], len(src_dictionary)
)
)
logger.info(
"| tgt Dictionary {} : {} types".format(
config["tgt_vocab"], len(tgt_dictionary)
)
)
return cls(args, config, src_dictionary, tgt_dictionary, num_tasks)
# Experimental overriding for backtranslation
def build_model(self, args):
model = models.build_model(args, self)
return model
def dataset(self, split):
if split not in self.datasets:
raise KeyError("Dataset not loaded: " + split)
return self.datasets[split]
def load_dataset(self, split, epoch=1, **kwargs):
"""Load a dataset split."""
def indexed_dataset(path, dictionary):
if self.args.raw_text:
raise Exception("Unable to handle raw text.")
dataset = IndexedDataset(path, fix_lua_indexing=True)
return dataset
pair_datasets = OrderedDict()
if split == "valid":
self.datasets[split] = pair_datasets
return
if split not in self.config:
raise FileNotFoundError(
"Dataset not found in config file: {}".format(split)
)
size_by_corpus = defaultdict(int)
size_sum = 0
size_sum_with_subsampling = 0
init_pair_datasets = {}
for dataset_config in self.config[split]:
src_path = os.path.dirname(dataset_config["src"])
corpus_name = src_path.split("/")[-2]
language_pair_name = src_path.split("/")[-1]
pair_datasets_key = corpus_name + "-" + language_pair_name
logger.info(f"loading... {pair_datasets_key}")
if "src" in dataset_config:
src_dataset = indexed_dataset(
dataset_config["src"], self.src_dictionary
)
else:
src_dataset = None
if "tgt" in dataset_config:
tgt_dataset = indexed_dataset(
dataset_config["tgt"], self.tgt_dictionary
)
else:
tgt_dataset = None
dataset = LanguagePairDataset(
src_dataset,
src_dataset.sizes,
self.src_dictionary,
tgt_dataset,
tgt_dataset.sizes,
self.tgt_dictionary,
left_pad_source=self.args.left_pad_source,
left_pad_target=self.args.left_pad_target,
)
if pair_datasets_key in init_pair_datasets:
logger.warning(
f"Ignoring already added {pair_datasets_key}. "
f"Consider using `sample` key in order to upsample."
)
else:
init_pair_datasets[pair_datasets_key] = {
"dataset": dataset,
"sample": dataset_config.get("sample", None),
"id": dataset_config.get("id", None),
"len": len(dataset),
}
length_sum = 0
weighted_freqs_sum = 0
freq_per_dataset = {}
vmax = 0
vmin = 1
weighted_freq_per_dataset = {}
if self.args.weighting_alpha:
for key in init_pair_datasets:
if init_pair_datasets[key]["sample"] is None:
length_sum += len(init_pair_datasets[key]["dataset"])
for key in init_pair_datasets:
if init_pair_datasets[key]["sample"] is None:
val = float(init_pair_datasets[key]["len"]) / length_sum
freq_per_dataset[key] = val
weighted_freqs_sum += val ** self.args.weighting_alpha
for key in freq_per_dataset:
val = (
freq_per_dataset[key] ** self.args.weighting_alpha
/ weighted_freqs_sum
)
vmin = min(vmin, val)
vmax = max(vmax, val)
weighted_freq_per_dataset[key] = val
for pair_datasets_key in init_pair_datasets:
dataset_config = init_pair_datasets[pair_datasets_key]
dataset = dataset_config["dataset"]
sample = dataset_config["sample"]
if sample is None:
sample = 1.0
if pair_datasets_key in weighted_freq_per_dataset:
w = vmax / weighted_freq_per_dataset[pair_datasets_key]
sample = w
sample = round(sample)
initial_sample = sample
initial_pair_datasets_key = pair_datasets_key
while sample >= 1.0:
assert (
pair_datasets_key not in pair_datasets
), f"{pair_datasets_key} already in"
size_sum_with_subsampling += len(dataset)
pair_datasets[pair_datasets_key] = MultitaskDatasetWrapper(
dataset, dataset_config.get("id", 0), 1.0, name=pair_datasets_key
)
size_sum += len(dataset)
sample -= 1.0
pair_datasets_key += "-up"
assert sample < 1e-6, f"sample remains > 0 {pair_datasets_key}"
logger.info(
f"added pair {initial_pair_datasets_key} length {len(dataset)} new_length = {len(dataset)*initial_sample}"
)
size_by_corpus[corpus_name] += len(dataset)
self.datasets[split] = pair_datasets
logger.info(
f"Datasets number = {len(self.datasets[split])} size = {size_sum} size_sum_with_subsampling = {size_sum_with_subsampling}"
)
@property
def source_dictionary(self):
return self.src_dictionary
@property
def target_dictionary(self):
return self.tgt_dictionary
def get_batch_iterator(
self,
dataset,
max_tokens=None,
max_sentences=None,
max_positions=None,
ignore_invalid_inputs=False,
required_batch_size_multiple=1,
seed=1,
num_shards=1,
shard_id=0,
num_workers=0,
epoch=1,
data_buffer_size=0,
disable_iterator_cache=False,
):
assert isinstance(dataset, OrderedDict)
assert len(dataset)
assert isinstance(dataset[next(iter(dataset))], FairseqDataset)
# initialize the dataset with the correct starting epoch
for _, dt in dataset.items():
dt.set_epoch(epoch)
indices = OrderedDict()
batch_sampler = OrderedDict()
with data_utils.numpy_seed(seed + epoch):
for key, dt in dataset.items():
logger.info(f"\t ordered_indices {key}")
indices[key] = dt.ordered_indices()
# filter examples that are too large
if max_positions is not None:
for key, dt in dataset.items():
logger.info(f"\t filter_by_size {key}")
indices[key], ignored = dt.filter_indices_by_size(
indices[key], max_positions
)
for key, dt in dataset.items():
logger.info(f"\t batch_by_size {key}")
batch_sampler[key] = data_utils.batch_by_size(
indices[key],
dt.num_tokens,
max_tokens=max_tokens,
max_sentences=max_sentences,
required_batch_size_multiple=required_batch_size_multiple,
)
epoch_iter = MultidatasetEpochBatchIterator(
dataset=dataset,
batch_sampler=batch_sampler,
seed=seed,
num_shards=num_shards,
shard_id=shard_id,
num_workers=num_workers,
epoch=epoch,
)
return epoch_iter