diff --git a/fairseq/fairseq/optim/__pycache__/__init__.cpython-310.pyc b/fairseq/fairseq/optim/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..83654b05326a95f66614abb44900ec3fb660c38f Binary files /dev/null and b/fairseq/fairseq/optim/__pycache__/__init__.cpython-310.pyc differ diff --git a/fairseq/fairseq/optim/__pycache__/adadelta.cpython-310.pyc b/fairseq/fairseq/optim/__pycache__/adadelta.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a244e6b908b52a54a2051804151db6e6953eaace Binary files /dev/null and b/fairseq/fairseq/optim/__pycache__/adadelta.cpython-310.pyc differ diff --git a/fairseq/fairseq/optim/__pycache__/adafactor.cpython-310.pyc b/fairseq/fairseq/optim/__pycache__/adafactor.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..311ae4c6d7e18b4377bd1ac365e804031574bce2 Binary files /dev/null and b/fairseq/fairseq/optim/__pycache__/adafactor.cpython-310.pyc differ diff --git a/fairseq/fairseq/optim/__pycache__/adagrad.cpython-310.pyc b/fairseq/fairseq/optim/__pycache__/adagrad.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..71a61e075fe6ddf596a85842b9bc23900e630c1c Binary files /dev/null and b/fairseq/fairseq/optim/__pycache__/adagrad.cpython-310.pyc differ diff --git a/fairseq/fairseq/optim/__pycache__/adam.cpython-310.pyc b/fairseq/fairseq/optim/__pycache__/adam.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b6e2634383e3bca6e17bf7db20a6cf25f43537e8 Binary files /dev/null and b/fairseq/fairseq/optim/__pycache__/adam.cpython-310.pyc differ diff --git a/fairseq/fairseq/optim/__pycache__/adamax.cpython-310.pyc b/fairseq/fairseq/optim/__pycache__/adamax.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5b8092333678d71dc42f589f254426b9eda849e9 Binary files /dev/null and b/fairseq/fairseq/optim/__pycache__/adamax.cpython-310.pyc differ diff --git a/fairseq/fairseq/optim/__pycache__/amp_optimizer.cpython-310.pyc b/fairseq/fairseq/optim/__pycache__/amp_optimizer.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c760a0ed3d9a4621e8a91276e78b8833cd24be1b Binary files /dev/null and b/fairseq/fairseq/optim/__pycache__/amp_optimizer.cpython-310.pyc differ diff --git a/fairseq/fairseq/optim/__pycache__/bmuf.cpython-310.pyc b/fairseq/fairseq/optim/__pycache__/bmuf.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..547b4ef210e327325a3d5a01d33fc2189ed1e02e Binary files /dev/null and b/fairseq/fairseq/optim/__pycache__/bmuf.cpython-310.pyc differ diff --git a/fairseq/fairseq/optim/__pycache__/composite.cpython-310.pyc b/fairseq/fairseq/optim/__pycache__/composite.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fdf225ef95e57686d4b6cb6a38041d5c24f39b3b Binary files /dev/null and b/fairseq/fairseq/optim/__pycache__/composite.cpython-310.pyc differ diff --git a/fairseq/fairseq/optim/__pycache__/cpu_adam.cpython-310.pyc b/fairseq/fairseq/optim/__pycache__/cpu_adam.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..88f2b697c91d025826f1039afb46f375b21679da Binary files /dev/null and b/fairseq/fairseq/optim/__pycache__/cpu_adam.cpython-310.pyc differ diff --git a/fairseq/fairseq/optim/__pycache__/dynamic_loss_scaler.cpython-310.pyc b/fairseq/fairseq/optim/__pycache__/dynamic_loss_scaler.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..85eb5af63ae8eb455e726a5d2845ac0d98f8a1d6 Binary files /dev/null and b/fairseq/fairseq/optim/__pycache__/dynamic_loss_scaler.cpython-310.pyc differ diff --git a/fairseq/fairseq/optim/__pycache__/fairseq_optimizer.cpython-310.pyc b/fairseq/fairseq/optim/__pycache__/fairseq_optimizer.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7fe6e1b3420395c9e326eaf74e876a59c5ba0330 Binary files /dev/null and b/fairseq/fairseq/optim/__pycache__/fairseq_optimizer.cpython-310.pyc differ diff --git a/fairseq/fairseq/optim/__pycache__/fp16_optimizer.cpython-310.pyc b/fairseq/fairseq/optim/__pycache__/fp16_optimizer.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c0ac903ac45c43247ea9664d08a3bab5fa470aea Binary files /dev/null and b/fairseq/fairseq/optim/__pycache__/fp16_optimizer.cpython-310.pyc differ diff --git a/fairseq/fairseq/optim/__pycache__/fused_adam.cpython-310.pyc b/fairseq/fairseq/optim/__pycache__/fused_adam.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4aa211d48c127a8d6f67c04ff5532a7cda635b1d Binary files /dev/null and b/fairseq/fairseq/optim/__pycache__/fused_adam.cpython-310.pyc differ diff --git a/fairseq/fairseq/optim/__pycache__/fused_lamb.cpython-310.pyc b/fairseq/fairseq/optim/__pycache__/fused_lamb.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..956ed95fce4cf5ef24ef2b502bd79f995aab3a5e Binary files /dev/null and b/fairseq/fairseq/optim/__pycache__/fused_lamb.cpython-310.pyc differ diff --git a/fairseq/fairseq/optim/__pycache__/nag.cpython-310.pyc b/fairseq/fairseq/optim/__pycache__/nag.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4bb7a78db984d98db5878e5eae898af06b0c939e Binary files /dev/null and b/fairseq/fairseq/optim/__pycache__/nag.cpython-310.pyc differ diff --git a/fairseq/fairseq/optim/__pycache__/sgd.cpython-310.pyc b/fairseq/fairseq/optim/__pycache__/sgd.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..565f7baa6c359ef64e02251cf0e3092daba2e7ca Binary files /dev/null and b/fairseq/fairseq/optim/__pycache__/sgd.cpython-310.pyc differ diff --git a/fairseq/fairseq/optim/__pycache__/shard.cpython-310.pyc b/fairseq/fairseq/optim/__pycache__/shard.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1ff60c2ca8f75d29507cc3a7c33633898d822ed5 Binary files /dev/null and b/fairseq/fairseq/optim/__pycache__/shard.cpython-310.pyc differ diff --git a/fairseq/fairseq/optim/bmuf.py b/fairseq/fairseq/optim/bmuf.py new file mode 100644 index 0000000000000000000000000000000000000000..d6d0e04e86eb894efe59e13a78843d01ca9e651d --- /dev/null +++ b/fairseq/fairseq/optim/bmuf.py @@ -0,0 +1,200 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from dataclasses import dataclass, field + +import torch +import torch.distributed as dist +from fairseq.dataclass.configs import FairseqBMUFConfig +from fairseq.dataclass.utils import gen_parser_from_dataclass +from fairseq.optim.fairseq_optimizer import FairseqOptimizer + + +class FairseqBMUF(FairseqOptimizer): + """ + Implements incremental block distributed data parallelism similar to + https://ieeexplore.ieee.org/document/7472805 + + Paper title: Scalable training of deep learning machines by incremental + block training with intra-block parallel optimization and blockwise + model-update filtering + """ + + def __init__(self, cfg: FairseqBMUFConfig, optimizer): + super().__init__(cfg) + self._optimizer = optimizer + self._num_updates = 0 + self.sync_iter = cfg.global_sync_iter + self.block_momentum = cfg.block_momentum + self.block_lr = cfg.block_lr + self._reset_local_data() + self.warmup_iteration = cfg.warmup_iterations + self.use_nbm = cfg.use_nbm + self.initial_state = self._optimizer.state_dict() + self.average_sync = self.cfg.average_sync + self.world_size = self.cfg.distributed_world_size + + @staticmethod + def add_args(parser): + """Add optimizer-specific arguments to the parser.""" + gen_parser_from_dataclass(parser, FairseqBMUFConfig()) + + @property + def optimizer(self): + return self._optimizer.optimizer + + @property + def optimizer_config(self): + return self._optimizer.optimizer_config + + def get_lr(self): + return self._optimizer.get_lr() + + def set_lr(self, lr): + self._optimizer.set_lr(lr) + + def state_dict(self): + return self._optimizer.state_dict() + + def load_state_dict(self, state_dict, optimizer_overrides=None): + self._optimizer.load_state_dict(state_dict, optimizer_overrides) + self.initial_state = self._optimizer.state_dict() + + def multiply_grads(self, c): + """Multiplies grads by a constant *c*.""" + self._optimizer.multiply_grads(c) + + def clip_grad_norm(self, max_norm, aggregate_norm_fn=None): + """Clips gradient norm.""" + return self._optimizer.clip_grad_norm(max_norm, aggregate_norm_fn) + + def average_params(self): + self._optimizer.average_params() + + def _block_sync(self): + if self.world_size <= 1: + return + # Update the global model using local models from all GPUs + # (Step-1) Calculate grad between previously synced model and + # currrent local model + if self.block_momentum != 0: + self._calc_grad() + + # (Step-2) Average gradient from all GPUs + self._avg_grad_from_all_gpus() + + # (Step-3) Calculate global momentum and update the global model + if self.block_momentum != 0: + self._update_global_model() + + # (Step-4) Average local optimizer params + if self.average_sync: + self.average_params() + + def _is_warmup_end(self): + # Check whether train iterations is equal to warmup iter + if self.get_num_updates() == self.warmup_iteration: + return True + return False + + def _is_bmuf_iter(self): + # Check whether train iterations is equal to bmuf sync iter + if (self.get_num_updates() > self.warmup_iteration) and ( + self.get_num_updates() % self.sync_iter == 0 + ): + return True + return False + + def _warmup_sync(self, root_rank=0): + if self.world_size <= 1: + return + # Broadcast the local model to all gpus + for param in self.params: + dist.broadcast(param.data, src=root_rank) + + # Update local optimizer state + if self.average_sync: + self._optimizer.average_params() + else: + self._optimizer.load_state_dict(self.initial_state) + + self._reset_local_data() + + def step(self, closure=None): + """Performs a single optimization step.""" + self._optimizer.step(closure) + self.set_num_updates(self.get_num_updates() + 1) + if self._is_warmup_end(): + self._warmup_sync() + elif self._is_bmuf_iter(): + self._block_sync() + + def zero_grad(self): + """Clears the gradients of all optimized parameters.""" + self._optimizer.zero_grad() + + def get_num_updates(self): + """Get the number of parameters updates.""" + return self._num_updates + + def set_num_updates(self, num_updates): + """Set the number of parameters updates.""" + self._num_updates = num_updates + + @torch.no_grad() + def _reset_local_data(self): + # (Step-0) Initialize global momentum parameters and store global copy on each gpu + self.global_params = [torch.zeros_like(p.data) for p in self.params] + self.smoothed_grads = [p.data.new_zeros(p.data.size()) for p in self.params] + self.grads = [p.data.new_zeros(p.data.size()) for p in self.params] + + # saving the global model locally for calculating gradient during bmuf sync + for param, global_param in zip(self.params, self.global_params): + global_param.copy_(param.data) + + @torch.no_grad() + def _calc_grad(self): + # global_params is basically the global copy from the previously finished + # synchronisation. param.data is local parameter after block_sync_freq + # for the local gpu. so grad is difference between previously synced + # model and currrent local model. + for index, (param, global_param) in enumerate( + zip(self.params, self.global_params) + ): + self.grads[index] = global_param - param.data + + def _avg_grad_from_all_gpus(self): + for index, param in enumerate(self.params): + sync_para = param.data if self.block_momentum == 0 else self.grads[index] + sync_para /= float(dist.get_world_size()) + dist.all_reduce(sync_para, op=dist.ReduceOp.SUM) + + @torch.no_grad() + def _update_global_model(self): + for index, (param, global_param, smoothed_grad, grad) in enumerate( + zip( + self.params, + self.global_params, + self.smoothed_grads, + # all gpus would share the same value of smoothed_grad, since it is + # always computed on synchronized gradients. + self.grads, + ) + ): + # global_param is basically last syncrhornized parameter. though + # smoothed_grad is local, all processes will have same value of + # smoothed_grad and hence param is globally synchronized copy. + # smoothed_grad(t) = BM * smoothed_grad(t-1) + BM_lr * grad(t) + smoothed_grad = self.block_momentum * smoothed_grad + self.block_lr * grad + param.data.copy_(global_param - smoothed_grad) + + # A Nesterov momentum here is to do a partial weight update before + # calculating the gradient + if self.use_nbm: + param.data.copy_(param.data - self.block_momentum * smoothed_grad) + + # backup for the next synchronization. + self.smoothed_grads[index] = smoothed_grad + global_param.copy_(param.data) diff --git a/fairseq/fairseq/optim/composite.py b/fairseq/fairseq/optim/composite.py new file mode 100644 index 0000000000000000000000000000000000000000..1ef0114ed63dcfbdf2898a257f38453ad2c69748 --- /dev/null +++ b/fairseq/fairseq/optim/composite.py @@ -0,0 +1,273 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import logging +from collections import defaultdict +from dataclasses import dataclass, field +from typing import Dict, Any, List, Optional + +import torch.optim +from fairseq.dataclass import FairseqDataclass +from fairseq.optim import FairseqOptimizer, register_optimizer, _build_optimizer +from fairseq.optim.lr_scheduler import FairseqLRScheduler, build_lr_scheduler +from omegaconf import II, open_dict +import copy + + +logger = logging.getLogger(__name__) + + +@dataclass +class OptimizerAndSchedulerConfig(FairseqDataclass): + optimizer: Any = None + lr_scheduler: Optional[Any] = None + lr: List = II("optimization.lr") + lr_float: Optional[ + float + ] = None # this makes it easier to sweep on learning rate with auto sweepers + + +@dataclass +class CompositeOptimizerConfig(FairseqDataclass): + groups: Dict[str, Any] = field( + default_factory=lambda: {}, + metadata={ + "help": "optimizer name -> optimizer OptimizerAndSchedulerConfig. " + "Configures a different optimizer and (optionally) lr scheduler for each parameter group" + }, + ) + dynamic_groups: bool = field( + default=False, + metadata={ + "help": "create groups dynamically based on parameters, if set to False, all parameters needs to have group_names" + }, + ) + + +@register_optimizer("composite", dataclass=CompositeOptimizerConfig) +class FairseqCompositeOptimizer(FairseqOptimizer): + + optimizers: Dict[str, FairseqOptimizer] = {} + lr_schedulers: Dict[str, FairseqLRScheduler] = {} + lr_scheduler: FairseqLRScheduler = None + _optimizer: torch.optim.Optimizer + + def __init__(self, cfg: CompositeOptimizerConfig, params): + super().__init__(cfg) + + assert ( + len(params) > 1 + ), "Composite optimizer only works when there are multiple parameter groups (try fp16_no_flatten_grads: true)" + + def dict_hash(dictionary: Dict[str, Any]) -> str: + import hashlib + import json + + dhash = hashlib.md5() + encoded = json.dumps(dictionary, sort_keys=True).encode() + dhash.update(encoded) + return dhash.hexdigest() + + groupped_params = defaultdict(list) + overrides = defaultdict(dict) + if not cfg.dynamic_groups: + for p in params: + group = getattr(p, "param_group", "default") + override_config = getattr(p, "optim_overrides", None) + if override_config is not None and bool(override_config): + overrides[group] = override_config + else: + assert ( + override_config == None or override_config == overrides[group] + ), f"For group {group}, different overrides found {override_config} v/s {overrides[group]}" + groupped_params[group].append(p) + + for p, params in groupped_params.items(): + override_config = getattr(params[0], "optim_overrides", None) + if override_config is not None: + for pp in params[1:]: + assert override_config == getattr( + pp, "optim_overrides", None + ), f" {str(override_config)} != {str(getattr(pp, 'optim_overrides', None))}" + else: + for p in params: + group = getattr(p, "param_group", "default") + override_config = getattr(p, "optim_overrides", None) + if override_config is not None: + override_config["group_name"] = group + group_name = dict_hash(override_config) + overrides[group_name] = override_config + else: + group_name = group + groupped_params[group_name].append(p) + + self.optimizers_config = {} + for group, group_params in groupped_params.items(): + p_group = group + if group in overrides and "group_name" in overrides[group]: + p_group = overrides[group]["group_name"] + if group in cfg.groups: + group_cfg = cfg.groups[group] + optimizer_config = copy.deepcopy(group_cfg.optimizer) + scheduler_config = copy.deepcopy(group_cfg.lr_scheduler) + explicit_group_present = True + else: + group_cfg = cfg.groups[p_group] + optimizer_config = copy.deepcopy(group_cfg.optimizer) + scheduler_config = copy.deepcopy(group_cfg.lr_scheduler) + explicit_group_present = False + + if getattr(group_cfg, "lr_float", None) is not None: + with open_dict(optimizer_config): + optimizer_config.lr = [group_cfg.lr_float] + + if group in overrides and "optimizer" in overrides[group]: + with open_dict(optimizer_config): + if "lr_scale" in overrides[group]["optimizer"]: + lr_scale = overrides[group]["optimizer"]["lr_scale"] + optimizer_config.lr = [ + lr * lr_scale for lr in optimizer_config.lr + ] + + if explicit_group_present: + logger.info( + f"For group:{group}, config as well as override present for lr" + ) + + if ( + "weight_decay_scale" in overrides[group]["optimizer"] + and "optimizer_config" in optimizer_config + ): + weight_decay_scale = overrides[group]["optimizer"][ + "weight_decay_scale" + ] + optimizer_config.weight_decay = ( + optimizer_config.weight_decay * weight_decay_scale + ) + if explicit_group_present: + logger.info( + f"For group:{group}, config as well as override present for weight_decay" + ) + + with open_dict(scheduler_config): + scheduler_config.lr = optimizer_config.lr + self.optimizers[group] = _build_optimizer(optimizer_config, group_params) + self.optimizers_config[group] = optimizer_config + if scheduler_config is not None: + self.lr_schedulers[group] = build_lr_scheduler( + scheduler_config, self.optimizers[group] + ) + logger.info("Optimizers for different groups are as below") + for group in self.optimizers_config.keys(): + logger.info(f"Group : {group}:{self.optimizers_config[group]}") + if len(self.lr_schedulers) > 0: + assert len(self.lr_schedulers) == len(self.optimizers), ( + f"Please provide an lr scheduler for each optimizer to use pass_through scheduler. " + f"Optimizers: {self.optimizers}; Lr scheds: {self.lr_schedulers}" + ) + self.lr_scheduler = CompositeLRScheduler(self.lr_schedulers) + + self._optimizer = CompositeOptimizer(self.optimizers) + + @property + def supports_groups(self): + return True + + @property + def param_groups(self): + for opt in self.optimizers.values(): + for group in opt.param_groups: + yield group + + def get_lr(self): + """Return the current learning rate.""" + k = ( + "default" + if "default" in self.optimizers + else next(iter(self.optimizers.keys())) + ) + return self.optimizers[k].param_groups[0]["lr"] + + def state_dict(self): + """Return the LR scheduler state dict.""" + return {k: s.state_dict() for k, s in self.optimizers.items()} + + def load_state_dict(self, state_dict, optimizer_overrides=None): + """Load an LR scheduler state dict.""" + for k, state in state_dict.items(): + if k not in self.optimizers: + # skip extra keys like "loss_scale" added by fp16 optimizer + continue + + overrides = ( + optimizer_overrides[k] + if isinstance(optimizer_overrides, dict) and k in optimizer_overrides + else None + ) + self.optimizers[k].load_state_dict(state, optimizer_overrides=overrides) + + +class CompositeOptimizer(torch.optim.Optimizer): + def __init__(self, optimizers: Dict[str, FairseqOptimizer]): + self.optimizers = optimizers + + @property + def supports_memory_efficient_fp16(self): + return all(o.supports_memory_efficient_fp16 for o in self.optimizers.values()) + + @property + def supports_flat_params(self): + return all(o.supports_flat_params for o in self.optimizers.values()) + + def step(self, closure=None, groups=None): + """Performs a single optimization step. + + Args: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + loss = closure() + + for k, opt in self.optimizers.items(): + if groups is None or k in groups: + opt.step() + + return loss + + def zero_grad(self): + for opt in self.optimizers.values(): + opt.zero_grad() + + +class CompositeLRScheduler(FairseqLRScheduler): + def __init__(self, lr_schedulers): + super().__init__(None, None) + + self.lr_schedulers = lr_schedulers + + def state_dict(self): + """Return the LR scheduler state dict.""" + return {k: s.state_dict() for k, s in self.lr_schedulers.items()} + + def load_state_dict(self, state_dict): + """Load an LR scheduler state dict.""" + for k, state in state_dict.items(): + self.lr_schedulers[k].load_state_dict(state) + + def step_begin_epoch(self, epoch): + """Update the learning rate at the beginning of the given epoch.""" + for s in self.lr_schedulers.values(): + s.step_begin_epoch(epoch) + + def step(self, epoch, val_loss=None): + """Update the learning rate at the end of the given epoch.""" + for s in self.lr_schedulers.values(): + s.step(epoch) + + def step_update(self, num_updates): + """Update the learning rate after each update.""" + return {k: s.step_update(num_updates) for k, s in self.lr_schedulers.items()} diff --git a/fairseq/fairseq/optim/fairseq_optimizer.py b/fairseq/fairseq/optim/fairseq_optimizer.py new file mode 100644 index 0000000000000000000000000000000000000000..73c7c695ee5ce71e1dd44b4deb0990f57858cb38 --- /dev/null +++ b/fairseq/fairseq/optim/fairseq_optimizer.py @@ -0,0 +1,187 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from fairseq import utils +from fairseq.dataclass.utils import gen_parser_from_dataclass +from collections import defaultdict + + +class FairseqOptimizer(object): + def __init__(self, cfg): + super().__init__() + self.cfg = cfg + + @classmethod + def add_args(cls, parser): + """Add optimizer-specific arguments to the parser.""" + dc = getattr(cls, "__dataclass", None) + if dc is not None: + gen_parser_from_dataclass(parser, dc()) + + @property + def optimizer(self): + """Return a torch.optim.optimizer.Optimizer instance.""" + if not hasattr(self, "_optimizer"): + raise NotImplementedError + if not isinstance(self._optimizer, torch.optim.Optimizer): + raise ValueError("_optimizer must be an instance of torch.optim.Optimizer") + return self._optimizer + + @optimizer.setter + def optimizer(self, optimizer): + """Reset optimizer instance.""" + if not hasattr(self, "_optimizer"): + raise NotImplementedError + if not isinstance(self._optimizer, torch.optim.Optimizer): + raise ValueError("_optimizer must be an instance of torch.optim.Optimizer") + self._optimizer = optimizer + + @property + def optimizer_config(self): + """ + Return a kwarg dictionary that will be used to override optimizer + args stored in checkpoints. This allows us to load a checkpoint and + resume training using a different set of optimizer args, e.g., with a + different learning rate. + """ + raise NotImplementedError + + @property + def params(self): + """Return an iterable of the parameters held by the optimizer.""" + for param_group in self.param_groups: + for p in param_group["params"]: + yield p + + @property + def param_groups(self): + return self.optimizer.param_groups + + def __getstate__(self): + return self._optimizer.__getstate__() + + def get_lr(self): + """Return the current learning rate.""" + return self.param_groups[0]["lr"] + + def set_lr(self, lr): + """Set the learning rate.""" + for param_group in self.param_groups: + param_group["lr"] = lr + + def state_dict(self): + """Return the optimizer's state dict.""" + return self.optimizer.state_dict() + + def load_state_dict(self, state_dict, optimizer_overrides=None): + """Load an optimizer state dict. + + In general we should prefer the configuration of the existing optimizer + instance (e.g., learning rate) over that found in the state_dict. This + allows us to resume training from a checkpoint using a new set of + optimizer args. + """ + self.optimizer.load_state_dict(state_dict) + + if optimizer_overrides is not None and len(optimizer_overrides) > 0: + # override learning rate, momentum, etc. with latest values + for group in self.param_groups: + group.update(optimizer_overrides) + + def backward(self, loss): + """Computes the sum of gradients of the given tensor w.r.t. graph leaves.""" + loss.backward() + + def all_reduce_grads(self, module): + """Manually all-reduce gradients (if required).""" + if hasattr(module, "all_reduce_grads"): + module.all_reduce_grads() + + def multiply_grads(self, c): + """Multiplies grads by a constant *c*.""" + per_device_and_dtype_grads = defaultdict(lambda: defaultdict(list)) + for p in self.params: + if p.grad is not None: + if p.grad.is_sparse: + p.grad.data.mul_(c.to(p.grad.device) if torch.is_tensor(c) else c) + else: + per_device_and_dtype_grads[p.grad.device][p.grad.dtype].append( + p.grad.data + ) + for device, per_dtype_grads in per_device_and_dtype_grads.items(): + for grads in per_dtype_grads.values(): + torch._foreach_mul_(grads, c.to(device) if torch.is_tensor(c) else c) + + def clip_grad_norm(self, max_norm, aggregate_norm_fn=None): + """Clips gradient norm.""" + return utils.clip_grad_norm_(self.params, max_norm, aggregate_norm_fn) + + def step(self, closure=None, scale=1.0, groups=None): + """Performs a single optimization step.""" + if self.supports_step_with_scale: + if self.supports_groups: + self.optimizer.step(closure, scale=scale, groups=groups) + else: + self.optimizer.step(closure, scale=scale) + else: + if scale != 1.0: + self.multiply_grads(1.0 / scale) + if self.supports_groups: + self.optimizer.step(closure, groups=groups) + else: + self.optimizer.step(closure) + + def zero_grad(self): + """Clears the gradients of all optimized parameters.""" + for p in self.params: + p.grad = None + self.optimizer.zero_grad() + + @property + def supports_memory_efficient_fp16(self): + if hasattr(self.optimizer, "supports_memory_efficient_fp16"): + return self.optimizer.supports_memory_efficient_fp16 + return False + + @property + def supports_step_with_scale(self): + if hasattr(self.optimizer, "supports_step_with_scale"): + return self.optimizer.supports_step_with_scale + return False + + @property + def supports_groups(self): + if hasattr(self.optimizer, "supports_groups"): + return self.optimizer.supports_groups + return False + + @property + def supports_flat_params(self): + """ + Whether the optimizer supports collapsing of the model + parameters/gradients into a single contiguous Tensor. + """ + if hasattr(self.optimizer, "supports_flat_params"): + return self.optimizer.supports_flat_params + return False + + def average_params(self): + pass + + def broadcast_global_state_dict(self, state_dict): + """ + Broadcasts a global state dict to all ranks. + Useful for optimizers that shard state between ranks. + """ + if hasattr(self.optimizer, "broadcast_global_state_dict"): + return self.optimizer.broadcast_global_state_dict(state_dict) + else: + return state_dict + + +class LegacyFairseqOptimizer(FairseqOptimizer): + def __init__(self, args): + self.args = args diff --git a/fairseq/fairseq/optim/fp16_optimizer.py b/fairseq/fairseq/optim/fp16_optimizer.py new file mode 100644 index 0000000000000000000000000000000000000000..6a4da342cab8526ca09893cfd1fbee82be955faa --- /dev/null +++ b/fairseq/fairseq/optim/fp16_optimizer.py @@ -0,0 +1,558 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from collections import defaultdict +from itertools import chain + +import torch +from omegaconf import DictConfig + +from fairseq import optim + +from .dynamic_loss_scaler import DynamicLossScaler + + +class _FP16OptimizerMixin(object): + def __init__(self, *args, **kwargs): + # forward __init__ call to the next class in mro(method resolution order) + super().__init__(*args, **kwargs) + self._multiply_factor = 1.0 + + @property + def has_flat_params(self): + return torch.is_tensor(self.fp32_params) or ( + isinstance(self.fp32_params, dict) + and all(torch.is_tensor(t) for t in self.fp32_params.values()) + ) + + @classmethod + def build_fp32_params(cls, args, params, flatten=True): + # create FP32 copy of parameters and grads + if flatten: + is_pipeline_parallel = getattr( + args, "pipeline_model_parallel", False + ) and getattr(args, "distributed_no_spawn", False) + total_param_size = sum(p.data.numel() for p in params) + devices = [torch.cuda.current_device()] + if is_pipeline_parallel: + devices = list(set(args.pipeline_devices)) + fp32_params = {} + for device in devices: + if is_pipeline_parallel: + device_param_size = sum( + p.data.numel() for p in params if p.device.index == device + ) + device_params = [p for p in params if p.device.index == device] + else: + device_param_size = total_param_size + device_params = params + fp32_params[device] = ( + device_params[0].new(0).float().new(device_param_size) + ) + offset = 0 + for p in device_params: + numel = p.data.numel() + fp32_params[device][offset : offset + numel].copy_(p.data.view(-1)) + offset += numel + fp32_params[device] = torch.nn.Parameter(fp32_params[device]) + fp32_params[device].grad = fp32_params[device].data.new( + device_param_size + ) + return fp32_params + else: + fp32_params = [] + for p in params: + p32 = torch.nn.Parameter(p.data.float()) + if hasattr(p, "expert"): + p32.expert = True + elif hasattr(p, "base_expert"): + p32.base_expert = True + p32.grad = torch.zeros_like(p32.data) + if hasattr(p, "param_group"): + p32.param_group = p.param_group + if hasattr(p, "optim_overrides"): + p32.optim_overrides = p.optim_overrides + fp32_params.append(p32) + return fp32_params + + def state_dict(self): + """Return the optimizer's state dict.""" + state_dict = self.fp32_optimizer.state_dict() + if self.scaler is not None: + state_dict["loss_scale"] = self.scaler.loss_scale + return state_dict + + def load_state_dict(self, state_dict, optimizer_overrides=None): + """Load an optimizer state dict. + + In general we should prefer the configuration of the existing optimizer + instance (e.g., learning rate) over that found in the state_dict. This + allows us to resume training from a checkpoint using a new set of + optimizer args. + """ + if "loss_scale" in state_dict and self.scaler is not None: + self.scaler.loss_scale = state_dict["loss_scale"] + self.fp32_optimizer.load_state_dict(state_dict, optimizer_overrides) + + def backward(self, loss): + """Computes the sum of gradients of the given tensor w.r.t. graph leaves. + + Compared to :func:`fairseq.optim.FairseqOptimizer.backward`, this + function additionally dynamically scales the loss to avoid gradient + underflow. + """ + if self.scaler is not None: + loss = self.scaler.scale(loss) + loss.backward() + self._needs_sync = True + + def _sync_fp16_grads_to_fp32(self): + if self._needs_sync: + # copy FP16 grads to FP32 + if self.has_flat_params: + devices = list(self.fp32_params.keys()) + device_params_dict = defaultdict(list) + for p in self.fp16_params: + if p.requires_grad: + device_params_dict[p.device.index].append(p) + for device in devices: + device_params = device_params_dict[device] + offset = 0 + for p in device_params: + grad_data = ( + p.grad.data + if p.grad is not None + else p.data.new_zeros(p.data.shape) + ) + numel = grad_data.numel() + self.fp32_params[device].grad.data[ + offset : offset + numel + ].copy_(grad_data.view(-1)) + offset += numel + else: + for p, p32 in zip(self.fp16_params, self.fp32_params): + if not p.requires_grad: + continue + if p.grad is not None: + if p32.grad is None: + p32.grad = p.grad.data.float() + else: + p32.grad.data.copy_(p.grad.data) + else: + p32.grad = torch.zeros_like(p.data, dtype=torch.float) + + self._needs_sync = False + + def _sync_fp32_params_to_fp16(self): + # copy FP32 params back into FP16 model + if self.has_flat_params: + devices = list(self.fp32_params.keys()) + device_params_dict = defaultdict(list) + for p in self.fp16_params: + device_params_dict[p.device.index].append(p) + for device in devices: + device_params = device_params_dict[device] + offset = 0 + for p in device_params: + numel = p.data.numel() + p.data.copy_( + self.fp32_params[device] + .data[offset : offset + numel] + .view_as(p.data) + ) + offset += numel + else: + for p, p32 in zip(self.fp16_params, self.fp32_params): + if not p.requires_grad: + continue + p.data.copy_(p32.data) + + def _unscale_grads(self): + self._sync_fp16_grads_to_fp32() + if ( + # Skip the multiplication if it's a no-op (i.e., if _multiply_factor + # is 1.0). At the same time, we want to avoid the device-to-host + # transfer by comparing it to 1.0. Since _multiply_factor starts as + # a Python float, we roughly assume that if it's a tensor then it's + # probably not =1.0 anymore and we do the multiplication. Otherwise + # we can safely check the value without a D2H transfer. + torch.is_tensor(self._multiply_factor) + or self._multiply_factor != 1.0 + ): + self.fp32_optimizer.multiply_grads(self._multiply_factor) + self._multiply_factor = 1.0 + + def multiply_grads(self, c): + """Multiplies grads by a constant ``c``.""" + self._multiply_factor *= c + + def clip_grad_norm(self, max_norm, aggregate_norm_fn=None): + """Clips gradient norm and updates dynamic loss scaler.""" + self._sync_fp16_grads_to_fp32() + + grad_norm = self._multiply_factor * self.fp32_optimizer.clip_grad_norm( + 0, aggregate_norm_fn + ) + + if torch.is_tensor(self._multiply_factor): + self._multiply_factor = self._multiply_factor.to(grad_norm.device) + + if self.scaler is not None: + if grad_norm > max_norm > 0.0: + self._multiply_factor *= max_norm / grad_norm + + self.scaler.check_overflow(grad_norm) + elif max_norm > 0.0: + clip_coef = (max_norm / (grad_norm + 1e-6)).clamp_(max=1) + self._multiply_factor *= clip_coef + + return grad_norm + + def step(self, closure=None, groups=None): + """Performs a single optimization step.""" + self._sync_fp16_grads_to_fp32() + + if getattr(self, "supports_step_with_scale", False): + self.fp32_optimizer.step( + closure, scale=(1.0 / self._multiply_factor), groups=groups + ) + else: + self._unscale_grads() + self.fp32_optimizer.step(closure, groups=groups) + + if self.scaler is not None: + self.scaler.update() + + self._sync_fp32_params_to_fp16() + + def zero_grad(self): + """Clears the gradients of all optimized parameters.""" + for p in self.fp16_params: + p.grad = None + if self.has_flat_params: + if torch.is_tensor(self.fp32_params): + self.fp32_params.grad.zero_() + elif isinstance(self.fp32_params, dict): + for fp32_params in self.fp32_params.values(): + fp32_params.grad.zero_() + else: + raise RuntimeError("self.fp32_params must be a tensor or dict") + else: + for p32 in self.fp32_params: + if p32.grad is not None: + p32.grad.zero_() + self._needs_sync = False + + if self.scaler is not None: + self._multiply_factor = 1.0 / float(self.scaler.loss_scale) + + +class FP16Optimizer(_FP16OptimizerMixin, optim.FairseqOptimizer): + """ + Wrap an *optimizer* to support FP16 (mixed precision) training. + """ + + def __init__(self, cfg: DictConfig, params, fp32_optimizer, fp32_params, **kwargs): + super().__init__(cfg.optimizer) + self.fp16_params = params + self.fp32_optimizer = fp32_optimizer + self.fp32_params = fp32_params + + if getattr(cfg.common, "fp16_scale_window", None) is None: + if len(cfg.optimization.update_freq) > 1: + raise ValueError( + "--fp16-scale-window must be given explicitly when using a " + "custom --update-freq schedule" + ) + data_parallel_size = int( + cfg.distributed_training.distributed_world_size + / cfg.common.model_parallel_size + ) + scale_window = int( + 2**14 / data_parallel_size / cfg.optimization.update_freq[0] + ) + else: + scale_window = cfg.common.fp16_scale_window + + if not getattr(cfg.common, "bf16", False): + self.scaler = DynamicLossScaler( + init_scale=cfg.common.fp16_init_scale, + scale_window=scale_window, + tolerance=cfg.common.fp16_scale_tolerance, + threshold=cfg.common.threshold_loss_scale, + min_loss_scale=cfg.common.min_loss_scale, + ) + else: + # disable loss scaling for bfloat16 + self.scaler = None + + @classmethod + def build_optimizer(cls, cfg: DictConfig, params, **kwargs): + """ + Args: + cfg (omegaconf.DictConfig): fairseq args + params (iterable): iterable of parameters to optimize + """ + flatten = not getattr(cfg.common, "fp16_no_flatten_grads", False) + if getattr(cfg.common, "bf16", False): + flatten = False # mixed precision is faster on TPUs without flat grads + fp32_params = cls.build_fp32_params(cfg.optimizer, params, flatten=flatten) + if flatten: + fp32_optimizer = optim.build_optimizer(cfg.optimizer, [fp32_params]) + else: + fp32_optimizer = optim.build_optimizer(cfg.optimizer, fp32_params) + if flatten and not fp32_optimizer.supports_flat_params: + raise RuntimeError( + f"chosen optimizer {fp32_optimizer.__class__.__name__} does not support flat params, please set --fp16-no-flatten-grads" + ) + return cls(cfg, params, fp32_optimizer, fp32_params, **kwargs) + + @property + def optimizer(self): + return self.fp32_optimizer.optimizer + + @optimizer.setter + def optimizer(self, optimizer): + self.fp32_optimizer.optimizer = optimizer + + @property + def lr_scheduler(self): + return getattr(self.fp32_optimizer, "lr_scheduler", None) + + @property + def optimizer_config(self): + return self.fp32_optimizer.optimizer_config + + def get_lr(self): + return self.fp32_optimizer.get_lr() + + def set_lr(self, lr): + self.fp32_optimizer.set_lr(lr) + + def all_reduce_grads(self, module): + self.fp32_optimizer.all_reduce_grads(module) + + @property + def supports_flat_params(self): + return self.fp32_optimizer.supports_flat_params + + +class _MemoryEfficientFP16OptimizerMixin(object): + def __init__(self, *args, **kwargs): + # forward __init__ call to the next class in MRO (method resolution order) + super().__init__(*args, **kwargs) + self._multiply_factor = 1.0 + + @property + def has_flat_params(self): + return False + + def state_dict(self): + """Return the optimizer's state dict.""" + state_dict = self.wrapped_optimizer.state_dict() + if self.scaler is not None: + state_dict["loss_scale"] = self.scaler.loss_scale + return state_dict + + def load_state_dict(self, state_dict, optimizer_overrides=None): + """Load an optimizer state dict. + + In general we should prefer the configuration of the existing optimizer + instance (e.g., learning rate) over that found in the state_dict. This + allows us to resume training from a checkpoint using a new set of + optimizer args. + """ + if "loss_scale" in state_dict and self.scaler is not None: + self.scaler.loss_scale = state_dict["loss_scale"] + + self.wrapped_optimizer.load_state_dict(state_dict, optimizer_overrides) + + # Hack: PyTorch automatically casts the optimizer state to match the + # type of the current parameters. But with --memory-efficient-fp16 the + # params are FP16 while the optimizer state is FP32 and we don't want + # to cast. A workaround is to manually copy back the original state + # after the optimizer has been loaded. + if not getattr(self.optimizer, "disable_mem_eff_fp16_loading_hack", False): + groups = self.optimizer.param_groups + saved_groups = state_dict["param_groups"] + id_map = { + old_id: p + for old_id, p in zip( + chain(*(g["params"] for g in saved_groups)), + chain(*(g["params"] for g in groups)), + ) + } + for k, v in state_dict["state"].items(): + if k in id_map: + param = id_map[k] + self.optimizer.state[param] = v + + def backward(self, loss): + """Computes the sum of gradients of the given tensor w.r.t. graph leaves. + + Compared to :func:`fairseq.optim.FairseqOptimizer.backward`, this + function additionally dynamically scales the loss to avoid gradient + underflow. + """ + if self.scaler is not None: + loss = self.scaler.scale(loss) + loss.backward() + + def _unscale_grads(self): + if ( + # Skip the multiplication if it's a no-op (i.e., if _multiply_factor + # is 1.0). At the same time, we want to avoid the device-to-host + # transfer by comparing it to 1.0. Since _multiply_factor starts as + # a Python float, we roughly assume that if it's a tensor then it's + # probably not =1.0 anymore and we do the multiplication. Otherwise + # we can safely check the value without a D2H transfer. + torch.is_tensor(self._multiply_factor) + or self._multiply_factor != 1.0 + ): + self.wrapped_optimizer.multiply_grads(self._multiply_factor) + self._multiply_factor = 1.0 + + def multiply_grads(self, c): + """Multiplies grads by a constant *c*.""" + self._multiply_factor *= c + + def clip_grad_norm(self, max_norm, aggregate_norm_fn=None): + """Clips gradient norm and updates dynamic loss scaler.""" + max_norm = float(max_norm) + grad_norm = self._multiply_factor * self.wrapped_optimizer.clip_grad_norm( + 0, aggregate_norm_fn + ) + + if self.scaler is not None: + grad_norm_cpu = float(grad_norm) + if grad_norm_cpu > max_norm > 0.0: + self._multiply_factor *= max_norm / grad_norm_cpu + + # detect overflow and adjust loss scale + self.scaler.check_overflow(grad_norm_cpu) + elif max_norm > 0.0: + clip_coef = (max_norm / (grad_norm + 1e-6)).clamp_(max=1) + self._multiply_factor *= clip_coef + + return grad_norm + + def step(self, closure=None, groups=None): + """Performs a single optimization step.""" + if getattr(self, "supports_step_with_scale", False): + # NOTE(msb) optimizer divides by scale factor + self.wrapped_optimizer.step( + closure, scale=(1.0 / self._multiply_factor), groups=groups + ) + else: + self._unscale_grads() + self.wrapped_optimizer.step(closure, groups=groups) + + if self.scaler is not None: + self.scaler.update() + + def zero_grad(self): + """Clears the gradients of all optimized parameters.""" + self.wrapped_optimizer.zero_grad() + if self.scaler is not None: + self._multiply_factor = 1.0 / float(self.scaler.loss_scale) + else: + self._multiply_factor = 1.0 + + @property + def supports_flat_params(self): + return self.wrapped_optimizer.supports_flat_params + + +class MemoryEfficientFP16Optimizer( + _MemoryEfficientFP16OptimizerMixin, optim.FairseqOptimizer +): + """ + Wrap an *optimizer* to support FP16 (mixed precision) training. + + Compared to :class:`fairseq.optim.FP16Optimizer`, this version does not + maintain an FP32 copy of the model. We instead expect the optimizer to + convert the gradients to FP32 internally and sync the results back to the + FP16 model params. This significantly reduces memory usage but slightly + increases the time spent in the optimizer. + + Since this wrapper depends on specific functionality in the wrapped + optimizer (i.e., on-the-fly conversion of grads to FP32), only certain + optimizers can be wrapped. This is determined by the + *supports_memory_efficient_fp16* property. + """ + + def __init__( + self, cfg: DictConfig, params, optimizer, allow_unsupported=False, **kwargs + ): + if not allow_unsupported and not optimizer.supports_memory_efficient_fp16: + raise ValueError( + "Unsupported optimizer: {}".format(optimizer.__class__.__name__) + ) + + super().__init__(getattr(cfg, "optimizer", None)) + self.wrapped_optimizer = optimizer + + if getattr(cfg.common, "fp16_scale_window", None) is None: + if len(cfg.optimization.update_freq) > 1: + raise ValueError( + "--fp16-scale-window must be given explicitly when using a " + "custom --update-freq schedule" + ) + data_parallel_size = int( + cfg.distributed_training.distributed_world_size + / cfg.common.model_parallel_size + ) + scale_window = int( + 2**14 / data_parallel_size / cfg.optimization.update_freq[0] + ) + else: + scale_window = cfg.common.fp16_scale_window + + if not getattr(cfg.common, "bf16", False): + self.scaler = DynamicLossScaler( + init_scale=cfg.common.fp16_init_scale, + scale_window=scale_window, + tolerance=cfg.common.fp16_scale_tolerance, + threshold=cfg.common.threshold_loss_scale, + min_loss_scale=cfg.common.min_loss_scale, + ) + else: + # disable loss scaling for bfloat16 + self.scaler = None + + @classmethod + def build_optimizer(cls, cfg: DictConfig, params, **kwargs): + """ + Args: + args (argparse.Namespace): fairseq args + params (iterable): iterable of parameters to optimize + """ + fp16_optimizer = optim.build_optimizer(cfg.optimizer, params) + return cls(cfg, params, fp16_optimizer, **kwargs) + + @property + def optimizer(self): + return self.wrapped_optimizer.optimizer + + @optimizer.setter + def optimizer(self, optimizer): + self.wrapped_optimizer.optimizer = optimizer + + @property + def optimizer_config(self): + return self.wrapped_optimizer.optimizer_config + + @property + def lr_scheduler(self): + return getattr(self.wrapped_optimizer, "lr_scheduler", None) + + def get_lr(self): + return self.wrapped_optimizer.get_lr() + + def set_lr(self, lr): + self.wrapped_optimizer.set_lr(lr) + + def all_reduce_grads(self, module): + self.wrapped_optimizer.all_reduce_grads(module) diff --git a/fairseq/fairseq/optim/fused_lamb.py b/fairseq/fairseq/optim/fused_lamb.py new file mode 100644 index 0000000000000000000000000000000000000000..f4f2bdb0c6c65f7758509b6d4d2f2c48cb6e8b4f --- /dev/null +++ b/fairseq/fairseq/optim/fused_lamb.py @@ -0,0 +1,51 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from fairseq.optim import LegacyFairseqOptimizer, register_optimizer + + +@register_optimizer("lamb") +class FairseqLAMB(LegacyFairseqOptimizer): + """LAMB optimizer.""" + + def __init__(self, args, params): + super().__init__(args) + try: + from apex.optimizers import FusedLAMB + + self._optimizer = FusedLAMB(params, **self.optimizer_config) + except ImportError: + raise ImportError("Please install apex to use LAMB optimizer") + + @staticmethod + def add_args(parser): + """Add optimizer-specific arguments to the parser.""" + # fmt: off + parser.add_argument('--lamb-betas', default='(0.9, 0.999)', metavar='B', + help='betas for LAMB optimizer') + parser.add_argument('--lamb-eps', type=float, default=1e-8, metavar='D', + help='epsilon for LAMB optimizer') + parser.add_argument('--weight-decay', '--wd', default=0.0, type=float, metavar='WD', + help='weight decay') + # fmt: on + + @property + def optimizer_config(self): + """ + Return a kwarg dictionary that will be used to override optimizer + args stored in checkpoints. This allows us to load a checkpoint and + resume training using a different set of optimizer args, e.g., with a + different learning rate. + """ + return { + "lr": self.args.lr[0], + "betas": eval(self.args.lamb_betas), + "eps": self.args.lamb_eps, + "weight_decay": self.args.weight_decay, + } + + @property + def supports_flat_params(self): + return False diff --git a/fairseq/fairseq/optim/lr_scheduler/__init__.py b/fairseq/fairseq/optim/lr_scheduler/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5b3dbc023aa4a6f7bfb8403b8204d71ca432f79c --- /dev/null +++ b/fairseq/fairseq/optim/lr_scheduler/__init__.py @@ -0,0 +1,36 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +"""isort:skip_file""" + +import importlib +import os + +from fairseq import registry +from fairseq.optim.lr_scheduler.fairseq_lr_scheduler import ( # noqa + FairseqLRScheduler, + LegacyFairseqLRScheduler, +) +from omegaconf import DictConfig + + +( + build_lr_scheduler_, + register_lr_scheduler, + LR_SCHEDULER_REGISTRY, + LR_SCHEDULER_DATACLASS_REGISTRY, +) = registry.setup_registry( + "--lr-scheduler", base_class=FairseqLRScheduler, default="fixed" +) + + +def build_lr_scheduler(cfg: DictConfig, optimizer): + return build_lr_scheduler_(cfg, optimizer) + + +# automatically import any Python files in the optim/lr_scheduler/ directory +for file in sorted(os.listdir(os.path.dirname(__file__))): + if file.endswith(".py") and not file.startswith("_"): + file_name = file[: file.find(".py")] + importlib.import_module("fairseq.optim.lr_scheduler." + file_name) diff --git a/fairseq/fairseq/optim/lr_scheduler/__pycache__/cosine_lr_scheduler.cpython-310.pyc b/fairseq/fairseq/optim/lr_scheduler/__pycache__/cosine_lr_scheduler.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9042723f64840a2314b34d0dbfd5f9dc3be9f3c2 Binary files /dev/null and b/fairseq/fairseq/optim/lr_scheduler/__pycache__/cosine_lr_scheduler.cpython-310.pyc differ diff --git a/fairseq/fairseq/optim/lr_scheduler/__pycache__/inverse_square_root_schedule.cpython-310.pyc b/fairseq/fairseq/optim/lr_scheduler/__pycache__/inverse_square_root_schedule.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f9a431db368bb49229741ad3b1f1e2d90a50f37c Binary files /dev/null and b/fairseq/fairseq/optim/lr_scheduler/__pycache__/inverse_square_root_schedule.cpython-310.pyc differ diff --git a/fairseq/fairseq/optim/lr_scheduler/__pycache__/polynomial_decay_schedule.cpython-310.pyc b/fairseq/fairseq/optim/lr_scheduler/__pycache__/polynomial_decay_schedule.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..96de6fdc49e8d0e9d7c45877c097a593de3f91bd Binary files /dev/null and b/fairseq/fairseq/optim/lr_scheduler/__pycache__/polynomial_decay_schedule.cpython-310.pyc differ diff --git a/fairseq/fairseq/optim/lr_scheduler/__pycache__/reduce_lr_on_plateau.cpython-310.pyc b/fairseq/fairseq/optim/lr_scheduler/__pycache__/reduce_lr_on_plateau.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6b361413d68f4bba159805d18a6d81055bccd249 Binary files /dev/null and b/fairseq/fairseq/optim/lr_scheduler/__pycache__/reduce_lr_on_plateau.cpython-310.pyc differ diff --git a/fairseq/fairseq/optim/lr_scheduler/__pycache__/step_lr_scheduler.cpython-310.pyc b/fairseq/fairseq/optim/lr_scheduler/__pycache__/step_lr_scheduler.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..92e2972ea1eb26c03ed053497b6bea8fb22a5d2f Binary files /dev/null and b/fairseq/fairseq/optim/lr_scheduler/__pycache__/step_lr_scheduler.cpython-310.pyc differ diff --git a/fairseq/fairseq/optim/lr_scheduler/__pycache__/tri_stage_lr_scheduler.cpython-310.pyc b/fairseq/fairseq/optim/lr_scheduler/__pycache__/tri_stage_lr_scheduler.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9a7524d0059ba6a789c2989456e93ac71868d437 Binary files /dev/null and b/fairseq/fairseq/optim/lr_scheduler/__pycache__/tri_stage_lr_scheduler.cpython-310.pyc differ diff --git a/fairseq/fairseq/optim/lr_scheduler/__pycache__/triangular_lr_scheduler.cpython-310.pyc b/fairseq/fairseq/optim/lr_scheduler/__pycache__/triangular_lr_scheduler.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..748d37476d29b4c828709e5c91d4fe59ba62fa9f Binary files /dev/null and b/fairseq/fairseq/optim/lr_scheduler/__pycache__/triangular_lr_scheduler.cpython-310.pyc differ diff --git a/fairseq/fairseq/optim/lr_scheduler/cosine_lr_scheduler.py b/fairseq/fairseq/optim/lr_scheduler/cosine_lr_scheduler.py new file mode 100644 index 0000000000000000000000000000000000000000..5fcaea25d493c3f33bab6b9aa65d50450ca446e3 --- /dev/null +++ b/fairseq/fairseq/optim/lr_scheduler/cosine_lr_scheduler.py @@ -0,0 +1,146 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import math +from collections.abc import Collection +from dataclasses import dataclass, field +from typing import List + +from omegaconf import II + +from fairseq.dataclass import FairseqDataclass +from fairseq.optim.lr_scheduler import FairseqLRScheduler, register_lr_scheduler + + +@dataclass +class CosineLRScheduleConfig(FairseqDataclass): + warmup_updates: int = field( + default=0, + metadata={"help": "warmup the learning rate linearly for the first N updates"}, + ) + warmup_init_lr: float = field( + default=-1, + metadata={ + "help": "initial learning rate during warmup phase; default is cfg.lr" + }, + ) + lr: List[float] = field( + default=II("optimization.lr"), + metadata={"help": "max learning rate, must be more than cfg.min_lr"}, + ) + min_lr: float = field(default=0.0, metadata={"help": "min learning rate"}) + t_mult: float = field( + default=1.0, metadata={"help": "factor to grow the length of each period"} + ) + lr_period_updates: float = field( + default=-1, metadata={"help": "initial number of updates per period"} + ) + lr_shrink: float = field( + default=0.1, metadata={"help": "shrink factor for annealing"} + ) + # This is not required, but is for convenience in inferring lr_period_updates + max_update: int = II("optimization.max_update") + + +@register_lr_scheduler("cosine", dataclass=CosineLRScheduleConfig) +class CosineLRSchedule(FairseqLRScheduler): + """Assign LR based on a cyclical schedule that follows the cosine function. + + See https://arxiv.org/pdf/1608.03983.pdf for details. + + We also support a warmup phase where we linearly increase the learning rate + from some initial learning rate (``--warmup-init-lr``) until the configured + max learning rate (``--lr``). + + During warmup:: + + lrs = torch.linspace(cfg.warmup_init_lr, cfg.lr, cfg.warmup_updates) + lr = lrs[update_num] + + After warmup:: + + lr = cfg.min_lr + 0.5*(cfg.lr - cfg.min_lr)*(1 + cos(t_curr / t_i)) + + where ``t_curr`` is current percentage of updates within the current period + range and ``t_i`` is the current period range, which is scaled by ``t_mul`` + after every iteration. + """ + + def __init__(self, cfg: CosineLRScheduleConfig, fairseq_optimizer): + super().__init__(cfg, fairseq_optimizer) + if isinstance(cfg.lr, Collection) and len(cfg.lr) > 1: + raise ValueError( + "Cannot use a fixed learning rate schedule with cosine." + f" Consider --lr-scheduler=fixed instead. ({cfg.lr})" + ) + + self.max_lr = cfg.lr[0] if isinstance(cfg.lr, Collection) else cfg.lr + if self.max_lr < cfg.min_lr: + cfg.min_lr = self.max_lr + + warmup_end_lr = self.max_lr + if cfg.warmup_init_lr < 0: + cfg.warmup_init_lr = cfg.min_lr + + self.t_mult = cfg.t_mult + self.period = cfg.lr_period_updates + + if self.period <= 0: + assert ( + cfg.max_update > 0 + ), "Either --max_update or --lr-period-updates must be set" + self.period = cfg.max_update - cfg.warmup_updates + + if cfg.warmup_updates > 0: + # linearly warmup for the first cfg.warmup_updates + self.lr_step = (warmup_end_lr - cfg.warmup_init_lr) / cfg.warmup_updates + else: + self.lr_step = 1 + + self.warmup_updates = cfg.warmup_updates + self.lr_shrink = cfg.lr_shrink + + # initial learning rate + self.lr = cfg.warmup_init_lr + self.optimizer.set_lr(self.lr) + + def step(self, epoch, val_loss=None): + """Update the learning rate at the end of the given epoch.""" + super().step(epoch, val_loss) + # we don't change the learning rate at epoch boundaries + return self.optimizer.get_lr() + + def step_update(self, num_updates): + """Update the learning rate after each update.""" + if num_updates < self.cfg.warmup_updates: + self.lr = self.cfg.warmup_init_lr + num_updates * self.lr_step + else: + curr_updates = num_updates - self.cfg.warmup_updates + if self.t_mult != 1: + i = math.floor( + math.log( + 1 - curr_updates / self.period * (1 - self.t_mult), self.t_mult + ) + ) + t_i = self.t_mult**i * self.period + t_curr = ( + curr_updates + - (1 - self.t_mult**i) / (1 - self.t_mult) * self.period + ) + else: + i = math.floor(curr_updates / self.period) + t_i = self.period + t_curr = curr_updates - (self.period * i) + + lr_shrink = self.lr_shrink**i + min_lr = self.cfg.min_lr * lr_shrink + max_lr = self.max_lr * lr_shrink + + self.lr = min_lr + 0.5 * (max_lr - min_lr) * ( + 1 + math.cos(math.pi * t_curr / t_i) + ) + + self.optimizer.set_lr(self.lr) + return self.lr diff --git a/fairseq/fairseq/optim/lr_scheduler/fairseq_lr_scheduler.py b/fairseq/fairseq/optim/lr_scheduler/fairseq_lr_scheduler.py new file mode 100644 index 0000000000000000000000000000000000000000..6c12fa56b825e81bcc3fc7a97d206777418260ef --- /dev/null +++ b/fairseq/fairseq/optim/lr_scheduler/fairseq_lr_scheduler.py @@ -0,0 +1,59 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from argparse import Namespace + +from fairseq.dataclass.utils import gen_parser_from_dataclass +from fairseq.optim import FairseqOptimizer + + +class FairseqLRScheduler(object): + def __init__(self, cfg, optimizer): + super().__init__() + if optimizer is not None and not isinstance(optimizer, FairseqOptimizer): + raise ValueError("optimizer must be an instance of FairseqOptimizer") + self.cfg = cfg + self.optimizer = optimizer + self.best = None + + @classmethod + def add_args(cls, parser): + """Add arguments to the parser for this LR scheduler.""" + dc = getattr(cls, "__dataclass", None) + if dc is not None: + gen_parser_from_dataclass(parser, dc()) + + def state_dict(self): + """Return the LR scheduler state dict.""" + return {"best": self.best} + + def load_state_dict(self, state_dict): + """Load an LR scheduler state dict.""" + self.best = state_dict["best"] + + def step_begin_epoch(self, epoch): + """Update the learning rate at the beginning of the given epoch.""" + pass + + def step(self, epoch, val_loss=None): + """Update the learning rate at the end of the given epoch.""" + if val_loss is not None: + if self.best is None: + self.best = val_loss + else: + self.best = min(self.best, val_loss) + + def step_update(self, num_updates): + """Update the learning rate after each update.""" + return self.optimizer.get_lr() + + +class LegacyFairseqLRScheduler(FairseqLRScheduler): + def __init__(self, args: Namespace, optimizer): + if not isinstance(optimizer, FairseqOptimizer): + raise ValueError("optimizer must be an instance of FairseqOptimizer") + self.args = args + self.optimizer = optimizer + self.best = None diff --git a/fairseq/fairseq/optim/lr_scheduler/fixed_schedule.py b/fairseq/fairseq/optim/lr_scheduler/fixed_schedule.py new file mode 100644 index 0000000000000000000000000000000000000000..d0e7e14b7e72b1151f7d7f19094430bbab64f8f0 --- /dev/null +++ b/fairseq/fairseq/optim/lr_scheduler/fixed_schedule.py @@ -0,0 +1,76 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from dataclasses import dataclass, field +from typing import Optional, List +from omegaconf import II + +from fairseq.dataclass import FairseqDataclass +from fairseq.optim.lr_scheduler import FairseqLRScheduler, register_lr_scheduler + + +@dataclass +class FixedLRScheduleConfig(FairseqDataclass): + force_anneal: Optional[int] = field( + default=None, + metadata={"help": "force annealing at specified epoch"}, + ) + lr_shrink: float = field( + default=0.1, + metadata={"help": "shrink factor for annealing, lr_new = (lr * lr_shrink)"}, + ) + warmup_updates: int = field( + default=0, + metadata={"help": "warmup the learning rate linearly for the first N updates"}, + ) + lr: List[float] = II("optimization.lr") + + +@register_lr_scheduler("fixed", dataclass=FixedLRScheduleConfig) +class FixedLRSchedule(FairseqLRScheduler): + """Decay the LR on a fixed schedule.""" + + def __init__(self, cfg: FixedLRScheduleConfig, optimizer): + super().__init__(cfg, optimizer) + + self.lr = cfg.lr[0] + if cfg.warmup_updates > 0: + self.warmup_factor = 1.0 / cfg.warmup_updates + else: + self.warmup_factor = 1 + + def state_dict(self): + return {"lr": self.lr} + + def load_state_dict(self, state_dict): + if "lr" in state_dict: + self.lr = state_dict["lr"] + + def get_next_lr(self, epoch): + lrs = self.cfg.lr + if self.cfg.force_anneal is None or epoch < self.cfg.force_anneal: + # use fixed LR schedule + next_lr = lrs[min(epoch - 1, len(lrs) - 1)] + else: + # annneal based on lr_shrink + next_lr = lrs[-1] * self.cfg.lr_shrink ** ( + epoch + 1 - self.cfg.force_anneal + ) + return next_lr + + def step_begin_epoch(self, epoch): + """Update the learning rate at the beginning of the given epoch.""" + self.lr = self.get_next_lr(epoch) + self.optimizer.set_lr(self.warmup_factor * self.lr) + return self.optimizer.get_lr() + + def step_update(self, num_updates): + """Update the learning rate after each update.""" + if self.cfg.warmup_updates > 0 and num_updates < self.cfg.warmup_updates: + self.warmup_factor = (num_updates + 1) / float(self.cfg.warmup_updates) + self.optimizer.set_lr(self.warmup_factor * self.lr) + else: + self.optimizer.set_lr(self.lr) + return self.optimizer.get_lr() diff --git a/fairseq/fairseq/optim/lr_scheduler/inverse_square_root_schedule.py b/fairseq/fairseq/optim/lr_scheduler/inverse_square_root_schedule.py new file mode 100644 index 0000000000000000000000000000000000000000..987c905a23d50342dd7e809e0eddb5a6df2ebe90 --- /dev/null +++ b/fairseq/fairseq/optim/lr_scheduler/inverse_square_root_schedule.py @@ -0,0 +1,85 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from collections.abc import Collection +from dataclasses import dataclass, field +from typing import List + +from omegaconf import II + +from fairseq.dataclass import FairseqDataclass +from fairseq.optim.lr_scheduler import FairseqLRScheduler, register_lr_scheduler + + +@dataclass +class InverseSquareRootLRScheduleConfig(FairseqDataclass): + warmup_updates: int = field( + default=4000, + metadata={"help": "warmup the learning rate linearly for the first N updates"}, + ) + warmup_init_lr: float = field( + default=-1, + metadata={ + "help": "initial learning rate during warmup phase; default is cfg.lr" + }, + ) + lr: List[float] = II("optimization.lr") + + +@register_lr_scheduler("inverse_sqrt", dataclass=InverseSquareRootLRScheduleConfig) +class InverseSquareRootSchedule(FairseqLRScheduler): + """Decay the LR based on the inverse square root of the update number. + + We also support a warmup phase where we linearly increase the learning rate + from some initial learning rate (``--warmup-init-lr``) until the configured + learning rate (``--lr``). Thereafter we decay proportional to the number of + updates, with a decay factor set to align with the configured learning rate. + + During warmup:: + + lrs = torch.linspace(cfg.warmup_init_lr, cfg.lr, cfg.warmup_updates) + lr = lrs[update_num] + + After warmup:: + + decay_factor = cfg.lr * sqrt(cfg.warmup_updates) + lr = decay_factor / sqrt(update_num) + """ + + def __init__(self, cfg: InverseSquareRootLRScheduleConfig, optimizer): + super().__init__(cfg, optimizer) + if isinstance(cfg.lr, Collection) and len(cfg.lr) > 1: + raise ValueError( + "Cannot use a fixed learning rate schedule with inverse_sqrt." + " Consider --lr-scheduler=fixed instead." + ) + warmup_end_lr = cfg.lr[0] if isinstance(cfg.lr, Collection) else cfg.lr + if cfg.warmup_init_lr < 0: + cfg.warmup_init_lr = 0 if cfg.warmup_updates > 0 else warmup_end_lr + + # linearly warmup for the first cfg.warmup_updates + self.lr_step = (warmup_end_lr - cfg.warmup_init_lr) / cfg.warmup_updates + + # then, decay prop. to the inverse square root of the update number + self.decay_factor = warmup_end_lr * cfg.warmup_updates**0.5 + + # initial learning rate + self.lr = cfg.warmup_init_lr + self.optimizer.set_lr(self.lr) + + def step(self, epoch, val_loss=None): + """Update the learning rate at the end of the given epoch.""" + super().step(epoch, val_loss) + # we don't change the learning rate at epoch boundaries + return self.optimizer.get_lr() + + def step_update(self, num_updates): + """Update the learning rate after each update.""" + if num_updates < self.cfg.warmup_updates: + self.lr = self.cfg.warmup_init_lr + num_updates * self.lr_step + else: + self.lr = self.decay_factor * num_updates**-0.5 + self.optimizer.set_lr(self.lr) + return self.lr diff --git a/fairseq/fairseq/optim/lr_scheduler/manual_lr_scheduler.py b/fairseq/fairseq/optim/lr_scheduler/manual_lr_scheduler.py new file mode 100644 index 0000000000000000000000000000000000000000..57edc256fd98e1671502ba18def54eae82a4a3ab --- /dev/null +++ b/fairseq/fairseq/optim/lr_scheduler/manual_lr_scheduler.py @@ -0,0 +1,121 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from . import LegacyFairseqLRScheduler, register_lr_scheduler +import logging +import ast + +logger = logging.getLogger(__name__) +logger.setLevel(logging.WARNING) + + +@register_lr_scheduler("manual") +class ManualSchedule(LegacyFairseqLRScheduler): + """Decay the LR on a manual schedule.""" + + def __init__(self, args, optimizer): + super().__init__(args, optimizer) + + self.epoch2lr = self.parse_manuallr_args(args.epoch2lr) + self.update2lr = self.parse_manuallr_args(args.update2lr) + logger.info("@@@ ManualSchedule epoch2lr={}".format(self.epoch2lr)) + logger.info("@@@ ManualSchedule update2lr={}".format(self.update2lr)) + + if 1 in self.epoch2lr: + self.lr = self.epoch2lr[1] + elif 1 in self.update2lr: + self.lr = self.update2lr[1] + else: + self.lr = args.lr[0] + self.optimizer.set_lr(self.lr) # Set the beginning of the epoch. + + def parse_manuallr_args(self, lr_args_str): + lr_dict = ast.literal_eval(lr_args_str.replace(" ", "")) + if not isinstance(lr_dict, dict): + raise ValueError("epoch2lr/update2lr must be abel to evaluated to a dict") + + lr_args = {} + logger.info("@@@ after parsing input dictionary lr_dict = {}".format(lr_dict)) + for key, val in lr_dict.items(): + if "," in key: + for k in key.split(","): + lr_args[int(k)] = float(val) + elif "-" in key: + s = int(key.split("-")[0]) + e = int(key.split("-")[1]) + for k in range(s, e + 1, 1): + lr_args[k] = float(val) + else: + lr_args[int(key)] = float(val) + + return lr_args + + @staticmethod + def add_args(parser): + """Add arguments to the parser for this LR scheduler.""" + # fmt: off + parser.add_argument( + "--epoch2lr", + type=str, + metavar="DICT", + default="{}", + help="a dictionary used to set lr for each epoch manually", + ) + parser.add_argument( + "--update2lr", + type=str, + metavar="DICT", + default="{}", + help="a dictionary used to set lr for each update manually", + ) + # fmt: on + + def state_dict(self): + return {"lr": self.lr} + + def load_state_dict(self, state_dict): + if "lr" in state_dict: + self.lr = state_dict["lr"] + + def get_next_lr(self, epoch): + manual_keys = [k for k in self.epoch2lr if k <= epoch] + if manual_keys: + manual_lr = self.epoch2lr[max(manual_keys)] + else: + logger.warning( + "@@@ epoch={} does not exist in manual lr input. epoch2lr={}...".format( + epoch, + list(self.epoch2lr.items())[ + : min(10, len(self.epoch2lr.keys()) - 1) + ], + ) + ) + manual_lr = self.optimizer.get_lr() + return manual_lr + + def step_begin_epoch(self, epoch): + """Update the learning rate at the beginning of the given epoch.""" + self.lr = self.get_next_lr(epoch) + self.optimizer.set_lr(self.lr) + return self.optimizer.get_lr() + + def step_update(self, num_updates): + """Update the learning rate after each update.""" + manual_keys = [k for k in self.update2lr if k <= num_updates] + if manual_keys: + manual_lr = self.update2lr[max(manual_keys)] + else: + logger.warning( + "epoch={} does not exist in manual lr input update2lr={}...".format( + num_updates, + list(self.update2lr.items())[ + : min(10, len(self.update2lr.keys()) - 1) + ], + ) + ) + manual_lr = self.optimizer.get_lr() + + self.optimizer.set_lr(manual_lr) + return self.optimizer.get_lr() diff --git a/fairseq/fairseq/optim/lr_scheduler/pass_through.py b/fairseq/fairseq/optim/lr_scheduler/pass_through.py new file mode 100644 index 0000000000000000000000000000000000000000..2f93db328c1de9b268e8ee1c0c1cad558fd089aa --- /dev/null +++ b/fairseq/fairseq/optim/lr_scheduler/pass_through.py @@ -0,0 +1,39 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from dataclasses import dataclass + +from fairseq.dataclass import FairseqDataclass +from fairseq.optim.lr_scheduler import FairseqLRScheduler, register_lr_scheduler + + +@dataclass +class PassThroughScheduleConfig(FairseqDataclass): + pass + + +@register_lr_scheduler("pass_through", dataclass=PassThroughScheduleConfig) +class PassThroughScheduleSchedule(FairseqLRScheduler): + """Delegate lr scheduling to the optimizer.""" + + def __init__(self, cfg: PassThroughScheduleConfig, optimizer): + super().__init__(cfg, optimizer) + assert ( + hasattr(optimizer, "lr_scheduler") and optimizer.lr_scheduler is not None + ), "Pass-through schedule can only be used with optimizers with their own schedulers" + + def state_dict(self): + return self.optimizer.lr_scheduler.state_dict() + + def load_state_dict(self, state_dict): + self.optimizer.lr_scheduler.load_state_dict(state_dict) + + def step_begin_epoch(self, epoch): + """Update the learning rate at the beginning of the given epoch.""" + return self.optimizer.lr_scheduler.step_begin_epoch(epoch) + + def step_update(self, num_updates): + """Update the learning rate after each update.""" + return self.optimizer.lr_scheduler.step_update(num_updates) diff --git a/fairseq/fairseq/optim/lr_scheduler/polynomial_decay_schedule.py b/fairseq/fairseq/optim/lr_scheduler/polynomial_decay_schedule.py new file mode 100644 index 0000000000000000000000000000000000000000..b8109a7c1e79cd057c355504d07bac5615c02ea9 --- /dev/null +++ b/fairseq/fairseq/optim/lr_scheduler/polynomial_decay_schedule.py @@ -0,0 +1,89 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from dataclasses import dataclass, field +from typing import Optional, List +from omegaconf import II + +from fairseq.dataclass import FairseqDataclass +from fairseq.optim.lr_scheduler import FairseqLRScheduler, register_lr_scheduler + + +@dataclass +class PolynomialDecayLRScheduleConfig(FairseqDataclass): + warmup_updates: int = field( + default=0, + metadata={"help": "warmup the learning rate linearly for the first N updates"}, + ) + force_anneal: Optional[int] = field( + default=None, + metadata={"help": "force annealing at specified epoch"}, + ) + end_learning_rate: float = field( + default=0.0, + metadata={"help": "learning rate to decay to"}, + ) + power: float = field( + default=1.0, + metadata={"help": "decay exponent"}, + ) + total_num_update: float = field( + default=II("optimization.max_update"), + metadata={"help": "total number of updates over which to decay learning rate"}, + ) + lr: List[float] = II("optimization.lr") + + +@register_lr_scheduler("polynomial_decay", dataclass=PolynomialDecayLRScheduleConfig) +class PolynomialDecayLRSchedule(FairseqLRScheduler): + """Decay the LR on a fixed schedule.""" + + def __init__(self, cfg: PolynomialDecayLRScheduleConfig, optimizer): + super().__init__(cfg, optimizer) + + assert cfg.total_num_update > 0 + + self.lr = cfg.lr[0] + if cfg.warmup_updates > 0: + self.warmup_factor = 1.0 / cfg.warmup_updates + else: + self.warmup_factor = 1 + self.end_learning_rate = cfg.end_learning_rate + self.total_num_update = cfg.total_num_update + self.power = cfg.power + self.optimizer.set_lr(self.warmup_factor * self.lr) + + def get_next_lr(self, epoch): + lrs = self.cfg.lr + if self.cfg.force_anneal is None or epoch < self.cfg.force_anneal: + # use fixed LR schedule + next_lr = lrs[min(epoch, len(lrs) - 1)] + else: + # annneal based on lr_shrink + next_lr = self.optimizer.get_lr() + return next_lr + + def step_begin_epoch(self, epoch): + """Update the learning rate at the beginning of the given epoch.""" + self.lr = self.get_next_lr(epoch) + self.optimizer.set_lr(self.warmup_factor * self.lr) + return self.optimizer.get_lr() + + def step_update(self, num_updates): + """Update the learning rate after each update.""" + if self.cfg.warmup_updates > 0 and num_updates <= self.cfg.warmup_updates: + self.warmup_factor = num_updates / float(self.cfg.warmup_updates) + lr = self.warmup_factor * self.lr + elif num_updates >= self.total_num_update: + lr = self.end_learning_rate + else: + warmup = self.cfg.warmup_updates + lr_range = self.lr - self.end_learning_rate + pct_remaining = 1 - (num_updates - warmup) / ( + self.total_num_update - warmup + ) + lr = lr_range * pct_remaining ** (self.power) + self.end_learning_rate + self.optimizer.set_lr(lr) + return self.optimizer.get_lr() diff --git a/fairseq/fairseq/optim/lr_scheduler/reduce_lr_on_plateau.py b/fairseq/fairseq/optim/lr_scheduler/reduce_lr_on_plateau.py new file mode 100644 index 0000000000000000000000000000000000000000..5ee9c1be4a59ad3d072412827ab4e9b62dc7434e --- /dev/null +++ b/fairseq/fairseq/optim/lr_scheduler/reduce_lr_on_plateau.py @@ -0,0 +1,143 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from dataclasses import dataclass, field +from typing import List + +import torch.optim.lr_scheduler +from omegaconf import II + +from fairseq.dataclass import FairseqDataclass +from fairseq.optim.lr_scheduler import FairseqLRScheduler, register_lr_scheduler + + +@dataclass +class ReduceLROnPlateauLRScheduleConfig(FairseqDataclass): + lr_shrink: float = field( + default=0.1, metadata={"help": "shrink factor for annealing"} + ) + lr_threshold: float = field( + default=1e-4, + metadata={ + "help": ( + "threshold for measuring the new optimum, to only focus on " + "significant changes" + ) + }, + ) + lr_patience: int = field( + default=0, + metadata={ + "help": ( + "number of epochs with no improvement after which learning rate will " + "be reduced" + ) + }, + ) + warmup_updates: int = field( + default=0, + metadata={"help": "warmup the learning rate linearly for the first N updates"}, + ) + warmup_init_lr: float = field( + default=-1, + metadata={ + "help": "initial learning rate during warmup phase; default is cfg.lr" + }, + ) + lr: List[float] = II("optimization.lr") + maximize_best_checkpoint_metric: bool = II( + "checkpoint.maximize_best_checkpoint_metric" + ) + + +@register_lr_scheduler( + "reduce_lr_on_plateau", dataclass=ReduceLROnPlateauLRScheduleConfig +) +class ReduceLROnPlateauLRSchedule(FairseqLRScheduler): + """ + Decay the LR by a factor every time the validation loss plateaus. + Also comes with optional warmup phase, where we linearly increase + the learning rate from some initial learning rate + (``--warmup-init-lr``) until the configured learning rate + (``--lr``). Thereafter the lr is adjusted according to original + reduce_on_plateau scheme. + + During warmup:: + + lrs = torch.linspace( + cfg.warmup_init_lr, cfg.lr, cfg.warmup_updates + ) + lr = lrs[update_num] + """ + + def __init__(self, cfg: ReduceLROnPlateauLRScheduleConfig, optimizer): + super().__init__(cfg, optimizer) + if len(cfg.lr) > 1: + raise ValueError( + "Cannot use a fixed learning rate schedule with reduce_lr_on_plateau." + " Consider --lr-scheduler=fixed instead." + ) + self.lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( + self.optimizer.optimizer, + patience=cfg.lr_patience, + factor=cfg.lr_shrink, + mode="max" if cfg.maximize_best_checkpoint_metric else "min", + threshold=cfg.lr_threshold, + ) + warmup_end_lr = cfg.lr[0] + # if no warm up, sets initial lr to be cfg.lr[0] + if cfg.warmup_init_lr < 0: + cfg.warmup_init_lr = 0 if cfg.warmup_updates > 0 else warmup_end_lr + + # linearly warmup for the first cfg.warmup_updates + if cfg.warmup_updates > 0: + self.lr_step = (warmup_end_lr - cfg.warmup_init_lr) / cfg.warmup_updates + + # this flag is either set from arg when no warm up, or set by + # step_update() when warmup finishes + self.warmup_end = True if cfg.warmup_updates <= 0 else False + + # initial learning rate + # this self.lr is used only during init and/or warm up period + self.lr = warmup_end_lr if self.warmup_end else cfg.warmup_init_lr + self.optimizer.set_lr(self.lr) + + def state_dict(self): + """Return the LR scheduler state dict.""" + return { + "best": self.lr_scheduler.best, + "last_epoch": self.lr_scheduler.last_epoch, + } + + def load_state_dict(self, state_dict): + """Load an LR scheduler state dict.""" + self.lr_scheduler.best = state_dict["best"] + if "last_epoch" in state_dict: + self.lr_scheduler.last_epoch = state_dict["last_epoch"] + + def step(self, epoch, val_loss=None): + """ + Update the learning rate at the end of the given epoch if warmup + finishes otherwise no update of lr on epoch boundaries + """ + if val_loss is not None and self.warmup_end is True: + self.lr_scheduler.step(val_loss) + else: + self.lr_scheduler.last_epoch = epoch + return self.optimizer.get_lr() + + def step_update(self, num_updates): + """ + Update the learning rate after each update.""" + # if there is warmup + if self.cfg.warmup_updates > 0: + if num_updates <= self.cfg.warmup_updates: + self.lr = self.cfg.warmup_init_lr + num_updates * self.lr_step + self.optimizer.set_lr(self.lr) + else: + if self.warmup_end is False: + self.warmup_end = True + # else do nothing + return self.optimizer.get_lr() diff --git a/fairseq/fairseq/optim/lr_scheduler/step_lr_scheduler.py b/fairseq/fairseq/optim/lr_scheduler/step_lr_scheduler.py new file mode 100644 index 0000000000000000000000000000000000000000..db99d4eee84475d7eaef625e4c72c71de847db42 --- /dev/null +++ b/fairseq/fairseq/optim/lr_scheduler/step_lr_scheduler.py @@ -0,0 +1,85 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from collections.abc import Collection +from dataclasses import dataclass, field +from typing import List + +from omegaconf import II + +from fairseq.dataclass import FairseqDataclass +from fairseq.optim.lr_scheduler import FairseqLRScheduler, register_lr_scheduler + + +@dataclass +class StepLRScheduleConfig(FairseqDataclass): + warmup_updates: int = field( + default=0, + metadata={"help": "warmup the learning rate linearly for the first N updates"}, + ) + warmup_init_lr: float = field( + default=-1, + metadata={ + "help": "initial learning rate during warmup phase; default is cfg.lr" + }, + ) + lr: List[float] = field( + default=II("optimization.lr"), + metadata={"help": "max learning rate, must be more than cfg.min_lr"}, + ) + min_lr: float = field(default=0.0, metadata={"help": "min learning rate"}) + lr_deacy_period: int = field(default=25000, metadata={"help": "decay period"}) + lr_decay: float = field(default=0.5, metadata={"help": "decay factor"}) + + +@register_lr_scheduler("step", dataclass=StepLRScheduleConfig) +class StepLRSchedule(FairseqLRScheduler): + """Decay learning rate every k updates by a fixed factor""" + + def __init__(self, cfg: StepLRScheduleConfig, fairseq_optimizer): + super().__init__(cfg, fairseq_optimizer) + self.max_lr = cfg.lr[0] if isinstance(cfg.lr, Collection) else cfg.lr + self.min_lr = cfg.min_lr + self.lr_deacy_period = cfg.lr_deacy_period + self.lr_decay = cfg.lr_decay + self.warmup_updates = cfg.warmup_updates + self.warmup_init_lr = ( + cfg.warmup_init_lr if cfg.warmup_init_lr >= 0 else self.min_lr + ) + + assert self.lr_deacy_period > 0 + assert self.lr_decay <= 1 + assert self.min_lr >= 0 + assert self.max_lr > self.min_lr + + if cfg.warmup_updates > 0: + # linearly warmup for the first cfg.warmup_updates + self.warmup_lr_step = ( + self.max_lr - self.warmup_init_lr + ) / self.warmup_updates + else: + self.warmup_lr_step = 1 + + # initial learning rate + self.lr = self.warmup_init_lr + self.optimizer.set_lr(self.lr) + + def step(self, epoch, val_loss=None): + """Update the learning rate at the end of the given epoch.""" + super().step(epoch, val_loss) + # we don't change the learning rate at epoch boundaries + return self.optimizer.get_lr() + + def step_update(self, num_updates): + """Update the learning rate after each update.""" + if num_updates < self.cfg.warmup_updates: + self.lr = self.warmup_init_lr + num_updates * self.warmup_lr_step + else: + curr_updates = num_updates - self.cfg.warmup_updates + lr_mult = self.lr_decay ** (curr_updates // self.lr_deacy_period) + self.lr = max(self.max_lr * lr_mult, self.min_lr) + + self.optimizer.set_lr(self.lr) + return self.lr diff --git a/fairseq/fairseq/optim/lr_scheduler/tri_stage_lr_scheduler.py b/fairseq/fairseq/optim/lr_scheduler/tri_stage_lr_scheduler.py new file mode 100644 index 0000000000000000000000000000000000000000..4d5547c39b14f62acbd4f4b9ab3abfb3009c0e6d --- /dev/null +++ b/fairseq/fairseq/optim/lr_scheduler/tri_stage_lr_scheduler.py @@ -0,0 +1,175 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import math +from dataclasses import dataclass, field +from typing import Optional, List, Tuple +from omegaconf import II + +from fairseq.dataclass import FairseqDataclass +from fairseq.optim.lr_scheduler import FairseqLRScheduler, register_lr_scheduler + + +@dataclass +class TriStageLRScheduleConfig(FairseqDataclass): + warmup_steps: int = field( + default=0, + metadata={"help": "warmup the learning rate linearly for the first N updates"}, + ) + hold_steps: int = field( + default=0, + metadata={"help": "steps in hold stage"}, + ) + decay_steps: int = field( + default=0, + metadata={"help": "steps in decay stages"}, + ) + phase_ratio: Optional[Tuple[float, float, float]] = field( + default=None, + metadata={ + "help": ( + "if set, automatically sets warmup/hold/decay steps to the ratio " + "specified here from max_updates. the ratios must add up to 1.0" + ) + }, + ) + init_lr_scale: float = field( + default=0.01, + metadata={"help": "initial learning rate scale during warmup phase"}, + ) + final_lr_scale: float = field( + default=0.01, + metadata={"help": "final learning rate scale"}, + ) + max_update: float = II("optimization.max_update") + lr: List[float] = II("optimization.lr") + + +@register_lr_scheduler("tri_stage", dataclass=TriStageLRScheduleConfig) +class TriStageLRSchedule(FairseqLRScheduler): + """Tristage learning rate schedulr + + Implement the learning rate scheduler in https://arxiv.org/pdf/1904.08779.pdf + + Similar to inverse_squre_root scheduler, but tri_stage learning rate employs + three stages LR scheduling: + + - warmup stage, starting from `lr` * `init_lr_scale`, linearly + increased to `lr` in `warmup_steps` iterations + + - hold stage, after `warmup_steps`, keep the LR as `lr` for `hold_steps` + iterations + + - decay stage, after hold stage, decay LR exponetially to + `lr` * `final_lr_scale` in `decay_steps`; + after that LR is keep as `final_lr_scale` * `lr` + + During warmup:: + + init_lr = cfg.init_lr_scale * cfg.lr + lrs = torch.linspace(init_lr, cfg.lr, cfg.warmup_steps) + lr = lrs[update_num] + + During hold:: + + lr = cfg.lr + + During decay:: + + decay_factor = - math.log(cfg.final_lr_scale) / cfg.decay_steps + lr = cfg.lr * exp(- (update_num - warmup_steps - decay_steps) * decay_factor) + + After that:: + + lr = cfg.lr * cfg.final_lr_scale + """ + + def __init__(self, cfg: TriStageLRScheduleConfig, optimizer): + super().__init__(cfg, optimizer) + if len(cfg.lr) > 1: + raise ValueError( + "Cannot use a fixed learning rate schedule with tri-stage lr." + " Consider --lr-scheduler=fixed instead." + ) + + # calculate LR at each point + self.peak_lr = cfg.lr[0] + self.init_lr = cfg.init_lr_scale * cfg.lr[0] + self.final_lr = cfg.final_lr_scale * cfg.lr[0] + + if cfg.phase_ratio is not None: + assert cfg.max_update > 0 + assert sum(cfg.phase_ratio) == 1, "phase ratios must add up to 1" + self.warmup_steps = int(cfg.max_update * cfg.phase_ratio[0]) + self.hold_steps = int(cfg.max_update * cfg.phase_ratio[1]) + self.decay_steps = int(cfg.max_update * cfg.phase_ratio[2]) + else: + self.warmup_steps = cfg.warmup_steps + self.hold_steps = cfg.hold_steps + self.decay_steps = cfg.decay_steps + + assert ( + self.warmup_steps + self.hold_steps + self.decay_steps > 0 + ), "please specify steps or phase_ratio" + + self.warmup_rate = ( + (self.peak_lr - self.init_lr) / self.warmup_steps + if self.warmup_steps != 0 + else 0 + ) + self.decay_factor = -math.log(cfg.final_lr_scale) / self.decay_steps + + # initial learning rate + self.lr = self.init_lr + self.optimizer.set_lr(self.lr) + + def _decide_stage(self, update_step): + """ + return stage, and the corresponding steps within the current stage + """ + if update_step < self.warmup_steps: + # warmup state + return 0, update_step + + offset = self.warmup_steps + + if update_step < offset + self.hold_steps: + # hold stage + return 1, update_step - offset + + offset += self.hold_steps + + if update_step <= offset + self.decay_steps: + # decay stage + return 2, update_step - offset + + offset += self.decay_steps + + # still here ? constant lr stage + return 3, update_step - offset + + def step(self, epoch, val_loss=None): + """Update the learning rate at the end of the given epoch.""" + super().step(epoch, val_loss) + # we don't change the learning rate at epoch boundaries + return self.optimizer.get_lr() + + def step_update(self, num_updates): + """Update the learning rate after each update.""" + stage, steps_in_stage = self._decide_stage(num_updates) + if stage == 0: + self.lr = self.init_lr + self.warmup_rate * steps_in_stage + elif stage == 1: + self.lr = self.peak_lr + elif stage == 2: + self.lr = self.peak_lr * math.exp(-self.decay_factor * steps_in_stage) + elif stage == 3: + self.lr = self.final_lr + else: + raise ValueError("Undefined stage") + + self.optimizer.set_lr(self.lr) + + return self.lr diff --git a/fairseq/fairseq/optim/lr_scheduler/triangular_lr_scheduler.py b/fairseq/fairseq/optim/lr_scheduler/triangular_lr_scheduler.py new file mode 100644 index 0000000000000000000000000000000000000000..2a32bd10f213e4d8408e529ac20fd4cac7a91704 --- /dev/null +++ b/fairseq/fairseq/optim/lr_scheduler/triangular_lr_scheduler.py @@ -0,0 +1,83 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import math +from dataclasses import dataclass, field +from typing import List + +from omegaconf import II + +from fairseq.dataclass import FairseqDataclass +from fairseq.optim.lr_scheduler import FairseqLRScheduler, register_lr_scheduler + + +@dataclass +class TriangularLRScheduleConfig(FairseqDataclass): + max_lr: float = field( + default="???", metadata={"help": "max learning rate, must be more than cfg.lr"} + ) + lr_period_updates: float = field( + default=5000, + metadata={"help": "initial number of updates per period (cycle length)"}, + ) + lr_shrink: float = field( + default=0.1, metadata={"help": "shrink factor for annealing"} + ) + shrink_min: bool = field( + default=False, metadata={"help": "if set, also shrinks min lr"} + ) + lr: List[float] = II("optimization.lr") + + +@register_lr_scheduler("triangular", dataclass=TriangularLRScheduleConfig) +class TriangularLRSchedule(FairseqLRScheduler): + """Assign LR based on a triangular cyclical schedule. + + See https://arxiv.org/pdf/1506.01186.pdf for details. + """ + + def __init__(self, cfg: TriangularLRScheduleConfig, optimizer): + super().__init__(cfg, optimizer) + if len(cfg.lr) > 1: + raise ValueError( + "Cannot use a fixed learning rate schedule with triangular." + " Consider --lr-scheduler=fixed instead." + ) + + lr = cfg.lr[0] + + assert cfg.max_lr > lr, "max_lr must be more than lr" + self.min_lr = lr + self.max_lr = cfg.max_lr + self.stepsize = cfg.lr_period_updates // 2 + self.lr_shrink = cfg.lr_shrink + self.shrink_min = cfg.shrink_min + + # initial learning rate + self.lr = self.min_lr + self.optimizer.set_lr(self.lr) + + def step(self, epoch, val_loss=None): + """Update the learning rate at the end of the given epoch.""" + super().step(epoch, val_loss) + # we don't change the learning rate at epoch boundaries + return self.optimizer.get_lr() + + def step_update(self, num_updates): + """Update the learning rate after each update.""" + cycle = math.floor(num_updates / (2 * self.stepsize)) + + lr_shrink = self.lr_shrink**cycle + max_lr = self.max_lr * lr_shrink + if self.shrink_min: + min_lr = self.min_lr * lr_shrink + else: + min_lr = self.min_lr + + x = abs(num_updates / self.stepsize - 2 * (cycle + 1) + 1) + self.lr = min_lr + (max_lr - min_lr) * max(0, (1 - x)) + + self.optimizer.set_lr(self.lr) + return self.lr diff --git a/fairseq/fairseq/optim/nag.py b/fairseq/fairseq/optim/nag.py new file mode 100644 index 0000000000000000000000000000000000000000..c30a6c0fb1e8d5dc7edd5b53ba15a6acd46ecbff --- /dev/null +++ b/fairseq/fairseq/optim/nag.py @@ -0,0 +1,111 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from collections.abc import Collection +from dataclasses import dataclass, field +from typing import List + +import torch +from fairseq.dataclass import FairseqDataclass +from omegaconf import II, DictConfig +from torch.optim.optimizer import Optimizer, required + +from . import FairseqOptimizer, register_optimizer + + +@dataclass +class FairseqNAGConfig(FairseqDataclass): + momentum: float = field(default=0.99, metadata={"help": "momentum factor"}) + weight_decay: float = field(default=0.0, metadata={"help": "weight decay"}) + # TODO common vars in parent class + lr: List[float] = II("optimization.lr") + + +@register_optimizer("nag", dataclass=FairseqNAGConfig) +class FairseqNAG(FairseqOptimizer): + def __init__(self, cfg: DictConfig, params): + super().__init__(cfg) + self._optimizer = NAG(params, **self.optimizer_config) + + @property + def optimizer_config(self): + """ + Return a kwarg dictionary that will be used to override optimizer + args stored in checkpoints. This allows us to load a checkpoint and + resume training using a different set of optimizer args, e.g., with a + different learning rate. + """ + return { + "lr": self.cfg.lr[0] + if isinstance(self.cfg.lr, Collection) + else self.cfg.lr, + "momentum": self.cfg.momentum, + "weight_decay": self.cfg.weight_decay, + } + + +class NAG(Optimizer): + def __init__(self, params, lr=required, momentum=0, weight_decay=0): + defaults = dict(lr=lr, lr_old=lr, momentum=momentum, weight_decay=weight_decay) + super(NAG, self).__init__(params, defaults) + + @property + def supports_memory_efficient_fp16(self): + return True + + @property + def supports_flat_params(self): + return True + + def step(self, closure=None): + """Performs a single optimization step. + + Args: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + loss = closure() + + for group in self.param_groups: + weight_decay = group["weight_decay"] + momentum = group["momentum"] + lr = group["lr"] + lr_old = group.get("lr_old", lr) + lr_correct = lr / lr_old if lr_old > 0 else lr + + for p in group["params"]: + if p.grad is None: + continue + + p_data_fp32 = p.data + if p_data_fp32.dtype in {torch.float16, torch.bfloat16}: + p_data_fp32 = p_data_fp32.float() + + d_p = p.grad.data.float() + param_state = self.state[p] + if "momentum_buffer" not in param_state: + param_state["momentum_buffer"] = torch.zeros_like(d_p) + else: + param_state["momentum_buffer"] = param_state["momentum_buffer"].to( + d_p + ) + + buf = param_state["momentum_buffer"] + + if weight_decay != 0: + p_data_fp32.mul_(1 - lr * weight_decay) + p_data_fp32.add_(buf, alpha=momentum * momentum * lr_correct) + p_data_fp32.add_(d_p, alpha=-(1 + momentum) * lr) + + buf.mul_(momentum * lr_correct).add_(d_p, alpha=-lr) + + if p.data.dtype in {torch.float16, torch.bfloat16}: + p.data.copy_(p_data_fp32) + + group["lr_old"] = lr + + return loss diff --git a/fairseq/fairseq/optim/sgd.py b/fairseq/fairseq/optim/sgd.py new file mode 100644 index 0000000000000000000000000000000000000000..8e34fb99a18fff12ab76be5894a84cbbb2f48176 --- /dev/null +++ b/fairseq/fairseq/optim/sgd.py @@ -0,0 +1,43 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch.optim + +from . import LegacyFairseqOptimizer, register_optimizer + + +@register_optimizer("sgd") +class SGD(LegacyFairseqOptimizer): + def __init__(self, args, params): + super().__init__(args) + self._optimizer = torch.optim.SGD(params, **self.optimizer_config) + + @staticmethod + def add_args(parser): + """Add optimizer-specific arguments to the parser.""" + # fmt: off + parser.add_argument('--momentum', default=0.0, type=float, metavar='M', + help='momentum factor') + parser.add_argument('--weight-decay', '--wd', default=0.0, type=float, metavar='WD', + help='weight decay') + # fmt: on + + @property + def optimizer_config(self): + """ + Return a kwarg dictionary that will be used to override optimizer + args stored in checkpoints. This allows us to load a checkpoint and + resume training using a different set of optimizer args, e.g., with a + different learning rate. + """ + return { + "lr": self.args.lr[0], + "momentum": self.args.momentum, + "weight_decay": self.args.weight_decay, + } + + @property + def supports_flat_params(self): + return True diff --git a/fairseq/fairseq/optim/shard.py b/fairseq/fairseq/optim/shard.py new file mode 100644 index 0000000000000000000000000000000000000000..9d7f2eb9e5de6086fe2435d432bde7521ebb8155 --- /dev/null +++ b/fairseq/fairseq/optim/shard.py @@ -0,0 +1,58 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Any, Dict + +from fairseq.distributed import utils + + +try: + from fairscale.optim import OSS + + _has_fairscale = True +except ImportError: + _has_fairscale = False + + +def shard_(optimizer, group): + if not _has_fairscale: + raise ImportError( + "\n\nPlease install the fairscale package:" "\n\n pip install fairscale" + ) + + class FairseqOSS(OSS): + @property + def disable_mem_eff_fp16_loading_hack(self): + return True + + def __getattr__(self, name): + if name.startswith("supports") and hasattr(self.optim, name): + return getattr(self.optim, name) + raise AttributeError( + "'FairseqOSS' object has no attribute {0!r}".format(name) + ) + + def broadcast_global_state_dict( + self, state_dict: Dict[str, Any] + ) -> Dict[str, Any]: + """ + Broadcasts the entire state_dict to all other ranks + each rank is responsible to load their own partition of data + """ + return utils.broadcast_object( + state_dict, + src_rank=0, + group=self.group, + ) + + torch_optimizer = optimizer.optimizer + optim_cls = type(torch_optimizer) + + optimizer.optimizer = FairseqOSS( + torch_optimizer.param_groups, + optim_cls, + group=group, + **optimizer.optimizer_config + ) diff --git a/fairseq/fairseq/scoring/__init__.py b/fairseq/fairseq/scoring/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..58f2f563e493327394dff1265030d18f0814b5a2 --- /dev/null +++ b/fairseq/fairseq/scoring/__init__.py @@ -0,0 +1,55 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + + +import importlib +import os +from abc import ABC, abstractmethod + +from fairseq import registry +from omegaconf import DictConfig + + +class BaseScorer(ABC): + def __init__(self, cfg): + self.cfg = cfg + self.ref = [] + self.pred = [] + + def add_string(self, ref, pred): + self.ref.append(ref) + self.pred.append(pred) + + @abstractmethod + def score(self) -> float: + pass + + @abstractmethod + def result_string(self) -> str: + pass + + +_build_scorer, register_scorer, SCORER_REGISTRY, _ = registry.setup_registry( + "--scoring", default="bleu" +) + + +def build_scorer(choice, tgt_dict): + _choice = choice._name if isinstance(choice, DictConfig) else choice + + if _choice == "bleu": + from fairseq.scoring import bleu + + return bleu.Scorer( + bleu.BleuConfig(pad=tgt_dict.pad(), eos=tgt_dict.eos(), unk=tgt_dict.unk()) + ) + return _build_scorer(choice) + + +# automatically import any Python files in the current directory +for file in sorted(os.listdir(os.path.dirname(__file__))): + if file.endswith(".py") and not file.startswith("_"): + module = file[: file.find(".py")] + importlib.import_module("fairseq.scoring." + module) diff --git a/fairseq/fairseq/scoring/__pycache__/__init__.cpython-310.pyc b/fairseq/fairseq/scoring/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c376f1d4c3ef14c4a8e1b90adbc6f92ebedcf66b Binary files /dev/null and b/fairseq/fairseq/scoring/__pycache__/__init__.cpython-310.pyc differ diff --git a/fairseq/fairseq/scoring/__pycache__/bertscore.cpython-310.pyc b/fairseq/fairseq/scoring/__pycache__/bertscore.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..58402e6ec77e1067b4c9d5065406798ea688a8f5 Binary files /dev/null and b/fairseq/fairseq/scoring/__pycache__/bertscore.cpython-310.pyc differ diff --git a/fairseq/fairseq/scoring/__pycache__/bleu.cpython-310.pyc b/fairseq/fairseq/scoring/__pycache__/bleu.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d282a5640b480073eafaf3cf88caf6d25fd6289f Binary files /dev/null and b/fairseq/fairseq/scoring/__pycache__/bleu.cpython-310.pyc differ diff --git a/fairseq/fairseq/scoring/__pycache__/chrf.cpython-310.pyc b/fairseq/fairseq/scoring/__pycache__/chrf.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0401443763906215bf1f582be40513918382e5dc Binary files /dev/null and b/fairseq/fairseq/scoring/__pycache__/chrf.cpython-310.pyc differ diff --git a/fairseq/fairseq/scoring/__pycache__/meteor.cpython-310.pyc b/fairseq/fairseq/scoring/__pycache__/meteor.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..74b6e0bed2a46d025ac22e5e90259b22b95d3089 Binary files /dev/null and b/fairseq/fairseq/scoring/__pycache__/meteor.cpython-310.pyc differ diff --git a/fairseq/fairseq/scoring/__pycache__/tokenizer.cpython-310.pyc b/fairseq/fairseq/scoring/__pycache__/tokenizer.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..962197de7746bb9aa8f3080af7b19c9bf98de412 Binary files /dev/null and b/fairseq/fairseq/scoring/__pycache__/tokenizer.cpython-310.pyc differ diff --git a/fairseq/fairseq/scoring/__pycache__/wer.cpython-310.pyc b/fairseq/fairseq/scoring/__pycache__/wer.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..58ccf140e6f76841a9b7ef9355dbe7fb0f298e3b Binary files /dev/null and b/fairseq/fairseq/scoring/__pycache__/wer.cpython-310.pyc differ diff --git a/fairseq/fairseq/scoring/bertscore.py b/fairseq/fairseq/scoring/bertscore.py new file mode 100644 index 0000000000000000000000000000000000000000..6d5a8450d34a7f76ba08748f61f5d84fcf5570ee --- /dev/null +++ b/fairseq/fairseq/scoring/bertscore.py @@ -0,0 +1,44 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from dataclasses import dataclass, field + +import numpy as np + +from fairseq.dataclass import FairseqDataclass +from fairseq.scoring import BaseScorer, register_scorer + + +@dataclass +class BertScoreScorerConfig(FairseqDataclass): + bert_score_lang: str = field(default="en", metadata={"help": "BERTScore language"}) + + +@register_scorer("bert_score", dataclass=BertScoreScorerConfig) +class BertScoreScorer(BaseScorer): + def __init__(self, cfg): + super(BertScoreScorer, self).__init__(cfg) + try: + import bert_score as _bert_score + except ImportError: + raise ImportError("Please install BERTScore: pip install bert-score") + + self.cfg = cfg + self._bert_score = _bert_score + self.scores = None + + def add_string(self, ref, pred): + self.ref.append(ref) + self.pred.append(pred) + + def score(self, order=4): + _, _, self.scores = self._bert_score.score( + self.pred, self.ref, lang=self.cfg.bert_score_lang + ) + self.scores = self.scores.numpy() + return np.mean(self.scores) + + def result_string(self, order=4): + return f"BERTScore: {self.score():.4f}" diff --git a/fairseq/fairseq/scoring/bleu.py b/fairseq/fairseq/scoring/bleu.py new file mode 100644 index 0000000000000000000000000000000000000000..e55bd2f393e279d318e08d680c41cc20e8d49931 --- /dev/null +++ b/fairseq/fairseq/scoring/bleu.py @@ -0,0 +1,168 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import ctypes +import math +import sys +from dataclasses import dataclass, field + +import torch +from fairseq.dataclass import FairseqDataclass +from fairseq.scoring import BaseScorer, register_scorer +from fairseq.scoring.tokenizer import EvaluationTokenizer + + +class BleuStat(ctypes.Structure): + _fields_ = [ + ("reflen", ctypes.c_size_t), + ("predlen", ctypes.c_size_t), + ("match1", ctypes.c_size_t), + ("count1", ctypes.c_size_t), + ("match2", ctypes.c_size_t), + ("count2", ctypes.c_size_t), + ("match3", ctypes.c_size_t), + ("count3", ctypes.c_size_t), + ("match4", ctypes.c_size_t), + ("count4", ctypes.c_size_t), + ] + + +@dataclass +class SacrebleuConfig(FairseqDataclass): + sacrebleu_tokenizer: EvaluationTokenizer.ALL_TOKENIZER_TYPES = field( + default="13a", metadata={"help": "tokenizer"} + ) + sacrebleu_lowercase: bool = field( + default=False, metadata={"help": "apply lowercasing"} + ) + sacrebleu_char_level: bool = field( + default=False, metadata={"help": "evaluate at character level"} + ) + + +@register_scorer("sacrebleu", dataclass=SacrebleuConfig) +class SacrebleuScorer(BaseScorer): + def __init__(self, cfg): + super(SacrebleuScorer, self).__init__(cfg) + import sacrebleu + + self.sacrebleu = sacrebleu + self.tokenizer = EvaluationTokenizer( + tokenizer_type=cfg.sacrebleu_tokenizer, + lowercase=cfg.sacrebleu_lowercase, + character_tokenization=cfg.sacrebleu_char_level, + ) + + def add_string(self, ref, pred): + self.ref.append(self.tokenizer.tokenize(ref)) + self.pred.append(self.tokenizer.tokenize(pred)) + + def _score(self, order=4): + if order != 4: + raise NotImplementedError + # tokenization and lowercasing are performed by self.tokenizer instead. + return self.sacrebleu.corpus_bleu(self.pred, [self.ref], tokenize="none") + + def score(self, order=4): + return self._score(order).score + + def result_string(self, order=4): + return self._score(order).format() + + +@dataclass +class BleuConfig(FairseqDataclass): + pad: int = field(default=1, metadata={"help": "padding index"}) + eos: int = field(default=2, metadata={"help": "eos index"}) + unk: int = field(default=3, metadata={"help": "unk index"}) + + +@register_scorer("bleu", dataclass=BleuConfig) +class Scorer(object): + def __init__(self, cfg): + self.stat = BleuStat() + self.pad = cfg.pad + self.eos = cfg.eos + self.unk = cfg.unk + + try: + from fairseq import libbleu + except ImportError as e: + sys.stderr.write( + "ERROR: missing libbleu.so. run `pip install --editable .`\n" + ) + raise e + + self.C = ctypes.cdll.LoadLibrary(libbleu.__file__) + + self.reset() + + def reset(self, one_init=False): + if one_init: + self.C.bleu_one_init(ctypes.byref(self.stat)) + else: + self.C.bleu_zero_init(ctypes.byref(self.stat)) + + def add(self, ref, pred): + if not isinstance(ref, torch.IntTensor): + raise TypeError("ref must be a torch.IntTensor (got {})".format(type(ref))) + if not isinstance(pred, torch.IntTensor): + raise TypeError("pred must be a torch.IntTensor(got {})".format(type(pred))) + + # don't match unknown words + rref = ref.clone() + assert not rref.lt(0).any() + rref[rref.eq(self.unk)] = -999 + + rref = rref.contiguous().view(-1) + pred = pred.contiguous().view(-1) + + self.C.bleu_add( + ctypes.byref(self.stat), + ctypes.c_size_t(rref.size(0)), + ctypes.c_void_p(rref.data_ptr()), + ctypes.c_size_t(pred.size(0)), + ctypes.c_void_p(pred.data_ptr()), + ctypes.c_int(self.pad), + ctypes.c_int(self.eos), + ) + + def score(self, order=4): + psum = sum( + math.log(p) if p > 0 else float("-Inf") for p in self.precision()[:order] + ) + return self.brevity() * math.exp(psum / order) * 100 + + def precision(self): + def ratio(a, b): + return a / b if b > 0 else 0 + + return [ + ratio(self.stat.match1, self.stat.count1), + ratio(self.stat.match2, self.stat.count2), + ratio(self.stat.match3, self.stat.count3), + ratio(self.stat.match4, self.stat.count4), + ] + + def brevity(self): + r = self.stat.reflen / self.stat.predlen + return min(1, math.exp(1 - r)) + + def result_string(self, order=4): + assert order <= 4, "BLEU scores for order > 4 aren't supported" + fmt = "BLEU{} = {:2.2f}, {:2.1f}" + for _ in range(1, order): + fmt += "/{:2.1f}" + fmt += " (BP={:.3f}, ratio={:.3f}, syslen={}, reflen={})" + bleup = [p * 100 for p in self.precision()[:order]] + return fmt.format( + order, + self.score(order=order), + *bleup, + self.brevity(), + self.stat.predlen / self.stat.reflen, + self.stat.predlen, + self.stat.reflen + ) diff --git a/fairseq/fairseq/scoring/chrf.py b/fairseq/fairseq/scoring/chrf.py new file mode 100644 index 0000000000000000000000000000000000000000..5df5a1c011243fe2e836c38a5f8459aeb824f0e7 --- /dev/null +++ b/fairseq/fairseq/scoring/chrf.py @@ -0,0 +1,36 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + + +from dataclasses import dataclass + +from fairseq.dataclass import FairseqDataclass +from fairseq.scoring import BaseScorer, register_scorer + + +@dataclass +class ChrFScorerConfig(FairseqDataclass): + pass + + +@register_scorer("chrf", dataclass=ChrFScorerConfig) +class ChrFScorer(BaseScorer): + def __init__(self, args): + super(ChrFScorer, self).__init__(args) + import sacrebleu + + self.sacrebleu = sacrebleu + + def add_string(self, ref, pred): + self.ref.append(ref) + self.pred.append(pred) + + def score(self, order=4): + return self.result_string(order).score + + def result_string(self, order=4): + if order != 4: + raise NotImplementedError + return self.sacrebleu.corpus_chrf(self.pred, [self.ref]).format() diff --git a/fairseq/fairseq/scoring/meteor.py b/fairseq/fairseq/scoring/meteor.py new file mode 100644 index 0000000000000000000000000000000000000000..32719956fe32cc25ab071aa7ff6abf84c6533fe8 --- /dev/null +++ b/fairseq/fairseq/scoring/meteor.py @@ -0,0 +1,42 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import numpy as np +from dataclasses import dataclass + +from fairseq.dataclass import FairseqDataclass +from fairseq.scoring import BaseScorer, register_scorer + + +@dataclass +class MeteorScorerConfig(FairseqDataclass): + pass + + +@register_scorer("meteor", dataclass=MeteorScorerConfig) +class MeteorScorer(BaseScorer): + def __init__(self, args): + super(MeteorScorer, self).__init__(args) + try: + import nltk + except ImportError: + raise ImportError("Please install nltk to use METEOR scorer") + + self.nltk = nltk + self.scores = [] + + def add_string(self, ref, pred): + self.ref.append(ref) + self.pred.append(pred) + + def score(self, order=4): + self.scores = [ + self.nltk.translate.meteor_score.single_meteor_score(r, p) + for r, p in zip(self.ref, self.pred) + ] + return np.mean(self.scores) + + def result_string(self, order=4): + return f"METEOR: {self.score():.4f}" diff --git a/fairseq/fairseq/scoring/tokenizer.py b/fairseq/fairseq/scoring/tokenizer.py new file mode 100644 index 0000000000000000000000000000000000000000..b0cedd5099b9ddc7278205e1a4f2aa18359a14bf --- /dev/null +++ b/fairseq/fairseq/scoring/tokenizer.py @@ -0,0 +1,80 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import unicodedata + +import sacrebleu as sb + +from fairseq.dataclass import ChoiceEnum + +SACREBLEU_V2_ABOVE = int(sb.__version__[0]) >= 2 + + +class EvaluationTokenizer(object): + """A generic evaluation-time tokenizer, which leverages built-in tokenizers + in sacreBLEU (https://github.com/mjpost/sacrebleu). It additionally provides + lowercasing, punctuation removal and character tokenization, which are + applied after sacreBLEU tokenization. + + Args: + tokenizer_type (str): the type of sacreBLEU tokenizer to apply. + lowercase (bool): lowercase the text. + punctuation_removal (bool): remove punctuation (based on unicode + category) from text. + character_tokenization (bool): tokenize the text to characters. + """ + + SPACE = chr(32) + SPACE_ESCAPE = chr(9601) + _ALL_TOKENIZER_TYPES = ( + sb.BLEU.TOKENIZERS + if SACREBLEU_V2_ABOVE + else ["none", "13a", "intl", "zh", "ja-mecab"] + ) + ALL_TOKENIZER_TYPES = ChoiceEnum(_ALL_TOKENIZER_TYPES) + + def __init__( + self, + tokenizer_type: str = "13a", + lowercase: bool = False, + punctuation_removal: bool = False, + character_tokenization: bool = False, + ): + + assert ( + tokenizer_type in self._ALL_TOKENIZER_TYPES + ), f"{tokenizer_type}, {self._ALL_TOKENIZER_TYPES}" + self.lowercase = lowercase + self.punctuation_removal = punctuation_removal + self.character_tokenization = character_tokenization + if SACREBLEU_V2_ABOVE: + self.tokenizer = sb.BLEU(tokenize=str(tokenizer_type)).tokenizer + else: + self.tokenizer = sb.tokenizers.TOKENIZERS[tokenizer_type]() + + @classmethod + def remove_punctuation(cls, sent: str): + """Remove punctuation based on Unicode category.""" + return cls.SPACE.join( + t + for t in sent.split(cls.SPACE) + if not all(unicodedata.category(c)[0] == "P" for c in t) + ) + + def tokenize(self, sent: str): + tokenized = self.tokenizer(sent) + + if self.punctuation_removal: + tokenized = self.remove_punctuation(tokenized) + + if self.character_tokenization: + tokenized = self.SPACE.join( + list(tokenized.replace(self.SPACE, self.SPACE_ESCAPE)) + ) + + if self.lowercase: + tokenized = tokenized.lower() + + return tokenized diff --git a/fairseq/fairseq/scoring/wer.py b/fairseq/fairseq/scoring/wer.py new file mode 100644 index 0000000000000000000000000000000000000000..633dc47c247691c4c9e36cbdbab7d7cb74b38452 --- /dev/null +++ b/fairseq/fairseq/scoring/wer.py @@ -0,0 +1,58 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from dataclasses import dataclass, field + +from fairseq.dataclass import FairseqDataclass +from fairseq.scoring import BaseScorer, register_scorer +from fairseq.scoring.tokenizer import EvaluationTokenizer + + +@dataclass +class WerScorerConfig(FairseqDataclass): + wer_tokenizer: EvaluationTokenizer.ALL_TOKENIZER_TYPES = field( + default="none", metadata={"help": "sacreBLEU tokenizer to use for evaluation"} + ) + wer_remove_punct: bool = field( + default=False, metadata={"help": "remove punctuation"} + ) + wer_char_level: bool = field( + default=False, metadata={"help": "evaluate at character level"} + ) + wer_lowercase: bool = field(default=False, metadata={"help": "lowercasing"}) + + +@register_scorer("wer", dataclass=WerScorerConfig) +class WerScorer(BaseScorer): + def __init__(self, cfg): + super().__init__(cfg) + self.reset() + try: + import editdistance as ed + except ImportError: + raise ImportError("Please install editdistance to use WER scorer") + self.ed = ed + self.tokenizer = EvaluationTokenizer( + tokenizer_type=self.cfg.wer_tokenizer, + lowercase=self.cfg.wer_lowercase, + punctuation_removal=self.cfg.wer_remove_punct, + character_tokenization=self.cfg.wer_char_level, + ) + + def reset(self): + self.distance = 0 + self.ref_length = 0 + + def add_string(self, ref, pred): + ref_items = self.tokenizer.tokenize(ref).split() + pred_items = self.tokenizer.tokenize(pred).split() + self.distance += self.ed.eval(ref_items, pred_items) + self.ref_length += len(ref_items) + + def result_string(self): + return f"WER: {self.score():.2f}" + + def score(self): + return 100.0 * self.distance / self.ref_length if self.ref_length > 0 else 0 diff --git a/fairseq/fairseq/tasks/__init__.py b/fairseq/fairseq/tasks/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6da1f001f095ad60557440ae39068f469b08ce1e --- /dev/null +++ b/fairseq/fairseq/tasks/__init__.py @@ -0,0 +1,138 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +"""isort:skip_file""" + +import argparse +import importlib +import os + +from fairseq.dataclass import FairseqDataclass +from fairseq.dataclass.utils import merge_with_parent +from hydra.core.config_store import ConfigStore + +from .fairseq_task import FairseqTask, LegacyFairseqTask # noqa + + +# register dataclass +TASK_DATACLASS_REGISTRY = {} +TASK_REGISTRY = {} +TASK_CLASS_NAMES = set() + + +def setup_task(cfg: FairseqDataclass, **kwargs): + task = None + task_name = getattr(cfg, "task", None) + + if isinstance(task_name, str): + # legacy tasks + task = TASK_REGISTRY[task_name] + if task_name in TASK_DATACLASS_REGISTRY: + dc = TASK_DATACLASS_REGISTRY[task_name] + cfg = dc.from_namespace(cfg) + else: + task_name = getattr(cfg, "_name", None) + + if task_name and task_name in TASK_DATACLASS_REGISTRY: + remove_missing = "from_checkpoint" in kwargs and kwargs["from_checkpoint"] + dc = TASK_DATACLASS_REGISTRY[task_name] + cfg = merge_with_parent(dc(), cfg, remove_missing=remove_missing) + task = TASK_REGISTRY[task_name] + + assert ( + task is not None + ), f"Could not infer task type from {cfg}. Available argparse tasks: {TASK_REGISTRY.keys()}. Available hydra tasks: {TASK_DATACLASS_REGISTRY.keys()}" + + return task.setup_task(cfg, **kwargs) + + +def register_task(name, dataclass=None): + """ + New tasks can be added to fairseq with the + :func:`~fairseq.tasks.register_task` function decorator. + + For example:: + + @register_task('classification') + class ClassificationTask(FairseqTask): + (...) + + .. note:: + + All Tasks must implement the :class:`~fairseq.tasks.FairseqTask` + interface. + + Args: + name (str): the name of the task + """ + + def register_task_cls(cls): + if name in TASK_REGISTRY: + return TASK_REGISTRY[name] + + if not issubclass(cls, FairseqTask): + raise ValueError( + "Task ({}: {}) must extend FairseqTask".format(name, cls.__name__) + ) + if cls.__name__ in TASK_CLASS_NAMES: + raise ValueError( + "Cannot register task with duplicate class name ({})".format( + cls.__name__ + ) + ) + TASK_REGISTRY[name] = cls + TASK_CLASS_NAMES.add(cls.__name__) + + if dataclass is not None and not issubclass(dataclass, FairseqDataclass): + raise ValueError( + "Dataclass {} must extend FairseqDataclass".format(dataclass) + ) + + cls.__dataclass = dataclass + if dataclass is not None: + TASK_DATACLASS_REGISTRY[name] = dataclass + + cs = ConfigStore.instance() + node = dataclass() + node._name = name + cs.store(name=name, group="task", node=node, provider="fairseq") + + return cls + + return register_task_cls + + +def get_task(name): + return TASK_REGISTRY[name] + + +def import_tasks(tasks_dir, namespace): + for file in os.listdir(tasks_dir): + path = os.path.join(tasks_dir, file) + if ( + not file.startswith("_") + and not file.startswith(".") + and (file.endswith(".py") or os.path.isdir(path)) + ): + task_name = file[: file.find(".py")] if file.endswith(".py") else file + importlib.import_module(namespace + "." + task_name) + + # expose `task_parser` for sphinx + if task_name in TASK_REGISTRY: + parser = argparse.ArgumentParser(add_help=False) + group_task = parser.add_argument_group("Task name") + # fmt: off + group_task.add_argument('--task', metavar=task_name, + help='Enable this task with: ``--task=' + task_name + '``') + # fmt: on + group_args = parser.add_argument_group( + "Additional command-line arguments" + ) + TASK_REGISTRY[task_name].add_args(group_args) + globals()[task_name + "_parser"] = parser + + +# automatically import any Python files in the tasks/ directory +tasks_dir = os.path.dirname(__file__) +import_tasks(tasks_dir, "fairseq.tasks") diff --git a/fairseq/fairseq/tasks/audio_classification.py b/fairseq/fairseq/tasks/audio_classification.py new file mode 100644 index 0000000000000000000000000000000000000000..4c21d23b696f7bf358145db32943074cf1b01626 --- /dev/null +++ b/fairseq/fairseq/tasks/audio_classification.py @@ -0,0 +1,269 @@ +# Copyright (c) 2017-present, Facebook, Inc. +# All rights reserved. +# +# This source code is licensed under the license found in the LICENSE file in +# the root directory of this source tree. An additional grant of patent rights +# can be found in the PATENTS file in the same directory. + +from collections import OrderedDict +import itertools +import logging +import os +import sys +from dataclasses import dataclass, field +from typing import Optional + +import numpy as np +import torch +from omegaconf import II, MISSING +from sklearn import metrics as sklearn_metrics + +from fairseq.data import AddTargetDataset, Dictionary, FileAudioDataset +from fairseq.data.multi_corpus_dataset import MultiCorpusDataset +from fairseq.data.text_compressor import TextCompressionLevel, TextCompressor +from fairseq.dataclass import FairseqDataclass +from fairseq.tasks.audio_pretraining import AudioPretrainingConfig, AudioPretrainingTask +from fairseq.tasks.audio_finetuning import label_len_fn, LabelEncoder + +from .. import utils +from ..logging import metrics +from . import FairseqTask, register_task + +logger = logging.getLogger(__name__) + +@dataclass +class AudioClassificationConfig(AudioPretrainingConfig): + target_dictionary: Optional[str] = field( + default=None, metadata={"help": "override default dictionary location"} + ) + + +@register_task("audio_classification", dataclass=AudioClassificationConfig) +class AudioClassificationTask(AudioPretrainingTask): + """Task for audio classification tasks.""" + + cfg: AudioClassificationConfig + + def __init__( + self, + cfg: AudioClassificationConfig, + ): + super().__init__(cfg) + self.state.add_factory("target_dictionary", self.load_target_dictionary) + logging.info(f"=== Number of labels = {len(self.target_dictionary)}") + + def load_target_dictionary(self): + if self.cfg.labels: + target_dictionary = self.cfg.data + if self.cfg.target_dictionary: # override dict + target_dictionary = self.cfg.target_dictionary + dict_path = os.path.join(target_dictionary, f"dict.{self.cfg.labels}.txt") + logger.info("Using dict_path : {}".format(dict_path)) + return Dictionary.load(dict_path, add_special_symbols=False) + return None + + def load_dataset( + self, split: str, task_cfg: AudioClassificationConfig = None, **kwargs + ): + super().load_dataset(split, task_cfg, **kwargs) + task_cfg = task_cfg or self.cfg + assert task_cfg.labels is not None + text_compression_level = getattr( + TextCompressionLevel, str(self.cfg.text_compression_level) + ) + data_path = self.cfg.data + if task_cfg.multi_corpus_keys is None: + label_path = os.path.join(data_path, f"{split}.{task_cfg.labels}") + skipped_indices = getattr(self.datasets[split], "skipped_indices", set()) + text_compressor = TextCompressor(level=text_compression_level) + with open(label_path, "r") as f: + labels = [ + text_compressor.compress(l) + for i, l in enumerate(f) + if i not in skipped_indices + ] + + assert len(labels) == len(self.datasets[split]), ( + f"labels length ({len(labels)}) and dataset length " + f"({len(self.datasets[split])}) do not match" + ) + + process_label = LabelEncoder(self.target_dictionary) + + self.datasets[split] = AddTargetDataset( + self.datasets[split], + labels, + pad=self.target_dictionary.pad(), + eos=self.target_dictionary.eos(), + batch_targets=True, + process_label=process_label, + label_len_fn=label_len_fn, + add_to_input=False, + # text_compression_level=text_compression_level, + ) + else: + target_dataset_map = OrderedDict() + + multi_corpus_keys = [ + k.strip() for k in task_cfg.multi_corpus_keys.split(",") + ] + corpus_idx_map = {k: idx for idx, k in enumerate(multi_corpus_keys)} + + data_keys = [k.split(":") for k in split.split(",")] + + multi_corpus_sampling_weights = [ + float(val.strip()) + for val in task_cfg.multi_corpus_sampling_weights.split(",") + ] + data_weights = [] + for key, file_name in data_keys: + k = key.strip() + label_path = os.path.join( + data_path, f"{file_name.strip()}.{task_cfg.labels}" + ) + skipped_indices = getattr( + self.dataset_map[split][k], "skipped_indices", set() + ) + text_compressor = TextCompressor(level=text_compression_level) + with open(label_path, "r") as f: + labels = [ + text_compressor.compress(l) + for i, l in enumerate(f) + if i not in skipped_indices + ] + + assert len(labels) == len(self.dataset_map[split][k]), ( + f"labels length ({len(labels)}) and dataset length " + f"({len(self.dataset_map[split][k])}) do not match" + ) + + process_label = LabelEncoder(self.target_dictionary) + + # TODO: Remove duplication of code from the if block above + target_dataset_map[k] = AddTargetDataset( + self.dataset_map[split][k], + labels, + pad=self.target_dictionary.pad(), + eos=self.target_dictionary.eos(), + batch_targets=True, + process_label=process_label, + label_len_fn=label_len_fn, + add_to_input=False, + # text_compression_level=text_compression_level, + ) + + data_weights.append(multi_corpus_sampling_weights[corpus_idx_map[k]]) + + if len(target_dataset_map) == 1: + self.datasets[split] = list(target_dataset_map.values())[0] + else: + self.datasets[split] = MultiCorpusDataset( + target_dataset_map, + distribution=data_weights, + seed=0, + sort_indices=True, + ) + + @property + def source_dictionary(self): + return None + + @property + def target_dictionary(self): + """Return the :class:`~fairseq.data.Dictionary` for the language + model.""" + return self.state.target_dictionary + + def train_step(self, sample, model, *args, **kwargs): + sample["target"] = sample["target"].to(dtype=torch.long) + loss, sample_size, logging_output = super().train_step( + sample, model, *args, **kwargs + ) + self._log_metrics(sample, model, logging_output) + return loss, sample_size, logging_output + + def valid_step(self, sample, model, criterion): + sample["target"] = sample["target"].to(dtype=torch.long) + loss, sample_size, logging_output = super().valid_step(sample, model, criterion) + self._log_metrics(sample, model, logging_output) + return loss, sample_size, logging_output + + def _log_metrics(self, sample, model, logging_output): + metrics = self._inference_with_metrics( + sample, + model, + ) + """ + logging_output["_precision"] = metrics["precision"] + logging_output["_recall"] = metrics["recall"] + logging_output["_f1"] = metrics["f1"] + logging_output["_eer"] = metrics["eer"] + logging_output["_accuracy"] = metrics["accuracy"] + """ + logging_output["_correct"] = metrics["correct"] + logging_output["_total"] = metrics["total"] + + def _inference_with_metrics(self, sample, model): + def _compute_eer(target_list, lprobs): + # from scipy.optimize import brentq + # from scipy.interpolate import interp1d + + y_one_hot = np.eye(len(self.state.target_dictionary))[target_list] + fpr, tpr, thresholds = sklearn_metrics.roc_curve( + y_one_hot.ravel(), lprobs.ravel() + ) + # Revisit the interpolation approach. + # eer = brentq(lambda x: 1.0 - x - interp1d(fpr, tpr)(x), 0.0, 1.0) + + fnr = 1 - tpr + eer = fpr[np.nanargmin(np.absolute((fnr - fpr)))] + + return eer + + with torch.no_grad(): + net_output = model(**sample["net_input"]) + lprobs = ( + model.get_normalized_probs(net_output, log_probs=True).cpu().detach() + ) + target_list = sample["target"][:, 0].detach().cpu() + predicted_list = torch.argmax(lprobs, 1).detach().cpu() # B,C->B + + metrics = { + "correct": torch.sum(target_list == predicted_list).item(), + "total": len(target_list), + } + return metrics + + def reduce_metrics(self, logging_outputs, criterion): + super().reduce_metrics(logging_outputs, criterion) + + zero = torch.scalar_tensor(0.0) + correct, total = 0, 0 + for log in logging_outputs: + correct += log.get("_correct", zero) + total += log.get("_total", zero) + metrics.log_scalar("_correct", correct) + metrics.log_scalar("_total", total) + + if total > 0: + def _fn_accuracy(meters): + if meters["_total"].sum > 0: + return utils.item(meters["_correct"].sum / meters["_total"].sum) + return float("nan") + + metrics.log_derived("accuracy", _fn_accuracy) + """ + prec_sum, recall_sum, f1_sum, acc_sum, eer_sum = 0.0, 0.0, 0.0, 0.0, 0.0 + for log in logging_outputs: + prec_sum += log.get("_precision", zero).item() + recall_sum += log.get("_recall", zero).item() + f1_sum += log.get("_f1", zero).item() + acc_sum += log.get("_accuracy", zero).item() + eer_sum += log.get("_eer", zero).item() + + metrics.log_scalar("avg_precision", prec_sum / len(logging_outputs)) + metrics.log_scalar("avg_recall", recall_sum / len(logging_outputs)) + metrics.log_scalar("avg_f1", f1_sum / len(logging_outputs)) + metrics.log_scalar("avg_accuracy", acc_sum / len(logging_outputs)) + metrics.log_scalar("avg_eer", eer_sum / len(logging_outputs)) + """ \ No newline at end of file diff --git a/fairseq/fairseq/tasks/audio_finetuning.py b/fairseq/fairseq/tasks/audio_finetuning.py new file mode 100644 index 0000000000000000000000000000000000000000..d79553cb86e830b12ed1b1affcb537cc4aec3630 --- /dev/null +++ b/fairseq/fairseq/tasks/audio_finetuning.py @@ -0,0 +1,404 @@ +# Copyright (c) 2017-present, Facebook, Inc. +# All rights reserved. +# +# This source code is licensed under the license found in the LICENSE file in +# the root directory of this source tree. An additional grant of patent rights +# can be found in the PATENTS file in the same directory. + +import logging +import os +from fairseq.data.multi_corpus_dataset import MultiCorpusDataset +import torch +import json + +from argparse import Namespace +from dataclasses import dataclass, field +from typing import Optional, Any, OrderedDict + +from fairseq.data import AddTargetDataset, Dictionary, encoders +from fairseq.tasks.audio_pretraining import AudioPretrainingTask, AudioPretrainingConfig +from fairseq.dataclass import FairseqDataclass +from fairseq.dataclass.configs import GenerationConfig +from fairseq.data.text_compressor import TextCompressor, TextCompressionLevel + +from . import register_task +from .. import utils +from ..logging import metrics + + +logger = logging.getLogger(__name__) + + +class LabelEncoder(object): + def __init__(self, dictionary): + self.dictionary = dictionary + + def __call__(self, label): + return self.dictionary.encode_line( + label, append_eos=False, add_if_not_exist=False + ) + + +def label_len_fn(label): + return len(label.split(" ")) + + +@dataclass +class AudioFinetuningConfig(AudioPretrainingConfig): + # Options for reporting WER metrics during validation. Only applicable to + # Seq2Seq models during fine-tuning + eval_wer: bool = field( + default=False, metadata={"help": "compute WER for Seq2Seq models"} + ) + eval_wer_config: GenerationConfig = field( + default_factory=lambda: GenerationConfig(), + metadata={"help": "beam search config for evaluating wer during training"}, + ) + eval_wer_tokenizer: Any = field( + default=None, + metadata={"help": "tokenizer config for evaluating wer during training"}, + ) + eval_wer_post_process: str = field( + default="letter", + metadata={ + "help": "remove BPE tokens before scoring (can be sentencepiece, letter, and more)" + }, + ) + eval_bleu: bool = field( + default=False, metadata={"help": "evaluation with BLEU scores"} + ) + eval_bleu_detok: Optional[str] = field( + default=None, + metadata={ + "help": "detokenize before computing BLEU (e.g., 'moses'); " + "required if using --eval-bleu; use 'space' to disable " + "detokenization; see fairseq.data.encoders for other options" + }, + ) + eval_bleu_detok_args: str = field( + default="{}", metadata={"help": "args for building the tokenizer, if needed"} + ) + eval_tokenized_bleu: bool = field( + default=False, metadata={"help": "compute tokenized BLEU instead of sacrebleu"} + ) + eval_bleu_remove_bpe: Optional[str] = field( + default=None, metadata={"help": "remove BPE before computing BLEU"} + ) + eval_bleu_args: str = field( + default="{}", + metadata={ + "help": "generation args for BLUE scoring, e.g., " + '\'{"beam": 4, "lenpen": 0.6}\'' + }, + ) + eval_bleu_print_samples: bool = field( + default=False, metadata={"help": "print sample generations during validation"} + ) + autoregressive: bool = field( + default=False, + metadata={ + "help": "required for autoregressive decoders (like seq2seq models); " + "adds 'prev_output_tokens' to input and appends eos to target" + }, + ) + rebuild_batches: bool = True + target_dictionary: Optional[str] = field( + default=None, + metadata={ + "help": "override default dictionary location" + } + ) + +@register_task("audio_finetuning", dataclass=AudioFinetuningConfig) +class AudioFinetuningTask(AudioPretrainingTask): + """ """ + + cfg: AudioFinetuningConfig + + def __init__( + self, + cfg: AudioFinetuningConfig, + ): + super().__init__(cfg) + self.blank_symbol = "" + + self.state.add_factory("target_dictionary", self.load_target_dictionary) + + def load_target_dictionary(self): + if self.cfg.labels: + target_dictionary = self.cfg.data + if self.cfg.target_dictionary: # override dict + target_dictionary = self.cfg.target_dictionary + dict_path = os.path.join(target_dictionary, f"dict.{self.cfg.labels}.txt") + logger.info('Using dict_path : {}'.format(dict_path)) + return Dictionary.load(dict_path) + return None + + def load_dataset( + self, split: str, task_cfg: AudioFinetuningConfig = None, **kwargs + ): + super().load_dataset(split, task_cfg, **kwargs) + + task_cfg = task_cfg or self.cfg + assert task_cfg.labels is not None + text_compression_level = getattr( + TextCompressionLevel, str(self.cfg.text_compression_level) + ) + data_path = self.cfg.data + if task_cfg.multi_corpus_keys is None: + label_path = os.path.join(data_path, f"{split}.{task_cfg.labels}") + skipped_indices = getattr(self.datasets[split], "skipped_indices", set()) + text_compressor = TextCompressor(level=text_compression_level) + with open(label_path, "r") as f: + labels = [ + text_compressor.compress(l) + for i, l in enumerate(f) + if i not in skipped_indices + ] + + assert len(labels) == len(self.datasets[split]), ( + f"labels length ({len(labels)}) and dataset length " + f"({len(self.datasets[split])}) do not match" + ) + + process_label = LabelEncoder(self.target_dictionary) + + self.datasets[split] = AddTargetDataset( + self.datasets[split], + labels, + pad=self.target_dictionary.pad(), + eos=self.target_dictionary.eos(), + batch_targets=True, + process_label=process_label, + label_len_fn=label_len_fn, + add_to_input=task_cfg.get("autoregressive", False), + text_compression_level=text_compression_level, + ) + else: + + target_dataset_map = OrderedDict() + + multi_corpus_keys = [k.strip() for k in task_cfg.multi_corpus_keys.split(",")] + corpus_idx_map = {k: idx for idx, k in enumerate(multi_corpus_keys)} + + data_keys = [k.split(":") for k in split.split(",")] + + multi_corpus_sampling_weights = [float(val.strip()) for val in task_cfg.multi_corpus_sampling_weights.split(",")] + data_weights = [] + for key, file_name in data_keys: + k = key.strip() + label_path = os.path.join(data_path, f"{file_name.strip()}.{task_cfg.labels}") + skipped_indices = getattr(self.dataset_map[split][k], "skipped_indices", set()) + text_compressor = TextCompressor(level=text_compression_level) + with open(label_path, "r") as f: + labels = [ + text_compressor.compress(l) + for i, l in enumerate(f) + if i not in skipped_indices + ] + + assert len(labels) == len(self.dataset_map[split][k]), ( + f"labels length ({len(labels)}) and dataset length " + f"({len(self.dataset_map[split][k])}) do not match" + ) + + process_label = LabelEncoder(self.target_dictionary) + + # TODO: Remove duplication of code from the if block above + target_dataset_map[k] = AddTargetDataset( + self.dataset_map[split][k], + labels, + pad=self.target_dictionary.pad(), + eos=self.target_dictionary.eos(), + batch_targets=True, + process_label=process_label, + label_len_fn=label_len_fn, + add_to_input=task_cfg.get("autoregressive", False), + text_compression_level=text_compression_level, + ) + + data_weights.append(multi_corpus_sampling_weights[corpus_idx_map[k]]) + + if len(target_dataset_map) == 1: + self.datasets[split] = list(target_dataset_map.values())[0] + else: + self.datasets[split] = MultiCorpusDataset(target_dataset_map, distribution=data_weights, seed=0, sort_indices=True) + + @property + def target_dictionary(self): + """Return the :class:`~fairseq.data.Dictionary` for the language + model.""" + return self.state.target_dictionary + + def valid_step(self, sample, model, criterion): + loss, sample_size, logging_output = super().valid_step(sample, model, criterion) + if self.cfg.eval_wer and self.cfg.autoregressive: + metrics = self._inference_with_wer(self.sequence_generator, sample, model) + logging_output["_num_char_errors"] = metrics["num_char_errors"] + logging_output["_num_chars"] = metrics["num_chars"] + logging_output["_num_word_errors"] = metrics["num_word_errors"] + logging_output["_num_words"] = metrics["num_words"] + if self.cfg.eval_bleu and self.cfg.autoregressive: + metrics = self._inference_with_bleu(self.sequence_generator, sample, model) + logging_output["_bleu_sys_len"] = metrics.sys_len + logging_output["_bleu_ref_len"] = metrics.ref_len + # we split counts into separate entries so that they can be + # summed efficiently across workers using fast-stat-sync + assert len(metrics.counts) == 4 + for i in range(4): + logging_output[f"_bleu_counts_{i}"] = metrics.counts[i] + logging_output[f"_bleu_totals_{i}"] = metrics.totals[i] + return loss, sample_size, logging_output + + def build_model(self, model_cfg: FairseqDataclass, from_checkpoint=False): + model = super().build_model(model_cfg, from_checkpoint) + + if self.cfg.eval_wer and self.cfg.autoregressive: + self.sequence_generator = self.build_generator( + [model], + self.cfg.eval_wer_config, + ) + if self.cfg.eval_wer_tokenizer: + self.tokenizer = encoders.build_tokenizer(self.cfg.eval_wer_tokenizer) + else: + self.tokenizer = None + if self.cfg.eval_bleu and self.cfg.autoregressive: + assert self.cfg.eval_bleu_detok is not None, ( + "--eval-bleu-detok is required if using --eval-bleu; " + "try --eval-bleu-detok=moses (or --eval-bleu-detok=space " + "to disable detokenization, e.g., when using sentencepiece)" + ) + detok_args = json.loads(self.cfg.eval_bleu_detok_args) + self.tokenizer = encoders.build_tokenizer( + Namespace(tokenizer=self.cfg.eval_bleu_detok, **detok_args) + ) + gen_args = json.loads(self.cfg.eval_bleu_args) + gen_args = Namespace(**gen_args) + self.sequence_generator = self.build_generator([model], gen_args) + + return model + + def _inference_with_wer(self, generator, sample, model): + import editdistance + + def decode(toks): + s = self.target_dictionary.string( + toks.int().cpu(), + self.cfg.eval_wer_post_process, + escape_unk=True, + ) + if self.tokenizer: + s = self.tokenizer.decode(s) + return s + + num_word_errors, num_char_errors = 0, 0 + num_chars, num_words = 0, 0 + gen_out = self.inference_step(generator, [model], sample, None) + for i in range(len(gen_out)): + hyp = decode(gen_out[i][0]["tokens"]) + ref = decode( + utils.strip_pad(sample["target"][i], self.target_dictionary.pad()), + ) + num_char_errors += editdistance.eval(hyp, ref) + num_chars += len(ref) + hyp_words = hyp.split() + ref_words = ref.split() + num_word_errors += editdistance.eval(hyp_words, ref_words) + num_words += len(ref_words) + + return { + "num_char_errors": num_char_errors, + "num_chars": num_chars, + "num_word_errors": num_word_errors, + "num_words": num_words, + } + + def _inference_with_bleu(self, generator, sample, model): + import sacrebleu + + def decode(toks, is_ref): + s = self.target_dictionary.string( + toks.int().cpu(), + self.cfg.eval_bleu_remove_bpe, + # The default unknown string in fairseq is ``, but + # this is tokenized by sacrebleu as `< unk >`, inflating + # BLEU scores. Instead, we use a somewhat more verbose + # alternative that is unlikely to appear in the real + # reference, but doesn't get split into multiple tokens. + unk_string=("UNKNOWNTOKENINREF" if is_ref else "UNKNOWNTOKENINHYP"), + ) + if self.tokenizer: + s = self.tokenizer.decode(s) + return s + + gen_out = self.inference_step(generator, [model], sample) + hyps, refs = [], [] + for i in range(len(gen_out)): + hyps.append(decode(gen_out[i][0]["tokens"], is_ref=False)) + refs.append( + decode( + utils.strip_pad(sample["target"][i], self.target_dictionary.pad()), + is_ref=True, # don't count as matches to the hypo + ) + ) + if self.cfg.eval_bleu_print_samples: + logger.info("H-{} {}".format(sample["id"][0], hyps[0])) + logger.info("T-{} {}".format(sample["id"][0], refs[0])) + + eval_tokenization = "none" if self.cfg.eval_tokenized_bleu else "13a" + return sacrebleu.corpus_bleu(hyps, [refs], tokenize=eval_tokenization) + + def reduce_metrics(self, logging_outputs, criterion): + super().reduce_metrics(logging_outputs, criterion) + + if self.cfg.eval_wer: + zero = torch.scalar_tensor(0.0) + num_char_errors = sum( + log.get("_num_char_errors", zero) for log in logging_outputs + ) + num_chars = sum(log.get("_num_chars", zero) for log in logging_outputs) + num_word_errors = sum( + log.get("_num_word_errors", zero) for log in logging_outputs + ) + num_words = sum(log.get("_num_words", zero) for log in logging_outputs) + metrics.log_scalar("_num_char_errors", num_char_errors) + metrics.log_scalar("_num_chars", num_chars) + metrics.log_scalar("_num_word_errors", num_word_errors) + metrics.log_scalar("_num_words", num_words) + if num_chars > 0: + metrics.log_derived( + "uer", + lambda meters: meters["_num_char_errors"].sum + * 100.0 + / meters["_num_chars"].sum + if meters["_num_chars"].sum > 0 + else float("nan"), + ) + if num_words > 0: + metrics.log_derived( + "wer", + lambda meters: meters["_num_word_errors"].sum + * 100.0 + / meters["_num_words"].sum + if meters["_num_words"].sum > 0 + else float("nan"), + ) + if self.cfg.eval_bleu: + len_keys = ["_bleu_sys_len", "_bleu_ref_len"] + count_keys = [f"_bleu_counts_{i}" for i in range(4)] + total_keys = [f"_bleu_totals_{i}" for i in range(4)] + for k in len_keys + count_keys + total_keys: + metrics.log_scalar(k, sum(log.get(k, 0) for log in logging_outputs)) + + import sacrebleu + + metrics.log_derived( + "bleu", + lambda meters: sacrebleu.compute_bleu( + correct=[meters[k].sum for k in count_keys], + total=[meters[k].sum for k in total_keys], + sys_len=meters["_bleu_sys_len"].sum, + ref_len=meters["_bleu_ref_len"].sum, + smooth_method="exp", + ).score, + ) diff --git a/fairseq/fairseq/tasks/audio_pretraining.py b/fairseq/fairseq/tasks/audio_pretraining.py new file mode 100644 index 0000000000000000000000000000000000000000..3e91303b6964627426f8b4791cdd2fc4f8676e3e --- /dev/null +++ b/fairseq/fairseq/tasks/audio_pretraining.py @@ -0,0 +1,253 @@ +# Copyright (c) 2017-present, Facebook, Inc. +# All rights reserved. +# +# This source code is licensed under the license found in the LICENSE file in +# the root directory of this source tree. An additional grant of patent rights +# can be found in the PATENTS file in the same directory. + +import logging +import os +import sys + +from argparse import Namespace +from dataclasses import dataclass, field +from typing import Optional, OrderedDict +from fairseq.data.multi_corpus_dataset import MultiCorpusDataset +from omegaconf import MISSING, II, OmegaConf + +from fairseq.data import BinarizedAudioDataset, FileAudioDataset, SubsampleDataset +from fairseq.dataclass import FairseqDataclass, ChoiceEnum +from fairseq.data.text_compressor import TextCompressionLevel + +from . import FairseqTask, register_task + + +logger = logging.getLogger(__name__) + + +@dataclass +class AudioMaskingConfig: + feature_encoder_spec: str = II("model.modalities.audio.feature_encoder_spec") + mask_prob: float = II("model.modalities.audio.mask_prob") + mask_prob_adjust: float = II("model.modalities.audio.mask_prob_adjust") + mask_length: int = II("model.modalities.audio.mask_length") + inverse_mask: bool = II("model.modalities.audio.inverse_mask") + mask_dropout: float = II("model.modalities.audio.mask_dropout") + clone_batch: int = II("model.clone_batch") + expand_adjacent: bool = False + non_overlapping: bool = False + + +@dataclass +class AudioPretrainingConfig(FairseqDataclass): + data: str = field(default=MISSING, metadata={"help": "path to data directory"}) + labels: Optional[str] = field( + default=None, + metadata={"help": "extension of the label file to load, used for fine-tuning"}, + ) + multi_corpus_keys: Optional[str] = field( + default=None, + metadata={"help": "Comma separated names for loading multi corpus datasets"}) + multi_corpus_sampling_weights: Optional[str] = field( + default=None, + metadata={"help": "Comma separated string of sampling weights corresponding to the multi_corpus_keys"}) + binarized_dataset: bool = field( + default=False, + metadata={ + "help": "if true, loads binarized dataset (useful for very large datasets). " + "See examples/wav2vec/scripts/binarize_manifest.sh" + }, + ) + sample_rate: int = field( + default=16_000, + metadata={ + "help": "target sample rate. audio files will be up/down sampled to this rate" + }, + ) + normalize: bool = field( + default=False, + metadata={"help": "if set, normalizes input to have 0 mean and unit variance"}, + ) + enable_padding: bool = field( + default=False, metadata={"help": "pad shorter samples instead of cropping"} + ) + max_sample_size: Optional[int] = field( + default=None, metadata={"help": "max sample size to crop to for batching"} + ) + min_sample_size: Optional[int] = field( + default=None, metadata={"help": "min sample size to skip small examples"} + ) + num_batch_buckets: int = field( + default=0, + metadata={"help": "number of buckets"}, + ) + tpu: bool = II("common.tpu") + text_compression_level: ChoiceEnum([x.name for x in TextCompressionLevel]) = field( + default="none", + metadata={ + "help": "compression level for texts (e.g. audio filenames, " + "target texts): none/low/high (default: none). " + }, + ) + + rebuild_batches: bool = True + precompute_mask_config: Optional[AudioMaskingConfig] = None + + post_save_script: Optional[str] = None + + subsample: float = 1 + seed: int = II("common.seed") + + +@register_task("audio_pretraining", dataclass=AudioPretrainingConfig) +class AudioPretrainingTask(FairseqTask): + """ """ + + cfg: AudioPretrainingConfig + + @classmethod + def setup_task(cls, cfg: AudioPretrainingConfig, **kwargs): + """Setup the task (e.g., load dictionaries). + + Args: + cfg (AudioPretrainingConfig): configuration of this task + """ + + return cls(cfg) + + def load_dataset(self, split: str, task_cfg: FairseqDataclass = None, **kwargs): + data_path = self.cfg.data + task_cfg = task_cfg or self.cfg + + # upgrade old task + if isinstance(task_cfg, Namespace): + if not hasattr(task_cfg, "autoregressive"): + task_cfg.autoregressive = not task_cfg.criterion == "ctc" + + text_compression_level = getattr( + TextCompressionLevel, str(self.cfg.text_compression_level) + ) + + compute_mask = getattr(task_cfg, "precompute_mask_config", None) is not None + mask_args = {} + if compute_mask: + mask_args = task_cfg.precompute_mask_config + + if getattr(task_cfg, "binarized_dataset", False): + self.datasets[split] = BinarizedAudioDataset( + data_path, + split=split, + sample_rate=task_cfg.get("sample_rate", self.cfg.sample_rate), + max_sample_size=self.cfg.max_sample_size, + min_sample_size=self.cfg.min_sample_size, + pad=task_cfg.labels is not None or task_cfg.enable_padding, + normalize=task_cfg.normalize, + num_buckets=self.cfg.num_batch_buckets or int(self.cfg.tpu), + compute_mask=compute_mask, + **mask_args, + ) + else: + if task_cfg.multi_corpus_keys is None: + manifest_path = os.path.join(data_path, "{}.tsv".format(split)) + + self.datasets[split] = FileAudioDataset( + manifest_path=manifest_path, + sample_rate=task_cfg.get("sample_rate", self.cfg.sample_rate), + max_sample_size=self.cfg.max_sample_size, + min_sample_size=self.cfg.min_sample_size, + pad=task_cfg.labels is not None or task_cfg.enable_padding, + normalize=task_cfg.normalize, + num_buckets=self.cfg.num_batch_buckets or int(self.cfg.tpu), + text_compression_level=text_compression_level, + compute_mask=compute_mask, + **mask_args, + ) + else: + dataset_map = OrderedDict() + self.dataset_map = {} + multi_corpus_keys = [k.strip() for k in task_cfg.multi_corpus_keys.split(",")] + corpus_idx_map = {k: idx for idx, k in enumerate(multi_corpus_keys)} + data_keys = [k.split(":") for k in split.split(",")] + + multi_corpus_sampling_weights = [float(val.strip()) for val in task_cfg.multi_corpus_sampling_weights.split(",")] + data_weights = [] + + for key, file_name in data_keys: + + k = key.strip() + manifest_path = os.path.join(data_path, "{}.tsv".format(file_name.strip())) + + # TODO: Remove duplication of code from the if block above + dataset_map[k] = FileAudioDataset( + manifest_path=manifest_path, + sample_rate=task_cfg.get("sample_rate", self.cfg.sample_rate), + max_sample_size=self.cfg.max_sample_size, + min_sample_size=self.cfg.min_sample_size, + pad=task_cfg.labels is not None or task_cfg.enable_padding, + normalize=task_cfg.normalize, + num_buckets=self.cfg.num_batch_buckets or int(self.cfg.tpu), + text_compression_level=text_compression_level, + compute_mask=compute_mask, + corpus_key=corpus_idx_map[k], + **mask_args, + ) + + data_weights.append(multi_corpus_sampling_weights[corpus_idx_map[k]]) + + self.dataset_map[split] = dataset_map + + if len(dataset_map) == 1: + self.datasets[split] = list(dataset_map.values())[0] + else: + self.datasets[split] = MultiCorpusDataset(dataset_map, distribution=data_weights, seed=0, sort_indices=True) + + if getattr(task_cfg, "subsample", 1) < 1: + self.datasets[split] = SubsampleDataset( + self.datasets[split], + task_cfg.subsample, + shuffle=True, + seed=task_cfg.seed, + ) + + if self.cfg.tpu and task_cfg.inferred_w2v_config.mask_channel_prob == 0.0: + logger.info( + "Pretraining on TPUs may suffer convergence " + "issues when training with `mask_channel_prob` value of " + "0. You may want to set this to a low value close to 0." + ) + + def max_positions(self): + """Maximum input length supported by the encoder.""" + return sys.maxsize, sys.maxsize + + def build_model(self, model_cfg: FairseqDataclass, from_checkpoint=False): + model = super().build_model(model_cfg, from_checkpoint) + + actualized_cfg = getattr(model, "cfg", None) + if actualized_cfg is not None: + # if "w2v_args" in actualized_cfg: + if hasattr(actualized_cfg, "w2v_args"): + model_cfg.w2v_args = actualized_cfg.w2v_args + + return model + + def post_save(self, cp_path, num_updates): + if self.cfg.post_save_script is not None: + logger.info(f"launching {self.cfg.post_save_script}") + import os.path as osp + from fairseq.file_io import PathManager + + eval_cp_path = osp.join( + osp.dirname(cp_path), f"checkpoint_eval_{num_updates}.pt" + ) + + print(cp_path, eval_cp_path, osp.dirname(cp_path)) + + assert PathManager.copy( + cp_path, eval_cp_path, overwrite=True + ), f"Failed to copy {cp_path} to {eval_cp_path}" + + import subprocess + import shlex + + subprocess.call(shlex.split(f"{self.cfg.post_save_script} {eval_cp_path}")) diff --git a/fairseq/fairseq/tasks/cross_lingual_lm.py b/fairseq/fairseq/tasks/cross_lingual_lm.py new file mode 100644 index 0000000000000000000000000000000000000000..8f8fe7e2de181e41bd0e6a2bf96948ee78de5ae8 --- /dev/null +++ b/fairseq/fairseq/tasks/cross_lingual_lm.py @@ -0,0 +1,191 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import itertools +import logging +import os +from collections import OrderedDict + +import numpy as np +from fairseq import tokenizer, utils +from fairseq.data import ConcatDataset, Dictionary, TokenBlockDataset, data_utils +from fairseq.data.legacy.masked_lm_dataset import MaskedLMDataset +from fairseq.data.legacy.masked_lm_dictionary import MaskedLMDictionary +from fairseq.data.multi_corpus_sampled_dataset import MultiCorpusSampledDataset +from fairseq.tasks import LegacyFairseqTask, register_task + + +logger = logging.getLogger(__name__) + + +@register_task("cross_lingual_lm") +class CrossLingualLMTask(LegacyFairseqTask): + """ + Task for training cross-lingual language models. + + For more details look at: https://arxiv.org/pdf/1901.07291.pdf + + Args: + dictionary (Dictionary): the dictionary for the input of the task + """ + + @staticmethod + def add_args(parser): + """Add task-specific arguments to the parser.""" + parser.add_argument( + "data", + help="colon separated path to data directories list, \ + will be iterated upon during epochs in round-robin manner", + ) + parser.add_argument( + "--tokens-per-sample", + default=512, + type=int, + help="max number of total tokens over all segments" " per sample", + ) + parser.add_argument( + "--monolingual-langs", + default="en", + type=str, + help="comma separated list of languages for which we" + " want to train XLM on", + ) + parser.add_argument( + "--shuffle", + action="store_true", + help="shuffle each monolingual dataset while" " training", + ) + + def __init__(self, args, dictionary): + super().__init__(args) + self.dictionary = dictionary + self.seed = args.seed + self.distributed_world_size = args.distributed_world_size + self.langs2id = self._lang_to_id(args.monolingual_langs) + + def _lang_to_id(self, languages: str): + """ + Build a map from languages to ids. These ids are used as segment labels + for cross-lingual LM training. + """ + lang2id = {} + langs = [l.strip() for l in languages.split(",")] + for id, lang in enumerate(langs): + lang2id[lang] = id + return lang2id + + @classmethod + def load_dictionary(cls, filename): + return MaskedLMDictionary.load(filename) + + @classmethod + def build_dictionary( + cls, filenames, workers=1, threshold=-1, nwords=-1, padding_factor=8 + ): + d = MaskedLMDictionary() + for filename in filenames: + Dictionary.add_file_to_dictionary( + filename, d, tokenizer.tokenize_line, workers + ) + d.finalize(threshold=threshold, nwords=nwords, padding_factor=padding_factor) + return d + + @property + def target_dictionary(self): + return self.dictionary + + @classmethod + def setup_task(cls, args, **kwargs): + """Setup the task.""" + dictionary = MaskedLMDictionary.load(os.path.join(args.data, "dict.txt")) + logger.info("dictionary: {} types".format(len(dictionary))) + return cls(args, dictionary) + + def _load_single_lang_dataset(self, split, epoch): + loaded_datasets = [] + + paths = utils.split_paths(self.args.data) + assert len(paths) > 0 + data_path = paths[(epoch - 1) % len(paths)] + + for k in itertools.count(): + split_k = split + (str(k) if k > 0 else "") + path = os.path.join(data_path, split_k) + + ds = data_utils.load_indexed_dataset( + path, self.dictionary, self.args.dataset_impl + ) + if ds is None: + if k > 0: + break + else: + raise FileNotFoundError( + "Dataset not found: {} ({})".format(split, data_path) + ) + + # Since we append each block with the classification_token, + # we need to effectively create blocks of length + # tokens_per_sample-1 + loaded_datasets.append( + TokenBlockDataset( + ds, + ds.sizes, + self.args.tokens_per_sample - 1, + pad=self.dictionary.pad(), + eos=self.dictionary.eos(), + ) + ) + + logger.info( + "{} {} {} examples".format(data_path, split_k, len(loaded_datasets[-1])) + ) + + if len(loaded_datasets) == 1: + dataset = loaded_datasets[0] + sizes = dataset.sizes + else: + dataset = ConcatDataset(loaded_datasets) + sizes = np.concatenate([ds.sizes for ds in loaded_datasets]) + + return dataset, sizes + + def load_dataset(self, split, epoch=1, combine=False, **kwargs): + """Load a given dataset split. + + Args: + split (str): name of the split (e.g., train, valid, test) + """ + dataset_map = OrderedDict() + + for lang in self.langs2id.keys(): + # Datasets are expected to be in "split.lang" format (Eg: train.en) + language_split = "{}.{}".format(split, lang) + + block_dataset, sizes = self._load_single_lang_dataset( + split=language_split, epoch=epoch + ) + + dataset_map[lang] = MaskedLMDataset( + dataset=block_dataset, + sizes=sizes, + vocab=self.dictionary, + pad_idx=self.dictionary.pad(), + mask_idx=self.dictionary.mask(), + classif_token_idx=self.dictionary.eos(), + sep_token_idx=self.dictionary.eos(), + shuffle=getattr(self.args, "shuffle", False), + has_pairs=False, + segment_id=self.langs2id[lang], + seed=self.seed, + ) + + self.datasets[split] = MultiCorpusSampledDataset(dataset_map) + logger.info( + "{} {} {} examples".format( + utils.split_paths(self.args.data)[epoch - 1], + split, + len(self.datasets[split]), + ) + ) diff --git a/fairseq/fairseq/tasks/denoising.py b/fairseq/fairseq/tasks/denoising.py new file mode 100644 index 0000000000000000000000000000000000000000..57b824d5812535d2a2f83292bad4f707e14ec618 --- /dev/null +++ b/fairseq/fairseq/tasks/denoising.py @@ -0,0 +1,296 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import logging +import os +from dataclasses import dataclass, field +from typing import Any, Optional + +import numpy as np +from omegaconf import II, MISSING + +from fairseq import utils +from fairseq.data import ( + AppendTokenDataset, + DenoisingDataset, + Dictionary, + IdDataset, + NestedDictionaryDataset, + NumelDataset, + PadDataset, + PrependTokenDataset, + StripTokenDataset, + TokenBlockDataset, + data_utils, +) +from fairseq.data.encoders.utils import get_whole_word_mask +from fairseq.data.shorten_dataset import maybe_shorten_dataset +from fairseq.dataclass import ChoiceEnum, FairseqDataclass +from fairseq.tasks import FairseqTask, register_task + +from ..data.indexed_dataset import get_available_dataset_impl + +logger = logging.getLogger(__name__) + +SAMPLE_BREAK_MODE_CHOICES = ChoiceEnum(["none", "complete", "complete_doc", "eos"]) +SHORTEN_METHOD_CHOICES = ChoiceEnum(["none", "truncate", "random_crop"]) +MASK_LENGTH_CHOICES = ChoiceEnum(["subword", "word", "span-poisson"]) + + +@dataclass +class DenoisingConfig(FairseqDataclass): + data: str = field( + default=MISSING, + metadata={"help": "path to data directory"}, + ) + bpe: Optional[str] = field( + default=None, + metadata={"help": "TODO"}, + ) + tokens_per_sample: int = field( + default=512, + metadata={ + "help": "max number of total tokens over all segments " + "per sample for dataset" + }, + ) + sample_break_mode: SAMPLE_BREAK_MODE_CHOICES = field( + default="complete_doc", + metadata={ + "help": 'If omitted or "none", fills each sample with tokens-per-sample ' + 'tokens. If set to "complete", splits samples only at the end ' + "of sentence, but may include multiple sentences per sample. " + '"complete_doc" is similar but respects doc boundaries. ' + 'If set to "eos", includes only one sentence per sample.' + }, + ) + replace_length: int = field( + default=0, + metadata={"help": "TODO, should only allow -1, 0 and 1"}, + ) + mask: float = field( + default=0.0, + metadata={"help": "fraction of words/subwords that will be masked"}, + ) + mask_random: float = field( + default=0.0, + metadata={"help": "instead of using [MASK], use random token this often"}, + ) + insert: float = field( + default=0.0, + metadata={"help": "insert this percentage of additional random tokens"}, + ) + permute: float = field( + default=0.0, + metadata={"help": "take this proportion of subwords and permute them"}, + ) + rotate: float = field( + default=0.5, + metadata={"help": "rotate this proportion of inputs"}, + ) + poisson_lambda: float = field( + default=3.0, + metadata={"help": "randomly shuffle sentences for this proportion of inputs"}, + ) + shuffle_instance: float = field( + default=0.0, + metadata={"help": "shuffle this proportion of sentences in all inputs"}, + ) + mask_length: MASK_LENGTH_CHOICES = field( + default="subword", + metadata={"help": "mask length to choose"}, + ) + permute_sentences: int = field( + default=-1, + metadata={ + "help": "when masking N tokens, replace with 0, 1, or N tokens (use -1 for N)" + }, + ) + seed: int = II("common.seed") + shorten_method: SHORTEN_METHOD_CHOICES = field( + default="none", + metadata={ + "help": "if not none, shorten sequences that exceed --tokens-per-sample" + }, + ) + shorten_data_split_list: str = field( + default="", + metadata={ + "help": "comma-separated list of dataset splits to apply shortening to, " + 'e.g., "train,valid" (default: all dataset splits)' + }, + ) + max_source_positions: int = field( + default=1024, + metadata={"help": "max number of tokens in the source sequence"}, + ) + max_target_positions: int = field( + default=1024, + metadata={"help": "max number of tokens in the target sequence"}, + ) + dataset_impl: Optional[ChoiceEnum(get_available_dataset_impl())] = II( + "dataset.dataset_impl" + ) + + +@register_task("denoising", dataclass=DenoisingConfig) +class DenoisingTask(FairseqTask): + """ + Denoising task for applying sequence to sequence denoising. (ie. BART) + """ + + cfg: DenoisingConfig + + def __init__(self, cfg, dictionary): + super().__init__(cfg) + self.dictionary = dictionary + + # add mask token + self.mask_idx = self.dictionary.add_symbol("") + + @classmethod + def setup_task(cls, cfg: DenoisingConfig, **kwargs): + """Setup the task.""" + paths = utils.split_paths(cfg.data) + assert len(paths) > 0 + dictionary = Dictionary.load(os.path.join(paths[0], "dict.txt")) + logger.info("dictionary: {} types".format(len(dictionary))) + if not hasattr(cfg, "shuffle_instance"): + cfg.shuffle_instance = False + return cls(cfg, dictionary) + + def _load_dataset_split(self, split, epoch, combine): + paths = utils.split_paths(self.cfg.data) + assert len(paths) > 0 + data_path = paths[(epoch - 1) % len(paths)] + split_path = os.path.join(data_path, split) + + dataset = data_utils.load_indexed_dataset( + split_path, + self.dictionary, + self.cfg.dataset_impl, + combine=combine, + ) + if dataset is None: + raise FileNotFoundError( + "Dataset not found: {} ({})".format(split, split_path) + ) + + dataset = StripTokenDataset(dataset, self.dictionary.eos()) + + dataset = maybe_shorten_dataset( + dataset, + split, + self.cfg.shorten_data_split_list, + self.cfg.shorten_method, + self.cfg.tokens_per_sample, + self.cfg.seed, + ) + + # create continuous blocks of tokens + dataset = TokenBlockDataset( + dataset, + dataset.sizes, + self.cfg.tokens_per_sample - 2, + # one less for and one for + pad=self.dictionary.pad(), + eos=self.dictionary.eos(), + break_mode=self.cfg.sample_break_mode, + document_sep_len=0, + ) + logger.info("loaded {} blocks from: {}".format(len(dataset), split_path)) + + # prepend beginning-of-sentence token (, equiv. to [CLS] in BERT) + dataset = PrependTokenDataset(dataset, self.source_dictionary.bos()) + dataset = AppendTokenDataset(dataset, self.source_dictionary.eos()) + return dataset + + def load_dataset(self, split, epoch=1, combine=False, **kwargs): + """Load a given dataset split. + + Args: + split (str): name of the split (e.g., train, valid, test) + """ + dataset = self._load_dataset_split(split, epoch, combine) + + mask_whole_words = ( + get_whole_word_mask(self.cfg.bpe, self.source_dictionary) + if self.cfg.mask_length != "subword" + else None + ) + + self.datasets[split] = DenoisingDataset( + dataset, + dataset.sizes, + self.dictionary, + self.mask_idx, + mask_whole_words, + shuffle=self.cfg.shuffle_instance, + seed=self.cfg.seed, + mask=self.cfg.mask, + mask_random=self.cfg.mask_random, + insert=self.cfg.insert, + rotate=self.cfg.rotate, + permute_sentences=self.cfg.permute_sentences, + bpe=self.cfg.bpe, + replace_length=self.cfg.replace_length, + mask_length=self.cfg.mask_length, + poisson_lambda=self.cfg.poisson_lambda, + ) + logger.info( + "Split: {0}, Loaded {1} samples of denoising_dataset".format( + split, + len(self.datasets[split]), + ) + ) + + def build_dataset_for_inference(self, src_tokens, src_lengths, **kwargs): + """ + Generate batches for inference. We assume that the input begins with a + bos symbol (``) and ends with an eos symbol (``). + """ + pad = self.source_dictionary.pad() + eos = self.source_dictionary.eos() + src_dataset = TokenBlockDataset( + src_tokens, + src_lengths, + block_size=self.cfg.tokens_per_sample - 2, # for and + pad=pad, + eos=eos, + break_mode=self.cfg.sample_break_mode, + document_sep_len=0, + ) + prev_output_tokens = PrependTokenDataset( + StripTokenDataset(src_dataset, eos), eos + ) + src_dataset = PadDataset(src_dataset, pad_idx=pad, left_pad=False) + return NestedDictionaryDataset( + { + "id": IdDataset(), + "net_input": { + "src_tokens": src_dataset, + "src_lengths": NumelDataset(src_dataset, reduce=False), + "prev_output_tokens": PadDataset( + prev_output_tokens, pad_idx=pad, left_pad=False + ), + }, + "target": src_dataset, + }, + sizes=[np.array(src_lengths)], + ) + + def max_positions(self): + """Return the max sentence length allowed by the task.""" + return (self.cfg.max_source_positions, self.cfg.max_target_positions) + + @property + def source_dictionary(self): + """Return the source :class:`~fairseq.data.Dictionary`.""" + return self.dictionary + + @property + def target_dictionary(self): + """Return the target :class:`~fairseq.data.Dictionary`.""" + return self.dictionary diff --git a/fairseq/fairseq/tasks/fairseq_task.py b/fairseq/fairseq/tasks/fairseq_task.py new file mode 100644 index 0000000000000000000000000000000000000000..e39d1d684882b3269e1fb97ebb6db61ef1a97df9 --- /dev/null +++ b/fairseq/fairseq/tasks/fairseq_task.py @@ -0,0 +1,708 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import logging +import os +import warnings +from argparse import Namespace +from typing import Any, Callable, Dict, List + +import torch +from fairseq import search, tokenizer, utils +from fairseq.logging import metrics +from fairseq.data import Dictionary, FairseqDataset, data_utils, encoders, iterators +from fairseq.dataclass import FairseqDataclass +from fairseq.dataclass.utils import gen_parser_from_dataclass +from fairseq.optim.amp_optimizer import AMPOptimizer +from omegaconf import DictConfig + + +logger = logging.getLogger(__name__) + + +class StatefulContainer(object): + def __init__(self): + self._state = dict() + self._factories = dict() + + def add_factory(self, name, factory: Callable[[], Any]): + self._factories[name] = factory + + def merge_state_dict(self, state_dict: Dict[str, Any]): + self._state.update(state_dict) + + @property + def state_dict(self) -> Dict[str, Any]: + return self._state + + def __getattr__(self, name): + if name not in self._state and name in self._factories: + self._state[name] = self._factories[name]() + + if name in self._state: + return self._state[name] + + raise AttributeError(f"Task state has no factory for attribute {name}") + + +class FairseqTask(object): + """ + Tasks store dictionaries and provide helpers for loading/iterating over + Datasets, initializing the Model/Criterion and calculating the loss. + + Tasks have limited statefulness. In particular, state that needs to be + saved to/loaded from checkpoints needs to be stored in the `self.state` + :class:`StatefulContainer` object. For example:: + + self.state.add_factory("dictionary", self.load_dictionary) + print(self.state.dictionary) # calls self.load_dictionary() + + This is necessary so that when loading checkpoints, we can properly + recreate the task state after initializing the task instance. + """ + + @classmethod + def add_args(cls, parser): + """Add task-specific arguments to the parser.""" + dc = getattr(cls, "__dataclass", None) + if dc is not None: + gen_parser_from_dataclass(parser, dc()) + + @staticmethod + def logging_outputs_can_be_summed(criterion) -> bool: + """ + Whether the logging outputs returned by `train_step` and `valid_step` can + be summed across workers prior to calling `aggregate_logging_outputs`. + Setting this to True will improves distributed training speed. + """ + return criterion.logging_outputs_can_be_summed() + + def __init__(self, cfg: FairseqDataclass, **kwargs): + self.cfg = cfg + self.datasets = dict() + self.dataset_to_epoch_iter = dict() + self.state = StatefulContainer() + + @classmethod + def load_dictionary(cls, filename): + """Load the dictionary from the filename + + Args: + filename (str): the filename + """ + return Dictionary.load(filename) + + @classmethod + def build_dictionary( + cls, filenames, workers=1, threshold=-1, nwords=-1, padding_factor=8 + ): + """Build the dictionary + + Args: + filenames (list): list of filenames + workers (int): number of concurrent workers + threshold (int): defines the minimum word count + nwords (int): defines the total number of words in the final dictionary, + including special symbols + padding_factor (int): can be used to pad the dictionary size to be a + multiple of 8, which is important on some hardware (e.g., Nvidia + Tensor Cores). + """ + d = Dictionary() + for filename in filenames: + Dictionary.add_file_to_dictionary( + filename, d, tokenizer.tokenize_line, workers + ) + d.finalize(threshold=threshold, nwords=nwords, padding_factor=padding_factor) + return d + + @classmethod + def setup_task(cls, cfg: DictConfig, **kwargs): + """Setup the task (e.g., load dictionaries). + + Args: + cfg (omegaconf.DictConfig): parsed command-line arguments + """ + return cls(cfg, **kwargs) + + def has_sharded_data(self, split): + return os.pathsep in getattr(self.cfg, "data", "") + + def load_dataset( + self, + split: str, + combine: bool = False, + task_cfg: FairseqDataclass = None, + **kwargs, + ): + """Load a given dataset split. + + Args: + split (str): name of the split (e.g., train, valid, test) + combine (bool): combines a split segmented into pieces into one dataset + task_cfg (FairseqDataclass): optional task configuration stored in the checkpoint that can be used + to load datasets + """ + raise NotImplementedError + + def dataset(self, split): + """ + Return a loaded dataset split. + + Args: + split (str): name of the split (e.g., train, valid, test) + + Returns: + a :class:`~fairseq.data.FairseqDataset` corresponding to *split* + """ + from fairseq.data import FairseqDataset + + if split not in self.datasets: + raise KeyError("Dataset not loaded: " + split) + if not isinstance(self.datasets[split], FairseqDataset): + raise TypeError("Datasets are expected to be of type FairseqDataset") + return self.datasets[split] + + def filter_indices_by_size( + self, indices, dataset, max_positions=None, ignore_invalid_inputs=False + ): + """ + Filter examples that are too large + + Args: + indices (np.array): original array of sample indices + dataset (~fairseq.data.FairseqDataset): dataset to batch + max_positions (optional): max sentence length supported by the + model (default: None). + ignore_invalid_inputs (bool, optional): don't raise Exception for + sentences that are too long (default: False). + Returns: + np.array: array of filtered sample indices + """ + indices, ignored = dataset.filter_indices_by_size(indices, max_positions) + if len(ignored) > 0: + if not ignore_invalid_inputs: + raise Exception( + ( + "Size of sample #{} is invalid (={}) since max_positions={}, " + "skip this example with --skip-invalid-size-inputs-valid-test" + ).format(ignored[0], dataset.size(ignored[0]), max_positions) + ) + logger.warning( + ( + "{:,} samples have invalid sizes and will be skipped, " + "max_positions={}, first few sample ids={}" + ).format(len(ignored), max_positions, ignored[:10]) + ) + return indices + + def can_reuse_epoch_itr(self, dataset): + # We can reuse the epoch iterator across epochs as long as the dataset + # hasn't disabled it. We default to ``False`` here, although in practice + # this will be ``True`` for most datasets that inherit from + # ``FairseqDataset`` due to the base implementation there. + return getattr(dataset, "can_reuse_epoch_itr_across_epochs", False) + + def get_batch_iterator( + self, + dataset, + max_tokens=None, + max_sentences=None, + max_positions=None, + ignore_invalid_inputs=False, + required_batch_size_multiple=1, + seed=1, + num_shards=1, + shard_id=0, + num_workers=0, + epoch=1, + data_buffer_size=0, + disable_iterator_cache=False, + skip_remainder_batch=False, + grouped_shuffling=False, + update_epoch_batch_itr=False, + ): + """ + Get an iterator that yields batches of data from the given dataset. + + Args: + dataset (~fairseq.data.FairseqDataset): dataset to batch + max_tokens (int, optional): max number of tokens in each batch + (default: None). + max_sentences (int, optional): max number of sentences in each + batch (default: None). + max_positions (optional): max sentence length supported by the + model (default: None). + ignore_invalid_inputs (bool, optional): don't raise Exception for + sentences that are too long (default: False). + required_batch_size_multiple (int, optional): require batch size to + be a multiple of N (default: 1). + seed (int, optional): seed for random number generator for + reproducibility (default: 1). + num_shards (int, optional): shard the data iterator into N + shards (default: 1). + shard_id (int, optional): which shard of the data iterator to + return (default: 0). + num_workers (int, optional): how many subprocesses to use for data + loading. 0 means the data will be loaded in the main process + (default: 0). + epoch (int, optional): the epoch to start the iterator from + (default: 1). + data_buffer_size (int, optional): number of batches to + preload (default: 0). + disable_iterator_cache (bool, optional): don't cache the + EpochBatchIterator (ignores `FairseqTask::can_reuse_epoch_itr`) + (default: False). + skip_remainder_batch (bool, optional): if set, discard the last + batch in each training epoch, as the last batch is often smaller than + local_batch_size * distributed_word_size (default: ``True``). + grouped_shuffling (bool, optional): group batches with each groups + containing num_shards batches and shuffle groups. Reduces difference + between sequence lengths among workers for batches sorted by length. + update_epoch_batch_itr (bool optional): if true then donot use the cached + batch iterator for the epoch + + Returns: + ~fairseq.iterators.EpochBatchIterator: a batched iterator over the + given dataset split + """ + can_reuse_epoch_itr = ( + not disable_iterator_cache + and not update_epoch_batch_itr + and self.can_reuse_epoch_itr(dataset) + ) + logger.info(f"can_reuse_epoch_itr = {can_reuse_epoch_itr}") + if can_reuse_epoch_itr and dataset in self.dataset_to_epoch_iter: + logger.debug("reusing EpochBatchIterator for epoch {}".format(epoch)) + return self.dataset_to_epoch_iter[dataset] + + assert isinstance(dataset, FairseqDataset) + + # initialize the dataset with the correct starting epoch + dataset.set_epoch(epoch) + + def make_batches(dataset, epoch): + logger.info(f"creating new batches for epoch {epoch}") + + # get indices ordered by example size + with data_utils.numpy_seed(seed + epoch): + indices = dataset.ordered_indices() + + # filter examples that are too large + if max_positions is not None: + indices = self.filter_indices_by_size( + indices, dataset, max_positions, ignore_invalid_inputs + ) + + # create mini-batches with given size constraints + batches = dataset.batch_by_size( + indices, + max_tokens=max_tokens, + max_sentences=max_sentences, + required_batch_size_multiple=required_batch_size_multiple, + ) + return batches + + reuse_dataloader = getattr(self.cfg, "reuse_dataloader", True) + persistent_workers = getattr(self.cfg, "persistent_workers", True) + rebuild_batches = getattr(self.cfg, "rebuild_batches", False) + logger.info(f"reuse_dataloader = {reuse_dataloader}") + logger.info(f"rebuild_batches = {rebuild_batches}") + + if rebuild_batches: + logger.info("batches will be rebuilt for each epoch") + batch_sampler = make_batches + else: + batch_sampler = make_batches(dataset, epoch) + + # return a reusable, sharded iterator + epoch_iter = iterators.EpochBatchIterator( + dataset=dataset, + collate_fn=dataset.collater, + batch_sampler=batch_sampler, + seed=seed, + num_shards=num_shards, + shard_id=shard_id, + num_workers=num_workers, + epoch=epoch, + buffer_size=data_buffer_size, + skip_remainder_batch=skip_remainder_batch, + grouped_shuffling=grouped_shuffling, + reuse_dataloader=reuse_dataloader, + persistent_workers=persistent_workers, + ) + + if can_reuse_epoch_itr: + self.dataset_to_epoch_iter[dataset] = epoch_iter + + return epoch_iter + + def build_model(self, cfg: FairseqDataclass, from_checkpoint=False): + """ + Build the :class:`~fairseq.models.BaseFairseqModel` instance for this + task. + + Args: + cfg (FairseqDataclass): configuration object + + Returns: + a :class:`~fairseq.models.BaseFairseqModel` instance + """ + from fairseq import models, quantization_utils + + model = models.build_model(cfg, self, from_checkpoint) + model = quantization_utils.quantize_model_scalar(model, cfg) + return model + + def build_criterion(self, cfg: DictConfig, from_checkpoint=False): + """ + Build the :class:`~fairseq.criterions.FairseqCriterion` instance for + this task. + + Args: + cfg (omegaconf.DictConfig): configration object + + Returns: + a :class:`~fairseq.criterions.FairseqCriterion` instance + """ + from fairseq import criterions + + return criterions.build_criterion(cfg, self, from_checkpoint=from_checkpoint) + + def build_generator( + self, + models, + args, + seq_gen_cls=None, + extra_gen_cls_kwargs=None, + prefix_allowed_tokens_fn=None, + ): + """ + Build a :class:`~fairseq.SequenceGenerator` instance for this + task. + + Args: + models (List[~fairseq.models.FairseqModel]): ensemble of models + args (fairseq.dataclass.configs.GenerationConfig): + configuration object (dataclass) for generation + extra_gen_cls_kwargs (Dict[str, Any]): extra options to pass + through to SequenceGenerator + prefix_allowed_tokens_fn (Callable[[int, torch.Tensor], List[int]]): + If provided, this function constrains the beam search to + allowed tokens only at each step. The provided function + should take 2 arguments: the batch ID (`batch_id: int`) + and a unidimensional tensor of token ids (`inputs_ids: + torch.Tensor`). It has to return a `List[int]` with the + allowed tokens for the next generation step conditioned + on the previously generated tokens (`inputs_ids`) and + the batch ID (`batch_id`). This argument is useful for + constrained generation conditioned on the prefix, as + described in "Autoregressive Entity Retrieval" + (https://arxiv.org/abs/2010.00904) and + https://github.com/facebookresearch/GENRE. + """ + if getattr(args, "score_reference", False): + from fairseq.sequence_scorer import SequenceScorer + + return SequenceScorer( + self.target_dictionary, + compute_alignment=getattr(args, "print_alignment", False), + ) + + from fairseq.sequence_generator import ( + SequenceGenerator, + SequenceGeneratorWithAlignment, + ) + + # Choose search strategy. Defaults to Beam Search. + sampling = getattr(args, "sampling", False) + sampling_topk = getattr(args, "sampling_topk", -1) + sampling_topp = getattr(args, "sampling_topp", -1.0) + diverse_beam_groups = getattr(args, "diverse_beam_groups", -1) + diverse_beam_strength = getattr(args, "diverse_beam_strength", 0.5) + match_source_len = getattr(args, "match_source_len", False) + diversity_rate = getattr(args, "diversity_rate", -1) + constrained = getattr(args, "constraints", False) + if prefix_allowed_tokens_fn is None: + prefix_allowed_tokens_fn = getattr(args, "prefix_allowed_tokens_fn", None) + if ( + sum( + int(cond) + for cond in [ + sampling, + diverse_beam_groups > 0, + match_source_len, + diversity_rate > 0, + ] + ) + > 1 + ): + raise ValueError("Provided Search parameters are mutually exclusive.") + assert sampling_topk < 0 or sampling, "--sampling-topk requires --sampling" + assert sampling_topp < 0 or sampling, "--sampling-topp requires --sampling" + + if sampling: + search_strategy = search.Sampling( + self.target_dictionary, sampling_topk, sampling_topp + ) + elif diverse_beam_groups > 0: + search_strategy = search.DiverseBeamSearch( + self.target_dictionary, diverse_beam_groups, diverse_beam_strength + ) + elif match_source_len: + # this is useful for tagging applications where the output + # length should match the input length, so we hardcode the + # length constraints for simplicity + search_strategy = search.LengthConstrainedBeamSearch( + self.target_dictionary, + min_len_a=1, + min_len_b=0, + max_len_a=1, + max_len_b=0, + ) + elif diversity_rate > -1: + search_strategy = search.DiverseSiblingsSearch( + self.target_dictionary, diversity_rate + ) + elif constrained: + search_strategy = search.LexicallyConstrainedBeamSearch( + self.target_dictionary, args.constraints + ) + elif prefix_allowed_tokens_fn: + search_strategy = search.PrefixConstrainedBeamSearch( + self.target_dictionary, prefix_allowed_tokens_fn + ) + else: + search_strategy = search.BeamSearch(self.target_dictionary) + + extra_gen_cls_kwargs = extra_gen_cls_kwargs or {} + if seq_gen_cls is None: + if getattr(args, "print_alignment", False): + seq_gen_cls = SequenceGeneratorWithAlignment + extra_gen_cls_kwargs["print_alignment"] = args.print_alignment + else: + seq_gen_cls = SequenceGenerator + + return seq_gen_cls( + models, + self.target_dictionary, + beam_size=getattr(args, "beam", 5), + max_len_a=getattr(args, "max_len_a", 0), + max_len_b=getattr(args, "max_len_b", 200), + min_len=getattr(args, "min_len", 1), + normalize_scores=(not getattr(args, "unnormalized", False)), + len_penalty=getattr(args, "lenpen", 1), + unk_penalty=getattr(args, "unkpen", 0), + temperature=getattr(args, "temperature", 1.0), + match_source_len=getattr(args, "match_source_len", False), + no_repeat_ngram_size=getattr(args, "no_repeat_ngram_size", 0), + search_strategy=search_strategy, + **extra_gen_cls_kwargs, + ) + + def train_step( + self, sample, model, criterion, optimizer, update_num, ignore_grad=False + ): + """ + Do forward and backward, and return the loss as computed by *criterion* + for the given *model* and *sample*. + + Args: + sample (dict): the mini-batch. The format is defined by the + :class:`~fairseq.data.FairseqDataset`. + model (~fairseq.models.BaseFairseqModel): the model + criterion (~fairseq.criterions.FairseqCriterion): the criterion + optimizer (~fairseq.optim.FairseqOptimizer): the optimizer + update_num (int): the current update + ignore_grad (bool): multiply loss by 0 if this is set to True + + Returns: + tuple: + - the loss + - the sample size, which is used as the denominator for the + gradient + - logging outputs to display while training + """ + model.train() + model.set_num_updates(update_num) + with torch.autograd.profiler.record_function("forward"): + with torch.cuda.amp.autocast(enabled=(isinstance(optimizer, AMPOptimizer))): + loss, sample_size, logging_output = criterion(model, sample) + if ignore_grad: + loss *= 0 + with torch.autograd.profiler.record_function("backward"): + optimizer.backward(loss) + return loss, sample_size, logging_output + + def valid_step(self, sample, model, criterion): + model.eval() + with torch.no_grad(): + loss, sample_size, logging_output = criterion(model, sample) + return loss, sample_size, logging_output + + def optimizer_step(self, optimizer, model, update_num): + optimizer.step() + + def build_dataset_for_inference( + self, src_tokens: List[torch.Tensor], src_lengths: List[int], **kwargs + ) -> torch.utils.data.Dataset: + raise NotImplementedError + + def inference_step( + self, generator, models, sample, prefix_tokens=None, constraints=None + ): + with torch.no_grad(): + return generator.generate( + models, sample, prefix_tokens=prefix_tokens, constraints=constraints + ) + + def begin_epoch(self, epoch, model): + """Hook function called before the start of each epoch.""" + pass + + def begin_valid_epoch(self, epoch, model): + """Hook function called before the start of each validation epoch.""" + pass + + def aggregate_logging_outputs(self, logging_outputs, criterion): + """[deprecated] Aggregate logging outputs from data parallel training.""" + utils.deprecation_warning( + "The aggregate_logging_outputs API is deprecated. " + "Please use the reduce_metrics API instead." + ) + with metrics.aggregate() as agg: + self.reduce_metrics(logging_outputs, criterion) + return agg.get_smoothed_values() + + def reduce_metrics(self, logging_outputs, criterion): + """Aggregate logging outputs from data parallel training.""" + # backward compatibility for tasks that override aggregate_logging_outputs + base_func = FairseqTask.aggregate_logging_outputs + self_func = getattr(self, "aggregate_logging_outputs").__func__ + if self_func is not base_func: + utils.deprecation_warning( + "Tasks should implement the reduce_metrics API. " + "Falling back to deprecated aggregate_logging_outputs API." + ) + agg_logging_outputs = self.aggregate_logging_outputs( + logging_outputs, criterion + ) + for k, v in agg_logging_outputs.items(): + metrics.log_scalar(k, v) + return + + if not any("ntokens" in log for log in logging_outputs): + warnings.warn( + "ntokens not found in Criterion logging outputs, cannot log wpb or wps" + ) + else: + ntokens = sum(log.get("ntokens", 0) for log in logging_outputs) + metrics.log_scalar("wpb", ntokens, priority=180, round=1) + metrics.log_speed("wps", ntokens, priority=90, round=1) + + if not any("nsentences" in log for log in logging_outputs): + warnings.warn( + "nsentences not found in Criterion logging outputs, cannot log bsz" + ) + else: + nsentences = sum(log.get("nsentences", 0) for log in logging_outputs) + metrics.log_scalar("bsz", nsentences, priority=190, round=1) + + criterion.__class__.reduce_metrics(logging_outputs) + + def state_dict(self): + if self.state is not None: + return self.state.state_dict + return {} + + def load_state_dict(self, state_dict: Dict[str, Any]): + if self.state is not None: + self.state.merge_state_dict(state_dict) + + def max_positions(self): + """Return the max input length allowed by the task.""" + return None + + @property + def source_dictionary(self): + """Return the source :class:`~fairseq.data.Dictionary` (if applicable + for this task).""" + return None + + @property + def target_dictionary(self): + """Return the target :class:`~fairseq.data.Dictionary` (if applicable + for this task).""" + return None + + def build_tokenizer(self, args): + """Build the pre-tokenizer for this task.""" + return encoders.build_tokenizer(args) + + def build_bpe(self, args): + """Build the tokenizer for this task.""" + return encoders.build_bpe(args) + + def get_interactive_tokens_and_lengths(self, lines, encode_fn): + tokens = [ + self.source_dictionary.encode_line( + encode_fn(src_str), add_if_not_exist=False + ).long() + for src_str in lines + ] + lengths = [t.numel() for t in tokens] + return tokens, lengths + + +class LegacyFairseqTask(FairseqTask): + def __init__(self, args: Namespace): + super().__init__(None) + self.args = args + self.datasets = {} + self.dataset_to_epoch_iter = {} + + @classmethod + def setup_task(cls, args: Namespace, **kwargs): + """Setup the task (e.g., load dictionaries). + + Args: + args (argparse.Namespace): parsed command-line arguments + """ + return cls(args, **kwargs) + + def has_sharded_data(self, split): + return os.pathsep in getattr(self.args, "data", "") + + def build_model(self, args: Namespace, from_checkpoint=False): + """ + Build the :class:`~fairseq.models.BaseFairseqModel` instance for this + task. + + Args: + args (argparse.Namespace): parsed command-line arguments + + Returns: + a :class:`~fairseq.models.BaseFairseqModel` instance + """ + from fairseq import models, quantization_utils + + model = models.build_model(args, self, from_checkpoint) + model = quantization_utils.quantize_model_scalar(model, args) + return model + + def build_criterion(self, args: Namespace): + """ + Build the :class:`~fairseq.criterions.FairseqCriterion` instance for + this task. + + Args: + args (argparse.Namespace): parsed command-line arguments + + Returns: + a :class:`~fairseq.criterions.FairseqCriterion` instance + """ + from fairseq import criterions + + return criterions.build_criterion(args, self) diff --git a/fairseq/fairseq/tasks/frm_text_to_speech.py b/fairseq/fairseq/tasks/frm_text_to_speech.py new file mode 100644 index 0000000000000000000000000000000000000000..667f5f8ee45694a992a44a8c975e98f34c5b207a --- /dev/null +++ b/fairseq/fairseq/tasks/frm_text_to_speech.py @@ -0,0 +1,55 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import logging + +from fairseq.data.audio.frm_text_to_speech_dataset import FrmTextToSpeechDatasetCreator +from fairseq.tasks import register_task +from fairseq.tasks.text_to_speech import TextToSpeechTask + + +logging.basicConfig( + format="%(asctime)s | %(levelname)s | %(name)s | %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + level=logging.INFO, +) +logger = logging.getLogger(__name__) + + +@register_task("frm_text_to_speech") +class FrmTextToSpeechTask(TextToSpeechTask): + @staticmethod + def add_args(parser): + TextToSpeechTask.add_args(parser) + parser.add_argument("--do_chunk", action="store_true", help="train on chunks") + parser.add_argument("--chunk_bound", default=-1, type=int) + parser.add_argument("--chunk_init", default=50, type=int) + parser.add_argument("--chunk_incr", default=5, type=int) + parser.add_argument("--add_eos", action="store_true") + parser.add_argument("--dedup", action="store_true") + parser.add_argument("--ref_fpu", default=-1, type=float) + + def load_dataset(self, split, **unused_kwargs): + is_train_split = split.startswith("train") + pre_tokenizer = self.build_tokenizer(self.args) + bpe_tokenizer = self.build_bpe(self.args) + self.datasets[split] = FrmTextToSpeechDatasetCreator.from_tsv( + self.args.data, + self.data_cfg, + split, + self.src_dict, + pre_tokenizer, + bpe_tokenizer, + is_train_split=is_train_split, + n_frames_per_step=self.args.n_frames_per_step, + speaker_to_id=self.speaker_to_id, + do_chunk=self.args.do_chunk, + chunk_bound=self.args.chunk_bound, + chunk_init=self.args.chunk_init, + chunk_incr=self.args.chunk_incr, + add_eos=self.args.add_eos, + dedup=self.args.dedup, + ref_fpu=self.args.ref_fpu, + ) diff --git a/fairseq/fairseq/tasks/hubert_pretraining.py b/fairseq/fairseq/tasks/hubert_pretraining.py new file mode 100644 index 0000000000000000000000000000000000000000..1a3605f14dec1a9d1893cc03891991235b78199c --- /dev/null +++ b/fairseq/fairseq/tasks/hubert_pretraining.py @@ -0,0 +1,191 @@ +# Copyright (c) 2017-present, Facebook, Inc. +# All rights reserved. +# +# This source code is licensed under the license found in the LICENSE file in +# the root directory of this source tree. An additional grant of patent rights +# can be found in the PATENTS file in the same directory. + +import logging +import os +import sys +from typing import Dict, List, Optional, Tuple + +import numpy as np + +from dataclasses import dataclass, field +from fairseq.data import Dictionary, HubertDataset +from fairseq.dataclass.configs import FairseqDataclass +from fairseq.tasks import register_task +from fairseq.tasks.fairseq_task import FairseqTask +from omegaconf import MISSING + +logger = logging.getLogger(__name__) + + +class LabelEncoder(object): + def __init__(self, dictionary: Dictionary) -> None: + self.dictionary = dictionary + + def __call__(self, label: str) -> List[str]: + return self.dictionary.encode_line( + label, + append_eos=False, + add_if_not_exist=False, + ) + + +@dataclass +class HubertPretrainingConfig(FairseqDataclass): + data: str = field(default=MISSING, metadata={"help": "path to data directory"}) + fine_tuning: bool = field( + default=False, metadata={"help": "set to true if fine-tuning Hubert"} + ) + labels: List[str] = field( + default_factory=lambda: ["ltr"], + metadata={ + "help": ( + "extension of the label files to load, frame-level labels for" + " pre-training, and sequence-level label for fine-tuning" + ) + }, + ) + label_dir: Optional[str] = field( + default=None, + metadata={ + "help": "if set, looks for labels in this directory instead", + }, + ) + label_rate: float = field( + default=-1.0, + metadata={"help": "label frame rate. -1.0 for sequence label"}, + ) + sample_rate: int = field( + default=16_000, + metadata={ + "help": "target sample rate. audio files will be up/down " + "sampled to this rate" + }, + ) + normalize: bool = field( + default=False, + metadata={"help": "if set, normalizes input to have 0 mean and unit variance"}, + ) + enable_padding: bool = field( + default=False, + metadata={"help": "pad shorter samples instead of cropping"}, + ) + max_keep_size: Optional[int] = field( + default=None, + metadata={"help": "exclude sample longer than this"}, + ) + max_sample_size: Optional[int] = field( + default=None, + metadata={"help": "max sample size to crop to for batching"}, + ) + min_sample_size: Optional[int] = field( + default=None, + metadata={"help": "min sample size to crop to for batching"}, + ) + single_target: Optional[bool] = field( + default=False, + metadata={ + "help": "if set, AddTargetDatasets outputs same keys " "as AddTargetDataset" + }, + ) + random_crop: Optional[bool] = field( + default=True, + metadata={"help": "always crop from the beginning if false"}, + ) + pad_audio: Optional[bool] = field( + default=False, + metadata={"help": "pad audio to the longest one in the batch if true"}, + ) + + +@register_task("hubert_pretraining", dataclass=HubertPretrainingConfig) +class HubertPretrainingTask(FairseqTask): + + cfg: HubertPretrainingConfig + + def __init__( + self, + cfg: HubertPretrainingConfig, + ) -> None: + super().__init__(cfg) + + logger.info(f"current directory is {os.getcwd()}") + logger.info(f"HubertPretrainingTask Config {cfg}") + + self.cfg = cfg + self.fine_tuning = cfg.fine_tuning + + if cfg.fine_tuning: + self.state.add_factory("target_dictionary", self.load_dictionaries) + else: + self.state.add_factory("dictionaries", self.load_dictionaries) + + self.blank_symbol = "" + + @property + def source_dictionary(self) -> Optional[Dictionary]: + return None + + @property + def target_dictionary(self) -> Optional[Dictionary]: + return self.state.target_dictionary + + @property + def dictionaries(self) -> List[Dictionary]: + return self.state.dictionaries + + @classmethod + def setup_task( + cls, cfg: HubertPretrainingConfig, **kwargs + ) -> "HubertPretrainingTask": + return cls(cfg) + + def load_dictionaries(self): + label_dir = self.cfg.data if self.cfg.label_dir is None else self.cfg.label_dir + dictionaries = [ + Dictionary.load(f"{label_dir}/dict.{label}.txt") + for label in self.cfg.labels + ] + return dictionaries[0] if self.cfg.fine_tuning else dictionaries + + def get_label_dir(self) -> str: + if self.cfg.label_dir is None: + return self.cfg.data + return self.cfg.label_dir + + def load_dataset(self, split: str, **kwargs) -> None: + manifest = f"{self.cfg.data}/{split}.tsv" + dicts = [self.target_dictionary] if self.cfg.fine_tuning else self.dictionaries + pad_list = [dict.pad() for dict in dicts] + eos_list = [dict.eos() for dict in dicts] + procs = [LabelEncoder(dict) for dict in dicts] + paths = [f"{self.get_label_dir()}/{split}.{l}" for l in self.cfg.labels] + + # hubert v1: pad_audio=True, random_crop=False; + self.datasets[split] = HubertDataset( + manifest, + sample_rate=self.cfg.sample_rate, + label_paths=paths, + label_rates=self.cfg.label_rate, + pad_list=pad_list, + eos_list=eos_list, + label_processors=procs, + max_keep_sample_size=self.cfg.max_keep_size, + min_keep_sample_size=self.cfg.min_sample_size, + max_sample_size=self.cfg.max_sample_size, + pad_audio=self.cfg.pad_audio, + normalize=self.cfg.normalize, + store_labels=False, + random_crop=self.cfg.random_crop, + single_target=self.cfg.single_target, + ) + + def max_positions(self) -> Tuple[int, int]: + return (sys.maxsize, sys.maxsize) + + def filter_indices_by_size(self, indices: np.array, *args, **kwargs) -> np.array: + return indices diff --git a/fairseq/fairseq/tasks/language_modeling.py b/fairseq/fairseq/tasks/language_modeling.py new file mode 100644 index 0000000000000000000000000000000000000000..44d5324b3d9bbc131a95ab4763fccdd8884729f6 --- /dev/null +++ b/fairseq/fairseq/tasks/language_modeling.py @@ -0,0 +1,383 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import logging +import os +from dataclasses import dataclass, field +from typing import Optional + +import numpy as np +import torch +from fairseq import utils +from fairseq.data import ( + AppendTokenDataset, + Dictionary, + IdDataset, + LMContextWindowDataset, + MonolingualDataset, + NestedDictionaryDataset, + NumelDataset, + PadDataset, + PrependTokenDataset, + StripTokenDataset, + TokenBlockDataset, + TruncatedDictionary, + data_utils, +) +from fairseq.data.indexed_dataset import get_available_dataset_impl +from fairseq.data.shorten_dataset import maybe_shorten_dataset +from fairseq.dataclass import ChoiceEnum, FairseqDataclass +from fairseq.tasks import LegacyFairseqTask, register_task +from omegaconf import II + + +SAMPLE_BREAK_MODE_CHOICES = ChoiceEnum(["none", "complete", "complete_doc", "eos"]) +SHORTEN_METHOD_CHOICES = ChoiceEnum(["none", "truncate", "random_crop"]) +logger = logging.getLogger(__name__) + + +@dataclass +class LanguageModelingConfig(FairseqDataclass): + data: Optional[str] = field( + default=None, metadata={"help": "path to data directory"} + ) + sample_break_mode: SAMPLE_BREAK_MODE_CHOICES = field( + default="none", + metadata={ + "help": 'If omitted or "none", fills each sample with tokens-per-sample ' + 'tokens. If set to "complete", splits samples only at the end ' + "of sentence, but may include multiple sentences per sample. " + '"complete_doc" is similar but respects doc boundaries. ' + 'If set to "eos", includes only one sentence per sample.' + }, + ) + tokens_per_sample: int = field( + default=1024, + metadata={"help": "max number of tokens per sample for LM dataset"}, + ) + output_dictionary_size: int = field( + default=-1, metadata={"help": "limit the size of output dictionary"} + ) + self_target: bool = field(default=False, metadata={"help": "include self target"}) + future_target: bool = field( + default=False, metadata={"help": "include future target"} + ) + past_target: bool = field(default=False, metadata={"help": "include past target"}) + add_bos_token: bool = field( + default=False, metadata={"help": "prepend beginning of sentence token ()"} + ) + max_target_positions: Optional[int] = field( + default=None, metadata={"help": "max number of tokens in the target sequence"} + ) + shorten_method: SHORTEN_METHOD_CHOICES = field( + default="none", + metadata={ + "help": "if not none, shorten sequences that exceed --tokens-per-sample" + }, + ) + shorten_data_split_list: str = field( + default="", + metadata={ + "help": "comma-separated list of dataset splits to apply shortening to, " + 'e.g., "train,valid" (default: all dataset splits)' + }, + ) + pad_to_fixed_length: Optional[bool] = field( + default=False, + metadata={"help": "pad to fixed length"}, + ) + pad_to_fixed_bsz: Optional[bool] = field( + default=False, + metadata={"help": "boolean to pad to fixed batch size"}, + ) + + # TODO common vars below add to parent + seed: int = II("common.seed") + batch_size: Optional[int] = II("dataset.batch_size") + batch_size_valid: Optional[int] = II("dataset.batch_size_valid") + dataset_impl: Optional[ChoiceEnum(get_available_dataset_impl())] = II( + "dataset.dataset_impl" + ) + data_buffer_size: int = II("dataset.data_buffer_size") + tpu: bool = II("common.tpu") + use_plasma_view: bool = II("common.use_plasma_view") + plasma_path: str = II("common.plasma_path") + + +@register_task("language_modeling", dataclass=LanguageModelingConfig) +class LanguageModelingTask(LegacyFairseqTask): + """ + Train a language model. + + Args: + dictionary (~fairseq.data.Dictionary): the dictionary for the input of + the language model + output_dictionary (~fairseq.data.Dictionary): the dictionary for the + output of the language model. In most cases it will be the same as + *dictionary*, but could possibly be a more limited version of the + dictionary (if ``--output-dictionary-size`` is used). + targets (List[str]): list of the target types that the language model + should predict. Can be one of "self", "future", and "past". + Defaults to "future". + + .. note:: + + The language modeling task is compatible with :mod:`fairseq-train`, + :mod:`fairseq-generate`, :mod:`fairseq-interactive` and + :mod:`fairseq-eval-lm`. + + The language modeling task provides the following additional command-line + arguments: + + .. argparse:: + :ref: fairseq.tasks.language_modeling_parser + :prog: + """ + + def __init__(self, args, dictionary, output_dictionary=None, targets=None): + super().__init__(args) + self.dictionary = dictionary + self.output_dictionary = output_dictionary or dictionary + + if targets is None: + targets = ["future"] + self.targets = targets + + @classmethod + def setup_dictionary(cls, args, **kwargs): + dictionary = None + output_dictionary = None + if args.data: + paths = utils.split_paths(args.data) + assert len(paths) > 0 + dictionary = Dictionary.load(os.path.join(paths[0], "dict.txt")) + logger.info("dictionary: {} types".format(len(dictionary))) + output_dictionary = dictionary + if args.output_dictionary_size >= 0: + output_dictionary = TruncatedDictionary( + dictionary, args.output_dictionary_size + ) + return (dictionary, output_dictionary) + + @classmethod + def setup_task(cls, args, **kwargs): + """Setup the task (e.g., load dictionaries). + + Args: + args (argparse.Namespace): parsed command-line arguments + """ + dictionary, output_dictionary = cls.setup_dictionary(args, **kwargs) + + # upgrade old checkpoints + if getattr(args, "exclude_self_target", False): + args.self_target = False + + targets = [] + if getattr(args, "self_target", False): + targets.append("self") + if getattr(args, "future_target", False): + targets.append("future") + if getattr(args, "past_target", False): + targets.append("past") + if len(targets) == 0: + # standard language modeling + targets = ["future"] + + return cls(args, dictionary, output_dictionary, targets=targets) + + def build_model(self, args, from_checkpoint=False): + model = super().build_model(args, from_checkpoint) + for target in self.targets: + if target not in model.supported_targets: + raise ValueError( + "Unsupported language modeling target: {}".format(target) + ) + + return model + + def load_dataset( + self, split: str, epoch=1, combine=False, **kwargs + ) -> MonolingualDataset: + """Load a given dataset split. + + Args: + split (str): name of the split (e.g., train, valid, valid1, test) + """ + paths = utils.split_paths(self.args.data) + assert len(paths) > 0 + + data_path = paths[(epoch - 1) % len(paths)] + split_path = os.path.join(data_path, split) + + # each process has its own copy of the raw data (likely to be an np.memmap) + dataset = data_utils.load_indexed_dataset( + split_path, self.dictionary, self.args.dataset_impl, combine=combine + ) + if dataset is None: + raise FileNotFoundError(f"Dataset not found: {split} ({split_path})") + + dataset = maybe_shorten_dataset( + dataset, + split, + self.args.shorten_data_split_list, + self.args.shorten_method, + self.args.tokens_per_sample, + self.args.seed, + ) + dataset = TokenBlockDataset( + dataset, + dataset.sizes, + self.args.tokens_per_sample, + pad=self.dictionary.pad(), + eos=self.dictionary.eos(), + break_mode=self.args.sample_break_mode, + include_targets=True, + use_plasma_view=self.args.use_plasma_view, + split_path=split_path, + plasma_path=self.args.plasma_path, + ) + + add_eos_for_other_targets = ( + self.args.sample_break_mode is not None + and self.args.sample_break_mode != "none" + ) + fixed_pad_length = None + if self.args.pad_to_fixed_length: + fixed_pad_length = self.args.tokens_per_sample + + pad_to_bsz = None + if self.args.pad_to_fixed_bsz: + pad_to_bsz = ( + self.args.batch_size_valid if "valid" in split else self.args.batch_size + ) + + self.datasets[split] = MonolingualDataset( + dataset=dataset, + sizes=dataset.sizes, + src_vocab=self.dictionary, + tgt_vocab=self.output_dictionary, + add_eos_for_other_targets=add_eos_for_other_targets, + shuffle=True, + targets=self.targets, + add_bos_token=self.args.add_bos_token, + fixed_pad_length=fixed_pad_length, + pad_to_bsz=pad_to_bsz, + ) + + def build_dataset_for_inference(self, src_tokens, src_lengths, **kwargs): + """ + Generate batches for inference. We prepend an eos token to src_tokens + (or bos if `--add-bos-token` is set) and we append a to target. + This is convenient both for generation with a prefix and LM scoring. + """ + dataset = StripTokenDataset( + TokenBlockDataset( + src_tokens, + src_lengths, + block_size=None, # ignored for "eos" break mode + pad=self.source_dictionary.pad(), + eos=self.source_dictionary.eos(), + break_mode="eos", + ), + # remove eos from (end of) target sequence + self.source_dictionary.eos(), + ) + src_dataset = PrependTokenDataset( + dataset, + token=( + self.source_dictionary.bos() + if getattr(self.args, "add_bos_token", False) + else self.source_dictionary.eos() + ), + ) + tgt_dataset = AppendTokenDataset(dataset, token=self.source_dictionary.pad()) + return NestedDictionaryDataset( + { + "id": IdDataset(), + "net_input": { + "src_tokens": PadDataset( + src_dataset, + pad_idx=self.source_dictionary.pad(), + left_pad=False, + ), + "src_lengths": NumelDataset(src_dataset, reduce=False), + }, + "target": PadDataset( + tgt_dataset, pad_idx=self.source_dictionary.pad(), left_pad=False + ), + }, + sizes=[np.array(src_lengths)], + ) + + def inference_step( + self, generator, models, sample, prefix_tokens=None, constraints=None + ): + with torch.no_grad(): + # Generation will always be conditioned on bos_token + if getattr(self.args, "add_bos_token", False): + bos_token = self.source_dictionary.bos() + else: + bos_token = self.source_dictionary.eos() + + if constraints is not None: + raise NotImplementedError( + "Constrained decoding with the language_modeling task is not supported" + ) + + # SequenceGenerator doesn't use src_tokens directly, we need to + # pass the `prefix_tokens` argument instead + if prefix_tokens is None and sample["net_input"]["src_tokens"].nelement(): + prefix_tokens = sample["net_input"]["src_tokens"] + if prefix_tokens[:, 0].eq(bos_token).all(): + prefix_tokens = prefix_tokens[:, 1:] + + return generator.generate( + models, sample, prefix_tokens=prefix_tokens, bos_token=bos_token + ) + + def eval_lm_dataloader( + self, + dataset, + max_tokens: Optional[int] = 36000, + batch_size: Optional[int] = None, + max_positions: Optional[int] = None, + num_shards: int = 1, + shard_id: int = 0, + num_workers: int = 1, + data_buffer_size: int = 10, + # ensures that every evaluated token has access to a context of at least + # this size, if possible + context_window: int = 0, + ): + if context_window > 0: + dataset = LMContextWindowDataset( + dataset=dataset, + tokens_per_sample=self.args.tokens_per_sample, + context_window=context_window, + pad_idx=self.source_dictionary.pad(), + ) + return self.get_batch_iterator( + dataset=dataset, + max_tokens=max_tokens, + max_sentences=batch_size, + max_positions=max_positions, + ignore_invalid_inputs=True, + num_shards=num_shards, + shard_id=shard_id, + num_workers=num_workers, + data_buffer_size=data_buffer_size, + ).next_epoch_itr(shuffle=False) + + @property + def source_dictionary(self): + """Return the :class:`~fairseq.data.Dictionary` for the language + model.""" + return self.dictionary + + @property + def target_dictionary(self): + """Return the :class:`~fairseq.data.Dictionary` for the language + model.""" + return self.output_dictionary diff --git a/fairseq/fairseq/tasks/legacy_masked_lm.py b/fairseq/fairseq/tasks/legacy_masked_lm.py new file mode 100644 index 0000000000000000000000000000000000000000..975497654926b64fff6c4960f54c4e6932e7fce1 --- /dev/null +++ b/fairseq/fairseq/tasks/legacy_masked_lm.py @@ -0,0 +1,152 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import itertools +import logging +import os + +import numpy as np +from fairseq import tokenizer, utils +from fairseq.data import ConcatDataset, Dictionary, data_utils, indexed_dataset +from fairseq.data.legacy.block_pair_dataset import BlockPairDataset +from fairseq.data.legacy.masked_lm_dataset import MaskedLMDataset +from fairseq.data.legacy.masked_lm_dictionary import BertDictionary +from fairseq.tasks import LegacyFairseqTask, register_task + + +logger = logging.getLogger(__name__) + + +@register_task("legacy_masked_lm") +class LegacyMaskedLMTask(LegacyFairseqTask): + """ + Task for training Masked LM (BERT) model. + Args: + dictionary (Dictionary): the dictionary for the input of the task + """ + + @staticmethod + def add_args(parser): + """Add task-specific arguments to the parser.""" + parser.add_argument( + "data", + help="colon separated path to data directories list, \ + will be iterated upon during epochs in round-robin manner", + ) + parser.add_argument( + "--tokens-per-sample", + default=512, + type=int, + help="max number of total tokens over all segments" + " per sample for BERT dataset", + ) + parser.add_argument( + "--break-mode", default="doc", type=str, help="mode for breaking sentence" + ) + parser.add_argument("--shuffle-dataset", action="store_true", default=False) + + def __init__(self, args, dictionary): + super().__init__(args) + self.dictionary = dictionary + self.seed = args.seed + + @classmethod + def load_dictionary(cls, filename): + return BertDictionary.load(filename) + + @classmethod + def build_dictionary( + cls, filenames, workers=1, threshold=-1, nwords=-1, padding_factor=8 + ): + d = BertDictionary() + for filename in filenames: + Dictionary.add_file_to_dictionary( + filename, d, tokenizer.tokenize_line, workers + ) + d.finalize(threshold=threshold, nwords=nwords, padding_factor=padding_factor) + return d + + @property + def target_dictionary(self): + return self.dictionary + + @classmethod + def setup_task(cls, args, **kwargs): + """Setup the task.""" + paths = utils.split_paths(args.data) + assert len(paths) > 0 + dictionary = BertDictionary.load(os.path.join(paths[0], "dict.txt")) + logger.info("dictionary: {} types".format(len(dictionary))) + + return cls(args, dictionary) + + def load_dataset(self, split, epoch=1, combine=False): + """Load a given dataset split. + + Args: + split (str): name of the split (e.g., train, valid, test) + """ + loaded_datasets = [] + + paths = utils.split_paths(self.args.data) + assert len(paths) > 0 + data_path = paths[(epoch - 1) % len(paths)] + logger.info("data_path", data_path) + + for k in itertools.count(): + split_k = split + (str(k) if k > 0 else "") + path = os.path.join(data_path, split_k) + ds = indexed_dataset.make_dataset( + path, + impl=self.args.dataset_impl, + fix_lua_indexing=True, + dictionary=self.dictionary, + ) + + if ds is None: + if k > 0: + break + else: + raise FileNotFoundError( + "Dataset not found: {} ({})".format(split, data_path) + ) + + with data_utils.numpy_seed(self.seed + k): + loaded_datasets.append( + BlockPairDataset( + ds, + self.dictionary, + ds.sizes, + self.args.tokens_per_sample, + break_mode=self.args.break_mode, + doc_break_size=1, + ) + ) + + logger.info( + "{} {} {} examples".format(data_path, split_k, len(loaded_datasets[-1])) + ) + + if not combine: + break + + if len(loaded_datasets) == 1: + dataset = loaded_datasets[0] + sizes = dataset.sizes + else: + dataset = ConcatDataset(loaded_datasets) + sizes = np.concatenate([ds.sizes for ds in loaded_datasets]) + + self.datasets[split] = MaskedLMDataset( + dataset=dataset, + sizes=sizes, + vocab=self.dictionary, + pad_idx=self.dictionary.pad(), + mask_idx=self.dictionary.mask(), + classif_token_idx=self.dictionary.cls(), + sep_token_idx=self.dictionary.sep(), + shuffle=self.args.shuffle_dataset, + seed=self.seed, + ) diff --git a/fairseq/fairseq/tasks/masked_lm.py b/fairseq/fairseq/tasks/masked_lm.py new file mode 100644 index 0000000000000000000000000000000000000000..b064907a50d69f925e424a5045e71608de881f13 --- /dev/null +++ b/fairseq/fairseq/tasks/masked_lm.py @@ -0,0 +1,327 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import logging +import os +from dataclasses import dataclass, field + +import numpy as np +from omegaconf import II, MISSING, OmegaConf + +from fairseq import utils +from fairseq.data import ( + Dictionary, + IdDataset, + MaskTokensDataset, + NestedDictionaryDataset, + NumelDataset, + NumSamplesDataset, + PrependTokenDataset, + RightPadDataset, + RightPaddingMaskDataset, + SortDataset, + TokenBlockDataset, + data_utils, +) +from fairseq.data.encoders.utils import get_whole_word_mask +from fairseq.data.shorten_dataset import maybe_shorten_dataset +from fairseq.dataclass import FairseqDataclass +from fairseq.tasks import FairseqTask, register_task + +from .language_modeling import SAMPLE_BREAK_MODE_CHOICES, SHORTEN_METHOD_CHOICES + +logger = logging.getLogger(__name__) + + +@dataclass +class MaskedLMConfig(FairseqDataclass): + data: str = field( + default=MISSING, + metadata={ + "help": "colon separated path to data directories list, \ + will be iterated upon during epochs in round-robin manner" + }, + ) + sample_break_mode: SAMPLE_BREAK_MODE_CHOICES = field( + default="none", + metadata={ + "help": 'If omitted or "none", fills each sample with tokens-per-sample ' + 'tokens. If set to "complete", splits samples only at the end ' + "of sentence, but may include multiple sentences per sample. " + '"complete_doc" is similar but respects doc boundaries. ' + 'If set to "eos", includes only one sentence per sample.' + }, + ) + tokens_per_sample: int = field( + default=1024, + metadata={"help": "max number of tokens per sample for LM dataset"}, + ) + mask_prob: float = field( + default=0.15, + metadata={"help": "probability of replacing a token with mask"}, + ) + leave_unmasked_prob: float = field( + default=0.1, + metadata={"help": "probability that a masked token is unmasked"}, + ) + random_token_prob: float = field( + default=0.1, + metadata={"help": "probability of replacing a token with a random token"}, + ) + freq_weighted_replacement: bool = field( + default=False, + metadata={"help": "sample random replacement words based on word frequencies"}, + ) + mask_whole_words: bool = field( + default=False, + metadata={"help": "mask whole words; you may also want to set --bpe"}, + ) + mask_multiple_length: int = field( + default=1, + metadata={"help": "repeat the mask indices multiple times"}, + ) + mask_stdev: float = field( + default=0.0, + metadata={"help": "stdev of the mask length"}, + ) + shorten_method: SHORTEN_METHOD_CHOICES = field( + default="none", + metadata={ + "help": "if not none, shorten sequences that exceed --tokens-per-sample" + }, + ) + shorten_data_split_list: str = field( + default="", + metadata={ + "help": "comma-separated list of dataset splits to apply shortening to, " + 'e.g., "train,valid" (default: all dataset splits)' + }, + ) + seed: int = II("common.seed") + + include_target_tokens: bool = field( + default=False, + metadata={ + "help": "include target tokens in model input. this is used for data2vec" + }, + ) + include_index: bool = field( + default=True, + metadata={"help": "include index in model input. this is used for data2vec"}, + ) + skip_masking: bool = field( + default=False, + metadata={"help": "skip masking at dataset"}, + ) + # subsample_train: float = field( + # default=1, + # metadata={"help": "shorten training set for debugging"}, + # ) + d2v2_multi: bool = field( + default=False, + metadata={"help": "prepare dataset for data2vec_multi"}, + ) + + +@register_task("masked_lm", dataclass=MaskedLMConfig) +class MaskedLMTask(FairseqTask): + + cfg: MaskedLMConfig + + """Task for training masked language models (e.g., BERT, RoBERTa).""" + + def __init__(self, cfg: MaskedLMConfig, dictionary=None): + super().__init__(cfg) + self.dictionary = dictionary or self.load_dict(cfg) + + # add mask token + self.mask_idx = self.dictionary.add_symbol("") + + @classmethod + def setup_task(cls, cfg: MaskedLMConfig, **kwargs): + dictionary = cls.load_dict(cfg) + return cls(cfg, dictionary) + + @classmethod + def load_dict(cls, cfg): + paths = utils.split_paths(cfg.data) + assert len(paths) > 0 + dictionary = Dictionary.load(os.path.join(paths[0], "dict.txt")) + logger.info("dictionary: {} types".format(len(dictionary))) + return dictionary + + def _load_dataset_split(self, split, epoch, combine): + paths = utils.split_paths(self.cfg.data) + assert len(paths) > 0 + data_path = paths[(epoch - 1) % len(paths)] + split_path = os.path.join(data_path, split) + + dataset = data_utils.load_indexed_dataset( + split_path, + self.source_dictionary, + combine=combine, + ) + if dataset is None: + raise FileNotFoundError( + "Dataset not found: {} ({})".format(split, split_path) + ) + + dataset = maybe_shorten_dataset( + dataset, + split, + self.cfg.shorten_data_split_list, + self.cfg.shorten_method, + self.cfg.tokens_per_sample, + self.cfg.seed, + ) + + # create continuous blocks of tokens + dataset = TokenBlockDataset( + dataset, + dataset.sizes, + self.cfg.tokens_per_sample - 1, # one less for + pad=self.source_dictionary.pad(), + eos=self.source_dictionary.eos(), + break_mode=self.cfg.sample_break_mode, + ) + logger.info("loaded {} blocks from: {}".format(len(dataset), split_path)) + + # prepend beginning-of-sentence token (, equiv. to [CLS] in BERT) + return PrependTokenDataset(dataset, self.source_dictionary.bos()) + + def load_dataset(self, split, epoch=1, combine=False, **kwargs): + """Load a given dataset split. + + Args: + split (str): name of the split (e.g., train, valid, test) + """ + dataset = self._load_dataset_split(split, epoch, combine) + + # create masked input and targets + mask_whole_words = ( + get_whole_word_mask(self.args, self.source_dictionary) + if self.cfg.mask_whole_words + else None + ) + + src_dataset, tgt_dataset = MaskTokensDataset.apply_mask( + dataset, + self.source_dictionary, + pad_idx=self.source_dictionary.pad(), + mask_idx=self.mask_idx, + seed=self.cfg.seed, + mask_prob=self.cfg.mask_prob, + leave_unmasked_prob=self.cfg.leave_unmasked_prob, + random_token_prob=self.cfg.random_token_prob, + freq_weighted_replacement=self.cfg.freq_weighted_replacement, + mask_whole_words=mask_whole_words, + mask_multiple_length=self.cfg.mask_multiple_length, + mask_stdev=self.cfg.mask_stdev, + skip_masking=self.cfg.skip_masking, + ) + + with data_utils.numpy_seed(self.cfg.seed): + shuffle = np.random.permutation(len(src_dataset)) + + target_dataset = RightPadDataset( + tgt_dataset, + pad_idx=self.source_dictionary.pad(), + ) + + if self.cfg.d2v2_multi: + dataset = self._d2v2_multi_dataset(src_dataset) + else: + dataset = self._regular_dataset(src_dataset, target_dataset) + + self.datasets[split] = SortDataset( + dataset, sort_order=[shuffle, src_dataset.sizes] + ) + + def _regular_dataset(self, src_dataset, target_dataset): + input_dict = { + "src_tokens": RightPadDataset( + src_dataset, + pad_idx=self.source_dictionary.pad(), + ), + "src_lengths": NumelDataset(src_dataset, reduce=False), + } + if self.cfg.include_target_tokens: + input_dict["target_tokens"] = target_dataset + if self.cfg.include_index: + input_dict["src_id"] = IdDataset() + + dataset = NestedDictionaryDataset( + { + "id": IdDataset(), + "net_input": input_dict, + "target": target_dataset, + "nsentences": NumSamplesDataset(), + "ntokens": NumelDataset(src_dataset, reduce=True), + }, + sizes=[src_dataset.sizes], + ) + return dataset + + def _d2v2_multi_dataset(self, src_dataset): + input_dict = { + "source": RightPadDataset( + src_dataset, + pad_idx=self.source_dictionary.pad(), + ), + "id": IdDataset(), + "padding_mask": RightPaddingMaskDataset(src_dataset), + } + + dataset = NestedDictionaryDataset( + { + "id": IdDataset(), + "net_input": input_dict, + "nsentences": NumSamplesDataset(), + "ntokens": NumelDataset(src_dataset, reduce=True), + }, + sizes=[src_dataset.sizes], + ) + return dataset + + def build_dataset_for_inference(self, src_tokens, src_lengths, sort=True): + src_dataset = RightPadDataset( + TokenBlockDataset( + src_tokens, + src_lengths, + self.cfg.tokens_per_sample - 1, # one less for + pad=self.source_dictionary.pad(), + eos=self.source_dictionary.eos(), + break_mode="eos", + ), + pad_idx=self.source_dictionary.pad(), + ) + src_dataset = PrependTokenDataset(src_dataset, self.source_dictionary.bos()) + src_dataset = NestedDictionaryDataset( + { + "id": IdDataset(), + "net_input": { + "src_tokens": src_dataset, + "src_lengths": NumelDataset(src_dataset, reduce=False), + }, + }, + sizes=src_lengths, + ) + if sort: + src_dataset = SortDataset(src_dataset, sort_order=[src_lengths]) + return src_dataset + + @property + def source_dictionary(self): + return self.dictionary + + @property + def target_dictionary(self): + return self.dictionary + + def begin_epoch(self, epoch, model): + model.set_epoch(epoch) + + def max_positions(self): + return self.cfg.tokens_per_sample diff --git a/fairseq/fairseq/tasks/multilingual_denoising.py b/fairseq/fairseq/tasks/multilingual_denoising.py new file mode 100644 index 0000000000000000000000000000000000000000..cb5ee345542940fedc899bdaa2947994b0b59554 --- /dev/null +++ b/fairseq/fairseq/tasks/multilingual_denoising.py @@ -0,0 +1,268 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +import logging +import os +from dataclasses import dataclass, field +from typing import Optional + +import numpy as np +from omegaconf import II + +from fairseq.data import ( + AppendTokenDataset, + ConcatDataset, + DenoisingDataset, + Dictionary, + PrependTokenDataset, + ResamplingDataset, + SortDataset, + TokenBlockDataset, + data_utils, +) +from fairseq.data.encoders.utils import get_whole_word_mask +from fairseq.tasks import register_task + +from .denoising import DenoisingConfig, DenoisingTask + +logger = logging.getLogger(__name__) + + +@dataclass +class MultilingualDenoisingConfig(DenoisingConfig): + multilang_sampling_alpha: float = field( + default=1.0, + metadata={"help": "smoothing alpha for sample ratios across multiple datasets"}, + ) + add_lang_token: bool = field( + default=False, + metadata={"help": ""}, + ) + langs: Optional[str] = field( + default=None, + metadata={"help": "language ids we are considering"}, + ) + no_whole_word_mask_langs: str = field( + default="", + metadata={ + "help": "languages without spacing between words don't support whole word masking" + }, + ) + train_subset: str = II("common.train_subset") + valid_subset: str = II("common.valid_subset") + + +@register_task("multilingual_denoising", dataclass=MultilingualDenoisingConfig) +class MultilingualDenoisingTask(DenoisingTask): + + cfg: MultilingualDenoisingConfig + + @classmethod + def setup_task(cls, cfg: MultilingualDenoisingConfig, **kwargs): + """Setup the task.""" + paths = cfg.data.split(":") + assert len(paths) > 0 + dictionary = Dictionary.load(os.path.join(paths[0], "dict.txt")) + + data_path = paths[0] + if cfg.langs is None: + languages = sorted( + [ + name + for name in os.listdir(data_path) + if os.path.isdir(os.path.join(data_path, name)) + ] + ) + else: + languages = cfg.langs.split(",") + + if cfg.add_lang_token: + for lang in languages: + dictionary.add_symbol("[{}]".format(lang)) + + logger.info("dictionary: {} types".format(len(dictionary))) + if not hasattr(cfg, "shuffle_instance"): + cfg.shuffle_instance = False + return cls(cfg, dictionary) + + def __init__(self, cfg: MultilingualDenoisingConfig, dictionary): + super().__init__(cfg, dictionary) + self.dictionary = dictionary + + # add mask token + self.mask_idx = self.dictionary.add_symbol("") + self.cfg = cfg + + def _get_sample_prob(self, dataset_lens): + """ + Get smoothed sampling probability by languages. This helps low resource + languages by upsampling them. + """ + prob = dataset_lens / dataset_lens.sum() + smoothed_prob = prob**self.cfg.multilang_sampling_alpha + smoothed_prob = smoothed_prob / smoothed_prob.sum() + return smoothed_prob + + def load_dataset(self, split, epoch=1, combine=False, **kwargs): + """Load a given dataset split. + + Args: + split (str): name of the split (e.g., train, valid, test) + """ + paths = self.cfg.data.split(":") + assert len(paths) > 0 + data_path = paths[(epoch - 1) % len(paths)] + split_path = os.path.join(data_path, split) + + if self.cfg.langs is None: + languages = sorted( + [ + name + for name in os.listdir(data_path) + if os.path.isdir(os.path.join(data_path, name)) + ] + ) + else: + languages = self.cfg.langs.split(",") + for name in languages: + p = os.path.join(data_path, name) + assert os.path.exists(p), "data not found: {}".format(p) + + logger.info("Training on {0} languages: {1}".format(len(languages), languages)) + logger.info( + "Language to id mapping: ", {lang: id for id, lang in enumerate(languages)} + ) + + mask_whole_words = get_whole_word_mask(self.cfg.bpe, self.dictionary) + language_without_segmentations = self.cfg.no_whole_word_mask_langs.split(",") + lang_datasets = [] + for language in languages: + split_path = os.path.join(data_path, language, split) + + dataset = data_utils.load_indexed_dataset( + split_path, + self.source_dictionary, + self.cfg.dataset_impl, + combine=combine, + ) + if dataset is None: + raise FileNotFoundError( + "Dataset not found: {} ({})".format(split, split_path) + ) + + end_token = ( + self.source_dictionary.index("[{}]".format(language)) + if self.cfg.add_lang_token + else self.source_dictionary.eos() + ) + + # create continuous blocks of tokens + dataset = TokenBlockDataset( + dataset, + dataset.sizes, + self.cfg.tokens_per_sample - 2, # one less for + pad=self.source_dictionary.pad(), + eos=end_token, + break_mode=self.cfg.sample_break_mode, + ) + logger.info("loaded {} blocks from: {}".format(len(dataset), split_path)) + + # prepend beginning-of-sentence token (, equiv. to [CLS] in BERT) + dataset = PrependTokenDataset(dataset, self.source_dictionary.bos()) + dataset = AppendTokenDataset(dataset, end_token) + + lang_mask_whole_words = ( + mask_whole_words + if language not in language_without_segmentations + else None + ) + lang_dataset = DenoisingDataset( + dataset, + dataset.sizes, + self.dictionary, + self.mask_idx, + lang_mask_whole_words, + shuffle=self.cfg.shuffle_instance, + seed=self.cfg.seed, + mask=self.cfg.mask, + mask_random=self.cfg.mask_random, + insert=self.cfg.insert, + rotate=self.cfg.rotate, + permute_sentences=self.cfg.permute_sentences, + bpe=self.cfg.bpe, + replace_length=self.cfg.replace_length, + mask_length=self.cfg.mask_length, + poisson_lambda=self.cfg.poisson_lambda, + eos=None + if not self.cfg.add_lang_token + else self.source_dictionary.index("[{}]".format(language)), + ) + lang_datasets.append(lang_dataset) + + dataset_lengths = np.array( + [len(d) for d in lang_datasets], + dtype=float, + ) + logger.info( + "loaded total {} blocks for all languages".format( + int(dataset_lengths.sum()), + ) + ) + if split == self.cfg.train_subset: + # For train subset, additionally up or down sample languages. + sample_probs = self._get_sample_prob(dataset_lengths) + logger.info( + "Sample probability by language: {}".format( + { + lang: "{0:.4f}".format(sample_probs[id]) + for id, lang in enumerate(languages) + } + ) + ) + size_ratio = (sample_probs * dataset_lengths.sum()) / dataset_lengths + logger.info( + "Up/Down Sampling ratio by language: {}".format( + { + lang: "{0:.2f}".format(size_ratio[id]) + for id, lang in enumerate(languages) + } + ) + ) + + resampled_lang_datasets = [ + ResamplingDataset( + lang_datasets[i], + size_ratio=size_ratio[i], + seed=self.cfg.seed, + epoch=epoch, + replace=size_ratio[i] >= 1.0, + ) + for i, d in enumerate(lang_datasets) + ] + dataset = ConcatDataset( + resampled_lang_datasets, + ) + else: + dataset = ConcatDataset(lang_datasets) + lang_splits = [split] + for lang_id, lang_dataset in enumerate(lang_datasets): + split_name = split + "_" + languages[lang_id] + lang_splits.append(split_name) + self.datasets[split_name] = lang_dataset + + if split in self.cfg.valid_subset: + self.cfg.valid_subset = self.cfg.valid_subset.replace( + split, ",".join(lang_splits) + ) + + with data_utils.numpy_seed(self.cfg.seed + epoch): + shuffle = np.random.permutation(len(dataset)) + + self.datasets[split] = SortDataset( + dataset, + sort_order=[ + shuffle, + dataset.sizes, + ], + ) diff --git a/fairseq/fairseq/tasks/multilingual_language_modeling.py b/fairseq/fairseq/tasks/multilingual_language_modeling.py new file mode 100644 index 0000000000000000000000000000000000000000..8fd5e5954d6692e5772dbda63f57f5c9669a575c --- /dev/null +++ b/fairseq/fairseq/tasks/multilingual_language_modeling.py @@ -0,0 +1,627 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import logging +import os +from dataclasses import dataclass, field +from typing import Optional + +import numpy as np +import torch +from omegaconf import II + +from fairseq import utils +from fairseq.data import ( + AppendTokenDataset, + ConcatDataset, + Dictionary, + IdDataset, + LMContextWindowDataset, + MonolingualDataset, + NestedDictionaryDataset, + NumelDataset, + PadDataset, + PrependTokenDataset, + ResamplingDataset, + SortDataset, + StripTokenDataset, + TokenBlockDataset, + TruncatedDictionary, + data_utils, +) +from fairseq.data.indexed_dataset import get_available_dataset_impl +from fairseq.data.shorten_dataset import maybe_shorten_dataset +from fairseq.dataclass import ChoiceEnum, FairseqDataclass +from fairseq.tasks import LegacyFairseqTask, register_task + +SAMPLE_BREAK_MODE_CHOICES = ChoiceEnum(["none", "complete", "complete_doc", "eos"]) +SHORTEN_METHOD_CHOICES = ChoiceEnum(["none", "truncate", "random_crop"]) +logger = logging.getLogger(__name__) + + +def lang_token(lang): + return f"<{lang}>" + + +@dataclass +class MultilingualLanguageModelingConfig(FairseqDataclass): + # TODO common var add to parent + data: Optional[str] = field( + default=None, metadata={"help": "path to data directory"} + ) + sample_break_mode: SAMPLE_BREAK_MODE_CHOICES = field( + default="none", + metadata={ + "help": 'If omitted or "none", fills each sample with tokens-per-sample ' + 'tokens. If set to "complete", splits samples only at the end ' + "of sentence, but may include multiple sentences per sample. " + '"complete_doc" is similar but respects doc boundaries. ' + 'If set to "eos", includes only one sentence per sample.' + }, + ) + tokens_per_sample: int = field( + default=1024, + metadata={"help": "max number of tokens per sample for LM dataset"}, + ) + output_dictionary_size: int = field( + default=-1, metadata={"help": "limit the size of output dictionary"} + ) + self_target: bool = field(default=False, metadata={"help": "include self target"}) + future_target: bool = field( + default=False, metadata={"help": "include future target"} + ) + past_target: bool = field(default=False, metadata={"help": "include past target"}) + add_bos_token: bool = field( + default=False, metadata={"help": "prepend lang id token "} + ) + max_source_positions: Optional[int] = field( + default=None, metadata={"help": "max number of tokens in the source sequence"} + ) + max_target_positions: Optional[int] = field( + default=None, metadata={"help": "max number of tokens in the target sequence"} + ) + pad_to_fixed_length: Optional[bool] = field( + default=False, metadata={"help": "pad to fixed length"} + ) + pad_to_fixed_bsz: Optional[bool] = field( + default=False, metadata={"help": "boolean to pad to fixed batch size"} + ) + + multilang_sampling_alpha: Optional[float] = field( + default=1.0, + metadata={ + "help": "smoothing alpha for sample rations across multiple datasets" + }, + ) + + shorten_method: SHORTEN_METHOD_CHOICES = field( + default="none", + metadata={ + "help": "if not none, shorten sequences that exceed --tokens-per-sample" + }, + ) + shorten_data_split_list: str = field( + default="", + metadata={ + "help": "comma-separated list of dataset splits to apply shortening to, " + 'e.g., "train,valid" (default: all dataset splits)' + }, + ) + + langs: str = field( + default="", + metadata={ + "help": "comma-separated list of languages (default: all directories in data path)" + }, + ) + baseline_model_langs: str = field( + default="", + metadata={ + "help": "comma-separated list of languages in the baseline model (default: none)" + }, + ) + # TODO: legacy parameter kept for compatibility + baseline_model: str = field( + default="", + metadata={"help": "path to the baseline model (default: none)"}, + ) + + lang_to_offline_shard_ratio: str = field( + default="", + metadata={ + "help": "absolute path of tsv file location to indicate lang to offline shard ratio.", + }, + ) + # TODO common vars below add to parent + seed: int = II("common.seed") + dataset_impl: Optional[ChoiceEnum(get_available_dataset_impl())] = II( + "dataset.dataset_impl" + ) + data_buffer_size: int = II("dataset.data_buffer_size") + tpu: bool = II("common.tpu") + batch_size: Optional[int] = II("dataset.batch_size") + batch_size_valid: Optional[int] = II("dataset.batch_size_valid") + train_subset: str = II("common.train_subset") + valid_subset: str = II("common.valid_subset") + + +@register_task( + "multilingual_language_modeling", dataclass=MultilingualLanguageModelingConfig +) +class MultilingualLanguageModelingTask(LegacyFairseqTask): + """ + Train a language model. + + Args: + dictionary (~fairseq.data.Dictionary): the dictionary for the input of + the language model + output_dictionary (~fairseq.data.Dictionary): the dictionary for the + output of the language model. In most cases it will be the same as + *dictionary*, but could possibly be a more limited version of the + dictionary (if ``--output-dictionary-size`` is used). + targets (List[str]): list of the target types that the language model + should predict. Can be one of "self", "future", and "past". + Defaults to "future". + + .. note:: + + The language modeling task is compatible with :mod:`fairseq-train`, + :mod:`fairseq-generate`, :mod:`fairseq-interactive` and + :mod:`fairseq-eval-lm`. + + The language modeling task provides the following additional command-line + arguments: + + .. argparse:: + :ref: fairseq.tasks.language_modeling_parser + :prog: + """ + + def __init__(self, args, dictionary, output_dictionary=None, targets=None): + super().__init__(args) + self.dictionary = dictionary + self.output_dictionary = output_dictionary or dictionary + + if targets is None: + targets = ["future"] + self.targets = targets + + @staticmethod + def _get_langs(args, epoch=1): + paths = utils.split_paths(args.data) + assert len(paths) > 0 + data_path = paths[(epoch - 1) % len(paths)] + + languages = sorted( + name + for name in os.listdir(data_path) + if os.path.isdir(os.path.join(data_path, name)) + ) + if args.langs: + keep_langs = set(args.langs.split(",")) + languages = [lang for lang in languages if lang in keep_langs] + assert len(languages) == len(keep_langs) + + return languages, data_path + + @classmethod + def setup_dictionary(cls, args, **kwargs): + dictionary = None + output_dictionary = None + if args.data: + paths = utils.split_paths(args.data) + assert len(paths) > 0 + dictionary = Dictionary.load(os.path.join(paths[0], "dict.txt")) + if args.add_bos_token: + languages, _ = cls._get_langs(args) + logger.info("----------------") + for lang in languages: + dictionary.add_symbol(lang_token(lang)) + logger.info(f"add language token: {lang_token(lang)}") + logger.info("----------------") + + logger.info("dictionary: {} types".format(len(dictionary))) + output_dictionary = dictionary + if args.output_dictionary_size >= 0: + output_dictionary = TruncatedDictionary( + dictionary, args.output_dictionary_size + ) + return (dictionary, output_dictionary) + + @classmethod + def setup_task(cls, args, **kwargs): + """Setup the task (e.g., load dictionaries). + + Args: + args (argparse.Namespace): parsed command-line arguments + """ + dictionary, output_dictionary = cls.setup_dictionary(args, **kwargs) + + # upgrade old checkpoints + if hasattr(args, "exclude_self_target"): + args.self_target = not args.exclude_self_target + + targets = [] + if getattr(args, "self_target", False): + targets.append("self") + if getattr(args, "future_target", False): + targets.append("future") + if getattr(args, "past_target", False): + targets.append("past") + if len(targets) == 0: + # standard language modeling + targets = ["future"] + + return cls(args, dictionary, output_dictionary, targets=targets) + + def build_model(self, args, from_checkpoint=False): + model = super().build_model(args, from_checkpoint) + for target in self.targets: + if target not in model.supported_targets: + raise ValueError( + f"Unsupported language modeling target: {target} not in {model.supported_targets}" + ) + + return model + + def _get_sample_prob(self, dataset_lens): + """ + Get smoothed sampling porbability by languages. This helps low resource + languages by upsampling them. + """ + prob = dataset_lens / dataset_lens.sum() + smoothed_prob = prob**self.args.multilang_sampling_alpha + smoothed_prob = smoothed_prob / smoothed_prob.sum() + return smoothed_prob + + def load_dataset(self, split: str, epoch=1, combine=False, **kwargs): + """Load a given dataset split. + + Args: + split (str): name of the split (e.g., train, valid, test) + """ + languages, data_path = MultilingualLanguageModelingTask._get_langs( + self.args, epoch + ) + lang_to_offline_shard_ratio = None + if self.args.lang_to_offline_shard_ratio != "": + lang_to_offline_shard_ratio = {} + assert os.path.exists( + self.args.lang_to_offline_shard_ratio + ), "provided offline shard ratio file doesn't exist: {0}".format( + self.args.lang_to_offline_shard_ratio + ) + with open(self.args.lang_to_offline_shard_ratio) as fin: + for line in fin: + lang, ratio = line.strip().split("\t") + ratio = float(ratio) + lang_to_offline_shard_ratio[lang] = ratio + + logger.info( + "Found offline sharded ratio: %s", + lang_to_offline_shard_ratio, + ) + + if split == self.args.train_subset: + logger.info( + "Training on {0} languages: {1}".format(len(languages), languages) + ) + else: + logger.info( + "Evaluating on {0} languages: {1}".format(len(languages), languages) + ) + + tokens_per_sample = self.args.tokens_per_sample - int(self.args.add_bos_token) + + fixed_pad_length = None + if self.args.pad_to_fixed_length: + fixed_pad_length = self.args.tokens_per_sample + + pad_to_bsz = None + if self.args.pad_to_fixed_bsz: + pad_to_bsz = ( + self.args.batch_size_valid if "valid" in split else self.args.batch_size + ) + + lang_datasets = [] + for lang_id, language in enumerate(languages): + split_path = os.path.join(data_path, language, split) + dataset = data_utils.load_indexed_dataset( + split_path, self.dictionary, self.args.dataset_impl, combine=combine + ) + # print('len(dataset) =', len(dataset)) + if dataset is None: + raise FileNotFoundError( + "Dataset not found: {} ({})".format(split, split_path) + ) + + dataset = maybe_shorten_dataset( + dataset, + split, + self.args.shorten_data_split_list, + self.args.shorten_method, + tokens_per_sample, + self.args.seed, + ) + + dataset = TokenBlockDataset( + dataset, + dataset.sizes, + tokens_per_sample, + pad=self.dictionary.pad(), + eos=self.dictionary.eos(), + break_mode=self.args.sample_break_mode, + include_targets=True, + ) + + add_eos_for_other_targets = ( + self.args.sample_break_mode is not None + and self.args.sample_break_mode != "none" + ) + src_lang_idx, tgt_lang_idx = None, None + if self.args.add_bos_token: + src_lang_idx = self.dictionary.index(lang_token(language)) + tgt_lang_idx = self.output_dictionary.index(lang_token(language)) + + lang_datasets.append( + MonolingualDataset( + dataset=dataset, + sizes=dataset.sizes, + src_vocab=self.dictionary, + tgt_vocab=self.output_dictionary, + add_eos_for_other_targets=add_eos_for_other_targets, + shuffle=True, + targets=self.targets, + fixed_pad_length=fixed_pad_length, + pad_to_bsz=pad_to_bsz, + add_bos_token=self.args.add_bos_token, + src_lang_idx=src_lang_idx, + tgt_lang_idx=tgt_lang_idx, + ) + ) + + dataset_lengths = np.array( + [len(d) for d in lang_datasets], + dtype=float, + ) + logger.info( + "loaded total {} blocks for all languages".format( + dataset_lengths.sum(), + ) + ) + if split == self.args.train_subset: + dataset_lengths_ratio_multiplier = np.ones(len(dataset_lengths)) + if lang_to_offline_shard_ratio is not None: + dataset_lengths_ratio_multiplier = [] + for lang in languages: + assert ( + lang in lang_to_offline_shard_ratio + ), "Lang: {0} missing in offline shard ratio file: {1}".format( + lang, + self.args.lang_to_offline_shard_ratio, + ) + dataset_lengths_ratio_multiplier.append( + lang_to_offline_shard_ratio[lang] + ) + dataset_lengths_ratio_multiplier = np.array( + dataset_lengths_ratio_multiplier + ) + true_dataset_lengths = ( + dataset_lengths * dataset_lengths_ratio_multiplier + ) + else: + true_dataset_lengths = dataset_lengths + # For train subset, additionally up or down sample languages. + sample_probs = self._get_sample_prob(true_dataset_lengths) + + logger.info( + "Sample probability by language: %s", + { + lang: "{0:.4f}".format(sample_probs[id]) + for id, lang in enumerate(languages) + }, + ) + size_ratio = (sample_probs * true_dataset_lengths.sum()) / dataset_lengths + # TODO: add an option for shrinking all size ratios to below 1 + # if self.args.multilang_sampling_alpha != 1: + # size_ratio /= size_ratio.max() + + # Fix numeric errors in size ratio computation + # 0.999999999999999999 -> 1 + # 1.000000000000000002 -> 1 + for i in range(len(size_ratio)): + size_ratio[i] = round(size_ratio[i], 8) + + logger.info( + "Up/Down Sampling ratio by language: %s", + { + lang: "{0:.2f}".format(size_ratio[id]) + for id, lang in enumerate(languages) + }, + ) + logger.info( + "Actual dataset size by language: %s", + { + lang: "{0:.2f}".format(len(lang_datasets[id])) + for id, lang in enumerate(languages) + }, + ) + resampled_lang_datasets = [ + ResamplingDataset( + lang_datasets[i], + size_ratio=size_ratio[i], + seed=self.args.seed, + epoch=epoch, + replace=size_ratio[i] > 1.0, + ) + for i, d in enumerate(lang_datasets) + ] + logger.info( + "Resampled dataset size by language: %s", + { + lang: "{0:.2f}".format(len(resampled_lang_datasets[id])) + for id, lang in enumerate(languages) + }, + ) + dataset = ConcatDataset(resampled_lang_datasets) + else: + dataset = ConcatDataset(lang_datasets) + lang_splits = [split] + for lang_id, lang_dataset in enumerate(lang_datasets): + split_name = split + "_" + languages[lang_id] + lang_splits.append(split_name) + self.datasets[split_name] = lang_dataset + + # [TODO]: This is hacky for now to print validation ppl for each + # language individually. Maybe need task API changes to allow it + # in more generic ways. + if split in self.args.valid_subset: + self.args.valid_subset = self.args.valid_subset.replace( + split, ",".join(lang_splits) + ) + + with data_utils.numpy_seed(self.args.seed + epoch): + shuffle = np.random.permutation(len(dataset)) + + self.datasets[split] = SortDataset( + dataset, + sort_order=[ + shuffle, + dataset.sizes, + ], + ) + + def build_dataset_for_inference( + self, src_tokens, src_lengths, language="en_XX", **kwargs + ): + """ + Generate batches for inference. We prepend an eos token to src_tokens + (or bos if `--add-bos-token` is set) and we append a to target. + This is convenient both for generation with a prefix and LM scoring. + """ + dataset = StripTokenDataset( + TokenBlockDataset( + src_tokens, + src_lengths, + block_size=None, # ignored for "eos" break mode + pad=self.source_dictionary.pad(), + eos=self.source_dictionary.eos(), + break_mode="eos", + ), + # remove eos from (end of) target sequence + self.source_dictionary.eos(), + ) + + src_lang_idx = self.dictionary.index(lang_token(language)) + src_dataset = PrependTokenDataset( + dataset, + token=( + (src_lang_idx or self.source_dictionary.bos()) + if getattr(self.args, "add_bos_token", False) + else self.source_dictionary.eos() + ), + ) + + max_seq_len = max(src_lengths) + 1 + tgt_dataset = AppendTokenDataset(dataset, token=self.source_dictionary.pad()) + return NestedDictionaryDataset( + { + "id": IdDataset(), + "net_input": { + "src_tokens": PadDataset( + src_dataset, + pad_idx=self.source_dictionary.pad(), + left_pad=False, + pad_length=max_seq_len, + ), + "src_lengths": NumelDataset(src_dataset, reduce=False), + }, + "target": PadDataset( + tgt_dataset, + pad_idx=self.source_dictionary.pad(), + left_pad=False, + pad_length=max_seq_len, + ), + }, + sizes=[np.array(src_lengths)], + ) + + @torch.no_grad() + def inference_step( + self, + generator, + models, + sample, + language="en_XX", + prefix_tokens=None, + constraints=None, + ): + # Generation will always be conditioned on bos_token + if getattr(self.args, "add_bos_token", False): + src_lang_idx = self.dictionary.index(lang_token(language)) + bos_token = src_lang_idx or self.source_dictionary.bos() + else: + bos_token = self.source_dictionary.eos() + + if constraints is not None: + raise NotImplementedError( + "Constrained decoding with the language_modeling task is not supported" + ) + + # SequenceGenerator doesn't use src_tokens directly, we need to + # pass the `prefix_tokens` argument instead + if prefix_tokens is None and sample["net_input"]["src_tokens"].nelement(): + prefix_tokens = sample["net_input"]["src_tokens"] + if prefix_tokens[:, 0].eq(bos_token).all(): + prefix_tokens = prefix_tokens[:, 1:] + + return generator.generate( + models, sample, prefix_tokens=prefix_tokens, bos_token=bos_token + ) + + def eval_lm_dataloader( + self, + dataset, + max_tokens: Optional[int] = 36000, + batch_size: Optional[int] = None, + max_positions: Optional[int] = None, + num_shards: int = 1, + shard_id: int = 0, + num_workers: int = 1, + data_buffer_size: int = 10, + # ensures that every evaluated token has access to a context of at least + # this size, if possible + context_window: int = 0, + ): + if context_window > 0: + dataset = LMContextWindowDataset( + dataset=dataset, + tokens_per_sample=self.args.tokens_per_sample, + context_window=context_window, + pad_idx=self.source_dictionary.pad(), + ) + return self.get_batch_iterator( + dataset=dataset, + max_tokens=max_tokens, + max_sentences=batch_size, + max_positions=max_positions, + ignore_invalid_inputs=True, + num_shards=num_shards, + shard_id=shard_id, + num_workers=num_workers, + data_buffer_size=data_buffer_size, + ) + + @property + def source_dictionary(self): + """Return the :class:`~fairseq.data.Dictionary` for the language + model.""" + return self.dictionary + + @property + def target_dictionary(self): + """Return the :class:`~fairseq.data.Dictionary` for the language + model.""" + return self.output_dictionary diff --git a/fairseq/fairseq/tasks/multilingual_masked_lm.py b/fairseq/fairseq/tasks/multilingual_masked_lm.py new file mode 100644 index 0000000000000000000000000000000000000000..156d085aa4e65292d28f0d5c8e3fefd0e7d60e17 --- /dev/null +++ b/fairseq/fairseq/tasks/multilingual_masked_lm.py @@ -0,0 +1,338 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import logging +import os + +import numpy as np +import torch + +from fairseq import utils +from fairseq.data import ( + ConcatDataset, + Dictionary, + IdDataset, + MaskTokensDataset, + NestedDictionaryDataset, + NumelDataset, + NumSamplesDataset, + PadDataset, + PrependTokenDataset, + RawLabelDataset, + ResamplingDataset, + SortDataset, + TokenBlockDataset, + data_utils, + encoders, +) +from fairseq.tasks import LegacyFairseqTask, register_task + +logger = logging.getLogger(__name__) + + +@register_task("multilingual_masked_lm") +class MultiLingualMaskedLMTask(LegacyFairseqTask): + """Task for training masked language models (e.g., BERT, RoBERTa).""" + + @staticmethod + def add_args(parser): + """Add task-specific arguments to the parser.""" + parser.add_argument( + "data", + help="colon separated path to data directories list, \ + will be iterated upon during epochs in round-robin manner", + ) + parser.add_argument( + "--sample-break-mode", + default="complete", + choices=["none", "complete", "complete_doc", "eos"], + help='If omitted or "none", fills each sample with tokens-per-sample ' + 'tokens. If set to "complete", splits samples only at the end ' + "of sentence, but may include multiple sentences per sample. " + '"complete_doc" is similar but respects doc boundaries. ' + 'If set to "eos", includes only one sentence per sample.', + ) + parser.add_argument( + "--tokens-per-sample", + default=512, + type=int, + help="max number of total tokens over all segments " + "per sample for BERT dataset", + ) + parser.add_argument( + "--mask-prob", + default=0.15, + type=float, + help="probability of replacing a token with mask", + ) + parser.add_argument( + "--leave-unmasked-prob", + default=0.1, + type=float, + help="probability that a masked token is unmasked", + ) + parser.add_argument( + "--random-token-prob", + default=0.1, + type=float, + help="probability of replacing a token with a random token", + ) + parser.add_argument( + "--freq-weighted-replacement", + action="store_true", + help="sample random replacement words based on word frequencies", + ) + parser.add_argument( + "--mask-whole-words", + default=False, + action="store_true", + help="mask whole words; you may also want to set --bpe", + ) + parser.add_argument( + "--multilang-sampling-alpha", + type=float, + default=1.0, + help="smoothing alpha for sample rations across multiple datasets", + ) + + def __init__(self, args, dictionary): + super().__init__(args) + self.dictionary = dictionary + self.seed = args.seed + + # add mask token + self.mask_idx = dictionary.add_symbol("") + + @classmethod + def setup_task(cls, args, **kwargs): + paths = utils.split_paths(args.data) + assert len(paths) > 0 + dictionary = Dictionary.load(os.path.join(paths[0], "dict.txt")) + logger.info("dictionary: {} types".format(len(dictionary))) + return cls(args, dictionary) + + def _get_whole_word_mask(self): + # create masked input and targets + if self.args.mask_whole_words: + bpe = encoders.build_bpe(self.args) + if bpe is not None: + + def is_beginning_of_word(i): + if i < self.source_dictionary.nspecial: + # special elements are always considered beginnings + return True + tok = self.source_dictionary[i] + if tok.startswith("madeupword"): + return True + try: + return bpe.is_beginning_of_word(tok) + except ValueError: + return True + + mask_whole_words = torch.ByteTensor( + list(map(is_beginning_of_word, range(len(self.source_dictionary)))) + ) + else: + mask_whole_words = None + return mask_whole_words + + def _get_sample_prob(self, dataset_lens): + """ + Get smoothed sampling porbability by languages. This helps low resource + languages by upsampling them. + """ + prob = dataset_lens / dataset_lens.sum() + smoothed_prob = prob**self.args.multilang_sampling_alpha + smoothed_prob = smoothed_prob / smoothed_prob.sum() + return smoothed_prob + + def load_dataset(self, split, epoch=1, combine=False, **kwargs): + """Load a given dataset split. + + Args: + split (str): name of the split (e.g., train, valid, test) + """ + paths = utils.split_paths(self.args.data) + assert len(paths) > 0 + data_path = paths[(epoch - 1) % len(paths)] + + languages = sorted( + name + for name in os.listdir(data_path) + if os.path.isdir(os.path.join(data_path, name)) + ) + + logger.info("Training on {0} languages: {1}".format(len(languages), languages)) + logger.info( + "Language to id mapping: ", {lang: id for id, lang in enumerate(languages)} + ) + + mask_whole_words = self._get_whole_word_mask() + lang_datasets = [] + for lang_id, language in enumerate(languages): + split_path = os.path.join(data_path, language, split) + + dataset = data_utils.load_indexed_dataset( + split_path, + self.source_dictionary, + self.args.dataset_impl, + combine=combine, + ) + if dataset is None: + raise FileNotFoundError( + "Dataset not found: {} ({})".format(split, split_path) + ) + + # create continuous blocks of tokens + dataset = TokenBlockDataset( + dataset, + dataset.sizes, + self.args.tokens_per_sample - 1, # one less for + pad=self.source_dictionary.pad(), + eos=self.source_dictionary.eos(), + break_mode=self.args.sample_break_mode, + ) + logger.info("loaded {} blocks from: {}".format(len(dataset), split_path)) + + # prepend beginning-of-sentence token (, equiv. to [CLS] in BERT) + dataset = PrependTokenDataset(dataset, self.source_dictionary.bos()) + + src_dataset, tgt_dataset = MaskTokensDataset.apply_mask( + dataset, + self.source_dictionary, + pad_idx=self.source_dictionary.pad(), + mask_idx=self.mask_idx, + seed=self.args.seed, + mask_prob=self.args.mask_prob, + leave_unmasked_prob=self.args.leave_unmasked_prob, + random_token_prob=self.args.random_token_prob, + freq_weighted_replacement=self.args.freq_weighted_replacement, + mask_whole_words=mask_whole_words, + ) + + lang_dataset = NestedDictionaryDataset( + { + "net_input": { + "src_tokens": PadDataset( + src_dataset, + pad_idx=self.source_dictionary.pad(), + left_pad=False, + ), + "src_lengths": NumelDataset(src_dataset, reduce=False), + }, + "target": PadDataset( + tgt_dataset, + pad_idx=self.source_dictionary.pad(), + left_pad=False, + ), + "nsentences": NumSamplesDataset(), + "ntokens": NumelDataset(src_dataset, reduce=True), + "lang_id": RawLabelDataset([lang_id] * src_dataset.sizes.shape[0]), + }, + sizes=[src_dataset.sizes], + ) + lang_datasets.append(lang_dataset) + + dataset_lengths = np.array( + [len(d) for d in lang_datasets], + dtype=float, + ) + logger.info( + "loaded total {} blocks for all languages".format( + dataset_lengths.sum(), + ) + ) + if split == self.args.train_subset: + # For train subset, additionally up or down sample languages. + sample_probs = self._get_sample_prob(dataset_lengths) + logger.info( + "Sample probability by language: ", + { + lang: "{0:.4f}".format(sample_probs[id]) + for id, lang in enumerate(languages) + }, + ) + size_ratio = (sample_probs * dataset_lengths.sum()) / dataset_lengths + logger.info( + "Up/Down Sampling ratio by language: ", + { + lang: "{0:.2f}".format(size_ratio[id]) + for id, lang in enumerate(languages) + }, + ) + + resampled_lang_datasets = [ + ResamplingDataset( + lang_datasets[i], + size_ratio=size_ratio[i], + seed=self.args.seed, + epoch=epoch, + replace=size_ratio[i] >= 1.0, + ) + for i, d in enumerate(lang_datasets) + ] + dataset = ConcatDataset(resampled_lang_datasets) + else: + dataset = ConcatDataset(lang_datasets) + lang_splits = [split] + for lang_id, lang_dataset in enumerate(lang_datasets): + split_name = split + "_" + languages[lang_id] + lang_splits.append(split_name) + self.datasets[split_name] = lang_dataset + + # [TODO]: This is hacky for now to print validation ppl for each + # language individually. Maybe need task API changes to allow it + # in more generic ways. + if split in self.args.valid_subset: + self.args.valid_subset = self.args.valid_subset.replace( + split, ",".join(lang_splits) + ) + + with data_utils.numpy_seed(self.args.seed + epoch): + shuffle = np.random.permutation(len(dataset)) + + self.datasets[split] = SortDataset( + dataset, + sort_order=[ + shuffle, + dataset.sizes, + ], + ) + + def build_dataset_for_inference(self, src_tokens, src_lengths, sort=True): + src_dataset = PadDataset( + TokenBlockDataset( + src_tokens, + src_lengths, + self.args.tokens_per_sample - 1, # one less for + pad=self.source_dictionary.pad(), + eos=self.source_dictionary.eos(), + break_mode="eos", + ), + pad_idx=self.source_dictionary.pad(), + left_pad=False, + ) + src_dataset = PrependTokenDataset(src_dataset, self.source_dictionary.bos()) + src_dataset = NestedDictionaryDataset( + { + "id": IdDataset(), + "net_input": { + "src_tokens": src_dataset, + "src_lengths": NumelDataset(src_dataset, reduce=False), + }, + }, + sizes=src_lengths, + ) + if sort: + src_dataset = SortDataset(src_dataset, sort_order=[src_lengths]) + return src_dataset + + @property + def source_dictionary(self): + return self.dictionary + + @property + def target_dictionary(self): + return self.dictionary diff --git a/fairseq/fairseq/tasks/multilingual_translation.py b/fairseq/fairseq/tasks/multilingual_translation.py new file mode 100644 index 0000000000000000000000000000000000000000..cef7656691aeb8042e42d24eccd328a85fe32e27 --- /dev/null +++ b/fairseq/fairseq/tasks/multilingual_translation.py @@ -0,0 +1,463 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import contextlib +import logging +import os +from collections import OrderedDict +from argparse import ArgumentError + +import torch +from fairseq import options, utils +from fairseq.logging import metrics +from fairseq.data import ( + Dictionary, + LanguagePairDataset, + RoundRobinZipDatasets, + TransformEosLangPairDataset, +) +from fairseq.models import FairseqMultiModel +from fairseq.tasks.translation import load_langpair_dataset + +from . import LegacyFairseqTask, register_task + + +logger = logging.getLogger(__name__) + + +def _lang_token(lang: str): + return "__{}__".format(lang) + + +def _lang_token_index(dic: Dictionary, lang: str): + """Return language token index.""" + idx = dic.index(_lang_token(lang)) + assert idx != dic.unk_index, "cannot find language token for lang {}".format(lang) + return idx + + +@register_task("multilingual_translation") +class MultilingualTranslationTask(LegacyFairseqTask): + """A task for training multiple translation models simultaneously. + + We iterate round-robin over batches from multiple language pairs, ordered + according to the `--lang-pairs` argument. + + The training loop is roughly: + + for i in range(len(epoch)): + for lang_pair in args.lang_pairs: + batch = next_batch_for_lang_pair(lang_pair) + loss = criterion(model_for_lang_pair(lang_pair), batch) + loss.backward() + optimizer.step() + + In practice, `next_batch_for_lang_pair` is abstracted in a FairseqDataset + (e.g., `RoundRobinZipDatasets`) and `model_for_lang_pair` is a model that + implements the `FairseqMultiModel` interface. + + During inference it is required to specify a single `--source-lang` and + `--target-lang`, which indicates the inference langauge direction. + `--lang-pairs`, `--encoder-langtok`, `--decoder-langtok` have to be set to + the same value as training. + """ + + @staticmethod + def add_args(parser): + """Add task-specific arguments to the parser.""" + # fmt: off + parser.add_argument('data', metavar='DIR', help='path to data directory') + parser.add_argument('--lang-pairs', default=None, metavar='PAIRS', + help='comma-separated list of language pairs (in training order): en-de,en-fr,de-fr') + parser.add_argument('-s', '--source-lang', default=None, metavar='SRC', + help='source language (only needed for inference)') + parser.add_argument('-t', '--target-lang', default=None, metavar='TARGET', + help='target language (only needed for inference)') + parser.add_argument('--left-pad-source', default='True', type=str, metavar='BOOL', + help='pad the source on the left (default: True)') + parser.add_argument('--left-pad-target', default='False', type=str, metavar='BOOL', + help='pad the target on the left (default: False)') + try: + parser.add_argument('--max-source-positions', default=1024, type=int, metavar='N', + help='max number of tokens in the source sequence') + parser.add_argument('--max-target-positions', default=1024, type=int, metavar='N', + help='max number of tokens in the target sequence') + except ArgumentError: + # this might have already been defined. Once we transition this to hydra it should be fine to add it here. + pass + parser.add_argument('--upsample-primary', default=1, type=int, + help='amount to upsample primary dataset') + parser.add_argument('--encoder-langtok', default=None, type=str, choices=['src', 'tgt'], + metavar='SRCTGT', + help='replace beginning-of-sentence in source sentence with source or target ' + 'language token. (src/tgt)') + parser.add_argument('--decoder-langtok', action='store_true', + help='replace beginning-of-sentence in target sentence with target language token') + # fmt: on + + def __init__(self, args, dicts, training): + super().__init__(args) + self.dicts = dicts + self.training = training + if training: + self.lang_pairs = args.lang_pairs + else: + self.lang_pairs = ["{}-{}".format(args.source_lang, args.target_lang)] + # eval_lang_pairs for multilingual translation is usually all of the + # lang_pairs. However for other multitask settings or when we want to + # optimize for certain languages we want to use a different subset. Thus + # the eval_lang_pairs class variable is provided for classes that extend + # this class. + self.eval_lang_pairs = self.lang_pairs + # model_lang_pairs will be used to build encoder-decoder model pairs in + # models.build_model(). This allows multitask type of sub-class can + # build models other than the input lang_pairs + self.model_lang_pairs = self.lang_pairs + self.langs = list(dicts.keys()) + + @classmethod + def setup_task(cls, args, **kwargs): + dicts, training = cls.prepare(args, **kwargs) + return cls(args, dicts, training) + + @classmethod + def update_args(cls, args): + args.left_pad_source = utils.eval_bool(args.left_pad_source) + args.left_pad_target = utils.eval_bool(args.left_pad_target) + + if args.lang_pairs is None: + raise ValueError( + "--lang-pairs is required. List all the language pairs in the training objective." + ) + if isinstance(args.lang_pairs, str): + args.lang_pairs = args.lang_pairs.split(",") + + @classmethod + def prepare(cls, args, **kargs): + cls.update_args(args) + sorted_langs = sorted( + list({x for lang_pair in args.lang_pairs for x in lang_pair.split("-")}) + ) + if args.source_lang is not None or args.target_lang is not None: + training = False + else: + training = True + + # load dictionaries + dicts = OrderedDict() + for lang in sorted_langs: + paths = utils.split_paths(args.data) + assert len(paths) > 0 + dicts[lang] = cls.load_dictionary( + os.path.join(paths[0], "dict.{}.txt".format(lang)) + ) + if len(dicts) > 0: + assert dicts[lang].pad() == dicts[sorted_langs[0]].pad() + assert dicts[lang].eos() == dicts[sorted_langs[0]].eos() + assert dicts[lang].unk() == dicts[sorted_langs[0]].unk() + if args.encoder_langtok is not None or args.decoder_langtok: + for lang_to_add in sorted_langs: + dicts[lang].add_symbol(_lang_token(lang_to_add)) + logger.info("[{}] dictionary: {} types".format(lang, len(dicts[lang]))) + return dicts, training + + def get_encoder_langtok(self, src_lang, tgt_lang): + if self.args.encoder_langtok is None: + return self.dicts[src_lang].eos() + if self.args.encoder_langtok == "src": + return _lang_token_index(self.dicts[src_lang], src_lang) + else: + return _lang_token_index(self.dicts[src_lang], tgt_lang) + + def get_decoder_langtok(self, tgt_lang): + if not self.args.decoder_langtok: + return self.dicts[tgt_lang].eos() + return _lang_token_index(self.dicts[tgt_lang], tgt_lang) + + def alter_dataset_langtok( + self, + lang_pair_dataset, + src_eos=None, + src_lang=None, + tgt_eos=None, + tgt_lang=None, + ): + if self.args.encoder_langtok is None and not self.args.decoder_langtok: + return lang_pair_dataset + + new_src_eos = None + if ( + self.args.encoder_langtok is not None + and src_eos is not None + and src_lang is not None + and tgt_lang is not None + ): + new_src_eos = self.get_encoder_langtok(src_lang, tgt_lang) + else: + src_eos = None + + new_tgt_bos = None + if self.args.decoder_langtok and tgt_eos is not None and tgt_lang is not None: + new_tgt_bos = self.get_decoder_langtok(tgt_lang) + else: + tgt_eos = None + + return TransformEosLangPairDataset( + lang_pair_dataset, + src_eos=src_eos, + new_src_eos=new_src_eos, + tgt_bos=tgt_eos, + new_tgt_bos=new_tgt_bos, + ) + + def load_dataset(self, split, epoch=1, **kwargs): + """Load a dataset split.""" + paths = utils.split_paths(self.args.data) + assert len(paths) > 0 + data_path = paths[(epoch - 1) % len(paths)] + + def language_pair_dataset(lang_pair): + src, tgt = lang_pair.split("-") + langpair_dataset = load_langpair_dataset( + data_path, + split, + src, + self.dicts[src], + tgt, + self.dicts[tgt], + combine=True, + dataset_impl=self.args.dataset_impl, + upsample_primary=self.args.upsample_primary, + left_pad_source=self.args.left_pad_source, + left_pad_target=self.args.left_pad_target, + max_source_positions=self.args.max_source_positions, + max_target_positions=self.args.max_target_positions, + ) + return self.alter_dataset_langtok( + langpair_dataset, + src_eos=self.dicts[src].eos(), + src_lang=src, + tgt_eos=self.dicts[tgt].eos(), + tgt_lang=tgt, + ) + + self.datasets[split] = RoundRobinZipDatasets( + OrderedDict( + [ + (lang_pair, language_pair_dataset(lang_pair)) + for lang_pair in self.lang_pairs + ] + ), + eval_key=None + if self.training + else "%s-%s" % (self.args.source_lang, self.args.target_lang), + ) + + def build_dataset_for_inference(self, src_tokens, src_lengths, constraints=None): + if constraints is not None: + raise NotImplementedError( + "Constrained decoding with the multilingual_translation task is not supported" + ) + + lang_pair = "%s-%s" % (self.args.source_lang, self.args.target_lang) + return RoundRobinZipDatasets( + OrderedDict( + [ + ( + lang_pair, + self.alter_dataset_langtok( + LanguagePairDataset( + src_tokens, src_lengths, self.source_dictionary + ), + src_eos=self.source_dictionary.eos(), + src_lang=self.args.source_lang, + tgt_eos=self.target_dictionary.eos(), + tgt_lang=self.args.target_lang, + ), + ) + ] + ), + eval_key=lang_pair, + ) + + def build_model(self, args, from_checkpoint=False): + def check_args(): + messages = [] + if ( + len(set(self.args.lang_pairs).symmetric_difference(args.lang_pairs)) + != 0 + ): + messages.append( + "--lang-pairs should include all the language pairs {}.".format( + args.lang_pairs + ) + ) + if self.args.encoder_langtok != args.encoder_langtok: + messages.append( + "--encoder-langtok should be {}.".format(args.encoder_langtok) + ) + if self.args.decoder_langtok != args.decoder_langtok: + messages.append( + "--decoder-langtok should {} be set.".format( + "" if args.decoder_langtok else "not" + ) + ) + + if len(messages) > 0: + raise ValueError(" ".join(messages)) + + # Update args -> the fact that the constructor here + # changes the args object doesn't mean you get the same one here + self.update_args(args) + + # Check if task args are consistant with model args + check_args() + + from fairseq import models + + model = models.build_model(args, self, from_checkpoint) + if not isinstance(model, FairseqMultiModel): + raise ValueError( + "MultilingualTranslationTask requires a FairseqMultiModel architecture" + ) + return model + + def _per_lang_pair_train_loss( + self, lang_pair, model, update_num, criterion, sample, optimizer, ignore_grad + ): + loss, sample_size, logging_output = criterion( + model.models[lang_pair], sample[lang_pair] + ) + if ignore_grad: + loss *= 0 + optimizer.backward(loss) + return loss, sample_size, logging_output + + def train_step( + self, sample, model, criterion, optimizer, update_num, ignore_grad=False + ): + model.train() + from collections import defaultdict + + agg_loss, agg_sample_size, agg_logging_output = 0.0, 0.0, defaultdict(float) + curr_lang_pairs = [ + lang_pair + for lang_pair in self.model_lang_pairs + if sample[lang_pair] is not None and len(sample[lang_pair]) != 0 + ] + + for idx, lang_pair in enumerate(curr_lang_pairs): + + def maybe_no_sync(): + if ( + self.args.distributed_world_size > 1 + and hasattr(model, "no_sync") + and idx < len(curr_lang_pairs) - 1 + ): + return model.no_sync() + else: + return contextlib.ExitStack() # dummy contextmanager + + with maybe_no_sync(): + loss, sample_size, logging_output = self._per_lang_pair_train_loss( + lang_pair, + model, + update_num, + criterion, + sample, + optimizer, + ignore_grad, + ) + agg_loss += loss.detach().item() + # TODO make summing of the sample sizes configurable + agg_sample_size += sample_size + for k in logging_output: + agg_logging_output[k] += logging_output[k] + agg_logging_output[f"{lang_pair}:{k}"] += logging_output[k] + return agg_loss, agg_sample_size, agg_logging_output + + def _per_lang_pair_valid_loss(self, lang_pair, model, criterion, sample): + return criterion(model.models[lang_pair], sample[lang_pair]) + + def valid_step(self, sample, model, criterion): + model.eval() + with torch.no_grad(): + from collections import defaultdict + + agg_loss, agg_sample_size, agg_logging_output = 0.0, 0.0, defaultdict(float) + for lang_pair in self.eval_lang_pairs: + if ( + lang_pair not in sample + or sample[lang_pair] is None + or len(sample[lang_pair]) == 0 + ): + continue + loss, sample_size, logging_output = self._per_lang_pair_valid_loss( + lang_pair, model, criterion, sample + ) + agg_loss += loss.data.item() + # TODO make summing of the sample sizes configurable + agg_sample_size += sample_size + for k in logging_output: + agg_logging_output[k] += logging_output[k] + agg_logging_output[f"{lang_pair}:{k}"] += logging_output[k] + return agg_loss, agg_sample_size, agg_logging_output + + def inference_step( + self, generator, models, sample, prefix_tokens=None, constraints=None + ): + with torch.no_grad(): + if self.args.decoder_langtok: + bos_token = _lang_token_index( + self.target_dictionary, self.args.target_lang + ) + else: + bos_token = self.target_dictionary.eos() + return generator.generate( + models, + sample, + prefix_tokens=prefix_tokens, + constraints=constraints, + bos_token=bos_token, + ) + + def reduce_metrics(self, logging_outputs, criterion): + with metrics.aggregate(): + # pass 'sample_size', 'nsentences', 'ntokens' stats to fairseq_task + super().reduce_metrics(logging_outputs, criterion) + for k in ["sample_size", "nsentences", "ntokens"]: + metrics.log_scalar(k, sum(l[k] for l in logging_outputs)) + + @property + def source_dictionary(self): + if self.training: + return next(iter(self.dicts.values())) + else: + return self.dicts[self.args.source_lang] + + @property + def target_dictionary(self): + if self.training: + return next(iter(self.dicts.values())) + else: + return self.dicts[self.args.target_lang] + + def max_positions(self): + """Return the max sentence length allowed by the task.""" + if len(self.datasets.values()) == 0: + return { + "%s-%s" + % (self.args.source_lang, self.args.target_lang): ( + self.args.max_source_positions, + self.args.max_target_positions, + ) + } + return OrderedDict( + [ + (key, (self.args.max_source_positions, self.args.max_target_positions)) + for split in self.datasets.keys() + for key in self.datasets[split].datasets.keys() + ] + )