diff --git a/.gitattributes b/.gitattributes index c9695f200404aeb644165b61cf026806a2ef64b2..b3da55b7d543c9d49c86c7bd90095cf53a8223cd 100644 --- a/.gitattributes +++ b/.gitattributes @@ -39,3 +39,4 @@ fairseq/alignment_train_cuda_binding.cpython-310-x86_64-linux-gnu.so filter=lfs fairseq/alignment_train_cpu_binding.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text fairseq/docs/fairseq.gif filter=lfs diff=lfs merge=lfs -text fairseq/examples/hubert/tests/6313-76958-0021.flac filter=lfs diff=lfs merge=lfs -text +fairseq/examples/textless_nlp/speech-resynth/img/fig.png filter=lfs diff=lfs merge=lfs -text diff --git a/fairseq/examples/textless_nlp/speech-resynth/img/fig.png b/fairseq/examples/textless_nlp/speech-resynth/img/fig.png new file mode 100644 index 0000000000000000000000000000000000000000..f92509bcfef84a7f5ead19f3b224ff89f9ee0a30 --- /dev/null +++ b/fairseq/examples/textless_nlp/speech-resynth/img/fig.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c19c570f3671d88551f5d5b908e270e69bd75d304c5d358868fa19f342979c17 +size 307833 diff --git a/fairseq/fairseq/benchmark/__init__.py b/fairseq/fairseq/benchmark/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0317d5c623778fe40b7bf07b77769cd10c243244 --- /dev/null +++ b/fairseq/fairseq/benchmark/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# import models/tasks to register them +from . import dummy_dataset, dummy_lm, dummy_masked_lm, dummy_model, dummy_mt # noqa diff --git a/fairseq/fairseq/benchmark/__pycache__/__init__.cpython-310.pyc b/fairseq/fairseq/benchmark/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dd19ba1402efce56a8838e25e30087094ddd8749 Binary files /dev/null and b/fairseq/fairseq/benchmark/__pycache__/__init__.cpython-310.pyc differ diff --git a/fairseq/fairseq/benchmark/__pycache__/dummy_dataset.cpython-310.pyc b/fairseq/fairseq/benchmark/__pycache__/dummy_dataset.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a95d2a9e6e8484c9ab346a08974db606452e1acc Binary files /dev/null and b/fairseq/fairseq/benchmark/__pycache__/dummy_dataset.cpython-310.pyc differ diff --git a/fairseq/fairseq/benchmark/__pycache__/dummy_lm.cpython-310.pyc b/fairseq/fairseq/benchmark/__pycache__/dummy_lm.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..430a09b269c3999a9bcc5f1a653e0be6021e6e4f Binary files /dev/null and b/fairseq/fairseq/benchmark/__pycache__/dummy_lm.cpython-310.pyc differ diff --git a/fairseq/fairseq/benchmark/__pycache__/dummy_masked_lm.cpython-310.pyc b/fairseq/fairseq/benchmark/__pycache__/dummy_masked_lm.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c253230d37a354435002a6b86307f90b4733f77e Binary files /dev/null and b/fairseq/fairseq/benchmark/__pycache__/dummy_masked_lm.cpython-310.pyc differ diff --git a/fairseq/fairseq/benchmark/__pycache__/dummy_model.cpython-310.pyc b/fairseq/fairseq/benchmark/__pycache__/dummy_model.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6ce1789d7acb2bd9c23226150a06454ba946f825 Binary files /dev/null and b/fairseq/fairseq/benchmark/__pycache__/dummy_model.cpython-310.pyc differ diff --git a/fairseq/fairseq/benchmark/__pycache__/dummy_mt.cpython-310.pyc b/fairseq/fairseq/benchmark/__pycache__/dummy_mt.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d472f4229968aad2d314d87ba480179308e29e2f Binary files /dev/null and b/fairseq/fairseq/benchmark/__pycache__/dummy_mt.cpython-310.pyc differ diff --git a/fairseq/fairseq/benchmark/benchmark_multihead_attention.py b/fairseq/fairseq/benchmark/benchmark_multihead_attention.py new file mode 100644 index 0000000000000000000000000000000000000000..a44847f25031ff2e4490ca47d560167af786f64d --- /dev/null +++ b/fairseq/fairseq/benchmark/benchmark_multihead_attention.py @@ -0,0 +1,172 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import itertools +import random + +import torch +from torch.utils import benchmark + +from fairseq.modules.multihead_attention import MultiheadAttention + +BATCH = [20, 41, 97] +SEQ = 64 +EMB = 48 +HEADS = 4 +DROP = 0.1 +DEVICE = torch.device("cuda") +ATTN_MASK_DTYPE = [torch.uint8, torch.bool, torch.float] +KEY_PADDING_MASK_DTYPE = [torch.uint8, torch.bool] + + +def _reset_seeds(): + torch.manual_seed(0) + random.seed(0) + + +def _get_mask(to_dtype: torch.dtype, dim0: int, dim1: int): + if to_dtype == torch.float: + mask = torch.randint(0, 2, (dim0, dim1)).to(dtype=torch.bool) + return mask.to(dtype=to_dtype).masked_fill(mask, -float("inf")) + return torch.randint(0, 2, (dim0, dim1)).to(dtype=to_dtype) + + +def benchmark_multihead_attention( + label="", + attn_dtype=torch.uint8, + key_padding_dtype=torch.uint8, + add_bias_kv=False, + add_zero_attn=False, + static_kv=False, + batch_size=20, + embedding=EMB, + seq_len=SEQ, + num_heads=HEADS, +): + + results = [] + # device = torch.device("cuda") + + xformers_att_config = '{"name": "scaled_dot_product"}' + + attn_mask = _get_mask(to_dtype=attn_dtype, dim0=seq_len, dim1=seq_len) + key_padding_mask = _get_mask( + to_dtype=key_padding_dtype, dim0=batch_size, dim1=seq_len + ) + + q = torch.rand(seq_len, batch_size, embedding, requires_grad=True) + k = torch.rand(seq_len, batch_size, embedding, requires_grad=True) + v = torch.rand(seq_len, batch_size, embedding, requires_grad=True) + + _reset_seeds() + + original_mha = MultiheadAttention( + embedding, + num_heads, + dropout=0.0, + xformers_att_config=None, + add_bias_kv=add_bias_kv, + add_zero_attn=add_zero_attn, + ) + + xformers_mha = MultiheadAttention( + embedding, + num_heads, + dropout=0.0, + xformers_att_config=xformers_att_config, + add_bias_kv=add_bias_kv, + add_zero_attn=add_zero_attn, + ) + + def original_bench_fw(q, k, v, key_padding_mask, attn_mask, static_kv): + original_mha( + query=q, + key=k, + value=v, + key_padding_mask=key_padding_mask, + attn_mask=attn_mask, + static_kv=static_kv, + ) + + def xformers_bench_fw(q, k, v, key_padding_mask, attn_mask, static_kv): + xformers_mha( + query=q, + key=k, + value=v, + key_padding_mask=key_padding_mask, + attn_mask=attn_mask, + static_kv=static_kv, + ) + + def original_bench_fw_bw(q, k, v, key_padding_mask, attn_mask, static_kv): + output, _ = original_mha( + query=q, + key=k, + value=v, + key_padding_mask=key_padding_mask, + attn_mask=attn_mask, + static_kv=static_kv, + ) + loss = torch.norm(output) + loss.backward() + + def xformers_bench_fw_bw(q, k, v, key_padding_mask, attn_mask, static_kv): + output, _ = xformers_mha( + query=q, + key=k, + value=v, + key_padding_mask=key_padding_mask, + attn_mask=attn_mask, + static_kv=static_kv, + ) + loss = torch.norm(output) + loss.backward() + + fns = [ + original_bench_fw, + xformers_bench_fw, + original_bench_fw_bw, + xformers_bench_fw_bw, + ] + + for fn in fns: + results.append( + benchmark.Timer( + stmt="fn(q, k, v, key_padding_mask, attn_mask, static_kv)", + globals={ + "q": q, + "k": k, + "v": v, + "key_padding_mask": key_padding_mask, + "attn_mask": attn_mask, + "static_kv": static_kv, + "fn": fn, + }, + label="multihead fw + bw", + sub_label=f"{fn.__name__}", + description=label, + ).blocked_autorange(min_run_time=1) + ) + + compare = benchmark.Compare(results) + compare.print() + + +def run_benchmarks(): + for attn_dtype, key_padding_dtype, add_bias_kv, add_zero_attn in itertools.product( + ATTN_MASK_DTYPE, KEY_PADDING_MASK_DTYPE, [True, False], [True, False] + ): + label = f"attn_dtype {attn_dtype}, key_padding_dtype {key_padding_dtype}, \ + add_bias_kv {add_bias_kv}, add_zero_attn {add_zero_attn}" + benchmark_multihead_attention( + label=label, + attn_dtype=attn_dtype, + key_padding_dtype=key_padding_dtype, + add_bias_kv=add_bias_kv, + add_zero_attn=add_zero_attn, + ) + + +run_benchmarks() diff --git a/fairseq/fairseq/benchmark/dummy_dataset.py b/fairseq/fairseq/benchmark/dummy_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..2f051754af55966e26850e94c121e0ff439bfd28 --- /dev/null +++ b/fairseq/fairseq/benchmark/dummy_dataset.py @@ -0,0 +1,36 @@ +import numpy as np +from fairseq.data import FairseqDataset + + +class DummyDataset(FairseqDataset): + def __init__(self, batch, num_items, item_size): + super().__init__() + self.batch = batch + self.num_items = num_items + self.item_size = item_size + + def __getitem__(self, index): + return index + + def __len__(self): + return self.num_items + + def collater(self, samples): + return self.batch + + @property + def sizes(self): + return np.array([self.item_size] * self.num_items) + + def num_tokens(self, index): + return self.item_size + + def size(self, index): + return self.item_size + + def ordered_indices(self): + return np.arange(self.num_items) + + @property + def supports_prefetch(self): + return False diff --git a/fairseq/fairseq/benchmark/dummy_lm.py b/fairseq/fairseq/benchmark/dummy_lm.py new file mode 100644 index 0000000000000000000000000000000000000000..c6246a0c0e338fa36244b3aa4fb57f189fbffcb6 --- /dev/null +++ b/fairseq/fairseq/benchmark/dummy_lm.py @@ -0,0 +1,83 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import logging +from dataclasses import dataclass, field +from typing import Optional + +import torch +from .dummy_dataset import DummyDataset +from fairseq.data import Dictionary +from fairseq.dataclass import FairseqDataclass +from fairseq.tasks import FairseqTask, register_task +from omegaconf import II + + +logger = logging.getLogger(__name__) + + +@dataclass +class DummyLMConfig(FairseqDataclass): + dict_size: int = 49996 + dataset_size: int = 100000 + tokens_per_sample: int = field( + default=512, metadata={"help": "max sequence length"} + ) + add_bos_token: bool = False + batch_size: Optional[int] = II("dataset.batch_size") + max_tokens: Optional[int] = II("dataset.max_tokens") + max_target_positions: int = II("task.tokens_per_sample") + + +@register_task("dummy_lm", dataclass=DummyLMConfig) +class DummyLMTask(FairseqTask): + def __init__(self, cfg: DummyLMConfig): + super().__init__(cfg) + + # load dictionary + self.dictionary = Dictionary() + for i in range(cfg.dict_size): + self.dictionary.add_symbol("word{}".format(i)) + self.dictionary.pad_to_multiple_(8) # often faster if divisible by 8 + logger.info("dictionary: {} types".format(len(self.dictionary))) + + seq = torch.arange(cfg.tokens_per_sample + 1) + self.dictionary.pad() + 1 + + self.dummy_src = seq[:-1] + self.dummy_tgt = seq[1:] + + def load_dataset(self, split, epoch=1, combine=False, **kwargs): + """Load a given dataset split. + Args: + split (str): name of the split (e.g., train, valid, test) + """ + if self.cfg.batch_size is not None: + bsz = self.cfg.batch_size + else: + bsz = max(1, self.cfg.max_tokens // self.cfg.tokens_per_sample) + self.datasets[split] = DummyDataset( + { + "id": 1, + "net_input": { + "src_tokens": torch.stack([self.dummy_src for _ in range(bsz)]), + "src_lengths": torch.full( + (bsz,), self.cfg.tokens_per_sample, dtype=torch.long + ), + }, + "target": torch.stack([self.dummy_tgt for _ in range(bsz)]), + "nsentences": bsz, + "ntokens": bsz * self.cfg.tokens_per_sample, + }, + num_items=self.cfg.dataset_size, + item_size=self.cfg.tokens_per_sample, + ) + + @property + def source_dictionary(self): + return self.dictionary + + @property + def target_dictionary(self): + return self.dictionary diff --git a/fairseq/fairseq/benchmark/dummy_masked_lm.py b/fairseq/fairseq/benchmark/dummy_masked_lm.py new file mode 100644 index 0000000000000000000000000000000000000000..12b9c5d0f55993bf8750564882a351fc3f8055f0 --- /dev/null +++ b/fairseq/fairseq/benchmark/dummy_masked_lm.py @@ -0,0 +1,94 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import logging +from dataclasses import dataclass, field +from typing import Optional + +import torch +from omegaconf import II + +from .dummy_dataset import DummyDataset +from fairseq.data import Dictionary +from fairseq.dataclass import FairseqDataclass +from fairseq.tasks import FairseqTask, register_task + +logger = logging.getLogger(__name__) + + +@dataclass +class DummyMaskedLMConfig(FairseqDataclass): + dict_size: int = 49996 + dataset_size: int = 100000 + tokens_per_sample: int = field( + default=512, + metadata={ + "help": "max number of total tokens over all" + " segments per sample for BERT dataset" + }, + ) + batch_size: Optional[int] = II("dataset.batch_size") + max_tokens: Optional[int] = II("dataset.max_tokens") + max_target_positions: int = II("task.tokens_per_sample") + + +@register_task("dummy_masked_lm", dataclass=DummyMaskedLMConfig) +class DummyMaskedLMTask(FairseqTask): + def __init__(self, cfg: DummyMaskedLMConfig): + super().__init__(cfg) + + self.dictionary = Dictionary() + for i in range(cfg.dict_size): + self.dictionary.add_symbol("word{}".format(i)) + logger.info("dictionary: {} types".format(len(self.dictionary))) + # add mask token + self.mask_idx = self.dictionary.add_symbol("") + self.dictionary.pad_to_multiple_(8) # often faster if divisible by 8 + + mask_idx = 0 + pad_idx = 1 + seq = torch.arange(cfg.tokens_per_sample) + pad_idx + 1 + mask = torch.arange(2, cfg.tokens_per_sample, 7) # ~15% + src = seq.clone() + src[mask] = mask_idx + tgt = torch.full_like(seq, pad_idx) + tgt[mask] = seq[mask] + + self.dummy_src = src + self.dummy_tgt = tgt + + def load_dataset(self, split, epoch=1, combine=False, **kwargs): + """Load a given dataset split. + Args: + split (str): name of the split (e.g., train, valid, test) + """ + if self.cfg.batch_size is not None: + bsz = self.cfg.batch_size + else: + bsz = max(1, self.cfg.max_tokens // self.cfg.tokens_per_sample) + self.datasets[split] = DummyDataset( + { + "id": 1, + "net_input": { + "src_tokens": torch.stack([self.dummy_src for _ in range(bsz)]), + "src_lengths": torch.full( + (bsz,), self.cfg.tokens_per_sample, dtype=torch.long + ), + }, + "target": torch.stack([self.dummy_tgt for _ in range(bsz)]), + "nsentences": bsz, + "ntokens": bsz * self.cfg.tokens_per_sample, + }, + num_items=self.cfg.dataset_size, + item_size=self.cfg.tokens_per_sample, + ) + + @property + def source_dictionary(self): + return self.dictionary + + @property + def target_dictionary(self): + return self.dictionary diff --git a/fairseq/fairseq/benchmark/dummy_model.py b/fairseq/fairseq/benchmark/dummy_model.py new file mode 100644 index 0000000000000000000000000000000000000000..ff26e4fe655d8e8d7f9942c4bd3df7cd267405fb --- /dev/null +++ b/fairseq/fairseq/benchmark/dummy_model.py @@ -0,0 +1,96 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch.nn as nn +import torch.nn.functional as F +from fairseq.data import Dictionary +from fairseq.models import ( + FairseqDecoder, + FairseqLanguageModel, + register_model, + register_model_architecture, +) + + +@register_model("dummy_model") +class DummyModel(FairseqLanguageModel): + def __init__(self, args, encoder): + super().__init__(encoder) + self.args = args + + @staticmethod + def add_args(parser): + parser.add_argument("--num-layers", type=int, default=24) + parser.add_argument("--embed-dim", type=int, default=1024) + + @classmethod + def build_model(cls, args, task): + encoder = DummyEncoder( + num_embed=len(task.target_dictionary), + embed_dim=args.embed_dim, + num_layers=args.num_layers, + ) + return cls(args, encoder) + + def forward(self, src_tokens, masked_tokens=None, **kwargs): + return self.decoder(src_tokens, masked_tokens=masked_tokens) + + +class DummyEncoder(FairseqDecoder): + def __init__(self, num_embed=50000, embed_dim=1024, num_layers=24): + super().__init__(Dictionary()) + self.embed = nn.Embedding( + num_embeddings=num_embed, embedding_dim=embed_dim, padding_idx=0 + ) + self.layers_a = nn.ModuleList( + [ + nn.Sequential( + nn.LayerNorm(embed_dim), + nn.Linear(embed_dim, 3 * embed_dim), # q, k, v input projection + nn.Linear(3 * embed_dim, embed_dim), # skip self-attention + nn.Linear(embed_dim, embed_dim), # output projection + nn.Dropout(), + ) + for i in range(num_layers) + ] + ) + self.layers_b = nn.ModuleList( + [ + nn.Sequential( + nn.LayerNorm(embed_dim), + nn.Linear(embed_dim, 4 * embed_dim), # FFN + nn.ReLU(), + nn.Linear(4 * embed_dim, embed_dim), # FFN + nn.Dropout(0.1), + ) + for i in range(num_layers) + ] + ) + self.out_proj = nn.Linear(embed_dim, num_embed) + + def forward(self, tokens, masked_tokens=None): + x = self.embed(tokens) + for layer_a, layer_b in zip(self.layers_a, self.layers_b): + x = x + layer_a(x) + x = x + layer_b(x) + x = self.out_proj(x) + if masked_tokens is not None: + x = x[masked_tokens] + return (x,) + + def max_positions(self): + return 1024 + + def get_normalized_probs(self, net_output, log_probs, sample=None): + logits = net_output[0].float() + if log_probs: + return F.log_softmax(logits, dim=-1) + else: + return F.softmax(logits, dim=-1) + + +@register_model_architecture("dummy_model", "dummy_model") +def base_architecture(args): + pass diff --git a/fairseq/fairseq/benchmark/dummy_mt.py b/fairseq/fairseq/benchmark/dummy_mt.py new file mode 100644 index 0000000000000000000000000000000000000000..28d78cffdbf8c2bcee69b454a79891cb34def200 --- /dev/null +++ b/fairseq/fairseq/benchmark/dummy_mt.py @@ -0,0 +1,119 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import logging + +import numpy as np +import torch + +from fairseq.data import Dictionary, FairseqDataset +from fairseq.tasks import LegacyFairseqTask, register_task + +logger = logging.getLogger(__name__) + + +@register_task("dummy_mt") +class DummyMTTask(LegacyFairseqTask): + @staticmethod + def add_args(parser): + """Add task-specific arguments to the parser.""" + parser.add_argument("--dict-size", default=49996, type=int) + parser.add_argument("--dataset-size", default=100000, type=int) + parser.add_argument("--src-len", default=30, type=int) + parser.add_argument("--tgt-len", default=30, type=int) + + def __init__(self, args, dictionary): + super().__init__(args) + self.dictionary = dictionary + self.seed = args.seed + + dictionary.pad_to_multiple_(8) # often faster if divisible by 8 + + self.dummy_src = torch.arange(args.src_len + 1) + dictionary.pad() + 1 + self.dummy_tgt = torch.arange(args.tgt_len + 1) + dictionary.pad() + 1 + + @classmethod + def setup_task(cls, args, **kwargs): + """Setup the task.""" + dictionary = Dictionary() + for i in range(args.dict_size): + dictionary.add_symbol("word{}".format(i)) + logger.info("dictionary: {} types".format(len(dictionary))) + + args.max_source_positions = args.src_len + dictionary.pad() + 2 + args.max_target_positions = args.tgt_len + dictionary.pad() + 2 + + return cls(args, dictionary) + + def load_dataset(self, split, epoch=1, combine=False, **kwargs): + """Load a given dataset split. + Args: + split (str): name of the split (e.g., train, valid, test) + """ + item_size = max(self.args.src_len, self.args.tgt_len) + if self.args.batch_size is not None: + bsz = self.args.batch_size + else: + bsz = max(1, self.args.max_tokens // item_size) + tgt = torch.stack([self.dummy_tgt for _ in range(bsz)]) + self.datasets[split] = DummyDataset( + { + "id": 1, + "net_input": { + "src_tokens": torch.stack([self.dummy_src for _ in range(bsz)]), + "src_lengths": torch.full( + (bsz,), self.args.src_len, dtype=torch.long + ), + "prev_output_tokens": tgt.clone(), + }, + "target": tgt, + "nsentences": bsz, + "ntokens": bsz * self.args.tgt_len, + }, + num_items=self.args.dataset_size, + item_size=item_size, + ) + + @property + def source_dictionary(self): + return self.dictionary + + @property + def target_dictionary(self): + return self.dictionary + + +class DummyDataset(FairseqDataset): + def __init__(self, batch, num_items, item_size): + super().__init__() + self.batch = batch + self.num_items = num_items + self.item_size = item_size + + def __getitem__(self, index): + return index + + def __len__(self): + return self.num_items + + def collater(self, samples): + return self.batch + + @property + def sizes(self): + return np.array([self.item_size] * self.num_items) + + def num_tokens(self, index): + return self.item_size + + def size(self, index): + return self.item_size + + def ordered_indices(self): + return np.arange(self.num_items) + + @property + def supports_prefetch(self): + return False diff --git a/fairseq/fairseq/clib/cuda/ngram_repeat_block_cuda.cpp b/fairseq/fairseq/clib/cuda/ngram_repeat_block_cuda.cpp new file mode 100644 index 0000000000000000000000000000000000000000..707219105a17a691e43de1296a72bbaffa0c7fe9 --- /dev/null +++ b/fairseq/fairseq/clib/cuda/ngram_repeat_block_cuda.cpp @@ -0,0 +1,55 @@ +/* +Copyright (c) Microsoft Corporation. +Licensed under the MIT License. +*/ + +#include +#include + +/* +CPP Binding for CUDA OP +*/ + +// CUDA forward declarations +torch::Tensor ngram_repeat_block_cuda_forward( + torch::Tensor tokens, + torch::Tensor lprobs, + int bsz, + int step, + int beam_size, + int no_repeat_ngram_size); + +#define CHECK_CUDA(x) \ + TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") +#define CHECK_CONTIGUOUS(x) \ + TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") +#define CHECK_INPUT(x) \ + CHECK_CUDA(x); \ + CHECK_CONTIGUOUS(x) + +// Input check and call to CUDA OP +// Backward method not required +torch::Tensor ngram_repeat_block_forward( + torch::Tensor tokens, + torch::Tensor lprobs, + int bsz, + int step, + int beam_size, + int no_repeat_ngram_size) { + CHECK_INPUT(tokens); + CHECK_INPUT(lprobs); + assert(bsz > 0); + assert(step >= 0); + assert(beam_size > 0); + assert(no_repeat_ngram_size > 0); + + return ngram_repeat_block_cuda_forward( + tokens, lprobs, bsz, step, beam_size, no_repeat_ngram_size); +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def( + "forward", + &ngram_repeat_block_forward, + "No Repeat Ngram Block forward (CUDA)"); +} diff --git a/fairseq/fairseq/clib/cuda/ngram_repeat_block_cuda_kernel.cu b/fairseq/fairseq/clib/cuda/ngram_repeat_block_cuda_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..bd6106cba0672c3ff29c925b0f5cea557ab3eced --- /dev/null +++ b/fairseq/fairseq/clib/cuda/ngram_repeat_block_cuda_kernel.cu @@ -0,0 +1,82 @@ +/* +Copyright (c) Microsoft Corporation. +Licensed under the MIT License. +*/ + +/* +Kernel implementation for blocking repeated n-grams. +*/ + +#include +#include +#include +#include +#include + +// Ban repeated ngrams of length = 'no_repeat_ngram_size' +__global__ void banRepeatedTokens( + long* __restrict__ tokens, + float* __restrict__ lprobs, + int max_predict_len, + int vocab_size, + int no_repeat_ngram_size) { + auto row = blockIdx.x; + auto col = threadIdx.x; + auto start = row * (max_predict_len) + col; + // Each thread compares ngram starting from + // thread index with final ngram starting from + // step - no_repeat_ngram_size +2 + auto check_start_pos = blockDim.x; + auto lprob_start = row * vocab_size; + bool is_banned = true; + extern __shared__ long tokens_shm[]; + tokens_shm[col] = tokens[start]; + if (col == blockDim.x - 1) { + for (int i = 1; i < no_repeat_ngram_size; i++) { + if (col + i < max_predict_len) { + tokens_shm[col + i] = tokens[start + i]; + } + } + } + __syncthreads(); + + for (int k = 0; k < no_repeat_ngram_size - 1; k++) { + if (tokens_shm[col + k] != tokens_shm[check_start_pos + k]) { + is_banned = false; + } + } + if (is_banned == true) { + auto token_to_be_banned = tokens_shm[col + no_repeat_ngram_size - 1]; + lprobs[lprob_start + token_to_be_banned] = -INFINITY; + } +} + +// Allocate blocks and threads based on +// batch size and sequence length and launch +// kernel +torch::Tensor ngram_repeat_block_cuda_forward( + const torch::Tensor tokens, + torch::Tensor lprobs, + int bsz, + int step, + int beam_size, + int no_repeat_ngram_size) { + int threads = step - no_repeat_ngram_size + 2; + if (threads <= 0) + return lprobs; + int max_predict_len = tokens.size(1); + int vocab_size = lprobs.size(1); + auto token_ptr = tokens.data_ptr(); + auto lprob_ptr = lprobs.data_ptr(); + int blocks = bsz * beam_size; + int shared_mem_size = (step + 1) * sizeof(long); + + // Launching N blocks where N is number of samples in a batch (beams*bsz) + // Launching T threads where T is number of previous ngrams in a sample + // Allocating shared mem per block for fastser access of input tokens since + // each token will be accessed N times to compare with current Ngram where + // N is Ngram size. + banRepeatedTokens<<>>( + token_ptr, lprob_ptr, max_predict_len, vocab_size, no_repeat_ngram_size); + return lprobs; +} diff --git a/fairseq/fairseq/clib/libbase/balanced_assignment.cpp b/fairseq/fairseq/clib/libbase/balanced_assignment.cpp new file mode 100644 index 0000000000000000000000000000000000000000..1a5a1061f3892be5a17e49192f744c39e0d395e8 --- /dev/null +++ b/fairseq/fairseq/clib/libbase/balanced_assignment.cpp @@ -0,0 +1,109 @@ +/** + * Copyright 2017-present, Facebook, Inc. + * All rights reserved. + * + * This source code is licensed under the license found in the + * LICENSE file in the root directory of this source tree. + */ + +/* +C++ code for solving the linear assignment problem. +Based on the Auction Algorithm from +https://dspace.mit.edu/bitstream/handle/1721.1/3265/P-2108-26912652.pdf and the +implementation from: https://github.com/bkj/auction-lap Adapted to be more +efficient when each worker is looking for k jobs instead of 1. +*/ +#include +#include +using namespace torch::indexing; +torch::Tensor balanced_assignment(torch::Tensor job_and_worker_to_score) { + int max_iterations = 100; + torch::Tensor epsilon = + (job_and_worker_to_score.max() - job_and_worker_to_score.min()) / 50; + epsilon.clamp_min_(1e-04); + torch::Tensor worker_and_job_to_score = + job_and_worker_to_score.detach().transpose(0, 1).contiguous(); + int num_workers = worker_and_job_to_score.size(0); + int num_jobs = worker_and_job_to_score.size(1); + auto device = worker_and_job_to_score.device(); + int jobs_per_worker = num_jobs / num_workers; + torch::Tensor value = worker_and_job_to_score.clone(); + int counter = 0; + torch::Tensor max_value = worker_and_job_to_score.max(); + + torch::Tensor bid_indices; + torch::Tensor cost = worker_and_job_to_score.new_zeros({1, num_jobs}); + torch::Tensor bids = + worker_and_job_to_score.new_empty({num_workers, num_jobs}); + torch::Tensor bid_increments = + worker_and_job_to_score.new_empty({num_workers, jobs_per_worker}); + torch::Tensor top_values = + worker_and_job_to_score.new_empty({num_workers, jobs_per_worker + 1}); + torch::Tensor high_bids = worker_and_job_to_score.new_empty({num_jobs}); + + torch::Tensor top_index = top_values.to(torch::kLong); + torch::Tensor high_bidders = top_index.new_empty({num_jobs}); + torch::Tensor have_bids = high_bidders.to(torch::kBool); + torch::Tensor jobs_indices = + torch::arange({num_jobs}, torch::dtype(torch::kLong).device(device)); + torch::Tensor true_tensor = + torch::ones({1}, torch::dtype(torch::kBool).device(device)); + + while (true) { + bids.zero_(); + torch::topk_out(top_values, top_index, value, jobs_per_worker + 1, 1); + + // Each worker bids the difference in value between that job and the k+1th + // job + torch::sub_out( + bid_increments, + top_values.index({Slice(None, None), Slice(0, jobs_per_worker)}), + top_values.index({Slice(None, None), jobs_per_worker}).unsqueeze(1)); + + bid_increments.add_(epsilon); + bids.scatter_( + 1, + top_index.index({Slice(None, None), Slice(0, jobs_per_worker)}), + bid_increments); + + if (counter < max_iterations && counter > 0) { + // Put in a minimal bid to retain items from the last round if no-one else + // bids for them this round + bids.view(-1).index_put_({bid_indices}, epsilon); + } + + // Find the highest bidding worker per job + torch::max_out(high_bids, high_bidders, bids, 0); + torch::gt_out(have_bids, high_bids, 0); + + if (have_bids.all().item()) { + // All jobs were bid for + break; + } + + // Make popular items more expensive + cost.add_(high_bids); + torch::sub_out(value, worker_and_job_to_score, cost); + + bid_indices = ((high_bidders * num_jobs) + jobs_indices).index({have_bids}); + + if (counter < max_iterations) { + // Make sure that this item will be in the winning worker's top-k next + // time. + value.view(-1).index_put_({bid_indices}, max_value); + } else { + // Suboptimal approximation that converges quickly from current solution + value.view(-1).index_put_( + {bid_indices}, worker_and_job_to_score.view(-1).index({bid_indices})); + } + + counter += 1; + } + + return top_index.index({Slice(None, None), Slice(0, jobs_per_worker)}) + .reshape(-1); +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("balanced_assignment", &balanced_assignment, "Balanced Assignment"); +} diff --git a/fairseq/fairseq/clib/libbleu/libbleu.cpp b/fairseq/fairseq/clib/libbleu/libbleu.cpp new file mode 100644 index 0000000000000000000000000000000000000000..939d9e1174e398fa48c840009b592c753a67939a --- /dev/null +++ b/fairseq/fairseq/clib/libbleu/libbleu.cpp @@ -0,0 +1,157 @@ +/** + * Copyright 2017-present, Facebook, Inc. + * All rights reserved. + * + * This source code is licensed under the license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include + +// NOLINTNEXTLINE +typedef struct { + size_t reflen; + size_t predlen; + size_t match1; + size_t count1; + size_t match2; + size_t count2; + size_t match3; + size_t count3; + size_t match4; + size_t count4; +} bleu_stat; + +// left trim (remove pad) +void bleu_ltrim(size_t* len, int** sent, int pad) { + size_t start = 0; + while (start < *len) { + if (*(*sent + start) != pad) { + break; + } + start++; + } + *sent += start; + *len -= start; +} + +// right trim remove (eos) +void bleu_rtrim(size_t* len, int** sent, int pad, int eos) { + size_t end = *len - 1; + while (end > 0) { + if (*(*sent + end) != eos && *(*sent + end) != pad) { + break; + } + end--; + } + *len = end + 1; +} + +// left and right trim +void bleu_trim(size_t* len, int** sent, int pad, int eos) { + bleu_ltrim(len, sent, pad); + bleu_rtrim(len, sent, pad, eos); +} + +size_t bleu_hash(int len, int* data) { + size_t h = 14695981039346656037ul; + size_t prime = 0x100000001b3; + char* b = (char*)data; + size_t blen = sizeof(int) * len; + + while (blen-- > 0) { + h ^= *b++; + h *= prime; + } + + return h; +} + +void bleu_addngram( + size_t* ntotal, + size_t* nmatch, + size_t n, + size_t reflen, + int* ref, + size_t predlen, + int* pred) { + if (predlen < n) { + return; + } + + predlen = predlen - n + 1; + (*ntotal) += predlen; + + if (reflen < n) { + return; + } + + reflen = reflen - n + 1; + + std::map count; + while (predlen > 0) { + size_t w = bleu_hash(n, pred++); + count[w]++; + predlen--; + } + + while (reflen > 0) { + size_t w = bleu_hash(n, ref++); + if (count[w] > 0) { + (*nmatch)++; + count[w] -= 1; + } + reflen--; + } +} + +extern "C" { + +#ifdef _WIN64 +__declspec(dllexport) +#endif + void bleu_zero_init(bleu_stat* stat) { + std::memset(stat, 0, sizeof(bleu_stat)); +} + +#ifdef _WIN64 +__declspec(dllexport) +#endif + void bleu_one_init(bleu_stat* stat) { + bleu_zero_init(stat); + stat->count1 = 0; + stat->count2 = 1; + stat->count3 = 1; + stat->count4 = 1; + stat->match1 = 0; + stat->match2 = 1; + stat->match3 = 1; + stat->match4 = 1; +} + +#ifdef _WIN64 +__declspec(dllexport) +#endif + void bleu_add( + bleu_stat* stat, + size_t reflen, + int* ref, + size_t predlen, + int* pred, + int pad, + int eos) { + + bleu_trim(&reflen, &ref, pad, eos); + bleu_trim(&predlen, &pred, pad, eos); + stat->reflen += reflen; + stat->predlen += predlen; + + bleu_addngram(&stat->count1, &stat->match1, 1, reflen, ref, predlen, pred); + bleu_addngram(&stat->count2, &stat->match2, 2, reflen, ref, predlen, pred); + bleu_addngram(&stat->count3, &stat->match3, 3, reflen, ref, predlen, pred); + bleu_addngram(&stat->count4, &stat->match4, 4, reflen, ref, predlen, pred); +} +} diff --git a/fairseq/fairseq/clib/libbleu/module.cpp b/fairseq/fairseq/clib/libbleu/module.cpp new file mode 100644 index 0000000000000000000000000000000000000000..35288b3177185670135f7bdc1f1589c5bb992304 --- /dev/null +++ b/fairseq/fairseq/clib/libbleu/module.cpp @@ -0,0 +1,33 @@ +/** + * Copyright 2017-present, Facebook, Inc. + * All rights reserved. + * + * This source code is licensed under the license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +static PyMethodDef method_def[] = {{NULL, NULL, 0, NULL}}; // NOLINT + +static struct PyModuleDef module_def = { + PyModuleDef_HEAD_INIT, + "libbleu", /* name of module */ + // NOLINTNEXTLINE + NULL, /* module documentation, may be NULL */ + -1, /* size of per-interpreter state of the module, + or -1 if the module keeps state in global variables. */ + method_def}; // NOLINT + +#if PY_MAJOR_VERSION == 2 +PyMODINIT_FUNC init_libbleu() +#else +PyMODINIT_FUNC PyInit_libbleu() +#endif +{ + PyObject* m = PyModule_Create(&module_def); + if (!m) { + return NULL; + } + return m; +} diff --git a/fairseq/fairseq/clib/libnat/edit_dist.cpp b/fairseq/fairseq/clib/libnat/edit_dist.cpp new file mode 100644 index 0000000000000000000000000000000000000000..9ffb60569d74d2868ed8113b7c787ef870e9da20 --- /dev/null +++ b/fairseq/fairseq/clib/libnat/edit_dist.cpp @@ -0,0 +1,231 @@ +/** + * Copyright 2017-present, Facebook, Inc. + * All rights reserved. + * + * This source code is licensed under the license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include // @manual=//caffe2:torch_extension +#include +#include +#include +#include +#include +#include +#include +#include + +using namespace ::std; + +vector> edit_distance2_with_dp( + vector& x, + vector& y) { + uint32_t lx = x.size(); + uint32_t ly = y.size(); + vector> d(lx + 1, vector(ly + 1)); + for (uint32_t i = 0; i < lx + 1; i++) { + d[i][0] = i; + } + for (uint32_t j = 0; j < ly + 1; j++) { + d[0][j] = j; + } + for (uint32_t i = 1; i < lx + 1; i++) { + for (uint32_t j = 1; j < ly + 1; j++) { + d[i][j] = + min(min(d[i - 1][j], d[i][j - 1]) + 1, + d[i - 1][j - 1] + 2 * (x.at(i - 1) == y.at(j - 1) ? 0 : 1)); + } + } + return d; +} + +vector> edit_distance2_backtracking( + vector>& d, + vector& x, + vector& y, + uint32_t terminal_symbol) { + vector seq; + vector> edit_seqs(x.size() + 2, vector()); + /* + edit_seqs: + 0~x.size() cell is the insertion sequences + last cell is the delete sequence + */ + + if (x.size() == 0) { + edit_seqs.at(0) = y; + return edit_seqs; + } + + uint32_t i = d.size() - 1; + uint32_t j = d.at(0).size() - 1; + + while ((i >= 0) && (j >= 0)) { + if ((i == 0) && (j == 0)) { + break; + } + + if ((j > 0) && (d.at(i).at(j - 1) < d.at(i).at(j))) { + seq.push_back(1); // insert + seq.push_back(y.at(j - 1)); + j--; + } else if ((i > 0) && (d.at(i - 1).at(j) < d.at(i).at(j))) { + seq.push_back(2); // delete + seq.push_back(x.at(i - 1)); + i--; + } else { + seq.push_back(3); // keep + seq.push_back(x.at(i - 1)); + i--; + j--; + } + } + + uint32_t prev_op, op, s, word; + prev_op = 0, s = 0; + for (uint32_t k = 0; k < seq.size() / 2; k++) { + op = seq.at(seq.size() - 2 * k - 2); + word = seq.at(seq.size() - 2 * k - 1); + if (prev_op != 1) { + s++; + } + if (op == 1) // insert + { + edit_seqs.at(s - 1).push_back(word); + } else if (op == 2) // delete + { + edit_seqs.at(x.size() + 1).push_back(1); + } else { + edit_seqs.at(x.size() + 1).push_back(0); + } + + prev_op = op; + } + + for (uint32_t k = 0; k < edit_seqs.size(); k++) { + if (edit_seqs[k].size() == 0) { + edit_seqs[k].push_back(terminal_symbol); + } + } + return edit_seqs; +} + +vector> edit_distance2_backtracking_with_delete( + vector>& d, + vector& x, + vector& y, + uint32_t terminal_symbol, + uint32_t deletion_symbol) { + vector seq; + vector> edit_seqs(x.size() + 1, vector()); + /* + edit_seqs: + 0~x.size() cell is the insertion sequences + last cell is the delete sequence + */ + + if (x.size() == 0) { + edit_seqs.at(0) = y; + return edit_seqs; + } + + uint32_t i = d.size() - 1; + uint32_t j = d.at(0).size() - 1; + + while ((i >= 0) && (j >= 0)) { + if ((i == 0) && (j == 0)) { + break; + } + + if ((j > 0) && (d.at(i).at(j - 1) < d.at(i).at(j))) { + seq.push_back(1); // insert + seq.push_back(y.at(j - 1)); + j--; + } else if ((i > 0) && (d.at(i - 1).at(j) < d.at(i).at(j))) { + seq.push_back(2); // delete + seq.push_back(x.at(i - 1)); + i--; + } else { + seq.push_back(3); // keep + seq.push_back(x.at(i - 1)); + i--; + j--; + } + } + + uint32_t prev_op, op, s, word; + prev_op = 0, s = 0; + for (uint32_t k = 0; k < seq.size() / 2; k++) { + op = seq.at(seq.size() - 2 * k - 2); + word = seq.at(seq.size() - 2 * k - 1); + if (prev_op != 1) { + s++; + } + if (op == 1) // insert + { + edit_seqs.at(s - 1).push_back(word); + } else if (op == 2) // delete + { + edit_seqs.at(s - 1).push_back(deletion_symbol); + } + + prev_op = op; + } + + for (uint32_t k = 0; k < edit_seqs.size(); k++) { + if (edit_seqs.at(k).size() == 0) { + edit_seqs.at(k).push_back(terminal_symbol); + } + } + return edit_seqs; +} + +vector compute_ed2( + vector>& xs, + vector>& ys) { + vector distances(xs.size()); + for (uint32_t i = 0; i < xs.size(); i++) { + vector> d = edit_distance2_with_dp(xs.at(i), ys.at(i)); + distances.at(i) = d.at(xs.at(i).size()).at(ys.at(i).size()); + } + return distances; +} + +vector>> suggested_ed2_path( + vector>& xs, + vector>& ys, + uint32_t terminal_symbol) { + vector>> seq(xs.size()); + for (uint32_t i = 0; i < xs.size(); i++) { + vector> d = edit_distance2_with_dp(xs.at(i), ys.at(i)); + seq.at(i) = + edit_distance2_backtracking(d, xs.at(i), ys.at(i), terminal_symbol); + } + return seq; +} + +vector>> suggested_ed2_path_with_delete( + vector>& xs, + vector>& ys, + uint32_t terminal_symbol, + uint32_t deletion_symbol) { + vector>> seq(xs.size()); + for (uint32_t i = 0; i < xs.size(); i++) { + vector> d = edit_distance2_with_dp(xs.at(i), ys.at(i)); + seq.at(i) = edit_distance2_backtracking_with_delete( + d, xs.at(i), ys.at(i), terminal_symbol, deletion_symbol); + } + return seq; +} + +PYBIND11_MODULE(libnat, m) { + m.def("compute_ed2", &compute_ed2, "compute_ed2"); + m.def("suggested_ed2_path", &suggested_ed2_path, "suggested_ed2_path"); + m.def( + "suggested_ed2_path_with_delete", + &suggested_ed2_path_with_delete, + "suggested_ed2_path_with_delete"); +} diff --git a/fairseq/fairseq/clib/libnat_cuda/binding.cpp b/fairseq/fairseq/clib/libnat_cuda/binding.cpp new file mode 100644 index 0000000000000000000000000000000000000000..ced91c0d0afab9071842911d9876e6360d90284a --- /dev/null +++ b/fairseq/fairseq/clib/libnat_cuda/binding.cpp @@ -0,0 +1,67 @@ +/** + * Copyright 2017-present, Facebook, Inc. + * All rights reserved. + * + * This source code is licensed under the license found in the + * LICENSE file in the root directory of this source tree. + */ + +/* + This code is partially adpoted from + https://github.com/1ytic/pytorch-edit-distance + */ + +#include +#include "edit_dist.h" + +#ifndef TORCH_CHECK +#define TORCH_CHECK AT_CHECK +#endif + +#define CHECK_CUDA(x) \ + TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") +#define CHECK_CONTIGUOUS(x) \ + TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") +#define CHECK_INPUT(x) \ + CHECK_CUDA(x); \ + CHECK_CONTIGUOUS(x) + +torch::Tensor LevenshteinDistance( + torch::Tensor source, + torch::Tensor target, + torch::Tensor source_length, + torch::Tensor target_length) { + CHECK_INPUT(source); + CHECK_INPUT(target); + CHECK_INPUT(source_length); + CHECK_INPUT(target_length); + return LevenshteinDistanceCuda(source, target, source_length, target_length); +} + +torch::Tensor GenerateDeletionLabel( + torch::Tensor source, + torch::Tensor operations) { + CHECK_INPUT(source); + CHECK_INPUT(operations); + return GenerateDeletionLabelCuda(source, operations); +} + +std::pair GenerateInsertionLabel( + torch::Tensor target, + torch::Tensor operations) { + CHECK_INPUT(target); + CHECK_INPUT(operations); + return GenerateInsertionLabelCuda(target, operations); +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("levenshtein_distance", &LevenshteinDistance, "Levenshtein distance"); + m.def( + "generate_deletion_labels", + &GenerateDeletionLabel, + "Generate Deletion Label"); + m.def( + "generate_insertion_labels", + &GenerateInsertionLabel, + "Generate Insertion Label"); +} diff --git a/fairseq/fairseq/clib/libnat_cuda/edit_dist.cu b/fairseq/fairseq/clib/libnat_cuda/edit_dist.cu new file mode 100644 index 0000000000000000000000000000000000000000..1ea5ec7e3cb31557fde20bc457f986bbcecc9cb2 --- /dev/null +++ b/fairseq/fairseq/clib/libnat_cuda/edit_dist.cu @@ -0,0 +1,344 @@ +/** + * Copyright 2017-present, Facebook, Inc. + * All rights reserved. + * + * This source code is licensed under the license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "edit_dist.h" + +#include +#include +#include +#include +#include // std::pair + +template +__global__ void generate_deletion_label_kernel( + const scalar_t* __restrict__ source, + const size_t source_size, + const size_t operation_size, + int* __restrict__ operations, + int* __restrict__ labels) { + const int index = blockIdx.x; + const int offset = index * operation_size; + const int offset_label = index * source_size; + + for (int i = 0; i < source_size; i++) { + labels[offset_label + i] = 0; + } + + int k = 0; + for (int i = 0; i < operation_size; i++) { + if (operations[offset + i] == 0) { + break; + } else if (operations[offset + i] == 1) { + continue; + } else { + labels[offset_label + k] = 3 - operations[offset + i]; + k++; + } + } +} + +template +__global__ void generate_insertion_label_kernel( + const scalar_t* __restrict__ target, + const size_t target_size, + const size_t operation_size, + int* __restrict__ operations, + int* __restrict__ labels, + int* __restrict__ masks) { + const int index = blockIdx.x; + const int offset = index * operation_size; + const int offset_label = index * target_size; + + int k = 0; + int u = 0; + int m = 0; + + for (int i = 0; i < target_size; i++) { + labels[offset_label + i] = 0; + masks[offset_label + i] = 0; + } + + for (int i = 0; i < operation_size - 1; i++) { + if (operations[offset + i] == 0) { + break; + } else if (operations[offset + i] == 2) { + continue; + } else if (operations[offset + i] == 1) { + masks[offset_label + m] = 1; + u++; + m++; + } else { + labels[offset_label + k] = u; + masks[offset_label + m] = 0; + k++; + m++; + u = 0; + } + } +} + +template +__global__ void levenshtein_distance_kernel( + const scalar_t* __restrict__ source, + const scalar_t* __restrict__ target, + const int* __restrict__ source_length, + const int* __restrict__ target_length, + const size_t source_size, + const size_t target_size, + int* __restrict__ operations, + int* __restrict__ errors_curr) { + const int index = blockIdx.x; + const int offset = index * (source_size + target_size); + const int d = index * (source_size + 1) * (target_size + 1); + const int t = target_size + 1; + + auto err_idx = [d, t](int i, int j) { return d + i * t + j; }; + auto opt_idx = [offset](int k) { return offset + k; }; + + const int hyp_len = source_length[index]; + const int ref_len = target_length[index]; + const scalar_t* hyp_begin = source + index * source_size; + const scalar_t* ref_begin = target + index * target_size; + + // dynamic programming + for (int i = 0; i <= hyp_len; i++) { + errors_curr[err_idx(i, 0)] = i; + } + for (int j = 0; j <= ref_len; j++) { + errors_curr[err_idx(0, j)] = j; + } + for (int i = 1; i <= hyp_len; i++) { + for (int j = 1; j <= ref_len; j++) { + errors_curr[err_idx(i, j)] = min( + min(errors_curr[err_idx(i - 1, j)], errors_curr[err_idx(i, j - 1)]) + + 1, + errors_curr[err_idx(i - 1, j - 1)] + + 2 * (*(hyp_begin + i - 1) == *(ref_begin + j - 1) ? 0 : 1)); + } + } + + // back-tracing + int i = hyp_len; + int j = ref_len; + int o = hyp_len + ref_len; + + for (int k = 0; k < source_size + target_size; k++) { + operations[opt_idx(k)] = 0; + } + + while ((i >= 0) && (j >= 0)) { + if ((i == 0) && (j == 0)) { + break; + } + + if ((j > 0) && + (errors_curr[err_idx(i, j - 1)] < errors_curr[err_idx(i, j)])) { + o--; + operations[opt_idx(o)] = 1; + j--; // insertion + } else if ( + (i > 0) && + (errors_curr[err_idx(i - 1, j)] < errors_curr[err_idx(i, j)])) { + o--; + operations[opt_idx(o)] = 2; + i--; // deletion + } else { + o--; + operations[opt_idx(o)] = 3; + i--; + j--; // do nothing + } + } + + // moving to the left + for (int k = 0; k < hyp_len + ref_len; k++) { + if (k + o < hyp_len + ref_len) { + operations[opt_idx(k)] = operations[opt_idx(k + o)]; + } else { + operations[opt_idx(k)] = 0; // padding + } + } +} + +template +__global__ void faster_levenshtein_distance_kernel( + const scalar_t* __restrict__ source, + const scalar_t* __restrict__ target, + const int* __restrict__ source_length, + const int* __restrict__ target_length, + const size_t source_size, + const size_t target_size, + int* __restrict__ operations) { + extern __shared__ short errors[]; + auto errors_curr = errors; + + const int index = blockIdx.x; + const int offset = index * (source_size + target_size); + const int t = target_size + 1; + + auto err_idx = [t](int i, int j) { return i * t + j; }; + auto opt_idx = [offset](int k) { return offset + k; }; + + const int hyp_len = source_length[index]; + const int ref_len = target_length[index]; + const scalar_t* hyp_begin = source + index * source_size; + const scalar_t* ref_begin = target + index * target_size; + + // dynamic programming + for (int i = 0; i <= hyp_len; i++) { + errors_curr[err_idx(i, 0)] = i; + } + for (int j = 0; j <= ref_len; j++) { + errors_curr[err_idx(0, j)] = j; + } + for (int i = 1; i <= hyp_len; i++) { + for (int j = 1; j <= ref_len; j++) { + errors_curr[err_idx(i, j)] = min( + min(errors_curr[err_idx(i - 1, j)], errors_curr[err_idx(i, j - 1)]) + + 1, + errors_curr[err_idx(i - 1, j - 1)] + + 2 * (*(hyp_begin + i - 1) == *(ref_begin + j - 1) ? 0 : 1)); + } + } + + // back-tracing + int i = hyp_len; + int j = ref_len; + int o = hyp_len + ref_len; + + for (int k = 0; k < source_size + target_size; k++) { + operations[opt_idx(k)] = 0; + } + + while ((i >= 0) && (j >= 0)) { + if ((i == 0) && (j == 0)) { + break; + } + + if ((j > 0) && + (errors_curr[err_idx(i, j - 1)] < errors_curr[err_idx(i, j)])) { + o--; + operations[opt_idx(o)] = 1; + j--; // insertion + } else if ( + (i > 0) && + (errors_curr[err_idx(i - 1, j)] < errors_curr[err_idx(i, j)])) { + o--; + operations[opt_idx(o)] = 2; + i--; // deletion + } else { + o--; + operations[opt_idx(o)] = 3; + i--; + j--; // do nothing + } + } + + // moving to the left + for (int k = 0; k < hyp_len + ref_len; k++) { + if (k + o < hyp_len + ref_len) { + operations[opt_idx(k)] = operations[opt_idx(k + o)]; + } else { + operations[opt_idx(k)] = 0; // padding + } + } +} + +torch::Tensor GenerateDeletionLabelCuda( + torch::Tensor source, + torch::Tensor operations) { + const auto batch_size = source.size(0); + at::TensorOptions options(source.device()); + options = options.dtype(at::ScalarType::Int); + auto labels = torch::empty({batch_size, source.size(1)}, options); + auto stream = at::cuda::getCurrentCUDAStream(source.device().index()); + + AT_DISPATCH_ALL_TYPES(source.scalar_type(), "generate_deletion_labels", ([&] { + generate_deletion_label_kernel + <<>>( + source.data_ptr(), + source.size(1), + operations.size(1), + operations.data_ptr(), + labels.data_ptr()); + })); + + return labels; +} + +std::pair GenerateInsertionLabelCuda( + torch::Tensor target, + torch::Tensor operations) { + const auto batch_size = target.size(0); + at::TensorOptions options(target.device()); + options = options.dtype(at::ScalarType::Int); + auto labels = torch::empty({batch_size, target.size(1)}, options); + auto masks = torch::empty({batch_size, target.size(1)}, options); + auto stream = at::cuda::getCurrentCUDAStream(target.device().index()); + + AT_DISPATCH_ALL_TYPES( + target.scalar_type(), "generate_insertion_labels", ([&] { + generate_insertion_label_kernel<<>>( + target.data_ptr(), + target.size(1), + operations.size(1), + operations.data_ptr(), + labels.data_ptr(), + masks.data_ptr()); + })); + + return std::make_pair(labels, masks); +} + +torch::Tensor LevenshteinDistanceCuda( + torch::Tensor source, + torch::Tensor target, + torch::Tensor source_length, + torch::Tensor target_length) { + const auto batch_size = source.size(0); + const auto shared_size = + (source.size(1) + 1) * (target.size(1) + 1) * sizeof(short); + + at::TensorOptions options(source.device()); + options = options.dtype(at::ScalarType::Int); + auto operations = + torch::empty({batch_size, source.size(1) + target.size(1)}, options); + auto stream = at::cuda::getCurrentCUDAStream(source.device().index()); + + if (shared_size > 40000) { + auto distances = torch::empty( + {batch_size, (source.size(1) + 1) * (target.size(1) + 1)}, options); + AT_DISPATCH_ALL_TYPES(source.scalar_type(), "levenshtein_distance", ([&] { + levenshtein_distance_kernel + <<>>( + source.data_ptr(), + target.data_ptr(), + source_length.data_ptr(), + target_length.data_ptr(), + source.size(1), + target.size(1), + operations.data_ptr(), + distances.data_ptr()); + })); + } else { + AT_DISPATCH_ALL_TYPES( + source.scalar_type(), "faster_levenshtein_distance", ([&] { + faster_levenshtein_distance_kernel + <<>>( + source.data_ptr(), + target.data_ptr(), + source_length.data_ptr(), + target_length.data_ptr(), + source.size(1), + target.size(1), + operations.data_ptr()); + })); + } + + return operations; +} diff --git a/fairseq/fairseq/clib/libnat_cuda/edit_dist.h b/fairseq/fairseq/clib/libnat_cuda/edit_dist.h new file mode 100644 index 0000000000000000000000000000000000000000..5220c52fd80529b90a67ba74e9ca73c668dab099 --- /dev/null +++ b/fairseq/fairseq/clib/libnat_cuda/edit_dist.h @@ -0,0 +1,25 @@ +/** + * Copyright 2017-present, Facebook, Inc. + * All rights reserved. + * + * This source code is licensed under the license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include + +torch::Tensor LevenshteinDistanceCuda( + torch::Tensor source, + torch::Tensor target, + torch::Tensor source_length, + torch::Tensor target_length); + +torch::Tensor GenerateDeletionLabelCuda( + torch::Tensor source, + torch::Tensor operations); + +std::pair GenerateInsertionLabelCuda( + torch::Tensor source, + torch::Tensor operations); diff --git a/fairseq/fairseq/config/__init__.py b/fairseq/fairseq/config/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6264236915a7269a4d920ee8213004374dd86a9a --- /dev/null +++ b/fairseq/fairseq/config/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. diff --git a/fairseq/fairseq/config/config.yaml b/fairseq/fairseq/config/config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..2ed7168cb7f7473c43d864478c5c6ce51639e030 --- /dev/null +++ b/fairseq/fairseq/config/config.yaml @@ -0,0 +1,19 @@ +# @package _group_ + +hydra: + run: + dir: . + +defaults: + - _self_ + - task: null + - model: null + - criterion: cross_entropy + - optimizer: null + - lr_scheduler: fixed + - bpe: null + - tokenizer: null + - scoring: null + - generation: null + - common_eval: null + - eval_lm: null diff --git a/fairseq/fairseq/config/fb_run_config/slurm.yaml b/fairseq/fairseq/config/fb_run_config/slurm.yaml new file mode 100644 index 0000000000000000000000000000000000000000..20cf8f52016ddc23d5bcb09ef94d900a035b81ca --- /dev/null +++ b/fairseq/fairseq/config/fb_run_config/slurm.yaml @@ -0,0 +1,29 @@ +# @package _global_ + +hydra: + job: + config: + override_dirname: + kv_sep: ':' + item_sep: '__' + exclude_keys: + - fb_run_config + - distributed_training.distributed_port + sweep: + dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname} + launcher: + cpus_per_task: 60 + gpus_per_node: ??? + tasks_per_node: 1 + nodes: 1 + partition: learnfair + mem_gb: 400 + timeout_min: 4320 + max_num_timeout: 10 + name: ${env:PREFIX}_${hydra.job.config_name} + submitit_folder: ${hydra.sweep.dir} + +distributed_training: + ddp_backend: c10d + distributed_world_size: ??? + distributed_port: ??? diff --git a/fairseq/fairseq/config/model/transformer_lm/transformer_lm_baevski_gbw.yaml b/fairseq/fairseq/config/model/transformer_lm/transformer_lm_baevski_gbw.yaml new file mode 100644 index 0000000000000000000000000000000000000000..30b1a4f1e0f5e7f7c2671ff8ec995cc32363f10f --- /dev/null +++ b/fairseq/fairseq/config/model/transformer_lm/transformer_lm_baevski_gbw.yaml @@ -0,0 +1,36 @@ +# @package _group_ +activation_fn: "relu" +dropout: 0.1 +attention_dropout: 0.1 +activation_dropout: 0.0 +relu_dropout: 0.0 +decoder_embed_dim: 512 +decoder_output_dim: 512 +decoder_input_dim: 512 +decoder_ffn_embed_dim: 4096 +decoder_layers: 12 +decoder_attention_heads: 16 +decoder_normalize_before: true +no_decoder_final_norm: true +adaptive_softmax_cutoff: null +adaptive_softmax_dropout: 0 +adaptive_softmax_factor: 4 +no_token_positional_embeddings: false +share_decoder_input_output_embed: false +character_embeddings: false +character_filters: "[(1, 64), (2, 128), (3, 192), (4, 256), (5, 256), (6, 256), (7, 256)]" +character_embedding_dim: 4 +char_embedder_highway_layers: 2 +adaptive_input: false +adaptive_input_factor: 4 +adaptive_input_cutoff: null +tie_adaptive_weights: false +tie_adaptive_proj: false +decoder_learned_pos: false +decoder_layerdrop: 0 +decoder_layers_to_keep: null +layernorm_embedding: false +no_scale_embedding: false +quant_noise_pq: 0 +quant_noise_pq_block_size: 8 +quant_noise_scalar: 0 diff --git a/fairseq/fairseq/config/model/transformer_lm/transformer_lm_baevski_wiki103.yaml b/fairseq/fairseq/config/model/transformer_lm/transformer_lm_baevski_wiki103.yaml new file mode 100644 index 0000000000000000000000000000000000000000..1154cfa660ee5ce6a272cd1a0049eead1e92c117 --- /dev/null +++ b/fairseq/fairseq/config/model/transformer_lm/transformer_lm_baevski_wiki103.yaml @@ -0,0 +1,36 @@ +# @package _group_ +activation_fn: "relu" +dropout: 0.3 +attention_dropout: 0.1 +activation_dropout: 0.1 +relu_dropout: 0.1 +decoder_embed_dim: 1024 +decoder_output_dim: 1024 +decoder_input_dim: 1024 +decoder_ffn_embed_dim: 4096 +decoder_layers: 16 +decoder_attention_heads: 8 +decoder_normalize_before: true +no_decoder_final_norm: true +adaptive_softmax_cutoff: "20000,60000" +adaptive_softmax_dropout: 0.2 +adaptive_softmax_factor: 4 +no_token_positional_embeddings: false +share_decoder_input_output_embed: false +character_embeddings: false +character_filters: "[(1, 64), (2, 128), (3, 192), (4, 256), (5, 256), (6, 256), (7, 256)]" +character_embedding_dim: 4 +char_embedder_highway_layers: 2 +adaptive_input: true +adaptive_input_factor: 4 +adaptive_input_cutoff: "20000,60000" +tie_adaptive_weights: true +tie_adaptive_proj: true +decoder_learned_pos: false +decoder_layerdrop: 0 +decoder_layers_to_keep: null +layernorm_embedding: false +no_scale_embedding: false +quant_noise_pq: 0 +quant_noise_pq_block_size: 8 +quant_noise_scalar: 0 diff --git a/fairseq/fairseq/config/model/transformer_lm/transformer_lm_big.yaml b/fairseq/fairseq/config/model/transformer_lm/transformer_lm_big.yaml new file mode 100644 index 0000000000000000000000000000000000000000..309575310bfc5d9c5cde31563073bef18abc646e --- /dev/null +++ b/fairseq/fairseq/config/model/transformer_lm/transformer_lm_big.yaml @@ -0,0 +1,36 @@ +# @package _group_ +activation_fn: "relu" +dropout: 0.1 +attention_dropout: 0.0 +activation_dropout: 0.0 +relu_dropout: 0.0 +decoder_embed_dim: 1024 +decoder_output_dim: 1024 +decoder_input_dim: 1024 +decoder_ffn_embed_dim: 4096 +decoder_layers: 12 +decoder_attention_heads: 16 +decoder_normalize_before: true +no_decoder_final_norm: false +adaptive_softmax_cutoff: null +adaptive_softmax_dropout: 0 +adaptive_softmax_factor: 4 +no_token_positional_embeddings: false +share_decoder_input_output_embed: false +character_embeddings: false +character_filters: "[(1, 64), (2, 128), (3, 192), (4, 256), (5, 256), (6, 256), (7, 256)]" +character_embedding_dim: 4 +char_embedder_highway_layers: 2 +adaptive_input: false +adaptive_input_factor: 4 +adaptive_input_cutoff: null +tie_adaptive_weights: false +tie_adaptive_proj: false +decoder_learned_pos: false +decoder_layerdrop: 0 +decoder_layers_to_keep: null +layernorm_embedding: false +no_scale_embedding: false +quant_noise_pq: 0 +quant_noise_pq_block_size: 8 +quant_noise_scalar: 0 diff --git a/fairseq/fairseq/config/model/transformer_lm/transformer_lm_gbw.yaml b/fairseq/fairseq/config/model/transformer_lm/transformer_lm_gbw.yaml new file mode 100644 index 0000000000000000000000000000000000000000..30b1a4f1e0f5e7f7c2671ff8ec995cc32363f10f --- /dev/null +++ b/fairseq/fairseq/config/model/transformer_lm/transformer_lm_gbw.yaml @@ -0,0 +1,36 @@ +# @package _group_ +activation_fn: "relu" +dropout: 0.1 +attention_dropout: 0.1 +activation_dropout: 0.0 +relu_dropout: 0.0 +decoder_embed_dim: 512 +decoder_output_dim: 512 +decoder_input_dim: 512 +decoder_ffn_embed_dim: 4096 +decoder_layers: 12 +decoder_attention_heads: 16 +decoder_normalize_before: true +no_decoder_final_norm: true +adaptive_softmax_cutoff: null +adaptive_softmax_dropout: 0 +adaptive_softmax_factor: 4 +no_token_positional_embeddings: false +share_decoder_input_output_embed: false +character_embeddings: false +character_filters: "[(1, 64), (2, 128), (3, 192), (4, 256), (5, 256), (6, 256), (7, 256)]" +character_embedding_dim: 4 +char_embedder_highway_layers: 2 +adaptive_input: false +adaptive_input_factor: 4 +adaptive_input_cutoff: null +tie_adaptive_weights: false +tie_adaptive_proj: false +decoder_learned_pos: false +decoder_layerdrop: 0 +decoder_layers_to_keep: null +layernorm_embedding: false +no_scale_embedding: false +quant_noise_pq: 0 +quant_noise_pq_block_size: 8 +quant_noise_scalar: 0 diff --git a/fairseq/fairseq/config/model/transformer_lm/transformer_lm_gpt.yaml b/fairseq/fairseq/config/model/transformer_lm/transformer_lm_gpt.yaml new file mode 100644 index 0000000000000000000000000000000000000000..2c6cb7be3801115371566932ffc78651c9ac6c0f --- /dev/null +++ b/fairseq/fairseq/config/model/transformer_lm/transformer_lm_gpt.yaml @@ -0,0 +1,36 @@ +# @package _group_ +activation_fn: "gelu" +dropout: 0.1 +attention_dropout: 0.1 +activation_dropout: 0.0 +relu_dropout: 0.0 +decoder_embed_dim: 768 +decoder_output_dim: 768 +decoder_input_dim: 768 +decoder_ffn_embed_dim: 3072 +decoder_layers: 12 +decoder_attention_heads: 12 +decoder_normalize_before: true +no_decoder_final_norm: false +adaptive_softmax_cutoff: null +adaptive_softmax_dropout: 0 +adaptive_softmax_factor: 4 +no_token_positional_embeddings: false +share_decoder_input_output_embed: false +character_embeddings: false +character_filters: "[(1, 64), (2, 128), (3, 192), (4, 256), (5, 256), (6, 256), (7, 256)]" +character_embedding_dim: 4 +char_embedder_highway_layers: 2 +adaptive_input: false +adaptive_input_factor: 4 +adaptive_input_cutoff: null +tie_adaptive_weights: false +tie_adaptive_proj: false +decoder_learned_pos: false +decoder_layerdrop: 0 +decoder_layers_to_keep: null +layernorm_embedding: false +no_scale_embedding: false +quant_noise_pq: 0 +quant_noise_pq_block_size: 8 +quant_noise_scalar: 0 diff --git a/fairseq/fairseq/config/model/transformer_lm/transformer_lm_gpt2_big.yaml b/fairseq/fairseq/config/model/transformer_lm/transformer_lm_gpt2_big.yaml new file mode 100644 index 0000000000000000000000000000000000000000..a08769a1781abdb13302bf57bf1338bcaf68a0ec --- /dev/null +++ b/fairseq/fairseq/config/model/transformer_lm/transformer_lm_gpt2_big.yaml @@ -0,0 +1,36 @@ +# @package _group_ +activation_fn: "gelu" +dropout: 0.1 +attention_dropout: 0.1 +activation_dropout: 0.0 +relu_dropout: 0.0 +decoder_embed_dim: 1600 +decoder_output_dim: 1600 +decoder_input_dim: 1600 +decoder_ffn_embed_dim: 6400 +decoder_layers: 48 +decoder_attention_heads: 25 +decoder_normalize_before: true +no_decoder_final_norm: false +adaptive_softmax_cutoff: null +adaptive_softmax_dropout: 0 +adaptive_softmax_factor: 4 +no_token_positional_embeddings: false +share_decoder_input_output_embed: false +character_embeddings: false +character_filters: "[(1, 64), (2, 128), (3, 192), (4, 256), (5, 256), (6, 256), (7, 256)]" +character_embedding_dim: 4 +char_embedder_highway_layers: 2 +adaptive_input: false +adaptive_input_factor: 4 +adaptive_input_cutoff: null +tie_adaptive_weights: false +tie_adaptive_proj: false +decoder_learned_pos: false +decoder_layerdrop: 0 +decoder_layers_to_keep: null +layernorm_embedding: false +no_scale_embedding: false +quant_noise_pq: 0 +quant_noise_pq_block_size: 8 +quant_noise_scalar: 0 diff --git a/fairseq/fairseq/config/model/transformer_lm/transformer_lm_gpt2_medium.yaml b/fairseq/fairseq/config/model/transformer_lm/transformer_lm_gpt2_medium.yaml new file mode 100644 index 0000000000000000000000000000000000000000..64261d793c0f1ae091c9bf5c8c77093a07326137 --- /dev/null +++ b/fairseq/fairseq/config/model/transformer_lm/transformer_lm_gpt2_medium.yaml @@ -0,0 +1,36 @@ +# @package _group_ +activation_fn: "gelu" +dropout: 0.1 +attention_dropout: 0.1 +activation_dropout: 0.0 +relu_dropout: 0.0 +decoder_embed_dim: 1280 +decoder_output_dim: 1280 +decoder_input_dim: 1280 +decoder_ffn_embed_dim: 5120 +decoder_layers: 36 +decoder_attention_heads: 20 +decoder_normalize_before: true +no_decoder_final_norm: false +adaptive_softmax_cutoff: null +adaptive_softmax_dropout: 0 +adaptive_softmax_factor: 4 +no_token_positional_embeddings: false +share_decoder_input_output_embed: false +character_embeddings: false +character_filters: "[(1, 64), (2, 128), (3, 192), (4, 256), (5, 256), (6, 256), (7, 256)]" +character_embedding_dim: 4 +char_embedder_highway_layers: 2 +adaptive_input: false +adaptive_input_factor: 4 +adaptive_input_cutoff: null +tie_adaptive_weights: false +tie_adaptive_proj: false +decoder_learned_pos: false +decoder_layerdrop: 0 +decoder_layers_to_keep: null +layernorm_embedding: false +no_scale_embedding: false +quant_noise_pq: 0 +quant_noise_pq_block_size: 8 +quant_noise_scalar: 0 diff --git a/fairseq/fairseq/config/model/transformer_lm/transformer_lm_gpt2_small.yaml b/fairseq/fairseq/config/model/transformer_lm/transformer_lm_gpt2_small.yaml new file mode 100644 index 0000000000000000000000000000000000000000..702e81f466c82edf40433589d389edbe0a7b96db --- /dev/null +++ b/fairseq/fairseq/config/model/transformer_lm/transformer_lm_gpt2_small.yaml @@ -0,0 +1,36 @@ +# @package _group_ +activation_fn: "gelu" +dropout: 0.1 +attention_dropout: 0.1 +activation_dropout: 0.0 +relu_dropout: 0.0 +decoder_embed_dim: 1024 +decoder_output_dim: 1024 +decoder_input_dim: 1024 +decoder_ffn_embed_dim: 4096 +decoder_layers: 24 +decoder_attention_heads: 16 +decoder_normalize_before: true +no_decoder_final_norm: false +adaptive_softmax_cutoff: null +adaptive_softmax_dropout: 0 +adaptive_softmax_factor: 4 +no_token_positional_embeddings: false +share_decoder_input_output_embed: false +character_embeddings: false +character_filters: "[(1, 64), (2, 128), (3, 192), (4, 256), (5, 256), (6, 256), (7, 256)]" +character_embedding_dim: 4 +char_embedder_highway_layers: 2 +adaptive_input: false +adaptive_input_factor: 4 +adaptive_input_cutoff: null +tie_adaptive_weights: false +tie_adaptive_proj: false +decoder_learned_pos: false +decoder_layerdrop: 0 +decoder_layers_to_keep: null +layernorm_embedding: false +no_scale_embedding: false +quant_noise_pq: 0 +quant_noise_pq_block_size: 8 +quant_noise_scalar: 0 diff --git a/fairseq/fairseq/config/model/transformer_lm/transformer_lm_wiki103.yaml b/fairseq/fairseq/config/model/transformer_lm/transformer_lm_wiki103.yaml new file mode 100644 index 0000000000000000000000000000000000000000..1154cfa660ee5ce6a272cd1a0049eead1e92c117 --- /dev/null +++ b/fairseq/fairseq/config/model/transformer_lm/transformer_lm_wiki103.yaml @@ -0,0 +1,36 @@ +# @package _group_ +activation_fn: "relu" +dropout: 0.3 +attention_dropout: 0.1 +activation_dropout: 0.1 +relu_dropout: 0.1 +decoder_embed_dim: 1024 +decoder_output_dim: 1024 +decoder_input_dim: 1024 +decoder_ffn_embed_dim: 4096 +decoder_layers: 16 +decoder_attention_heads: 8 +decoder_normalize_before: true +no_decoder_final_norm: true +adaptive_softmax_cutoff: "20000,60000" +adaptive_softmax_dropout: 0.2 +adaptive_softmax_factor: 4 +no_token_positional_embeddings: false +share_decoder_input_output_embed: false +character_embeddings: false +character_filters: "[(1, 64), (2, 128), (3, 192), (4, 256), (5, 256), (6, 256), (7, 256)]" +character_embedding_dim: 4 +char_embedder_highway_layers: 2 +adaptive_input: true +adaptive_input_factor: 4 +adaptive_input_cutoff: "20000,60000" +tie_adaptive_weights: true +tie_adaptive_proj: true +decoder_learned_pos: false +decoder_layerdrop: 0 +decoder_layers_to_keep: null +layernorm_embedding: false +no_scale_embedding: false +quant_noise_pq: 0 +quant_noise_pq_block_size: 8 +quant_noise_scalar: 0 diff --git a/fairseq/fairseq/config/model/wav2vec/vq_wav2vec_gumbel.yaml b/fairseq/fairseq/config/model/wav2vec/vq_wav2vec_gumbel.yaml new file mode 100644 index 0000000000000000000000000000000000000000..ee1329bf4612d8bb295c6cc3d8bc0a3bcef1777d --- /dev/null +++ b/fairseq/fairseq/config/model/wav2vec/vq_wav2vec_gumbel.yaml @@ -0,0 +1,5 @@ +# @package _group_ +activation: gelu +vq_type: gumbel +vq_depth: 2 +combine_groups: true diff --git a/fairseq/fairseq/config/model/wav2vec2/wav2vec2_base.yaml b/fairseq/fairseq/config/model/wav2vec2/wav2vec2_base.yaml new file mode 100644 index 0000000000000000000000000000000000000000..ce65499b808b9a3821cee4ca87c36e84d09005a1 --- /dev/null +++ b/fairseq/fairseq/config/model/wav2vec2/wav2vec2_base.yaml @@ -0,0 +1,8 @@ +# @package _group_ + +quantize_targets: true +final_dim: 256 +encoder_layerdrop: 0.05 +dropout_input: 0.1 +dropout_features: 0.1 +feature_grad_mult: 0.1 diff --git a/fairseq/fairseq/config/model/wav2vec2/wav2vec2_large.yaml b/fairseq/fairseq/config/model/wav2vec2/wav2vec2_large.yaml new file mode 100644 index 0000000000000000000000000000000000000000..5846f75243f27f201c85bfe6820815c015971275 --- /dev/null +++ b/fairseq/fairseq/config/model/wav2vec2/wav2vec2_large.yaml @@ -0,0 +1,20 @@ +# @package _group_ + +quantize_targets: true +extractor_mode: layer_norm +layer_norm_first: true +final_dim: 768 +latent_temp: [2.0,0.1,0.999995] +encoder_layerdrop: 0.0 +dropout_input: 0.0 +dropout_features: 0.0 +dropout: 0.0 +attention_dropout: 0.0 +conv_bias: true + +encoder_layers: 24 +encoder_embed_dim: 1024 +encoder_ffn_embed_dim: 4096 +encoder_attention_heads: 16 + +feature_grad_mult: 1.0 diff --git a/fairseq/fairseq/criterions/__init__.py b/fairseq/fairseq/criterions/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ecd65d34adec4222ac8781106560ebc5dc2622f5 --- /dev/null +++ b/fairseq/fairseq/criterions/__init__.py @@ -0,0 +1,36 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +"""isort:skip_file""" + +import importlib +import os + +from fairseq import registry +from fairseq.criterions.fairseq_criterion import ( # noqa + FairseqCriterion, + LegacyFairseqCriterion, +) +from omegaconf import DictConfig + + +( + build_criterion_, + register_criterion, + CRITERION_REGISTRY, + CRITERION_DATACLASS_REGISTRY, +) = registry.setup_registry( + "--criterion", base_class=FairseqCriterion, default="cross_entropy" +) + + +def build_criterion(cfg: DictConfig, task, from_checkpoint=False): + return build_criterion_(cfg, task, from_checkpoint=from_checkpoint) + + +# automatically import any Python files in the criterions/ directory +for file in sorted(os.listdir(os.path.dirname(__file__))): + if file.endswith(".py") and not file.startswith("_"): + file_name = file[: file.find(".py")] + importlib.import_module("fairseq.criterions." + file_name) diff --git a/fairseq/fairseq/criterions/__pycache__/fairseq_criterion.cpython-310.pyc b/fairseq/fairseq/criterions/__pycache__/fairseq_criterion.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0eb8eb442549851841d7daf251e501d917ad1032 Binary files /dev/null and b/fairseq/fairseq/criterions/__pycache__/fairseq_criterion.cpython-310.pyc differ diff --git a/fairseq/fairseq/criterions/adaptive_loss.py b/fairseq/fairseq/criterions/adaptive_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..fc1ac8540461ab402e6ba6b4b40afe363774fffc --- /dev/null +++ b/fairseq/fairseq/criterions/adaptive_loss.py @@ -0,0 +1,124 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import math +from dataclasses import dataclass + +import torch.nn.functional as F +from fairseq import utils +from fairseq.logging import metrics +from fairseq.criterions import FairseqCriterion, register_criterion +from fairseq.dataclass import FairseqDataclass +from fairseq.dataclass.constants import DDP_BACKEND_CHOICES +from omegaconf import II + + +@dataclass +class AdaptiveLossConfig(FairseqDataclass): + sentence_avg: bool = II("optimization.sentence_avg") + ddp_backend: DDP_BACKEND_CHOICES = II("distributed_training.ddp_backend") + + +@register_criterion("adaptive_loss", dataclass=AdaptiveLossConfig) +class AdaptiveLoss(FairseqCriterion): + """This is an implementation of the loss function accompanying the adaptive softmax approximation for + graphical processing units (GPU), described in the paper "Efficient softmax approximation for GPUs" + (http://arxiv.org/abs/1609.04309).""" + + def __init__(self, task, sentence_avg): + super().__init__(task) + self.sentence_avg = sentence_avg + + @classmethod + def build_criterion(cls, cfg: AdaptiveLossConfig, task): + if cfg.ddp_backend in {"c10d", "pytorch_ddp"}: + raise Exception( + "AdaptiveLoss is not compatible with the PyTorch " + "version of DistributedDataParallel. Please use " + "`--ddp-backend=legacy_ddp` instead." + ) + return cls(task, cfg.sentence_avg) + + def forward(self, model, sample, reduce=True): + """Compute the loss for the given sample. + + Returns a tuple with three elements: + 1) the loss + 2) the sample size, which is used as the denominator for the gradient + 3) logging outputs to display while training + """ + + assert ( + hasattr(model.decoder, "adaptive_softmax") + and model.decoder.adaptive_softmax is not None + ) + adaptive_softmax = model.decoder.adaptive_softmax + + net_output = model(**sample["net_input"]) + orig_target = model.get_targets(sample, net_output) + + nsentences = orig_target.size(0) + orig_target = orig_target.view(-1) + + bsz = orig_target.size(0) + + logits, target = adaptive_softmax(net_output[0], orig_target) + assert len(target) == len(logits) + + loss = net_output[0].new(1 if reduce else bsz).zero_() + + for i in range(len(target)): + if target[i] is not None: + assert target[i].min() >= 0 and target[i].max() <= logits[i].size(1) + loss += F.cross_entropy( + logits[i], + target[i], + ignore_index=self.padding_idx, + reduction="sum" if reduce else "none", + ) + + orig = utils.strip_pad(orig_target, self.padding_idx) + ntokens = orig.numel() + sample_size = sample["target"].size(0) if self.sentence_avg else ntokens + logging_output = { + "loss": loss.data, + "ntokens": ntokens, + "nsentences": nsentences, + "sample_size": sample_size, + } + return loss, sample_size, logging_output + + @staticmethod + def reduce_metrics(logging_outputs) -> None: + """Aggregate logging outputs from data parallel training.""" + loss_sum = utils.item(sum(log.get("loss", 0) for log in logging_outputs)) + ntokens = utils.item(sum(log.get("ntokens", 0) for log in logging_outputs)) + sample_size = utils.item( + sum(log.get("sample_size", 0) for log in logging_outputs) + ) + + metrics.log_scalar( + "loss", loss_sum / sample_size / math.log(2), sample_size, round=3 + ) + if sample_size != ntokens: + metrics.log_scalar( + "nll_loss", loss_sum / ntokens / math.log(2), ntokens, round=3 + ) + metrics.log_derived( + "ppl", lambda meters: utils.get_perplexity(meters["nll_loss"].avg) + ) + else: + metrics.log_derived( + "ppl", lambda meters: utils.get_perplexity(meters["loss"].avg) + ) + + @staticmethod + def logging_outputs_can_be_summed() -> bool: + """ + Whether the logging outputs returned by `forward` can be summed + across workers prior to calling `reduce_metrics`. Setting this + to True will improves distributed training speed. + """ + return True diff --git a/fairseq/fairseq/criterions/composite_loss.py b/fairseq/fairseq/criterions/composite_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..98e835fa6e4c0bcad062df9c519701bf795c98be --- /dev/null +++ b/fairseq/fairseq/criterions/composite_loss.py @@ -0,0 +1,100 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from fairseq import utils +from fairseq.criterions import LegacyFairseqCriterion, register_criterion +from torch import nn + + +@register_criterion("composite_loss") +class CompositeLoss(LegacyFairseqCriterion): + """This is a composite loss that, given a list of model outputs and a list of targets, + computes an average of losses for each output-target pair""" + + def __init__(self, args, task): + super().__init__(args, task) + self.underlying_criterion = args.underlying_criterion + + @staticmethod + def add_args(parser): + """Add criterion-specific arguments to the parser.""" + # fmt: off + parser.add_argument('--underlying-criterion', type=str, metavar='VAL', required=True, + help='underlying criterion to use for the composite loss') + # fmt: on + + @staticmethod + def build_underlying_criterion(args, task): + saved_criterion = args.criterion + args.criterion = args.underlying_criterion + assert saved_criterion != args.underlying_criterion + underlying_criterion = task.build_criterion(args) + args.criterion = saved_criterion + return underlying_criterion + + @classmethod + def build_criterion(cls, args, task): + underlying_criterion = CompositeLoss.build_underlying_criterion(args, task) + + class FakeModel(nn.Module): + def __init__(self, model, net_out, target): + super().__init__() + self.model = model + self.net_out = net_out + self.target = target + + def forward(self, **unused): + return self.net_out + + def get_normalized_probs(self, net_output, log_probs, sample=None): + return self.model.get_normalized_probs( + net_output, log_probs, sample=sample + ) + + def get_targets(self, *unused): + return self.target + + @property + def decoder(self): + return self.model.decoder + + class _CompositeLoss(LegacyFairseqCriterion): + def __init__(self, args, task, underlying_criterion): + super().__init__(args, task) + self.underlying_criterion = underlying_criterion + + def forward(self, model, sample, reduce=True): + net_outputs = model(**sample["net_input"]) + targets = sample["target"] + + bsz = targets[0].size(0) + loss = net_outputs[0][0].new(1 if reduce else bsz).float().zero_() + + sample_size = 0 + logging_output = {} + for o, t in zip(net_outputs[0], targets): + m = FakeModel(model, (o, net_outputs[1]), t) + sample["target"] = t + l, ss, logging_output = self.underlying_criterion(m, sample, reduce) + loss += l + sample_size += ss + + loss.div_(len(targets)) + sample_size /= len(targets) + + logging_output["loss"] = utils.item(loss.data) if reduce else loss.data + return loss, sample_size, logging_output + + @staticmethod + def aggregate_logging_outputs(logging_outputs): + return underlying_criterion.__class__.aggregate_logging_outputs( + logging_outputs + ) + + @staticmethod + def reduce_metrics(logging_outputs) -> None: + underlying_criterion.__class__.reduce_metrics(logging_outputs) + + return _CompositeLoss(args, task, underlying_criterion) diff --git a/fairseq/fairseq/criterions/cross_entropy.py b/fairseq/fairseq/criterions/cross_entropy.py new file mode 100644 index 0000000000000000000000000000000000000000..24d6bcd6128f9bfb7667adcdb3e7f001cc57a523 --- /dev/null +++ b/fairseq/fairseq/criterions/cross_entropy.py @@ -0,0 +1,91 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import math +from dataclasses import dataclass + +import torch.nn.functional as F +from fairseq import utils +from fairseq.logging import metrics +from fairseq.criterions import FairseqCriterion, register_criterion +from fairseq.dataclass import FairseqDataclass +from omegaconf import II + + +@dataclass +class CrossEntropyCriterionConfig(FairseqDataclass): + sentence_avg: bool = II("optimization.sentence_avg") + + +@register_criterion("cross_entropy", dataclass=CrossEntropyCriterionConfig) +class CrossEntropyCriterion(FairseqCriterion): + def __init__(self, task, sentence_avg): + super().__init__(task) + self.sentence_avg = sentence_avg + + def forward(self, model, sample, reduce=True): + """Compute the loss for the given sample. + + Returns a tuple with three elements: + 1) the loss + 2) the sample size, which is used as the denominator for the gradient + 3) logging outputs to display while training + """ + net_output = model(**sample["net_input"]) + loss, _ = self.compute_loss(model, net_output, sample, reduce=reduce) + sample_size = ( + sample["target"].size(0) if self.sentence_avg else sample["ntokens"] + ) + logging_output = { + "loss": loss.data, + "ntokens": sample["ntokens"], + "nsentences": sample["target"].size(0), + "sample_size": sample_size, + } + return loss, sample_size, logging_output + + def compute_loss(self, model, net_output, sample, reduce=True): + lprobs = model.get_normalized_probs(net_output, log_probs=True) + lprobs = lprobs.view(-1, lprobs.size(-1)) + target = model.get_targets(sample, net_output).view(-1) + loss = F.nll_loss( + lprobs, + target, + ignore_index=self.padding_idx, + reduction="sum" if reduce else "none", + ) + return loss, loss + + @staticmethod + def reduce_metrics(logging_outputs) -> None: + """Aggregate logging outputs from data parallel training.""" + loss_sum = sum(log.get("loss", 0) for log in logging_outputs) + ntokens = sum(log.get("ntokens", 0) for log in logging_outputs) + sample_size = sum(log.get("sample_size", 0) for log in logging_outputs) + + # we divide by log(2) to convert the loss from base e to base 2 + metrics.log_scalar( + "loss", loss_sum / sample_size / math.log(2), sample_size, round=3 + ) + if sample_size != ntokens: + metrics.log_scalar( + "nll_loss", loss_sum / ntokens / math.log(2), ntokens, round=3 + ) + metrics.log_derived( + "ppl", lambda meters: utils.get_perplexity(meters["nll_loss"].avg) + ) + else: + metrics.log_derived( + "ppl", lambda meters: utils.get_perplexity(meters["loss"].avg) + ) + + @staticmethod + def logging_outputs_can_be_summed() -> bool: + """ + Whether the logging outputs returned by `forward` can be summed + across workers prior to calling `reduce_metrics`. Setting this + to True will improves distributed training speed. + """ + return True diff --git a/fairseq/fairseq/criterions/ctc.py b/fairseq/fairseq/criterions/ctc.py new file mode 100644 index 0000000000000000000000000000000000000000..368213cb2b05bcfcec3ae84aef68c82bd792492b --- /dev/null +++ b/fairseq/fairseq/criterions/ctc.py @@ -0,0 +1,325 @@ +# All rights reserved. +# +# This source code is licensed under the license found in the LICENSE file in +# the root directory of this source tree. An additional grant of patent rights +# can be found in the PATENTS file in the same directory. + +import math +from argparse import Namespace +from dataclasses import dataclass, field +from omegaconf import II +from typing import Optional + +import torch +import torch.nn.functional as F + +from fairseq import utils +from fairseq.logging import metrics +from fairseq.criterions import FairseqCriterion, register_criterion +from fairseq.dataclass import FairseqDataclass +from fairseq.data.data_utils import post_process +from fairseq.tasks import FairseqTask +from fairseq.logging.meters import safe_round + + +@dataclass +class CtcCriterionConfig(FairseqDataclass): + zero_infinity: bool = field( + default=False, + metadata={"help": "zero inf loss when source length <= target length"}, + ) + sentence_avg: bool = II("optimization.sentence_avg") + post_process: str = field( + default="letter", + metadata={ + "help": "how to post process predictions into words. can be letter, " + "wordpiece, BPE symbols, etc. " + "See fairseq.data.data_utils.post_process() for full list of options" + }, + ) + wer_kenlm_model: Optional[str] = field( + default=None, + metadata={ + "help": "if this is provided, use kenlm to compute wer (along with other wer_* args)" + }, + ) + wer_lexicon: Optional[str] = field( + default=None, + metadata={"help": "lexicon to use with wer_kenlm_model"}, + ) + wer_lm_weight: float = field( + default=2.0, + metadata={"help": "lm weight to use with wer_kenlm_model"}, + ) + wer_word_score: float = field( + default=-1.0, + metadata={"help": "lm word score to use with wer_kenlm_model"}, + ) + wer_sil_weight: float = field( + default=0, + metadata={"help": "lm word score to use with wer_kenlm_model"}, + ) + + wer_args: Optional[str] = field( + default=None, + metadata={ + "help": "DEPRECATED: tuple of (wer_kenlm_model, wer_lexicon, wer_lm_weight, wer_word_score)" + }, + ) + + +@register_criterion("ctc", dataclass=CtcCriterionConfig) +class CtcCriterion(FairseqCriterion): + def __init__( + self, cfg: CtcCriterionConfig, task: FairseqTask, rdrop_alpha: int = 0.0 + ): + super().__init__(task) + self.blank_idx = ( + task.target_dictionary.index(task.blank_symbol) + if hasattr(task, "blank_symbol") + else 0 + ) + self.pad_idx = task.target_dictionary.pad() + self.eos_idx = task.target_dictionary.eos() + self.post_process = cfg.post_process + + self.rdrop_alpha = rdrop_alpha + + if cfg.wer_args is not None: + ( + cfg.wer_kenlm_model, + cfg.wer_lexicon, + cfg.wer_lm_weight, + cfg.wer_word_score, + ) = eval(cfg.wer_args) + + if cfg.wer_kenlm_model is not None and cfg.wer_kenlm_model != "": + from examples.speech_recognition.w2l_decoder import W2lKenLMDecoder + + dec_args = Namespace() + dec_args.nbest = 1 + dec_args.criterion = "ctc" + dec_args.kenlm_model = cfg.wer_kenlm_model + dec_args.lexicon = cfg.wer_lexicon + dec_args.beam = 50 + dec_args.beam_size_token = min(50, len(task.target_dictionary)) + dec_args.beam_threshold = min(50, len(task.target_dictionary)) + dec_args.lm_weight = cfg.wer_lm_weight + dec_args.word_score = cfg.wer_word_score + dec_args.sil_weight = cfg.wer_sil_weight + dec_args.unk_weight = -math.inf + dec_args.sil_weight = 0 + + self.w2l_decoder = W2lKenLMDecoder(dec_args, task.target_dictionary) + else: + self.w2l_decoder = None + + self.zero_infinity = cfg.zero_infinity + self.sentence_avg = cfg.sentence_avg + + def forward(self, model, sample, reduce=True, **kwargs): + net_output = model(**sample["net_input"]) + lprobs = model.get_normalized_probs( + net_output, log_probs=True + ).contiguous() # (T, B, C) from the encoder + + # CTC loss is calculated over duplicated inputs + # sample is already duplicated for R-Drop + if self.rdrop_alpha > 0: + for k, v in sample.items(): + if k in ["target", "target_lengths"]: + sample[k] = torch.cat([v, v.clone()], dim=0) + elif k == "net_input": + if sample[k]["src_tokens"].size(1) != sample[k]["src_lengths"].size( + 0 + ): + # for decoder CTC loss + sample[k]["src_lengths"] = torch.cat( + [ + sample[k]["src_lengths"], + sample[k]["src_lengths"].clone(), + ], + dim=0, + ) + + if "src_lengths" in sample["net_input"]: + input_lengths = sample["net_input"]["src_lengths"] + else: + if net_output["padding_mask"] is not None: + non_padding_mask = ~net_output["padding_mask"] + input_lengths = non_padding_mask.long().sum(-1) + else: + input_lengths = lprobs.new_full( + (lprobs.size(1),), lprobs.size(0), dtype=torch.long + ) + + pad_mask = (sample["target"] != self.pad_idx) & ( + sample["target"] != self.eos_idx + ) + targets_flat = sample["target"].masked_select(pad_mask) + if "target_lengths" in sample: + target_lengths = sample["target_lengths"] + else: + target_lengths = pad_mask.sum(-1) + + with torch.backends.cudnn.flags(enabled=False): + loss = F.ctc_loss( + lprobs, + targets_flat, + input_lengths, + target_lengths, + blank=self.blank_idx, + reduction="sum", + zero_infinity=self.zero_infinity, + ) + + ntokens = ( + sample["ntokens"] if "ntokens" in sample else target_lengths.sum().item() + ) + + sample_size = sample["target"].size(0) if self.sentence_avg else ntokens + logging_output = { + "loss": utils.item(loss.data), # * sample['ntokens'], + "ntokens": ntokens, + "nsentences": sample["id"].numel(), + "sample_size": sample_size, + } + + if not model.training: + import editdistance + + with torch.no_grad(): + lprobs_t = lprobs.transpose(0, 1).float().contiguous().cpu() + + c_err = 0 + c_len = 0 + w_errs = 0 + w_len = 0 + wv_errs = 0 + for lp, t, inp_l in zip( + lprobs_t, + sample["target_label"] + if "target_label" in sample + else sample["target"], + input_lengths, + ): + lp = lp[:inp_l].unsqueeze(0) + + decoded = None + if self.w2l_decoder is not None: + decoded = self.w2l_decoder.decode(lp) + if len(decoded) < 1: + decoded = None + else: + decoded = decoded[0] + if len(decoded) < 1: + decoded = None + else: + decoded = decoded[0] + + p = (t != self.task.target_dictionary.pad()) & ( + t != self.task.target_dictionary.eos() + ) + targ = t[p] + targ_units = self.task.target_dictionary.string(targ) + targ_units_arr = targ.tolist() + + toks = lp.argmax(dim=-1).unique_consecutive() + pred_units_arr = toks[toks != self.blank_idx].tolist() + + c_err += editdistance.eval(pred_units_arr, targ_units_arr) + c_len += len(targ_units_arr) + + targ_words = post_process(targ_units, self.post_process).split() + + pred_units = self.task.target_dictionary.string(pred_units_arr) + pred_words_raw = post_process(pred_units, self.post_process).split() + + if decoded is not None and "words" in decoded: + pred_words = decoded["words"] + w_errs += editdistance.eval(pred_words, targ_words) + wv_errs += editdistance.eval(pred_words_raw, targ_words) + else: + dist = editdistance.eval(pred_words_raw, targ_words) + w_errs += dist + wv_errs += dist + + w_len += len(targ_words) + + logging_output["wv_errors"] = wv_errs + logging_output["w_errors"] = w_errs + logging_output["w_total"] = w_len + logging_output["c_errors"] = c_err + logging_output["c_total"] = c_len + + return loss, sample_size, logging_output + + @staticmethod + def reduce_metrics(logging_outputs) -> None: + """Aggregate logging outputs from data parallel training.""" + + loss_sum = utils.item(sum(log.get("loss", 0) for log in logging_outputs)) + ntokens = utils.item(sum(log.get("ntokens", 0) for log in logging_outputs)) + nsentences = utils.item( + sum(log.get("nsentences", 0) for log in logging_outputs) + ) + sample_size = utils.item( + sum(log.get("sample_size", 0) for log in logging_outputs) + ) + + metrics.log_scalar( + "loss", loss_sum / sample_size / math.log(2), sample_size, round=3 + ) + metrics.log_scalar("ntokens", ntokens) + metrics.log_scalar("nsentences", nsentences) + if sample_size != ntokens: + metrics.log_scalar( + "nll_loss", loss_sum / ntokens / math.log(2), ntokens, round=3 + ) + + c_errors = sum(log.get("c_errors", 0) for log in logging_outputs) + metrics.log_scalar("_c_errors", c_errors) + c_total = sum(log.get("c_total", 0) for log in logging_outputs) + metrics.log_scalar("_c_total", c_total) + w_errors = sum(log.get("w_errors", 0) for log in logging_outputs) + metrics.log_scalar("_w_errors", w_errors) + wv_errors = sum(log.get("wv_errors", 0) for log in logging_outputs) + metrics.log_scalar("_wv_errors", wv_errors) + w_total = sum(log.get("w_total", 0) for log in logging_outputs) + metrics.log_scalar("_w_total", w_total) + + if c_total > 0: + metrics.log_derived( + "uer", + lambda meters: safe_round( + meters["_c_errors"].sum * 100.0 / meters["_c_total"].sum, 3 + ) + if meters["_c_total"].sum > 0 + else float("nan"), + ) + if w_total > 0: + metrics.log_derived( + "wer", + lambda meters: safe_round( + meters["_w_errors"].sum * 100.0 / meters["_w_total"].sum, 3 + ) + if meters["_w_total"].sum > 0 + else float("nan"), + ) + metrics.log_derived( + "raw_wer", + lambda meters: safe_round( + meters["_wv_errors"].sum * 100.0 / meters["_w_total"].sum, 3 + ) + if meters["_w_total"].sum > 0 + else float("nan"), + ) + + @staticmethod + def logging_outputs_can_be_summed() -> bool: + """ + Whether the logging outputs returned by `forward` can be summed + across workers prior to calling `reduce_metrics`. Setting this + to True will improves distributed training speed. + """ + return True diff --git a/fairseq/fairseq/criterions/fairseq_criterion.py b/fairseq/fairseq/criterions/fairseq_criterion.py new file mode 100644 index 0000000000000000000000000000000000000000..0b1e64a8e36c55be90d7e3f854effd99ed5bcc44 --- /dev/null +++ b/fairseq/fairseq/criterions/fairseq_criterion.py @@ -0,0 +1,121 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import inspect +from typing import Any, Dict, List + +from fairseq import utils +from fairseq.logging import metrics +from fairseq.dataclass import FairseqDataclass +from fairseq.dataclass.utils import gen_parser_from_dataclass +from torch.nn.modules.loss import _Loss + + +class FairseqCriterion(_Loss): + def __init__(self, task): + super().__init__() + self.task = task + if hasattr(task, "target_dictionary"): + tgt_dict = task.target_dictionary + self.padding_idx = tgt_dict.pad() if tgt_dict is not None else -100 + + @classmethod + def add_args(cls, parser): + """Add criterion-specific arguments to the parser.""" + dc = getattr(cls, "__dataclass", None) + if dc is not None: + gen_parser_from_dataclass(parser, dc()) + + @classmethod + def build_criterion(cls, cfg: FairseqDataclass, task): + """Construct a criterion from command-line args.""" + # arguments in the __init__. + init_args = {} + for p in inspect.signature(cls).parameters.values(): + if ( + p.kind == p.POSITIONAL_ONLY + or p.kind == p.VAR_POSITIONAL + or p.kind == p.VAR_KEYWORD + ): + # we haven't implemented inference for these argument types, + # but PRs welcome :) + raise NotImplementedError("{} not supported".format(p.kind)) + + assert p.kind in {p.POSITIONAL_OR_KEYWORD, p.KEYWORD_ONLY} + + if p.name == "task": + init_args["task"] = task + elif p.name == "cfg": + init_args["cfg"] = cfg + elif hasattr(cfg, p.name): + init_args[p.name] = getattr(cfg, p.name) + elif p.default != p.empty: + pass # we'll use the default value + else: + raise NotImplementedError( + "Unable to infer Criterion arguments, please implement " + "{}.build_criterion".format(cls.__name__) + ) + return cls(**init_args) + + def forward(self, model, sample, reduce=True): + """Compute the loss for the given sample. + + Returns a tuple with three elements: + 1) the loss + 2) the sample size, which is used as the denominator for the gradient + 3) logging outputs to display while training + """ + raise NotImplementedError + + @staticmethod + def aggregate_logging_outputs( + logging_outputs: List[Dict[str, Any]] + ) -> Dict[str, Any]: + """Aggregate logging outputs from data parallel training.""" + utils.deprecation_warning( + "The aggregate_logging_outputs API is deprecated. " + "Please use the reduce_metrics API instead." + ) + raise NotImplementedError + + @classmethod + def reduce_metrics(cls, logging_outputs: List[Dict[str, Any]]) -> None: + """Aggregate logging outputs from data parallel training.""" + utils.deprecation_warning( + "Criterions should implement the reduce_metrics API. " + "Falling back to deprecated aggregate_logging_outputs API." + ) + agg_logging_outputs = cls.aggregate_logging_outputs(logging_outputs) + for k, v in agg_logging_outputs.items(): + if k in {"nsentences", "ntokens", "sample_size"}: + continue + metrics.log_scalar(k, v) + + @staticmethod + def logging_outputs_can_be_summed() -> bool: + """ + Whether the logging outputs returned by `forward` can be summed + across workers prior to calling `reduce_metrics`. Setting this + to True will improves distributed training speed. + """ + return False + + +class LegacyFairseqCriterion(FairseqCriterion): + def __init__(self, args, task): + super().__init__(task=task) + self.args = args + + utils.deprecation_warning( + "Criterions should take explicit arguments instead of an " + "argparse.Namespace object, please update your criterion by " + "extending FairseqCriterion instead of LegacyFairseqCriterion." + ) + + @classmethod + def build_criterion(cls, args, task): + """Construct a criterion from command-line args.""" + return cls(args, task) diff --git a/fairseq/fairseq/criterions/fastspeech2_loss.py b/fairseq/fairseq/criterions/fastspeech2_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..ab7cd08e3bd9c8d5c4be017095034b18362d77e0 --- /dev/null +++ b/fairseq/fairseq/criterions/fastspeech2_loss.py @@ -0,0 +1,137 @@ +# Copyright (c) 2017-present, Facebook, Inc. +# All rights reserved. +# +# This source code is licensed under the license found in the LICENSE file in +# the root directory of this source tree. An additional grant of patent rights +# can be found in the PATENTS file in the same directory. + +from typing import List, Dict, Any +from dataclasses import dataclass, field + +import torch +import torch.nn.functional as F + +from fairseq import utils +from fairseq.logging import metrics +from fairseq.criterions import FairseqCriterion, register_criterion +from fairseq.dataclass import FairseqDataclass +from fairseq.data.data_utils import lengths_to_mask +from fairseq.models.fairseq_model import FairseqEncoderModel + + +@dataclass +class FastSpeech2CriterionConfig(FairseqDataclass): + ctc_weight: float = field(default=0.0, metadata={"help": "weight for CTC loss"}) + + +@register_criterion("fastspeech2", dataclass=FastSpeech2CriterionConfig) +class FastSpeech2Loss(FairseqCriterion): + def __init__(self, task, ctc_weight): + super().__init__(task) + self.ctc_weight = ctc_weight + + def forward(self, model: FairseqEncoderModel, sample, reduction="mean"): + src_tokens = sample["net_input"]["src_tokens"] + src_lens = sample["net_input"]["src_lengths"] + tgt_lens = sample["target_lengths"] + _feat_out, _feat_out_post, _, log_dur_out, pitch_out, energy_out = model( + src_tokens=src_tokens, + src_lengths=src_lens, + prev_output_tokens=sample["net_input"]["prev_output_tokens"], + incremental_state=None, + target_lengths=tgt_lens, + speaker=sample["speaker"], + durations=sample["durations"], + pitches=sample["pitches"], + energies=sample["energies"], + ) + + src_mask = lengths_to_mask(sample["net_input"]["src_lengths"]) + tgt_mask = lengths_to_mask(sample["target_lengths"]) + + pitches, energies = sample["pitches"], sample["energies"] + pitch_out, pitches = pitch_out[src_mask], pitches[src_mask] + energy_out, energies = energy_out[src_mask], energies[src_mask] + + feat_out, feat = _feat_out[tgt_mask], sample["target"][tgt_mask] + l1_loss = F.l1_loss(feat_out, feat, reduction=reduction) + if _feat_out_post is not None: + l1_loss += F.l1_loss(_feat_out_post[tgt_mask], feat, reduction=reduction) + + pitch_loss = F.mse_loss(pitch_out, pitches, reduction=reduction) + energy_loss = F.mse_loss(energy_out, energies, reduction=reduction) + + log_dur_out = log_dur_out[src_mask] + dur = sample["durations"].float() + dur = dur.half() if log_dur_out.type().endswith(".HalfTensor") else dur + log_dur = torch.log(dur + 1)[src_mask] + dur_loss = F.mse_loss(log_dur_out, log_dur, reduction=reduction) + + ctc_loss = torch.tensor(0.0).type_as(l1_loss) + if self.ctc_weight > 0.0: + lprobs = model.get_normalized_probs((_feat_out,), log_probs=True) + lprobs = lprobs.transpose(0, 1) # T x B x C + src_mask = lengths_to_mask(src_lens) + src_tokens_flat = src_tokens.masked_select(src_mask) + ctc_loss = ( + F.ctc_loss( + lprobs, + src_tokens_flat, + tgt_lens, + src_lens, + reduction=reduction, + zero_infinity=True, + ) + * self.ctc_weight + ) + + loss = l1_loss + dur_loss + pitch_loss + energy_loss + ctc_loss + + sample_size = sample["nsentences"] + logging_output = { + "loss": utils.item(loss.data), + "ntokens": sample["ntokens"], + "nsentences": sample["nsentences"], + "sample_size": sample_size, + "l1_loss": utils.item(l1_loss.data), + "dur_loss": utils.item(dur_loss.data), + "pitch_loss": utils.item(pitch_loss.data), + "energy_loss": utils.item(energy_loss.data), + "ctc_loss": utils.item(ctc_loss.data), + } + return loss, sample_size, logging_output + + @classmethod + def reduce_metrics(cls, logging_outputs: List[Dict[str, Any]]) -> None: + ns = [log.get("sample_size", 0) for log in logging_outputs] + ntot = sum(ns) + ws = [n / (ntot + 1e-8) for n in ns] + for key in [ + "loss", + "l1_loss", + "dur_loss", + "pitch_loss", + "energy_loss", + "ctc_loss", + ]: + vals = [log.get(key, 0) for log in logging_outputs] + val = sum(val * w for val, w in zip(vals, ws)) + metrics.log_scalar(key, val, ntot, round=3) + metrics.log_scalar("sample_size", ntot, len(logging_outputs)) + + # inference metrics + if "targ_frames" not in logging_outputs[0]: + return + n = sum(log.get("targ_frames", 0) for log in logging_outputs) + for key, new_key in [ + ("mcd_loss", "mcd_loss"), + ("pred_frames", "pred_ratio"), + ("nins", "ins_rate"), + ("ndel", "del_rate"), + ]: + val = sum(log.get(key, 0) for log in logging_outputs) + metrics.log_scalar(new_key, val / n, n, round=3) + + @staticmethod + def logging_outputs_can_be_summed() -> bool: + return False diff --git a/fairseq/fairseq/criterions/hubert_criterion.py b/fairseq/fairseq/criterions/hubert_criterion.py new file mode 100644 index 0000000000000000000000000000000000000000..262874b582aa4765981fd9cc958c7221596d681e --- /dev/null +++ b/fairseq/fairseq/criterions/hubert_criterion.py @@ -0,0 +1,195 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import math +import re +from dataclasses import dataclass, field +from typing import List, Optional + +import torch +import torch.nn.functional as F +from fairseq import utils +from fairseq.logging import metrics +from fairseq.criterions import FairseqCriterion, register_criterion +from fairseq.dataclass import FairseqDataclass + + +@dataclass +class HubertCriterionConfig(FairseqDataclass): + pred_masked_weight: float = field( + default=1.0, + metadata={"help": "weight for predictive loss for masked frames"}, + ) + pred_nomask_weight: float = field( + default=0.0, + metadata={"help": "weight for predictive loss for unmasked frames"}, + ) + loss_weights: Optional[List[float]] = field( + default=None, + metadata={"help": "weights for additional loss terms (not first one)"}, + ) + log_keys: List[str] = field( + default_factory=lambda: [], + metadata={"help": "output keys to log"}, + ) + + +@register_criterion("hubert", dataclass=HubertCriterionConfig) +class HubertCriterion(FairseqCriterion): + def __init__( + self, + task, + pred_masked_weight, + pred_nomask_weight, + loss_weights=None, + log_keys=None, + ): + super().__init__(task) + self.pred_masked_weight = pred_masked_weight + self.pred_nomask_weight = pred_nomask_weight + self.loss_weights = loss_weights + self.log_keys = [] if log_keys is None else log_keys + + def forward(self, model, sample, reduce=True, log_pred=False): + """Compute the loss for the given sample. + Returns a tuple with three elements: + 1) the loss + 2) the sample size, which is used as the denominator for the gradient + 3) logging outputs to display while training + """ + net_output = model(target_list=sample["target_list"], **sample["net_input"]) + loss = 0.0 + sample_size = 0 + logging_output = {} + reduction = "sum" if reduce else "none" + + loss_m_list = [] + logp_m_list = model.get_logits(net_output, True) + targ_m_list = model.get_targets(net_output, True) + assert self.pred_masked_weight == 0 or len(logp_m_list) > 0 + for i, (logp_m, targ_m) in enumerate(zip(logp_m_list, targ_m_list)): + loss_m = F.cross_entropy(logp_m, targ_m, reduction=reduction) + loss_m_list.append(loss_m) + logging_output[f"loss_m_{i}"] = loss_m.detach().item() + if self.pred_masked_weight > 0: + loss += self.pred_masked_weight * sum(loss_m_list) + sample_size += targ_m_list[0].numel() + + loss_u_list = [] + logp_u_list = model.get_logits(net_output, False) + targ_u_list = model.get_targets(net_output, False) + assert self.pred_nomask_weight == 0 or len(logp_u_list) > 0 + for i, (logp_u, targ_u) in enumerate(zip(logp_u_list, targ_u_list)): + loss_u = F.cross_entropy(logp_u, targ_u, reduction=reduction) + loss_u_list.append(loss_u) + logging_output[f"loss_u_{i}"] = loss_u.detach().item() + if self.pred_nomask_weight > 0: + loss += self.pred_nomask_weight * sum(loss_u_list) + sample_size += targ_u_list[0].numel() + + if self.loss_weights is not None: + assert hasattr(model, "get_extra_losses") + extra_losses, names = model.get_extra_losses(net_output) + if torch.is_tensor(extra_losses): + extra_losses = [extra_losses] + names = [names] + if len(self.loss_weights) == 1 and len(extra_losses) != 1: + self.loss_weights = [self.loss_weights[0]] * len(extra_losses) + assert len(extra_losses) == len( + self.loss_weights + ), f"{len(extra_losses)}, {len(self.loss_weights)}" + for p, n, coef in zip(extra_losses, names, self.loss_weights): + if coef != 0 and p is not None: + p = coef * p.float() * sample_size + loss += p + logging_output[f"loss_{n}"] = p.item() + + logging_output = { + "loss": loss.item() if reduce else loss, + "ntokens": sample_size, + "nsentences": sample["id"].numel(), + "sample_size": sample_size, + **logging_output, + } + + for lk in self.log_keys: + if lk in net_output: + logging_output[lk] = float((net_output[lk])) + + def compute_correct(logits): + if logits.numel() == 0: + return 0, 0 + else: + assert logits.dim() > 1, logits.shape + max = logits.argmax(-1) == 0 + min = logits.argmin(-1) == 0 + both = max & min + corr = max.long().sum().item() - both.long().sum().item() + count = max.numel() + return corr, count + + with torch.no_grad(): + for i, logp_m in enumerate(logp_m_list): + corr_m, count_m = compute_correct(logp_m) + logging_output[f"correct_m_{i}"] = corr_m + logging_output[f"count_m_{i}"] = count_m + + for i, logp_u in enumerate(logp_u_list): + corr_u, count_u = compute_correct(logp_u) + logging_output[f"correct_u_{i}"] = corr_u + logging_output[f"count_u_{i}"] = count_u + + return loss, sample_size, logging_output + + @staticmethod + def reduce_metrics(logging_outputs) -> None: + """Aggregate logging outputs from data parallel training (copied from normal cross entropy).""" + loss_sum = sum(log.get("loss", 0) for log in logging_outputs) + ntokens = sum(log.get("ntokens", 0) for log in logging_outputs) + sample_size = sum(log.get("sample_size", 0) for log in logging_outputs) + + metrics.log_scalar( + "loss", loss_sum / sample_size / math.log(2), sample_size, round=3 + ) + if sample_size != ntokens: + metrics.log_scalar( + "nll_loss", loss_sum / ntokens / math.log(2), ntokens, round=3 + ) + metrics.log_derived( + "ppl", lambda meters: utils.get_perplexity(meters["nll_loss"].avg) + ) + else: + metrics.log_derived( + "ppl", lambda meters: utils.get_perplexity(meters["loss"].avg) + ) + + counts = {} + for lk in logging_outputs[0].keys(): + if lk.startswith("count_"): + val = sum(log[lk] for log in logging_outputs) + metrics.log_scalar(lk, val) + counts[lk] = val + + for lk in logging_outputs[0].keys(): + if lk.startswith("loss_"): + val = sum(log[lk] for log in logging_outputs) + metrics.log_scalar(lk, val / sample_size / math.log(2), round=3) + elif lk.startswith("correct_"): + val = sum(log[lk] for log in logging_outputs) + metrics.log_scalar(lk, val / counts[re.sub("correct", "count", lk)]) + + @staticmethod + def aggregate_logging_outputs(logging_outputs): + """Aggregate logging outputs from data parallel training.""" + raise NotImplementedError() + + @staticmethod + def logging_outputs_can_be_summed() -> bool: + """ + Whether the logging outputs returned by `forward` can be summed + across workers prior to calling `reduce_metrics`. Setting this + to True will improves distributed training speed. + """ + return False diff --git a/fairseq/fairseq/criterions/label_smoothed_cross_entropy.py b/fairseq/fairseq/criterions/label_smoothed_cross_entropy.py new file mode 100644 index 0000000000000000000000000000000000000000..325679bb1678928b9fe644293b39f00115300a15 --- /dev/null +++ b/fairseq/fairseq/criterions/label_smoothed_cross_entropy.py @@ -0,0 +1,168 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import math +from dataclasses import dataclass, field + +import torch +from fairseq import utils +from fairseq.logging import metrics +from fairseq.criterions import FairseqCriterion, register_criterion +from fairseq.dataclass import FairseqDataclass +from omegaconf import II + + +@dataclass +class LabelSmoothedCrossEntropyCriterionConfig(FairseqDataclass): + label_smoothing: float = field( + default=0.0, + metadata={"help": "epsilon for label smoothing, 0 means no label smoothing"}, + ) + report_accuracy: bool = field( + default=False, + metadata={"help": "report accuracy metric"}, + ) + ignore_prefix_size: int = field( + default=0, + metadata={"help": "Ignore first N tokens"}, + ) + sentence_avg: bool = II("optimization.sentence_avg") + + +def label_smoothed_nll_loss(lprobs, target, epsilon, ignore_index=None, reduce=True): + if target.dim() == lprobs.dim() - 1: + target = target.unsqueeze(-1) + nll_loss = -lprobs.gather(dim=-1, index=target) + smooth_loss = -lprobs.sum(dim=-1, keepdim=True) + if ignore_index is not None: + pad_mask = target.eq(ignore_index) + nll_loss.masked_fill_(pad_mask, 0.0) + smooth_loss.masked_fill_(pad_mask, 0.0) + else: + nll_loss = nll_loss.squeeze(-1) + smooth_loss = smooth_loss.squeeze(-1) + if reduce: + nll_loss = nll_loss.sum() + smooth_loss = smooth_loss.sum() + eps_i = epsilon / (lprobs.size(-1) - 1) + loss = (1.0 - epsilon - eps_i) * nll_loss + eps_i * smooth_loss + return loss, nll_loss + + +@register_criterion( + "label_smoothed_cross_entropy", dataclass=LabelSmoothedCrossEntropyCriterionConfig +) +class LabelSmoothedCrossEntropyCriterion(FairseqCriterion): + def __init__( + self, + task, + sentence_avg, + label_smoothing, + ignore_prefix_size=0, + report_accuracy=False, + ): + super().__init__(task) + self.sentence_avg = sentence_avg + self.eps = label_smoothing + self.ignore_prefix_size = ignore_prefix_size + self.report_accuracy = report_accuracy + + def forward(self, model, sample, reduce=True): + """Compute the loss for the given sample. + + Returns a tuple with three elements: + 1) the loss + 2) the sample size, which is used as the denominator for the gradient + 3) logging outputs to display while training + """ + net_output = model(**sample["net_input"]) + loss, nll_loss = self.compute_loss(model, net_output, sample, reduce=reduce) + sample_size = ( + sample["target"].size(0) if self.sentence_avg else sample["ntokens"] + ) + logging_output = { + "loss": loss.data, + "nll_loss": nll_loss.data, + "ntokens": sample["ntokens"], + "nsentences": sample["target"].size(0), + "sample_size": sample_size, + } + if self.report_accuracy: + n_correct, total = self.compute_accuracy(model, net_output, sample) + logging_output["n_correct"] = utils.item(n_correct.data) + logging_output["total"] = utils.item(total.data) + return loss, sample_size, logging_output + + def get_lprobs_and_target(self, model, net_output, sample): + lprobs = model.get_normalized_probs(net_output, log_probs=True) + target = model.get_targets(sample, net_output) + if self.ignore_prefix_size > 0: + # lprobs: B x T x C + lprobs = lprobs[:, self.ignore_prefix_size :, :].contiguous() + target = target[:, self.ignore_prefix_size :].contiguous() + return lprobs.view(-1, lprobs.size(-1)), target.view(-1) + + def compute_loss(self, model, net_output, sample, reduce=True): + lprobs, target = self.get_lprobs_and_target(model, net_output, sample) + loss, nll_loss = label_smoothed_nll_loss( + lprobs, + target, + self.eps, + ignore_index=self.padding_idx, + reduce=reduce, + ) + return loss, nll_loss + + def compute_accuracy(self, model, net_output, sample): + lprobs, target = self.get_lprobs_and_target(model, net_output, sample) + mask = target.ne(self.padding_idx) + n_correct = torch.sum( + lprobs.argmax(1).masked_select(mask).eq(target.masked_select(mask)) + ) + total = torch.sum(mask) + return n_correct, total + + @classmethod + def reduce_metrics(cls, logging_outputs) -> None: + """Aggregate logging outputs from data parallel training.""" + loss_sum = sum(log.get("loss", 0) for log in logging_outputs) + nll_loss_sum = sum(log.get("nll_loss", 0) for log in logging_outputs) + ntokens = sum(log.get("ntokens", 0) for log in logging_outputs) + sample_size = sum(log.get("sample_size", 0) for log in logging_outputs) + + metrics.log_scalar( + "loss", loss_sum / sample_size / math.log(2), sample_size, round=3 + ) + metrics.log_scalar( + "nll_loss", nll_loss_sum / ntokens / math.log(2), ntokens, round=3 + ) + metrics.log_derived( + "ppl", lambda meters: utils.get_perplexity(meters["nll_loss"].avg) + ) + + total = utils.item(sum(log.get("total", 0) for log in logging_outputs)) + if total > 0: + metrics.log_scalar("total", total) + n_correct = utils.item( + sum(log.get("n_correct", 0) for log in logging_outputs) + ) + metrics.log_scalar("n_correct", n_correct) + metrics.log_derived( + "accuracy", + lambda meters: round( + meters["n_correct"].sum * 100.0 / meters["total"].sum, 3 + ) + if meters["total"].sum > 0 + else float("nan"), + ) + + @staticmethod + def logging_outputs_can_be_summed() -> bool: + """ + Whether the logging outputs returned by `forward` can be summed + across workers prior to calling `reduce_metrics`. Setting this + to True will improves distributed training speed. + """ + return True diff --git a/fairseq/fairseq/criterions/label_smoothed_cross_entropy_latency_augmented.py b/fairseq/fairseq/criterions/label_smoothed_cross_entropy_latency_augmented.py new file mode 100644 index 0000000000000000000000000000000000000000..6eaedab9cf35dd6636e3463fdab1f0d4f9dda7e4 --- /dev/null +++ b/fairseq/fairseq/criterions/label_smoothed_cross_entropy_latency_augmented.py @@ -0,0 +1,221 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from dataclasses import dataclass, field +import torch +from fairseq import utils +from fairseq.logging import metrics +from fairseq.criterions import register_criterion +from fairseq.criterions.label_smoothed_cross_entropy import ( + LabelSmoothedCrossEntropyCriterion, + LabelSmoothedCrossEntropyCriterionConfig, +) + +try: + from simuleval.metrics.latency import ( + AverageLagging, + AverageProportion, + DifferentiableAverageLagging, + ) + + LATENCY_METRICS = { + "average_lagging": AverageLagging, + "average_proportion": AverageProportion, + "differentiable_average_lagging": DifferentiableAverageLagging, + } +except ImportError: + LATENCY_METRICS = None + + +@dataclass +class LabelSmoothedCrossEntropyCriterionLatencyAugmentConfig( + LabelSmoothedCrossEntropyCriterionConfig +): + latency_avg_weight: float = field( + default=0.0, + metadata={"help": "weight fot average latency loss."}, + ) + latency_var_weight: float = field( + default=0.0, + metadata={"help": "weight fot variance latency loss."}, + ) + latency_avg_type: str = field( + default="differentiable_average_lagging", + metadata={"help": "latency type for average loss"}, + ) + latency_var_type: str = field( + default="variance_delay", + metadata={"help": "latency typ for variance loss"}, + ) + latency_gather_method: str = field( + default="weighted_average", + metadata={"help": "method to gather latency loss for all heads"}, + ) + latency_update_after: int = field( + default=0, + metadata={"help": "Add latency loss after certain steps"}, + ) + + +@register_criterion( + "latency_augmented_label_smoothed_cross_entropy", + dataclass=LabelSmoothedCrossEntropyCriterionLatencyAugmentConfig, +) +class LatencyAugmentedLabelSmoothedCrossEntropyCriterion( + LabelSmoothedCrossEntropyCriterion +): + def __init__( + self, + task, + sentence_avg, + label_smoothing, + ignore_prefix_size, + report_accuracy, + latency_avg_weight, + latency_var_weight, + latency_avg_type, + latency_var_type, + latency_gather_method, + latency_update_after, + ): + super().__init__( + task, sentence_avg, label_smoothing, ignore_prefix_size, report_accuracy + ) + assert LATENCY_METRICS is not None, "Please make sure SimulEval is installed." + + self.latency_avg_weight = latency_avg_weight + self.latency_var_weight = latency_var_weight + self.latency_avg_type = latency_avg_type + self.latency_var_type = latency_var_type + self.latency_gather_method = latency_gather_method + self.latency_update_after = latency_update_after + + def forward(self, model, sample, reduce=True): + net_output = model(**sample["net_input"]) + # 1. Compute cross entropy loss + loss, nll_loss = self.compute_loss(model, net_output, sample, reduce=reduce) + + # 2. Compute cross latency loss + latency_loss, expected_latency, expected_delays_var = self.compute_latency_loss( + model, sample, net_output + ) + + if self.latency_update_after > 0: + num_updates = getattr(model.decoder, "num_updates", None) + assert ( + num_updates is not None + ), "model.decoder doesn't have attribute 'num_updates'" + if num_updates <= self.latency_update_after: + latency_loss = 0 + + loss += latency_loss + + sample_size = ( + sample["target"].size(0) if self.sentence_avg else sample["ntokens"] + ) + + logging_output = { + "loss": loss.data, + "nll_loss": nll_loss.data, + "ntokens": sample["ntokens"], + "nsentences": sample["target"].size(0), + "sample_size": sample_size, + "latency": expected_latency, + "delays_var": expected_delays_var, + "latency_loss": latency_loss, + } + + if self.report_accuracy: + n_correct, total = self.compute_accuracy(model, net_output, sample) + logging_output["n_correct"] = utils.item(n_correct.data) + logging_output["total"] = utils.item(total.data) + return loss, sample_size, logging_output + + def compute_latency_loss(self, model, sample, net_output): + assert ( + net_output[-1].encoder_padding_mask is None + or not net_output[-1].encoder_padding_mask[:, 0].any() + ), "Only right padding on source is supported." + # 1. Obtain the expected alignment + alpha_list = [item["alpha"] for item in net_output[1].attn_list] + num_layers = len(alpha_list) + bsz, num_heads, tgt_len, src_len = alpha_list[0].size() + + # bsz * num_layers * num_heads, tgt_len, src_len + alpha_all = torch.cat(alpha_list, dim=1).view(-1, tgt_len, src_len) + + # 2 compute expected delays + # bsz * num_heads * num_layers, tgt_len, src_len for MMA + steps = ( + torch.arange(1, 1 + src_len) + .unsqueeze(0) + .unsqueeze(1) + .expand_as(alpha_all) + .type_as(alpha_all) + ) + + expected_delays = torch.sum(steps * alpha_all, dim=-1) + + target_padding_mask = ( + model.get_targets(sample, net_output) + .eq(self.padding_idx) + .unsqueeze(1) + .expand(bsz, num_layers * num_heads, tgt_len) + .contiguous() + .view(-1, tgt_len) + ) + + src_lengths = ( + sample["net_input"]["src_lengths"] + .unsqueeze(1) + .expand(bsz, num_layers * num_heads) + .contiguous() + .view(-1) + ) + expected_latency = LATENCY_METRICS[self.latency_avg_type]( + expected_delays, src_lengths, None, target_padding_mask=target_padding_mask + ) + + # 2.1 average expected latency of heads + # bsz, num_layers * num_heads + expected_latency = expected_latency.view(bsz, -1) + if self.latency_gather_method == "average": + # bsz * tgt_len + expected_latency = expected_delays.mean(dim=1) + elif self.latency_gather_method == "weighted_average": + weights = torch.nn.functional.softmax(expected_latency, dim=1) + expected_latency = torch.sum(expected_latency * weights, dim=1) + elif self.latency_gather_method == "max": + expected_latency = expected_latency.max(dim=1)[0] + else: + raise NotImplementedError + + expected_latency = expected_latency.sum() + avg_loss = self.latency_avg_weight * expected_latency + + # 2.2 variance of expected delays + expected_delays_var = ( + expected_delays.view(bsz, -1, tgt_len).var(dim=1).mean(dim=1) + ) + expected_delays_var = expected_delays_var.sum() + var_loss = self.latency_avg_weight * expected_delays_var + + # 3. Final loss + latency_loss = avg_loss + var_loss + + return latency_loss, expected_latency, expected_delays_var + + @classmethod + def reduce_metrics(cls, logging_outputs) -> None: + super().reduce_metrics(logging_outputs) + latency = sum(log.get("latency", 0) for log in logging_outputs) + delays_var = sum(log.get("delays_var", 0) for log in logging_outputs) + latency_loss = sum(log.get("latency_loss", 0) for log in logging_outputs) + nsentences = sum(log.get("nsentences", 0) for log in logging_outputs) + metrics.log_scalar("latency", latency.float() / nsentences, nsentences, round=3) + metrics.log_scalar("delays_var", delays_var / nsentences, nsentences, round=3) + metrics.log_scalar( + "latency_loss", latency_loss / nsentences, nsentences, round=3 + ) diff --git a/fairseq/fairseq/criterions/label_smoothed_cross_entropy_with_alignment.py b/fairseq/fairseq/criterions/label_smoothed_cross_entropy_with_alignment.py new file mode 100644 index 0000000000000000000000000000000000000000..b55f65e5cc7e9e949208786a4974b4ba09b0de66 --- /dev/null +++ b/fairseq/fairseq/criterions/label_smoothed_cross_entropy_with_alignment.py @@ -0,0 +1,131 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import math + +from fairseq import utils +from fairseq.logging import metrics +from fairseq.criterions import register_criterion + +from .label_smoothed_cross_entropy import ( + LabelSmoothedCrossEntropyCriterion, + LabelSmoothedCrossEntropyCriterionConfig, +) + +from dataclasses import dataclass, field + + +@dataclass +class LabelSmoothedCrossEntropyCriterionWithAlignmentConfig( + LabelSmoothedCrossEntropyCriterionConfig +): + alignment_lambda: float = field( + default=0.05, metadata={"help": "weight for the alignment loss"} + ) + + +@register_criterion( + "label_smoothed_cross_entropy_with_alignment", + dataclass=LabelSmoothedCrossEntropyCriterionWithAlignmentConfig, +) +class LabelSmoothedCrossEntropyCriterionWithAlignment( + LabelSmoothedCrossEntropyCriterion +): + def __init__(self, task, sentence_avg, label_smoothing, alignment_lambda): + super().__init__(task, sentence_avg, label_smoothing) + self.alignment_lambda = alignment_lambda + + def forward(self, model, sample, reduce=True): + """Compute the loss for the given sample. + + Returns a tuple with three elements: + 1) the loss + 2) the sample size, which is used as the denominator for the gradient + 3) logging outputs to display while training + """ + net_output = model(**sample["net_input"]) + loss, nll_loss = self.compute_loss(model, net_output, sample, reduce=reduce) + sample_size = ( + sample["target"].size(0) if self.sentence_avg else sample["ntokens"] + ) + logging_output = { + "loss": utils.item(loss.data) if reduce else loss.data, + "nll_loss": utils.item(nll_loss.data) if reduce else nll_loss.data, + "ntokens": sample["ntokens"], + "nsentences": sample["target"].size(0), + "sample_size": sample_size, + } + + alignment_loss = None + + # Compute alignment loss only for training set and non dummy batches. + if "alignments" in sample and sample["alignments"] is not None: + alignment_loss = self.compute_alignment_loss(sample, net_output) + + if alignment_loss is not None: + logging_output["alignment_loss"] = utils.item(alignment_loss.data) + loss += self.alignment_lambda * alignment_loss + + return loss, sample_size, logging_output + + def compute_alignment_loss(self, sample, net_output): + attn_prob = net_output[1]["attn"][0] + bsz, tgt_sz, src_sz = attn_prob.shape + attn = attn_prob.view(bsz * tgt_sz, src_sz) + + align = sample["alignments"] + align_weights = sample["align_weights"].float() + + if len(align) > 0: + # Alignment loss computation. align (shape [:, 2]) contains the src-tgt index pairs corresponding to + # the alignments. align_weights (shape [:]) contains the 1 / frequency of a tgt index for normalizing. + loss = -( + (attn[align[:, 1][:, None], align[:, 0][:, None]]).log() + * align_weights[:, None] + ).sum() + else: + return None + + return loss + + @staticmethod + def reduce_metrics(logging_outputs) -> None: + """Aggregate logging outputs from data parallel training.""" + loss_sum = utils.item(sum(log.get("loss", 0) for log in logging_outputs)) + nll_loss_sum = utils.item( + sum(log.get("nll_loss", 0) for log in logging_outputs) + ) + alignment_loss_sum = utils.item( + sum(log.get("alignment_loss", 0) for log in logging_outputs) + ) + ntokens = utils.item(sum(log.get("ntokens", 0) for log in logging_outputs)) + sample_size = utils.item( + sum(log.get("sample_size", 0) for log in logging_outputs) + ) + + metrics.log_scalar( + "loss", loss_sum / sample_size / math.log(2), sample_size, round=3 + ) + metrics.log_scalar( + "nll_loss", nll_loss_sum / ntokens / math.log(2), ntokens, round=3 + ) + metrics.log_scalar( + "alignment_loss", + alignment_loss_sum / sample_size / math.log(2), + sample_size, + round=3, + ) + metrics.log_derived( + "ppl", lambda meters: utils.get_perplexity(meters["nll_loss"].avg) + ) + + @staticmethod + def logging_outputs_can_be_summed() -> bool: + """ + Whether the logging outputs returned by `forward` can be summed + across workers prior to calling `reduce_metrics`. Setting this + to True will improves distributed training speed. + """ + return True diff --git a/fairseq/fairseq/criterions/label_smoothed_cross_entropy_with_ctc.py b/fairseq/fairseq/criterions/label_smoothed_cross_entropy_with_ctc.py new file mode 100644 index 0000000000000000000000000000000000000000..f2e8cdf3bfe0caea99125c6f9607dff9495891cf --- /dev/null +++ b/fairseq/fairseq/criterions/label_smoothed_cross_entropy_with_ctc.py @@ -0,0 +1,97 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import math +from dataclasses import dataclass, field + +import torch +import torch.nn.functional as F + +from fairseq import utils +from fairseq.logging import metrics +from fairseq.criterions import register_criterion +from fairseq.criterions.label_smoothed_cross_entropy import ( + LabelSmoothedCrossEntropyCriterion, + LabelSmoothedCrossEntropyCriterionConfig, +) +from fairseq.data.data_utils import lengths_to_mask + + +@dataclass +class LabelSmoothedCrossEntropyWithCtcCriterionConfig( + LabelSmoothedCrossEntropyCriterionConfig +): + ctc_weight: float = field(default=1.0, metadata={"help": "weight for CTC loss"}) + + +@register_criterion( + "label_smoothed_cross_entropy_with_ctc", + dataclass=LabelSmoothedCrossEntropyWithCtcCriterionConfig, +) +class LabelSmoothedCrossEntropyWithCtcCriterion(LabelSmoothedCrossEntropyCriterion): + def __init__( + self, + task, + sentence_avg, + label_smoothing, + ignore_prefix_size, + report_accuracy, + ctc_weight, + ): + super().__init__( + task, sentence_avg, label_smoothing, ignore_prefix_size, report_accuracy + ) + self.ctc_weight = ctc_weight + + def forward(self, model, sample, reduce=True): + net_output = model(**sample["net_input"]) + loss, nll_loss = self.compute_loss(model, net_output, sample, reduce=reduce) + + ctc_loss = torch.tensor(0.0).type_as(loss) + if self.ctc_weight > 0.0: + ctc_lprobs, ctc_lens = model.get_ctc_output(net_output, sample) + ctc_tgt, ctc_tgt_lens = model.get_ctc_target(sample) + ctc_tgt_mask = lengths_to_mask(ctc_tgt_lens) + ctc_tgt_flat = ctc_tgt.masked_select(ctc_tgt_mask) + reduction = "sum" if reduce else "none" + ctc_loss = ( + F.ctc_loss( + ctc_lprobs, + ctc_tgt_flat, + ctc_lens, + ctc_tgt_lens, + reduction=reduction, + zero_infinity=True, + ) + * self.ctc_weight + ) + loss += ctc_loss + + sample_size = ( + sample["target"].size(0) if self.sentence_avg else sample["ntokens"] + ) + logging_output = { + "loss": utils.item(loss.data), + "nll_loss": utils.item(nll_loss.data), + "ctc_loss": utils.item(ctc_loss.data), + "ntokens": sample["ntokens"], + "nsentences": sample["target"].size(0), + "sample_size": sample_size, + } + if self.report_accuracy: + n_correct, total = self.compute_accuracy(model, net_output, sample) + logging_output["n_correct"] = utils.item(n_correct.data) + logging_output["total"] = utils.item(total.data) + return loss, sample_size, logging_output + + @classmethod + def reduce_metrics(cls, logging_outputs) -> None: + super().reduce_metrics(logging_outputs) + loss_sum = sum(log.get("ctc_loss", 0) for log in logging_outputs) + sample_size = sum(log.get("sample_size", 0) for log in logging_outputs) + + metrics.log_scalar( + "ctc_loss", loss_sum / sample_size / math.log(2), sample_size, round=3 + ) diff --git a/fairseq/fairseq/criterions/label_smoothed_cross_entropy_with_rdrop.py b/fairseq/fairseq/criterions/label_smoothed_cross_entropy_with_rdrop.py new file mode 100644 index 0000000000000000000000000000000000000000..47ee263a8de63261e4c8838ba44fe269553f5f3b --- /dev/null +++ b/fairseq/fairseq/criterions/label_smoothed_cross_entropy_with_rdrop.py @@ -0,0 +1,177 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import math +from dataclasses import dataclass, field + +import torch + +from fairseq import utils +from fairseq.logging import metrics +from fairseq.criterions import register_criterion +from fairseq.criterions.label_smoothed_cross_entropy import ( + LabelSmoothedCrossEntropyCriterion, + LabelSmoothedCrossEntropyCriterionConfig, + label_smoothed_nll_loss, +) + + +@dataclass +class RdropLabelSmoothedCrossEntropyCriterionConfig( + LabelSmoothedCrossEntropyCriterionConfig +): + rdrop_alpha: float = field( + default=0.0, + metadata={"help": "alpha for r-drop, 0 means no r-drop"}, + ) + + +@register_criterion( + "label_smoothed_cross_entropy_with_rdrop", + dataclass=RdropLabelSmoothedCrossEntropyCriterionConfig, +) +class RdropLabelSmoothedCrossEntropyCriterion(LabelSmoothedCrossEntropyCriterion): + def __init__( + self, + task, + sentence_avg, + label_smoothing, + ignore_prefix_size=0, + report_accuracy=False, + rdrop_alpha=0.0, + ): + super().__init__( + task, + sentence_avg, + label_smoothing, + ignore_prefix_size=ignore_prefix_size, + report_accuracy=report_accuracy, + ) + self.sentence_avg = sentence_avg + self.eps = label_smoothing + self.ignore_prefix_size = ignore_prefix_size + self.report_accuracy = report_accuracy + self.rdrop_alpha = rdrop_alpha + + def forward(self, model, sample, reduce=True, net_output=None): + """Compute the loss for the given sample. + + Returns a tuple with three elements: + 1) the loss + 2) the sample size, which is used as the denominator for the gradient + 3) logging outputs to display while training + """ + if net_output is None: + if self.rdrop_alpha > 0 and sample["net_input"]["src_tokens"].size( + 0 + ) == sample["target"].size(0): + sample = duplicate_input(sample) + net_output = model(**sample["net_input"]) + loss, nll_loss, rdrop_kl_loss = self.compute_loss( + model, net_output, sample, reduce=reduce + ) + sample_size = ( + sample["target"].size(0) if self.sentence_avg else sample["ntokens"] + ) + logging_output = { + "loss": loss.data, + "nll_loss": nll_loss.data, + "ntokens": sample["ntokens"], + "nsentences": sample["target"].size(0), + "sample_size": sample_size, + } + if self.report_accuracy: + n_correct, total = self.compute_accuracy(model, net_output, sample) + logging_output["n_correct"] = utils.item(n_correct.data) + logging_output["total"] = utils.item(total.data) + if self.rdrop_alpha > 0: + logging_output["rdrop_kl_loss"] = utils.item(rdrop_kl_loss.data) + return loss, sample_size, logging_output + + def get_lprobs_and_target(self, model, net_output, sample): + lprobs = model.get_normalized_probs(net_output, log_probs=True) + target = model.get_targets(sample, net_output) + if self.rdrop_alpha > 0 or target.size(0) != lprobs.size(0): + target = torch.cat([target, target.clone()], dim=0) + + if self.ignore_prefix_size > 0: + # lprobs: B x T x C + lprobs = lprobs[:, self.ignore_prefix_size :, :].contiguous() + target = target[:, self.ignore_prefix_size :].contiguous() + return lprobs.view(-1, lprobs.size(-1)), target.view(-1) + + def compute_loss(self, model, net_output, sample, reduce=True): + lprobs, target = self.get_lprobs_and_target(model, net_output, sample) + loss, nll_loss = label_smoothed_nll_loss( + lprobs, + target, + self.eps, + ignore_index=self.padding_idx, + reduce=reduce, + ) + + if self.rdrop_alpha > 0: + pad_mask = target[: target.size(0) // 2].unsqueeze(-1).eq(self.padding_idx) + rdrop_kl_loss = compute_kl_loss(model, net_output, pad_mask) + loss += self.rdrop_alpha * rdrop_kl_loss + else: + rdrop_kl_loss = loss.new_zeros(1) + return loss, nll_loss, rdrop_kl_loss + + @classmethod + def reduce_metrics(cls, logging_outputs) -> None: + """Aggregate logging outputs from data parallel training.""" + super().reduce_metrics(logging_outputs) + + sample_size = sum(log.get("sample_size", 0) for log in logging_outputs) + + rdrop_kl_loss = utils.item( + sum(log.get("rdrop_kl_loss", 0) for log in logging_outputs) + / sample_size + / math.log(2) + ) + if rdrop_kl_loss > 0: + metrics.log_scalar("rdrop_kl_loss", rdrop_kl_loss) + + +def duplicate_input(sample): + if "net_input" in sample.keys(): + sample_input = sample["net_input"] + else: + sample_input = sample + + for k, v in sample_input.items(): + if isinstance(v, torch.Tensor): + sample_input[k] = torch.cat([v, v.clone()], dim=0) + if "net_input" in sample.keys(): + sample["net_input"] = sample_input + else: + sample = sample_input + return sample + + +def compute_kl_loss(model, net_output, pad_mask=None, reduce=True): + net_prob = model.get_normalized_probs(net_output, log_probs=True) + net_prob_tec = model.get_normalized_probs(net_output, log_probs=False) + + net_prob = net_prob.view(-1, net_prob.size(-1)) + net_prob_tec = net_prob_tec.view(-1, net_prob_tec.size(-1)) + + p, q = torch.split(net_prob, net_prob.size(0) // 2, dim=0) + p_tec, q_tec = torch.split(net_prob_tec, net_prob_tec.size(0) // 2, dim=0) + + p_loss = torch.nn.functional.kl_div(p, q_tec, reduction="none") + q_loss = torch.nn.functional.kl_div(q, p_tec, reduction="none") + + if pad_mask is not None: + p_loss.masked_fill_(pad_mask, 0.0) + q_loss.masked_fill_(pad_mask, 0.0) + + if reduce: + p_loss = p_loss.sum() + q_loss = q_loss.sum() + + loss = (p_loss + q_loss) / 2 + return loss diff --git a/fairseq/fairseq/criterions/legacy_masked_lm.py b/fairseq/fairseq/criterions/legacy_masked_lm.py new file mode 100644 index 0000000000000000000000000000000000000000..5cf70df2ab97eef1ec454ddc8ccaf5a86cc3c153 --- /dev/null +++ b/fairseq/fairseq/criterions/legacy_masked_lm.py @@ -0,0 +1,178 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import math + +import torch +import torch.nn.functional as F +from fairseq import utils +from fairseq.logging import metrics +from fairseq.criterions import FairseqCriterion, register_criterion + + +def compute_cross_entropy_loss(logits, targets, ignore_index=-100): + """ + Function to compute the cross entropy loss. The default value of + ignore_index is the same as the default value for F.cross_entropy in + pytorch. + """ + assert logits.size(0) == targets.size( + -1 + ), "Logits and Targets tensor shapes don't match up" + + loss = F.nll_loss( + F.log_softmax(logits, -1, dtype=torch.float32), + targets, + reduction="sum", + ignore_index=ignore_index, + ) + return loss + + +@register_criterion("legacy_masked_lm_loss") +class LegacyMaskedLmLoss(FairseqCriterion): + """ + Implementation for the loss used in masked language model (MLM) training. + This optionally also computes the next sentence prediction (NSP) loss and + adds it to the overall loss based on the specified args. There are three + cases to consider: + 1) Generic MLM training without NSP loss. In this case sentence_targets + and sentence_logits are both None. + 2) BERT training without NSP loss. In this case sentence_targets is + not None but sentence_logits is None and we should not be computing + a sentence level loss. + 3) BERT training with NSP loss. In this case both sentence_targets and + sentence_logits are not None and we should be computing a sentence + level loss. The weight of the sentence level loss is specified as + an argument. + """ + + def __init__(self, task, masked_lm_only, nsp_loss_weight): + super().__init__(task) + self.masked_lm_only = masked_lm_only + self.nsp_loss_weight = nsp_loss_weight + + @staticmethod + def add_args(parser): + """Args for MaskedLM Loss""" + # Default for masked_lm_only is False so as to not break BERT training + parser.add_argument( + "--masked-lm-only", + default=False, + action="store_true", + help="compute MLM loss only", + ) + parser.add_argument( + "--nsp-loss-weight", + default=1.0, + type=float, + help="weight for next sentence prediction" " loss (default 1)", + ) + + def forward(self, model, sample, reduce=True): + """Compute the loss for the given sample. + Returns a tuple with three elements: + 1) the loss + 2) the sample size, which is used as the denominator for the gradient + 3) logging outputs to display while training + """ + lm_logits, output_metadata = model(**sample["net_input"]) + + # reshape lm_logits from (N,T,C) to (N*T,C) + lm_logits = lm_logits.view(-1, lm_logits.size(-1)) + lm_targets = sample["lm_target"].view(-1) + lm_loss = compute_cross_entropy_loss(lm_logits, lm_targets, self.padding_idx) + + # compute the number of tokens for which loss is computed. This is used + # to normalize the loss + ntokens = utils.strip_pad(lm_targets, self.padding_idx).numel() + loss = lm_loss / ntokens + nsentences = sample["nsentences"] + # nsentences = 0 + + # Compute sentence loss if masked_lm_only is False + sentence_loss = None + if not self.masked_lm_only: + sentence_logits = output_metadata["sentence_logits"] + sentence_targets = sample["sentence_target"].view(-1) + # This needs to be recomputed due to some differences between + # TokenBlock and BlockPair dataset. This can be resolved with a + # refactor of BERTModel which we will do in the future. + # TODO: Remove this after refactor of BERTModel + nsentences = sentence_targets.size(0) + + # Check for logits being none which can happen when remove_heads + # is set to true in the BERT model. Ideally we should set + # masked_lm_only to true in this case, but that requires some + # refactor in the BERT model. + if sentence_logits is not None: + sentence_loss = compute_cross_entropy_loss( + sentence_logits, sentence_targets + ) + + loss += self.nsp_loss_weight * (sentence_loss / nsentences) + + # NOTE: as we are summing up per token mlm loss and per sentence nsp loss + # we don't need to use sample_size as denominator for the gradient + # here sample_size is just used for logging + sample_size = 1 + logging_output = { + "loss": utils.item(loss.data) if reduce else loss.data, + "lm_loss": utils.item(lm_loss.data) if reduce else lm_loss.data, + # sentence loss is not always computed + "sentence_loss": ( + (utils.item(sentence_loss.data) if reduce else sentence_loss.data) + if sentence_loss is not None + else 0.0 + ), + "ntokens": ntokens, + "nsentences": nsentences, + "sample_size": sample_size, + } + return loss, sample_size, logging_output + + @staticmethod + def reduce_metrics(logging_outputs) -> None: + """Aggregate logging outputs from data parallel training.""" + lm_loss_sum = sum(log.get("lm_loss", 0) for log in logging_outputs) + sentence_loss_sum = sum(log.get("sentence_loss", 0) for log in logging_outputs) + ntokens = sum(log.get("ntokens", 0) for log in logging_outputs) + nsentences = sum(log.get("nsentences", 0) for log in logging_outputs) + sample_size = sum(log.get("sample_size", 0) for log in logging_outputs) + agg_loss = sum(log.get("loss", 0) for log in logging_outputs) + + metrics.log_scalar( + "loss", + agg_loss / sample_size / math.log(2) if sample_size > 0 else 0.0, + sample_size, + round=3, + ) + metrics.log_scalar( + "lm_loss", + lm_loss_sum / ntokens / math.log(2) if ntokens > 0 else 0.0, + ntokens, + round=3, + ) + metrics.log_scalar( + "sentence_loss", + sentence_loss_sum / nsentences / math.log(2) if nsentences > 0 else 0.0, + nsentences, + round=3, + ) + metrics.log_scalar( + "nll_loss", + lm_loss_sum / ntokens / math.log(2) if ntokens > 0 else 0.0, + ntokens, + round=3, + ) + + @staticmethod + def logging_outputs_can_be_summed() -> bool: + """ + Whether the logging outputs returned by `forward` can be summed + across workers prior to calling `reduce_metrics`. Setting this + to True will improves distributed training speed. + """ + return True diff --git a/fairseq/fairseq/criterions/masked_lm.py b/fairseq/fairseq/criterions/masked_lm.py new file mode 100644 index 0000000000000000000000000000000000000000..09ddd9f3e6d41c4771521bb187ad981223e09e95 --- /dev/null +++ b/fairseq/fairseq/criterions/masked_lm.py @@ -0,0 +1,99 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from dataclasses import dataclass +import math +from omegaconf import II + +import torch +from fairseq import modules, utils +from fairseq.logging import metrics +from fairseq.criterions import FairseqCriterion, register_criterion +from fairseq.dataclass import FairseqDataclass + + +@dataclass +class MaskedLmConfig(FairseqDataclass): + tpu: bool = II("common.tpu") + + +@register_criterion("masked_lm", dataclass=MaskedLmConfig) +class MaskedLmLoss(FairseqCriterion): + """ + Implementation for the loss used in masked language model (MLM) training. + """ + + def __init__(self, cfg: MaskedLmConfig, task): + super().__init__(task) + self.tpu = cfg.tpu + + def forward(self, model, sample, reduce=True): + """Compute the loss for the given sample. + + Returns a tuple with three elements: + 1) the loss + 2) the sample size, which is used as the denominator for the gradient + 3) logging outputs to display while training + """ + masked_tokens = sample["target"].ne(self.padding_idx) + sample_size = masked_tokens.int().sum() + + # Rare: when all tokens are masked, project all tokens. + # We use torch.where to avoid device-to-host transfers, + # except on CPU where torch.where is not well supported + # (see github.com/pytorch/pytorch/issues/26247). + if self.tpu: + masked_tokens = None # always project all tokens on TPU + elif masked_tokens.device == torch.device("cpu"): + if not masked_tokens.any(): + masked_tokens = None + else: + masked_tokens = torch.where( + masked_tokens.any(), + masked_tokens, + masked_tokens.new([True]), + ) + + logits = model(**sample["net_input"], masked_tokens=masked_tokens)[0] + targets = model.get_targets(sample, [logits]) + if masked_tokens is not None: + targets = targets[masked_tokens] + + loss = modules.cross_entropy( + logits.view(-1, logits.size(-1)), + targets.view(-1), + reduction="sum", + ignore_index=self.padding_idx, + ) + + logging_output = { + "loss": loss if self.tpu else loss.data, + "ntokens": sample["ntokens"], + "nsentences": sample["nsentences"], + "sample_size": sample_size, + } + return loss, sample_size, logging_output + + @staticmethod + def reduce_metrics(logging_outputs) -> None: + """Aggregate logging outputs from data parallel training.""" + loss_sum = sum(log.get("loss", 0) for log in logging_outputs) + sample_size = sum(log.get("sample_size", 0) for log in logging_outputs) + + metrics.log_scalar( + "loss", loss_sum / sample_size / math.log(2), sample_size, round=3 + ) + metrics.log_derived( + "ppl", lambda meters: utils.get_perplexity(meters["loss"].avg) + ) + + @staticmethod + def logging_outputs_can_be_summed() -> bool: + """ + Whether the logging outputs returned by `forward` can be summed + across workers prior to calling `reduce_metrics`. Setting this + to True will improves distributed training speed. + """ + return True diff --git a/fairseq/fairseq/criterions/model_criterion.py b/fairseq/fairseq/criterions/model_criterion.py new file mode 100644 index 0000000000000000000000000000000000000000..4c020ddbd2fde9a879d94fe63e745d3e6b9a627e --- /dev/null +++ b/fairseq/fairseq/criterions/model_criterion.py @@ -0,0 +1,177 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import logging +from dataclasses import dataclass, field +from typing import Dict, List + +import torch + +from fairseq import utils +from fairseq.logging import metrics +from fairseq.criterions import FairseqCriterion, register_criterion +from fairseq.dataclass import FairseqDataclass +from fairseq.logging.meters import safe_round + + +logger = logging.getLogger(__name__) + + +@dataclass +class ModelCriterionConfig(FairseqDataclass): + loss_weights: Dict[str, float] = field( + default_factory=dict, + metadata={"help": "weights for the loss terms"}, + ) + log_keys: List[str] = field( + default_factory=list, + metadata={"help": "additional output keys to log"}, + ) + can_sum: bool = True + + +@register_criterion("model", dataclass=ModelCriterionConfig) +class ModelCriterion(FairseqCriterion): + """ + This criterion relies on the model to supply losses. + The losses should be a dictionary of name -> scalar returned by + the model either by including it in the net_output dict or by + implementing a get_losses(net_output, sample) method. The final loss is + a scaled sum of all losses according to weights in loss_weights. + If no weights are provided, then all losses are scaled by 1.0. + + The losses will be automatically logged. Additional keys from + net_output dict can be logged via the log_keys parameter. + """ + + def __init__(self, task, loss_weights=None, log_keys=None, can_sum=True): + super().__init__(task) + self.loss_weights = loss_weights + self.log_keys = log_keys + self.can_sum = can_sum + + def forward(self, model, sample, reduce=True): + net_output = model(**sample["net_input"]) + + scaled_losses = {} + + if hasattr(model, "get_losses"): + losses = model.get_losses(net_output, sample) + elif isinstance(net_output, dict) and "losses" in net_output: + losses = net_output["losses"] + else: + raise Exception("Could not retrieve losses") + + for lk, p in losses.items(): + try: + coef = 1.0 if len(self.loss_weights) == 0 else self.loss_weights[lk] + except KeyError: + logger.error( + f"weight for loss {lk} is not in loss_weights ({self.loss_weights})" + ) + raise + if coef != 0 and p is not None: + scaled_losses[lk] = coef * p.float().sum() + + loss = sum(scaled_losses.values()) + + if "sample_size" in net_output: + sample_size = net_output["sample_size"] + else: + sample_size = loss.numel() + + if reduce and loss.numel() > 1: + loss = loss.sum() + + logging_output = { + "loss": loss.data, + "ntokens": sample_size, + "nsentences": sample["id"].numel(), + "sample_size": sample_size, + "_world_size": 1, + } + + for lk in self.log_keys: + if lk in net_output and net_output[lk] is not None: + if not torch.is_tensor(net_output[lk]) or net_output[lk].numel() == 1: + logging_output[lk] = float(net_output[lk]) + elif lk.startswith("_"): + logging_output[lk] = net_output[lk] + else: + for i, v in enumerate(net_output[lk]): + logging_output[f"{lk}_{i}"] = float(v) + + if len(scaled_losses) > 1: + for lk, l in scaled_losses.items(): + if l.numel() > 1: + l = l.sum() + logging_output[f"loss_{lk}"] = l.item() + + if "logs" in net_output: + for lgw in net_output["logs"]: + logging_output[lgw] = net_output["logs"][lgw] + + return loss, sample_size, logging_output + + @staticmethod + def reduce_metrics(logging_outputs) -> None: + """Aggregate logging outputs from data parallel training.""" + loss_sum = utils.item(sum(log.get("loss", 0) for log in logging_outputs)) + ntokens = utils.item(sum(log.get("ntokens", 0) for log in logging_outputs)) + nsentences = utils.item( + sum(log.get("nsentences", 0) for log in logging_outputs) + ) + sample_size = utils.item( + sum(log.get("sample_size", 0) for log in logging_outputs) + ) + + metrics.log_scalar("loss", loss_sum / sample_size, sample_size, round=3) + metrics.log_scalar("ntokens", ntokens) + metrics.log_scalar("nsentences", nsentences) + metrics.log_scalar("sample_size", sample_size) + + builtin_keys = { + "loss", + "ntokens", + "nsentences", + "sample_size", + "_world_size", + } + + world_size = utils.item( + sum(log.get("_world_size", 0) for log in logging_outputs) + ) + + for k in logging_outputs[0]: + if k not in builtin_keys and not k.startswith("_"): + val = sum(log.get(k, 0) for log in logging_outputs) + if k.startswith("loss_"): + metrics.log_scalar(k, val / sample_size, sample_size, round=3) + else: + metrics.log_scalar(k, val / world_size, round=3) + + correct = sum(log.get("correct", 0) for log in logging_outputs) + total = sum(log.get("count", 0) for log in logging_outputs) + + if total > 0: + metrics.log_scalar("_correct", correct) + metrics.log_scalar("_total", total) + + metrics.log_derived( + "accuracy", + lambda meters: safe_round( + meters["_correct"].sum / meters["_total"].sum, 5 + ) + if meters["_total"].sum > 0 + else float("nan"), + ) + + def logging_outputs_can_be_summed(self) -> bool: + """ + Whether the logging outputs returned by `forward` can be summed + across workers prior to calling `reduce_metrics`. Setting this + to True will improves distributed training speed. + """ + return self.can_sum diff --git a/fairseq/fairseq/criterions/nat_loss.py b/fairseq/fairseq/criterions/nat_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..fc0bdaf8510d8feae20779aad49b53a4d84d37db --- /dev/null +++ b/fairseq/fairseq/criterions/nat_loss.py @@ -0,0 +1,181 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import math + +import torch +import torch.nn.functional as F +from fairseq import utils +from fairseq.logging import metrics +from fairseq.criterions import FairseqCriterion, register_criterion +from fairseq.dataclass import FairseqDataclass +from torch import Tensor + +from dataclasses import dataclass, field + + +@dataclass +class LabelSmoothedDualImitationCriterionConfig(FairseqDataclass): + label_smoothing: float = field( + default=0.0, + metadata={"help": "epsilon for label smoothing, 0 means no label smoothing"}, + ) + + +@register_criterion("nat_loss", dataclass=LabelSmoothedDualImitationCriterionConfig) +class LabelSmoothedDualImitationCriterion(FairseqCriterion): + def __init__(self, task, label_smoothing): + super().__init__(task) + self.label_smoothing = label_smoothing + + def _compute_loss( + self, outputs, targets, masks=None, label_smoothing=0.0, name="loss", factor=1.0 + ): + """ + outputs: batch x len x d_model + targets: batch x len + masks: batch x len + + policy_logprob: if there is some policy + depends on the likelihood score as rewards. + """ + + def mean_ds(x: Tensor, dim=None) -> Tensor: + return ( + x.float().mean().type_as(x) + if dim is None + else x.float().mean(dim).type_as(x) + ) + + if masks is not None: + outputs, targets = outputs[masks], targets[masks] + + if masks is not None and not masks.any(): + nll_loss = torch.tensor(0) + loss = nll_loss + else: + logits = F.log_softmax(outputs, dim=-1) + if targets.dim() == 1: + losses = F.nll_loss(logits, targets.to(logits.device), reduction="none") + + else: # soft-labels + losses = F.kl_div(logits, targets.to(logits.device), reduction="none") + losses = losses.sum(-1) + + nll_loss = mean_ds(losses) + if label_smoothing > 0: + loss = ( + nll_loss * (1 - label_smoothing) - mean_ds(logits) * label_smoothing + ) + else: + loss = nll_loss + + loss = loss * factor + return {"name": name, "loss": loss, "nll_loss": nll_loss, "factor": factor} + + def _custom_loss(self, loss, name="loss", factor=1.0): + return {"name": name, "loss": loss, "factor": factor} + + def forward(self, model, sample, reduce=True): + """Compute the loss for the given sample. + Returns a tuple with three elements: + 1) the loss + 2) the sample size, which is used as the denominator for the gradient + 3) logging outputs to display while training + """ + nsentences, ntokens = sample["nsentences"], sample["ntokens"] + + # B x T + src_tokens, src_lengths = ( + sample["net_input"]["src_tokens"], + sample["net_input"]["src_lengths"], + ) + tgt_tokens, prev_output_tokens = sample["target"], sample["prev_target"] + + outputs = model(src_tokens, src_lengths, prev_output_tokens, tgt_tokens) + losses, nll_loss = [], [] + + for obj in outputs: + if outputs[obj].get("loss", None) is None: + _losses = self._compute_loss( + outputs[obj].get("out"), + outputs[obj].get("tgt"), + outputs[obj].get("mask", None), + outputs[obj].get("ls", 0.0), + name=obj + "-loss", + factor=outputs[obj].get("factor", 1.0), + ) + else: + _losses = self._custom_loss( + outputs[obj].get("loss"), + name=obj + "-loss", + factor=outputs[obj].get("factor", 1.0), + ) + + losses += [_losses] + if outputs[obj].get("nll_loss", False): + nll_loss += [_losses.get("nll_loss", 0.0)] + + loss = sum(l["loss"] for l in losses) + nll_loss = sum(l for l in nll_loss) if len(nll_loss) > 0 else loss.new_tensor(0) + + # NOTE: + # we don't need to use sample_size as denominator for the gradient + # here sample_size is just used for logging + sample_size = 1 + logging_output = { + "loss": loss.data, + "nll_loss": nll_loss.data, + "ntokens": ntokens, + "nsentences": nsentences, + "sample_size": sample_size, + } + + for l in losses: + logging_output[l["name"]] = ( + utils.item(l["loss"].data / l["factor"]) + if reduce + else l[["loss"]].data / l["factor"] + ) + + return loss, sample_size, logging_output + + @staticmethod + def reduce_metrics(logging_outputs) -> None: + """Aggregate logging outputs from data parallel training.""" + sample_size = utils.item( + sum(log.get("sample_size", 0) for log in logging_outputs) + ) + loss = utils.item(sum(log.get("loss", 0) for log in logging_outputs)) + nll_loss = utils.item(sum(log.get("nll_loss", 0) for log in logging_outputs)) + + metrics.log_scalar( + "loss", loss / sample_size / math.log(2), sample_size, round=3 + ) + metrics.log_scalar( + "nll_loss", nll_loss / sample_size / math.log(2), sample_size, round=3 + ) + metrics.log_derived( + "ppl", lambda meters: utils.get_perplexity(meters["loss"].avg) + ) + + for key in logging_outputs[0]: + if key[-5:] == "-loss": + val = sum(log.get(key, 0) for log in logging_outputs) + metrics.log_scalar( + key[:-5], + val / sample_size / math.log(2) if sample_size > 0 else 0.0, + sample_size, + round=3, + ) + + @staticmethod + def logging_outputs_can_be_summed() -> bool: + """ + Whether the logging outputs returned by `forward` can be summed + across workers prior to calling `reduce_metrics`. Setting this + to True will improves distributed training speed. + """ + return True diff --git a/fairseq/fairseq/criterions/sentence_prediction.py b/fairseq/fairseq/criterions/sentence_prediction.py new file mode 100644 index 0000000000000000000000000000000000000000..298b80576814d785e8722ae4726698a0a928a20f --- /dev/null +++ b/fairseq/fairseq/criterions/sentence_prediction.py @@ -0,0 +1,288 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import math +from dataclasses import dataclass, field +from itertools import chain + +import numpy as np +import torch +import torch.nn.functional as F +from sklearn.metrics import f1_score +from sklearn.metrics import matthews_corrcoef as _matthews_corrcoef +from scipy.stats import pearsonr, spearmanr + +from fairseq.logging import metrics +from fairseq.criterions import FairseqCriterion, register_criterion +from fairseq.dataclass import FairseqDataclass +from fairseq.logging.meters import safe_round + + +def simple_accuracy(preds, labels): + return (preds == labels).mean() + + +def acc_and_f1(preds, labels): + acc = simple_accuracy(preds, labels) + f1 = f1_score(y_true=labels, y_pred=preds) + return { + "acc": acc, + "f1": f1, + "acc_and_f1": (acc + f1) / 2, + } + + +def pearson_and_spearman(preds, labels): + pearson_corr = pearsonr(preds, labels)[0] + spearman_corr = spearmanr(preds, labels)[0] + return { + "pearson": pearson_corr, + "spearmanr": spearman_corr, + "corr": (pearson_corr + spearman_corr) / 2, + } + + +def matthews_corrcoef(preds, labels): + # make it consistent with other metrics taking (preds, labels) as input + mcc = _matthews_corrcoef(labels, preds) + return mcc + + +@dataclass +class SentencePredictionConfig(FairseqDataclass): + classification_head_name: str = field( + default="sentence_classification_head", + metadata={"help": "name of the classification head to use"}, + ) + regression_target: bool = field( + default=False, + ) + report_mcc: bool = False + report_acc_and_f1: bool = False + report_pearson_and_spearman: bool = False + + +@register_criterion("sentence_prediction", dataclass=SentencePredictionConfig) +class SentencePredictionCriterion(FairseqCriterion): + def __init__(self, cfg: SentencePredictionConfig, task): + super().__init__(task) + self.classification_head_name = cfg.classification_head_name + self.regression_target = cfg.regression_target + self.keep_pred_and_targ = ( + cfg.report_mcc or cfg.report_acc_and_f1 or cfg.report_pearson_and_spearman + ) + self.report_mcc = cfg.report_mcc + self.report_acc_and_f1 = cfg.report_acc_and_f1 + self.report_pearson_and_spearman = cfg.report_pearson_and_spearman + self.label_dict = task.label_dictionary + + def forward(self, model, sample, reduce=True): + """Compute the loss for the given sample. + + Returns a tuple with three elements: + 1) the loss + 2) the sample size, which is used as the denominator for the gradient + 3) logging outputs to display while training + """ + assert ( + hasattr(model, "classification_heads") + and self.classification_head_name in model.classification_heads + ), "model must provide sentence classification head for --criterion=sentence_prediction" + + logits, _ = model( + **sample["net_input"], + features_only=True, + classification_head_name=self.classification_head_name, + ) + targets = model.get_targets(sample, [logits]).view(-1) + sample_size = targets.numel() + + if not self.regression_target: + lprobs = F.log_softmax(logits, dim=-1, dtype=torch.float32) + task_loss = F.nll_loss(lprobs, targets, reduction="sum") + else: + logits = logits.view(-1).float() + targets = targets.float() + task_loss = F.mse_loss(logits, targets, reduction="sum") + + logging_output = {} + loss = task_loss + # mha & ffn regularization update + if ( + hasattr(model, "args") + and hasattr(model.args, "mha_reg_scale_factor") + and model.args.mha_reg_scale_factor != 0.0 + ): + mha_reg_loss = model._get_adaptive_head_loss() + loss += mha_reg_loss + logging_output.update({"mha_reg_loss": mha_reg_loss}) + if ( + hasattr(model, "args") + and hasattr(model.args, "ffn_reg_scale_factor") + and model.args.ffn_reg_scale_factor != 0.0 + ): + ffn_reg_loss = model._get_adaptive_ffn_loss() + loss += ffn_reg_loss + logging_output.update({"ffn_reg_loss": ffn_reg_loss}) + + logging_output.update( + { + "loss": loss.data, + "ntokens": sample["ntokens"], + "nsentences": sample_size, + "sample_size": sample_size, + } + ) + if not self.regression_target: + preds = logits.argmax(dim=1) + logging_output["ncorrect"] = (preds == targets).sum() + if self.keep_pred_and_targ and not model.training: + if self.regression_target: + logging_output["pred"] = logits.detach().cpu().tolist() + logging_output["targ"] = targets.detach().cpu().tolist() + else: + # remove offset `self.label_dict.nspecial` from OffsetTokensDataset + preds = self.label_dict.string(preds + self.label_dict.nspecial).split() + targets = self.label_dict.string( + targets + self.label_dict.nspecial + ).split() + logging_output["pred"] = list(map(int, preds)) + logging_output["targ"] = list(map(int, targets)) + + if self.report_mcc: + logging_output["report_mcc"] = True + if self.report_acc_and_f1: + logging_output["report_acc_and_f1"] = True + if self.report_pearson_and_spearman: + logging_output["report_pearson_and_spearman"] = True + + return loss, sample_size, logging_output + + @staticmethod + def reduce_metrics(logging_outputs) -> None: + """Aggregate logging outputs from data parallel training.""" + loss_sum = sum(log.get("loss", 0) for log in logging_outputs) + ntokens = sum(log.get("ntokens", 0) for log in logging_outputs) + nsentences = sum(log.get("nsentences", 0) for log in logging_outputs) + sample_size = sum(log.get("sample_size", 0) for log in logging_outputs) + mha_reg_loss_sum = sum(log.get("mha_reg_loss", 0) for log in logging_outputs) + ffn_reg_loss_sum = sum(log.get("ffn_reg_loss", 0) for log in logging_outputs) + + metrics.log_scalar( + "loss", loss_sum / sample_size / math.log(2), sample_size, round=3 + ) + if mha_reg_loss_sum: + metrics.log_scalar( + "mha_reg_loss", + mha_reg_loss_sum / sample_size / math.log(2), + sample_size, + round=3, + ) + if ffn_reg_loss_sum: + metrics.log_scalar( + "ffn_reg_loss", + ffn_reg_loss_sum / sample_size / math.log(2), + sample_size, + round=3, + ) + if sample_size != ntokens: + metrics.log_scalar( + "nll_loss", loss_sum / ntokens / math.log(2), ntokens, round=3 + ) + + if len(logging_outputs) > 0 and "ncorrect" in logging_outputs[0]: + ncorrect = sum(log.get("ncorrect", 0) for log in logging_outputs) + metrics.log_scalar( + "accuracy", 100.0 * ncorrect / nsentences, nsentences, round=1 + ) + + # Metrics used by GLUE + pred = np.array( + list(chain.from_iterable(log.get("pred", []) for log in logging_outputs)) + ) + targ = np.array( + list(chain.from_iterable(log.get("targ", []) for log in logging_outputs)) + ) + if len(pred): + metrics.log_concat_tensor("pred", torch.from_numpy(pred), dim=0) + metrics.log_concat_tensor("targ", torch.from_numpy(targ), dim=0) + if any("report_mcc" in log for log in logging_outputs): + metrics.log_derived( + "mcc", + lambda meters: safe_round( + matthews_corrcoef( + meters["pred"].tensor.numpy(), + meters["targ"].tensor.numpy(), + ) + * 100, + 1, + ), + ) + if any("report_acc_and_f1" in log for log in logging_outputs): + metrics.log_derived( + "acc_and_f1", + lambda meters: safe_round( + acc_and_f1( + meters["pred"].tensor.numpy(), + meters["targ"].tensor.numpy(), + )["acc_and_f1"] + * 100, + 1, + ), + ) + metrics.log_derived( + "f1", + lambda meters: safe_round( + acc_and_f1( + meters["pred"].tensor.numpy(), + meters["targ"].tensor.numpy(), + )["f1"] + * 100, + 1, + ), + ) + if any("report_pearson_and_spearman" in log for log in logging_outputs): + metrics.log_derived( + "pearson_and_spearman", + lambda meters: safe_round( + pearson_and_spearman( + meters["pred"].tensor.numpy(), + meters["targ"].tensor.numpy(), + )["corr"] + * 100, + 1, + ), + ) + metrics.log_derived( + "pearson", + lambda meters: safe_round( + pearson_and_spearman( + meters["pred"].tensor.numpy(), + meters["targ"].tensor.numpy(), + )["pearson"] + * 100, + 1, + ), + ) + metrics.log_derived( + "spearman", + lambda meters: safe_round( + pearson_and_spearman( + meters["pred"].tensor.numpy(), + meters["targ"].tensor.numpy(), + )["spearmanr"] + * 100, + 1, + ), + ) + + @staticmethod + def logging_outputs_can_be_summed() -> bool: + """ + Whether the logging outputs returned by `forward` can be summed + across workers prior to calling `reduce_metrics`. Setting this + to True will improves distributed training speed. + """ + return True diff --git a/fairseq/fairseq/criterions/sentence_prediction_adapters.py b/fairseq/fairseq/criterions/sentence_prediction_adapters.py new file mode 100644 index 0000000000000000000000000000000000000000..8a873a45b3b121730c8a27d64facfa2922f8eb88 --- /dev/null +++ b/fairseq/fairseq/criterions/sentence_prediction_adapters.py @@ -0,0 +1,63 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn.functional as F +from fairseq.criterions import register_criterion +from fairseq.criterions.sentence_prediction import ( + SentencePredictionCriterion, + SentencePredictionConfig, +) + + +@register_criterion("sentence_prediction_adapters", dataclass=SentencePredictionConfig) +class SentencePredictionCriterionAdapters(SentencePredictionCriterion): + def forward(self, model, sample, reduce=True): + """Compute the loss for the given sample. + + Returns a tuple with three elements: + 1) the loss + 2) the sample size, which is used as the denominator for the gradient + 3) logging outputs to display while training + """ + assert ( + hasattr(model, "classification_heads") + and self.classification_head_name in model.classification_heads + ), "model must provide sentence classification head for --criterion=sentence_prediction" + + if not hasattr(sample, "lang_id"): + # If no language ID is given, we fall back to English + lang_id = ["en_XX"] * sample["nsentences"] + else: + lang_id = sample["lang_id"] + + logits, _ = model( + **sample["net_input"], + features_only=True, + classification_head_name=self.classification_head_name, + lang_id=lang_id, + ) + targets = model.get_targets(sample, [logits]).view(-1) + sample_size = targets.numel() + + if not self.regression_target: + lprobs = F.log_softmax(logits, dim=-1, dtype=torch.float32) + loss = F.nll_loss(lprobs, targets, reduction="sum") + else: + logits = logits.view(-1).float() + targets = targets.float() + loss = F.mse_loss(logits, targets, reduction="sum") + + logging_output = { + "loss": loss.data, + "ntokens": sample["ntokens"], + "nsentences": sample_size, + "sample_size": sample_size, + } + if not self.regression_target: + preds = logits.argmax(dim=1) + logging_output["ncorrect"] = (preds == targets).sum() + + return loss, sample_size, logging_output diff --git a/fairseq/fairseq/criterions/sentence_ranking.py b/fairseq/fairseq/criterions/sentence_ranking.py new file mode 100644 index 0000000000000000000000000000000000000000..bfb9f058f9208a1eb1218715df0f6f2183085dd9 --- /dev/null +++ b/fairseq/fairseq/criterions/sentence_ranking.py @@ -0,0 +1,121 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import math + +import torch +import torch.nn.functional as F +from fairseq import utils +from fairseq.logging import metrics +from fairseq.criterions import FairseqCriterion, register_criterion + + +@register_criterion("sentence_ranking") +class SentenceRankingCriterion(FairseqCriterion): + def __init__(self, task, ranking_head_name, save_predictions, num_classes): + super().__init__(task) + self.ranking_head_name = ranking_head_name + if save_predictions is not None: + self.prediction_h = open(save_predictions, "w") + else: + self.prediction_h = None + self.num_classes = num_classes + + def __del__(self): + if self.prediction_h is not None: + self.prediction_h.close() + + @staticmethod + def add_args(parser): + # fmt: off + parser.add_argument('--save-predictions', metavar='FILE', + help='file to save predictions to') + parser.add_argument('--ranking-head-name', + default='sentence_classification_head', + help='name of the ranking head to use') + # fmt: on + + def forward(self, model, sample, reduce=True): + """Compute ranking loss for the given sample. + + Returns a tuple with three elements: + 1) the loss + 2) the sample size, which is used as the denominator for the gradient + 3) logging outputs to display while training + """ + assert ( + hasattr(model, "classification_heads") + and self.ranking_head_name in model.classification_heads + ), "model must provide sentence ranking head for --criterion=sentence_ranking" + + scores = [] + for idx in range(self.num_classes): + score, _ = model( + **sample["net_input{idx}".format(idx=idx + 1)], + classification_head_name=self.ranking_head_name, + ) + scores.append(score) + + logits = torch.cat(scores, dim=1) + sample_size = logits.size(0) + + if "target" in sample: + targets = model.get_targets(sample, [logits]).view(-1) + lprobs = F.log_softmax(logits, dim=-1, dtype=torch.float32) + loss = F.nll_loss(lprobs, targets, reduction="sum") + else: + targets = None + loss = torch.tensor(0.0, requires_grad=True) + + if self.prediction_h is not None: + preds = logits.argmax(dim=1) + for i, (id, pred) in enumerate(zip(sample["id"].tolist(), preds.tolist())): + if targets is not None: + label = targets[i].item() + print("{}\t{}\t{}".format(id, pred, label), file=self.prediction_h) + else: + print("{}\t{}".format(id, pred), file=self.prediction_h) + + logging_output = { + "loss": loss.data, + "ntokens": sample["ntokens"], + "nsentences": sample_size, + "sample_size": sample_size, + } + if targets is not None: + logging_output["ncorrect"] = (logits.argmax(dim=1) == targets).sum() + + return loss, sample_size, logging_output + + @staticmethod + def reduce_metrics(logging_outputs) -> None: + """Aggregate logging outputs from data parallel training.""" + loss_sum = sum(log.get("loss", 0) for log in logging_outputs) + ntokens = sum(log.get("ntokens", 0) for log in logging_outputs) + nsentences = sum(log.get("nsentences", 0) for log in logging_outputs) + sample_size = sum(log.get("sample_size", 0) for log in logging_outputs) + + metrics.log_scalar( + "loss", loss_sum / sample_size / math.log(2), sample_size, round=3 + ) + if sample_size != ntokens: + metrics.log_scalar( + "nll_loss", loss_sum / ntokens / math.log(2), ntokens, round=3 + ) + + if len(logging_outputs) > 0 and "ncorrect" in logging_outputs[0]: + ncorrect = sum(log.get("ncorrect", 0) for log in logging_outputs) + metrics.log_scalar( + "accuracy", 100.0 * ncorrect / nsentences, nsentences, round=1 + ) + + @staticmethod + def logging_outputs_can_be_summed() -> bool: + """ + Whether the logging outputs returned by `forward` can be summed + across workers prior to calling `reduce_metrics`. Setting this + to True will improves distributed training speed. + """ + return True diff --git a/fairseq/fairseq/criterions/speech_dlm_criterion.py b/fairseq/fairseq/criterions/speech_dlm_criterion.py new file mode 100644 index 0000000000000000000000000000000000000000..888818011408fd81e5cd3e3c9b074e5082702c79 --- /dev/null +++ b/fairseq/fairseq/criterions/speech_dlm_criterion.py @@ -0,0 +1,335 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import math +from dataclasses import dataclass, field +from typing import Optional + +import torch.nn.functional as F +from fairseq import metrics, utils +from fairseq.criterions import FairseqCriterion, register_criterion +from fairseq.dataclass import FairseqDataclass +from omegaconf import II + + +@dataclass +class SpeechDLMCriterionConfig(FairseqDataclass): + sentence_avg: bool = II("optimization.sentence_avg") + main_and_cross_weights: Optional[str] = field( + default="1,0", + metadata={ + "help": "Comma-separated list of weights of Main-channel vs Cross-channel Prediction Losses" + "(default: 1,0)" + }, + ) + general_unit_loss_weight: float = field( + default=0, + metadata={ + "help": "The weight of the General Prediction Loss (Next-step Unit Prediction Loss)" + "(default: 0)" + }, + ) + edge_unit_loss_weight: float = field( + default=1, + metadata={"help": "The weight of the Edge Unit Prediction Loss" "(default: 1)"}, + ) + duration_loss_weight: float = field( + default=1, + metadata={ + "help": "The weight of the Edge Unit Duration Prediction Loss" + "(default: 1)" + }, + ) + + +@register_criterion("speech_dlm_criterion", dataclass=SpeechDLMCriterionConfig) +class SpeechDLMCriterion(FairseqCriterion): + """Criteron for the SpeechDLM model as described in the paper: + https://arxiv.org/pdf/2203.16502.pdf + + There are 3 possible losses depending on the targets of the model: + - general_unit_loss : The next unit prediction loss, corresponding to + 'next' target + - edge_unit_loss : The edge unit prediction loss, corresponding to + 'edge' target + - duration_loss : The duration prediction loss, corresponding to + 'duration' target + """ + + def __init__( + self, + task, + sentence_avg, + main_and_cross_weights, + general_unit_loss_weight, + edge_unit_loss_weight, + duration_loss_weight, + ): + super().__init__(task) + self.sentence_avg = sentence_avg + + self.channels = task.channels + self.targets = task.targets + self.delayed_duration_target = task.delayed_duration_target + + self.main_channel_weight = float(main_and_cross_weights.split(",")[0]) + self.cross_channel_weight = float(main_and_cross_weights.split(",")[1]) + assert self.main_channel_weight >= 0 and self.cross_channel_weight >= 0 + + self.channel_weights = { + channel: weight + for channel, weight in zip(self.channels, task.channel_weights) + } + + self.target_weights = {} + for t in self.targets: + if t == "next": + self.target_weights[t] = general_unit_loss_weight + assert ( + general_unit_loss_weight > 0 + ), "Expect a positive --general-unit-loss-weight for next unit prediction" + elif t == "edge": + self.target_weights[t] = edge_unit_loss_weight + assert ( + edge_unit_loss_weight > 0 + ), "Expect a positive --edge-unit-loss-weight for edge unit prediction" + elif t == "duration": + self.target_weights[t] = duration_loss_weight + assert ( + duration_loss_weight > 0 + ), "Expect a positive --duration-loss-weight for duration prediction" + + def forward(self, model, sample, reduce=True): + """Compute the loss for the given sample. + + Returns a tuple with three elements: + 1) the loss + 2) the sample size, which is used as the denominator for the gradient + 3) logging outputs to display while training + """ + net_output = model(**sample["net_input"]) + loss_dict, stats_dict = self.compute_loss( + model, net_output, sample, reduce=reduce + ) + nsentences = sample["net_input"]["src_tokens"][self.channels[0]].size(0) + + logging_output = { + "nsentences": nsentences, + } + logging_output["nsentences"] = nsentences + + loss_all = {t: 0 for t in self.targets} + correct_all = {t: 0 for t in self.targets} + count_all = {t: 0 for t in self.targets} + ntokens_all = 0 + sample_size_all = 0 + for channel in loss_dict: + for pred_channel in loss_dict[channel]: + # Get ntokens & sample_size + ntokens = sample["net_input"]["src_tokens"][channel].numel() + sample_size = nsentences if self.sentence_avg else ntokens + prefix = "[{}-{}]".format(channel, pred_channel) + log_keys = { + "next": "general_token", + "edge": "edge_token", + "duration": "edge_duration", + } + + # Log & Update the sizes + logging_output["{}ntokens".format(prefix)] = ntokens + logging_output["{}sample_size".format(prefix)] = sample_size + ntokens_all += ntokens + sample_size_all += sample_size + + for t in self.targets: + log_key = log_keys[t] + loss = loss_dict[channel][pred_channel][t] + correct, count = stats_dict[channel][pred_channel][t] + + # Log the statistics + logging_output["{}{}_loss".format(prefix, log_key)] = loss.data + logging_output["{}{}_correct".format(prefix, log_key)] = correct + logging_output["{}{}_count".format(prefix, log_key)] = count + + # Scale the training loss by weights + target_loss = loss * self.channel_weights[channel] + if pred_channel == channel: + target_loss = target_loss * self.main_channel_weight + else: + target_loss = target_loss * self.cross_channel_weight + # Normalize the losses in the training by the number of edges + if t in ["edge", "duration"]: + target_loss = target_loss / count * sample_size + + # Update the statistics + loss_all[t] += target_loss + correct_all[t] += correct + count_all[t] += count + + # Logging the average statistics + logging_output["ntokens"] = ntokens_all + logging_output["sample_size"] = sample_size_all + for t in self.targets: + log_key = { + "next": "general_token", + "edge": "edge_token", + "duration": "edge_duration", + }[t] + logging_output["{}_loss".format(log_key)] = loss_all[t].data + logging_output["{}_correct".format(log_key)] = correct_all[t] + logging_output["{}_count".format(log_key)] = count_all[t] + + # Define the training loss + training_loss = 0 + for t in self.targets: + training_loss += loss_all[t] * self.target_weights[t] + logging_output["loss"] = training_loss.data + + return training_loss, sample_size_all, logging_output + + def compute_loss(self, model, net_output, sample, reduce=True): + # Get the model outputs and target + lprobs_dict = model.get_normalized_probs(net_output, log_probs=True) + target_dict = model.get_targets(sample, net_output) + + # Init the dictionaries + loss_dict, stats_dict = {}, {} + + for channel in lprobs_dict: + # Init the dictionaries + loss_dict[channel], stats_dict[channel] = {}, {} + + for pred_channel in lprobs_dict[channel]: + # Init the dictionaries + loss_dict[channel][pred_channel] = {} + stats_dict[channel][pred_channel] = {} + + # Get token & duration predictions + outputs = lprobs_dict[channel][pred_channel] + if not isinstance(outputs, dict): + token_lprobs = outputs + else: + token_lprobs = outputs["pred_token"] + dur_preds = outputs["pred_duration"] + dur_preds = dur_preds.view(-1) + token_lprobs = token_lprobs.view(-1, token_lprobs.size(-1)) + token_preds = token_lprobs.argmax(dim=-1) + + # Get edge indices + if "edge" in self.targets or "duration" in self.targets: + edge_indices = target_dict["edge_indices"][pred_channel] + + # Compute loss and statistics + for t in self.targets: + if t in ["next", "edge"]: + if t == "next": + target = target_dict["next"][pred_channel].view(-1) + lprobs = token_lprobs + preds = token_preds + elif t == "edge": + target = target_dict["edge"][pred_channel] + lprobs = token_lprobs[edge_indices] + preds = token_preds[edge_indices] + + loss = F.nll_loss( + lprobs, + target, + ignore_index=self.padding_idx, + reduction="sum" if reduce else "none", + ) + elif t == "duration": + target = target_dict["duration"][pred_channel] + if self.delayed_duration_target: + duration_indices = edge_indices + 1 + if duration_indices[-1] == len(dur_preds): + duration_indices = duration_indices[:-1] + target = target[:-1] + else: + duration_indices = edge_indices + preds = dur_preds[duration_indices] + + loss = F.l1_loss( + preds, + target, + reduction="sum" if reduce else "none", + ) + preds = preds.round() + + correct = (preds == target).sum().float().cpu().item() + count = float(target.size(0)) + + loss_dict[channel][pred_channel][t] = loss + stats_dict[channel][pred_channel][t] = (correct, count) + + return loss_dict, stats_dict + + @staticmethod + def reduce_metrics(logging_outputs) -> None: + """Aggregate logging outputs from data parallel training.""" + logging_keys = next(iter(logging_outputs)).keys() + channels = [item[:-7] for item in logging_keys if item.endswith("ntokens")] + target_prefixes = set( + [ + item[:-5].split("]")[-1] + for item in logging_keys + if item.endswith("_loss") + ] + ) + for channel_prefix in channels: + for target_prefix in target_prefixes: + prefix = "{}{}".format(channel_prefix, target_prefix) + count_sum = sum( + log.get("{}_count".format(prefix), 0) for log in logging_outputs + ) + correct_sum = sum( + log.get("{}_correct".format(prefix), 0) for log in logging_outputs + ) + loss_sum = sum( + log.get("{}_loss".format(prefix), 0) for log in logging_outputs + ) + + if "duration" not in target_prefix: + # we divide by log(2) to convert the loss from base e to base 2 + metrics.log_scalar( + "{}_loss".format(prefix), + loss_sum / count_sum / math.log(2), + count_sum, + round=3, + ) + metrics.log_derived( + "{}_ppl".format(prefix), + lambda meters, prefix=prefix: utils.get_perplexity( + meters["{}_loss".format(prefix)].avg + ), + ) + else: + # for duration we don't need to divide by log(2) + metrics.log_scalar( + "{}_loss".format(prefix), + loss_sum / count_sum, + count_sum, + round=3, + ) + + accuracy = 100 * correct_sum / count_sum + metrics.log_scalar("{}_pred_acc".format(prefix), accuracy, round=3) + + # Logging training loss + sample_size = sum(log.get("sample_size", 0) for log in logging_outputs) + loss_sum = sum(log.get("loss", 0) for log in logging_outputs) + + # we divide by log(2) to convert the loss from base e to base 2 + metrics.log_scalar( + "loss", loss_sum / sample_size / math.log(2), sample_size, round=3 + ) + + @staticmethod + def logging_outputs_can_be_summed() -> bool: + """ + Whether the logging outputs returned by `forward` can be summed + across workers prior to calling `reduce_metrics`. Setting this + to True will improves distributed training speed. + """ + return True diff --git a/fairseq/fairseq/criterions/speech_ulm_criterion.py b/fairseq/fairseq/criterions/speech_ulm_criterion.py new file mode 100644 index 0000000000000000000000000000000000000000..eea74bae2641e285c56e86feb6e9866464c9673f --- /dev/null +++ b/fairseq/fairseq/criterions/speech_ulm_criterion.py @@ -0,0 +1,126 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from dataclasses import dataclass, field + +import torch.nn.functional as F +from fairseq.logging import metrics +from fairseq.tasks import FairseqTask +from fairseq.criterions import FairseqCriterion, register_criterion +from fairseq.dataclass import FairseqDataclass +from omegaconf import II + + +@dataclass +class SpeechUnitLmCriterionConfig(FairseqDataclass): + sentence_avg: bool = II("optimization.sentence_avg") + loss_weights: str = field( + default="1.;0.0;0.0", + metadata={ + "help": "Weights of the losses that correspond to token, duration, and F0 streams" + }, + ) + discrete_duration: bool = II("task.discrete_duration") + discrete_f0: bool = II("task.discrete_f0") + + +def mae_loss(pred, targ, mask, reduce=True): + if pred.ndim == 3: + pred = pred.squeeze(2) + else: + assert pred.ndim == 2 + loss = (pred.float() - targ.float()).abs() * (~mask).float() + loss = loss.sum() if reduce else loss.view(-1) + return loss + + +def nll_loss(pred, targ, mask, reduce=True): + lprob = F.log_softmax(pred, dim=-1) + loss = F.nll_loss(lprob.view(-1, lprob.size(-1)), targ.view(-1), reduction="none") + loss = loss * (~mask).float().view(-1) + loss = loss.sum() if reduce else loss.view(-1) + return loss + + +@register_criterion("speech_unit_lm_criterion", dataclass=SpeechUnitLmCriterionConfig) +class SpeechUnitLmCriterion(FairseqCriterion): + def __init__(self, cfg: SpeechUnitLmCriterionConfig, task: FairseqTask): + super().__init__(task) + self.sentence_avg = cfg.sentence_avg + self.weights = torch.tensor([float(w) for w in cfg.loss_weights.split(";")]) + assert self.weights.size(0) == 3 + assert (self.weights >= 0.0).all() + + self.dur_loss_fn = nll_loss if cfg.discrete_duration else mae_loss + self.f0_loss_fn = nll_loss if cfg.discrete_f0 else mae_loss + + def forward(self, model, sample, reduce=True): + """Compute the loss for the given sample. + + Returns a tuple with three elements: + 1) the loss + 2) the sample size, which is used as the denominator for the gradient + 3) logging outputs to display while training + """ + net_output = model(**sample["net_input"]) + + token_loss = nll_loss( + net_output["token"], sample["target"], sample["mask"], reduce + ) + dur_loss = self.dur_loss_fn( + net_output["duration"], + sample["dur_target"], + sample["dur_mask"], + reduce, + ) + f0_loss = self.f0_loss_fn( + net_output["f0"], + sample["f0_target"], + sample["f0_mask"], + reduce, + ) + loss = self.weights.to(token_loss.device) * torch.stack( + [token_loss, dur_loss, f0_loss], dim=-1 + ) + loss = loss.sum() if reduce else loss.sum(-1) + + sample_size = ( + sample["target"].size(0) if self.sentence_avg else sample["ntokens"] + ) + logging_output = { + "loss": loss.detach().sum().item(), + "token_loss": token_loss.detach().sum().item(), + "dur_loss": dur_loss.detach().sum().item(), + "f0_loss": f0_loss.detach().sum().item(), + "ntokens": sample["ntokens"], + "nsentences": sample["target"].size(0), + "sample_size": sample_size, + } + return loss, sample_size, logging_output + + @staticmethod + def reduce_metrics(logging_outputs) -> None: + """Aggregate logging outputs from data parallel training.""" + loss_sum = sum(log.get("loss", 0) for log in logging_outputs) + token_loss_sum = sum(log.get("token_loss", 0) for log in logging_outputs) + dur_loss_sum = sum(log.get("dur_loss", 0) for log in logging_outputs) + f0_loss_sum = sum(log.get("f0_loss", 0) for log in logging_outputs) + + sample_size = sum(log.get("sample_size", 0) for log in logging_outputs) + + metrics.log_scalar("loss", loss_sum / sample_size, sample_size, round=3) + + metrics.log_scalar( + "token_loss", token_loss_sum / sample_size, sample_size, round=3 + ) + + metrics.log_scalar("dur_loss", dur_loss_sum / sample_size, sample_size, round=3) + + metrics.log_scalar("f0_loss", f0_loss_sum / sample_size, sample_size, round=3) + + @staticmethod + def logging_outputs_can_be_summed() -> bool: + return True diff --git a/fairseq/fairseq/file_utils.py b/fairseq/fairseq/file_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..b99da2e8cd82a7f4e419fc0abdbc00d617efc611 --- /dev/null +++ b/fairseq/fairseq/file_utils.py @@ -0,0 +1,370 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +""" +Utilities for working with the local dataset cache. +This file is adapted from `AllenNLP `_. +and `huggingface `_. +""" + +import fnmatch +import json +import logging +import os +import shutil +import tarfile +import tempfile +from functools import partial, wraps +from hashlib import sha256 +from io import open + + +try: + from torch.hub import _get_torch_home + + torch_cache_home = _get_torch_home() +except ImportError: + torch_cache_home = os.path.expanduser( + os.getenv( + "TORCH_HOME", os.path.join(os.getenv("XDG_CACHE_HOME", "~/.cache"), "torch") + ) + ) +default_cache_path = os.path.join(torch_cache_home, "pytorch_fairseq") + +try: + from urllib.parse import urlparse +except ImportError: + from urlparse import urlparse + +try: + from pathlib import Path + + PYTORCH_FAIRSEQ_CACHE = Path(os.getenv("PYTORCH_FAIRSEQ_CACHE", default_cache_path)) +except (AttributeError, ImportError): + PYTORCH_FAIRSEQ_CACHE = os.getenv("PYTORCH_FAIRSEQ_CACHE", default_cache_path) + +CONFIG_NAME = "config.json" +WEIGHTS_NAME = "pytorch_model.bin" + +logger = logging.getLogger(__name__) # pylint: disable=invalid-name + + +def load_archive_file(archive_file): + # redirect to the cache, if necessary + try: + resolved_archive_file = cached_path(archive_file, cache_dir=None) + except EnvironmentError: + logger.info( + "Archive name '{}' was not found in archive name list. " + "We assumed '{}' was a path or URL but couldn't find any file " + "associated to this path or URL.".format( + archive_file, + archive_file, + ) + ) + return None + + if resolved_archive_file == archive_file: + logger.info("loading archive file {}".format(archive_file)) + else: + logger.info( + "loading archive file {} from cache at {}".format( + archive_file, resolved_archive_file + ) + ) + + # Extract archive to temp dir and replace .tar.bz2 if necessary + tempdir = None + if not os.path.isdir(resolved_archive_file): + tempdir = tempfile.mkdtemp() + logger.info( + "extracting archive file {} to temp dir {}".format( + resolved_archive_file, tempdir + ) + ) + ext = os.path.splitext(archive_file)[1][1:] + with tarfile.open(resolved_archive_file, "r:" + ext) as archive: + top_dir = os.path.commonprefix(archive.getnames()) + archive.extractall(tempdir) + os.remove(resolved_archive_file) + shutil.move(os.path.join(tempdir, top_dir), resolved_archive_file) + shutil.rmtree(tempdir) + + return resolved_archive_file + + +def url_to_filename(url, etag=None): + """ + Convert `url` into a hashed filename in a repeatable way. + If `etag` is specified, append its hash to the URL's, delimited + by a period. + """ + url_bytes = url.encode("utf-8") + url_hash = sha256(url_bytes) + filename = url_hash.hexdigest() + + if etag: + etag_bytes = etag.encode("utf-8") + etag_hash = sha256(etag_bytes) + filename += "." + etag_hash.hexdigest() + + return filename + + +def filename_to_url(filename, cache_dir=None): + """ + Return the url and etag (which may be ``None``) stored for `filename`. + Raise ``EnvironmentError`` if `filename` or its stored metadata do not exist. + """ + if cache_dir is None: + cache_dir = PYTORCH_FAIRSEQ_CACHE + if isinstance(cache_dir, Path): + cache_dir = str(cache_dir) + + cache_path = os.path.join(cache_dir, filename) + if not os.path.exists(cache_path): + raise EnvironmentError("file {} not found".format(cache_path)) + + meta_path = cache_path + ".json" + if not os.path.exists(meta_path): + raise EnvironmentError("file {} not found".format(meta_path)) + + with open(meta_path, encoding="utf-8") as meta_file: + metadata = json.load(meta_file) + url = metadata["url"] + etag = metadata["etag"] + + return url, etag + + +def cached_path_from_pm(url_or_filename): + """ + Tries to cache the specified URL using PathManager class. + Returns the cached path if success otherwise failure. + """ + try: + from fairseq.file_io import PathManager + + local_path = PathManager.get_local_path(url_or_filename) + return local_path + except Exception: + return None + + +def cached_path(url_or_filename, cache_dir=None): + """ + Given something that might be a URL (or might be a local path), + determine which. If it's a URL, download the file and cache it, and + return the path to the cached file. If it's already a local path, + make sure the file exists and then return the path. + """ + if cache_dir is None: + cache_dir = PYTORCH_FAIRSEQ_CACHE + if isinstance(url_or_filename, Path): + url_or_filename = str(url_or_filename) + if isinstance(cache_dir, Path): + cache_dir = str(cache_dir) + + parsed = urlparse(url_or_filename) + + if parsed.scheme in ("http", "https", "s3"): + # URL, so get it from the cache (downloading if necessary) + return get_from_cache(url_or_filename, cache_dir) + elif os.path.exists(url_or_filename): + # File, and it exists. + return url_or_filename + elif parsed.scheme == "": + # File, but it doesn't exist. + raise EnvironmentError("file {} not found".format(url_or_filename)) + else: + cached_path = cached_path_from_pm(url_or_filename) + if cached_path: + return cached_path + # Something unknown + raise ValueError( + "unable to parse {} as a URL or as a local path".format(url_or_filename) + ) + + +def split_s3_path(url): + """Split a full s3 path into the bucket name and path.""" + parsed = urlparse(url) + if not parsed.netloc or not parsed.path: + raise ValueError("bad s3 path {}".format(url)) + bucket_name = parsed.netloc + s3_path = parsed.path + # Remove '/' at beginning of path. + if s3_path.startswith("/"): + s3_path = s3_path[1:] + return bucket_name, s3_path + + +def s3_request(func): + """ + Wrapper function for s3 requests in order to create more helpful error + messages. + """ + + @wraps(func) + def wrapper(url, *args, **kwargs): + from botocore.exceptions import ClientError + + try: + return func(url, *args, **kwargs) + except ClientError as exc: + if int(exc.response["Error"]["Code"]) == 404: + raise EnvironmentError("file {} not found".format(url)) + else: + raise + + return wrapper + + +@s3_request +def s3_etag(url): + """Check ETag on S3 object.""" + import boto3 + + s3_resource = boto3.resource("s3") + bucket_name, s3_path = split_s3_path(url) + s3_object = s3_resource.Object(bucket_name, s3_path) + return s3_object.e_tag + + +@s3_request +def s3_get(url, temp_file): + """Pull a file directly from S3.""" + import boto3 + + s3_resource = boto3.resource("s3") + bucket_name, s3_path = split_s3_path(url) + s3_resource.Bucket(bucket_name).download_fileobj(s3_path, temp_file) + + +def request_wrap_timeout(func, url): + import requests + + for attempt, timeout in enumerate([10, 20, 40, 60, 60]): + try: + return func(timeout=timeout) + except requests.exceptions.Timeout as e: + logger.warning( + "Request for %s timed-out (attempt %d). Retrying with a timeout of %d secs", + url, + attempt, + timeout, + exc_info=e, + ) + continue + raise RuntimeError(f"Unable to fetch file {url}") + + +def http_get(url, temp_file): + import requests + from tqdm import tqdm + + req = request_wrap_timeout(partial(requests.get, url, stream=True), url) + content_length = req.headers.get("Content-Length") + total = int(content_length) if content_length is not None else None + progress = tqdm(unit="B", total=total) + for chunk in req.iter_content(chunk_size=1024): + if chunk: # filter out keep-alive new chunks + progress.update(len(chunk)) + temp_file.write(chunk) + progress.close() + + +def get_from_cache(url, cache_dir=None): + """ + Given a URL, look for the corresponding dataset in the local cache. + If it's not there, download it. Then return the path to the cached file. + """ + if cache_dir is None: + cache_dir = PYTORCH_FAIRSEQ_CACHE + if isinstance(cache_dir, Path): + cache_dir = str(cache_dir) + + if not os.path.exists(cache_dir): + os.makedirs(cache_dir) + + # Get eTag to add to filename, if it exists. + if url.startswith("s3://"): + etag = s3_etag(url) + else: + try: + import requests + + response = request_wrap_timeout( + partial(requests.head, url, allow_redirects=True), url + ) + if response.status_code != 200: + etag = None + else: + etag = response.headers.get("ETag") + except RuntimeError: + etag = None + + filename = url_to_filename(url, etag) + + # get cache path to put the file + cache_path = os.path.join(cache_dir, filename) + + # If we don't have a connection (etag is None) and can't identify the file + # try to get the last downloaded one + if not os.path.exists(cache_path) and etag is None: + matching_files = fnmatch.filter(os.listdir(cache_dir), filename + ".*") + matching_files = list(filter(lambda s: not s.endswith(".json"), matching_files)) + if matching_files: + cache_path = os.path.join(cache_dir, matching_files[-1]) + + if not os.path.exists(cache_path): + # Download to temporary file, then copy to cache dir once finished. + # Otherwise you get corrupt cache entries if the download gets interrupted. + with tempfile.NamedTemporaryFile() as temp_file: + logger.info("%s not found in cache, downloading to %s", url, temp_file.name) + + # GET file object + if url.startswith("s3://"): + s3_get(url, temp_file) + else: + http_get(url, temp_file) + + # we are copying the file before closing it, so flush to avoid truncation + temp_file.flush() + # shutil.copyfileobj() starts at the current position, so go to the start + temp_file.seek(0) + + logger.info("copying %s to cache at %s", temp_file.name, cache_path) + with open(cache_path, "wb") as cache_file: + shutil.copyfileobj(temp_file, cache_file) + + logger.info("creating metadata file for %s", cache_path) + meta = {"url": url, "etag": etag} + meta_path = cache_path + ".json" + with open(meta_path, "w") as meta_file: + output_string = json.dumps(meta) + meta_file.write(output_string) + + logger.info("removing temp file %s", temp_file.name) + + return cache_path + + +def read_set_from_file(filename): + """ + Extract a de-duped collection (set) of text from a file. + Expected file format is one item per line. + """ + collection = set() + with open(filename, "r", encoding="utf-8") as file_: + for line in file_: + collection.add(line.rstrip()) + return collection + + +def get_file_extension(path, dot=True, lower=True): + ext = os.path.splitext(path)[1] + ext = ext if dot else ext[1:] + return ext.lower() if lower else ext diff --git a/fairseq/fairseq/libbleu.cpython-310-x86_64-linux-gnu.so b/fairseq/fairseq/libbleu.cpython-310-x86_64-linux-gnu.so new file mode 100644 index 0000000000000000000000000000000000000000..09691098eb19dae1fb026acb70b8296e486198fe Binary files /dev/null and b/fairseq/fairseq/libbleu.cpython-310-x86_64-linux-gnu.so differ diff --git a/fairseq/fairseq/nan_detector.py b/fairseq/fairseq/nan_detector.py new file mode 100644 index 0000000000000000000000000000000000000000..bd0f9110731461f78e85196185303f1b5ea62c91 --- /dev/null +++ b/fairseq/fairseq/nan_detector.py @@ -0,0 +1,108 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import logging + +import torch + + +logger = logging.getLogger(__name__) + + +class NanDetector: + """ + Detects the first NaN or Inf in forward and/or backward pass and logs, together with the module name + """ + + def __init__(self, model, forward=True, backward=True): + self.bhooks = [] + self.fhooks = [] + self.forward = forward + self.backward = backward + self.named_parameters = list(model.named_parameters()) + self.reset() + + for name, mod in model.named_modules(): + mod.__module_name = name + self.add_hooks(mod) + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, exc_traceback): + # Dump out all model gnorms to enable better debugging + norm = {} + gradients = {} + for name, param in self.named_parameters: + if param.grad is not None: + grad_norm = torch.norm(param.grad.data.float(), p=2) + norm[name] = param.norm().item() + if torch.isnan(grad_norm).any() or torch.isinf(grad_norm).any(): + gradients[name] = param.grad.data + if len(gradients) > 0: + logger.info("Detected nan/inf grad norm, dumping norms...") + logger.info(f"norms: {norm}") + logger.info(f"gradients: {gradients}") + + self.close() + + def add_hooks(self, module): + if self.forward: + self.fhooks.append(module.register_forward_hook(self.fhook_fn)) + if self.backward: + self.bhooks.append(module.register_backward_hook(self.bhook_fn)) + + def reset(self): + self.has_printed_f = False + self.has_printed_b = False + + def _detect(self, tensor, name, backward): + err = None + if ( + torch.is_floating_point(tensor) + # single value tensors (like the loss) will not provide much info + and tensor.numel() >= 2 + ): + with torch.no_grad(): + if torch.isnan(tensor).any(): + err = "NaN" + elif torch.isinf(tensor).any(): + err = "Inf" + if err is not None: + err = f"{err} detected in output of {name}, shape: {tensor.shape}, {'backward' if backward else 'forward'}" + return err + + def _apply(self, module, inp, x, backward): + if torch.is_tensor(x): + if isinstance(inp, tuple) and len(inp) > 0: + inp = inp[0] + err = self._detect(x, module.__module_name, backward) + if err is not None: + if torch.is_tensor(inp) and not backward: + err += ( + f" input max: {inp.max().item()}, input min: {inp.min().item()}" + ) + + has_printed_attr = "has_printed_b" if backward else "has_printed_f" + logger.warning(err) + setattr(self, has_printed_attr, True) + elif isinstance(x, dict): + for v in x.values(): + self._apply(module, inp, v, backward) + elif isinstance(x, list) or isinstance(x, tuple): + for v in x: + self._apply(module, inp, v, backward) + + def fhook_fn(self, module, inp, output): + if not self.has_printed_f: + self._apply(module, inp, output, backward=False) + + def bhook_fn(self, module, inp, output): + if not self.has_printed_b: + self._apply(module, inp, output, backward=True) + + def close(self): + for hook in self.fhooks + self.bhooks: + hook.remove() diff --git a/fairseq/fairseq/ngram_repeat_block.py b/fairseq/fairseq/ngram_repeat_block.py new file mode 100644 index 0000000000000000000000000000000000000000..4eb50303116671a47d03528fbe8a808647cbb116 --- /dev/null +++ b/fairseq/fairseq/ngram_repeat_block.py @@ -0,0 +1,120 @@ +# Originally from Microsoft Corporation. +# Licensed under the MIT License. + +""" Wrapper for ngram_repeat_block cuda extension """ +import math +import warnings +from typing import List + +import torch +from torch import nn + +try: + from fairseq import ngram_repeat_block_cuda + + EXTENSION_BUILT = True +except ImportError: + EXTENSION_BUILT = False + + +def is_cuda_extension_usable() -> bool: + """Check whether ngram_repeat_block_cuda is built properly""" + if not EXTENSION_BUILT or not torch.cuda.is_available(): + return False + bsz = 2 + tokens = torch.tensor([[4, 4, 3, 2], [1, 2, 3, 4]], dtype=torch.long, device="cuda") + lprobs = torch.rand((8, 12), device="cuda") + try: + outputs = ngram_repeat_block_cuda.forward(tokens, lprobs, bsz, 3, 4, 3) + outputs = outputs + 4 # This line breaks if the extension is built incorrectly. + return True + except RuntimeError: + warnings.warn( + "NGramRepeatBlock extension must be rebuilt." + 'Run TORCH_CUDA_ARCH_LIST="6.0;6.1;7.0" python setup.py build_ext --inplace' + ) + return False + + +class NGramRepeatBlock(nn.Module): + """Wrapper class for calling ngram_repeat_block cuda extension""" + + def __init__(self, no_repeat_ngram_size: int, use_extension: bool = True): + super().__init__() + self.use_extension = is_cuda_extension_usable() if use_extension else False + self.no_repeat_ngram_size = no_repeat_ngram_size + + def reset_parameters(self): + pass + + @torch.jit.unused + def call_cuda_extension( + self, + tokens, + lprobs, + bsz: int, + beam_size: int, + step: int, + ): + return ngram_repeat_block_cuda.forward( + tokens, lprobs, bsz, step, beam_size, self.no_repeat_ngram_size + ) + + def forward( + self, + tokens, + lprobs, + bsz: int, + beam_size: int, + step: int, + ): + """ + Args: + tokens(Tensor): Input tokens(Bsz*beam, seq_len) + lprobs(Tensor): likelihood probability, + Expected to be updated in place.(Bsz*beam, vocab_size) + bsz(int): batch size + step(int): current step + beam_size(int): beam size + no_repeat_ngram_size(int): Ngram size + """ + msg = f"expected {bsz *beam_size} got" + assert tokens.size(0) == bsz * beam_size, f"{msg} {tokens.size(0)}" + assert lprobs.size(0) == bsz * beam_size, f"{msg} {lprobs.size(0)}" + if self.use_extension: + return self.call_cuda_extension(tokens, lprobs, bsz, beam_size, step) + + else: + return self._no_repeat_ngram( + tokens, + lprobs, + bsz, + beam_size, + step, + ) + + def _no_repeat_ngram(self, tokens, lprobs, bsz: int, beam_size: int, step: int): + """For each hypothesis generate a list of previous ngrams and set associated lprobs to -inf""" + banned_tokens = [ + torch.jit.annotate(List[int], []) for bbsz_idx in range(bsz * beam_size) + ] + if step + 2 - self.no_repeat_ngram_size >= 0: + cpu_tokens: List[List[int]] = tokens.cpu().tolist() + check_start_pos = step + 2 - self.no_repeat_ngram_size + for bbsz_idx in range(bsz * beam_size): + ngram_to_check = cpu_tokens[bbsz_idx][ + -(self.no_repeat_ngram_size - 1) : + ] + for i in range(check_start_pos): + if ( + ngram_to_check + == cpu_tokens[bbsz_idx][i : i + self.no_repeat_ngram_size - 1] + ): + banned_tokens[bbsz_idx].append( + cpu_tokens[bbsz_idx][i + self.no_repeat_ngram_size - 1] + ) + for bbsz_idx in range(bsz * beam_size): + lprobs[bbsz_idx][ + torch.tensor(banned_tokens[bbsz_idx], dtype=torch.int64) + ] = torch.tensor(-math.inf).to(lprobs) + return lprobs diff --git a/fairseq/fairseq/quantization_utils.py b/fairseq/fairseq/quantization_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..11fc414c852b199b80a569bf024272535929abcc --- /dev/null +++ b/fairseq/fairseq/quantization_utils.py @@ -0,0 +1,143 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import logging + +from fairseq.modules.quantization import pq, quantization_options, scalar +from omegaconf import DictConfig + + +logger = logging.getLogger(__name__) + + +def quantize_model_scalar(model, model_cfg: DictConfig): + quant_noise_scalar = getattr(model_cfg, "quant_noise_scalar", 0) or 0 + if quant_noise_scalar > 0: + # quantize_model edits the model in place + scalar.quantize_model_(model, p=quant_noise_scalar, bits=8, update_step=1000) + return model + + +class Quantizer(object): + def __init__(self, config_path, max_epoch, max_update): + try: + import yaml + except ImportError: + raise ImportError("Please install yaml with: pip install yaml") + + # parse config + if config_path: + with open(config_path) as config_file: + config = quantization_options.parse_config_yaml( + yaml.safe_load(config_file) + ) + else: + config = quantization_options.parse_config_yaml({}) + + self.n_centroids_config = config["n_centroids"] + self.block_sizes_config = config["block_sizes"] + self.layers_to_quantize = config["layers_to_quantize"] + + # We assume that training will run for a fixed number of epochs + # (or updates) and that we should train for equal durations + # between iterations of PQ. + num_iterations = len(self.layers_to_quantize) + if max_epoch > 0: + assert max_epoch % num_iterations == 0, ( + "for iterative PQ, --max-epoch (={}) must be evenly divisible by " + "len(layers_to_quantize) (={})".format(max_epoch, num_iterations) + ) + self.epoch_schedule = max_epoch // num_iterations + else: + self.epoch_schedule = None + if max_update > 0: + assert max_update % num_iterations == 0, ( + "for iterative PQ, --max-update (={}) must be evenly divisible by " + "len(layers_to_quantize) (={})".format(max_update, num_iterations) + ) + self.update_schedule = max_update // num_iterations + else: + self.update_schedule = None + assert (self.epoch_schedule is not None) ^ ( + self.update_schedule is not None + ), "for iterative PQ, cannot specify both --max-update and --max-epoch" + + # 0 is a special value for quantization step, which will force + # the first call to begin_epoch() to call step() + self.quantization_step = 0 + + def set_trainer(self, trainer): + self.trainer = trainer + self.size_tracker = pq.SizeTracker(self.trainer.get_model()) + + def step(self): + """Move to the next stage of quantization.""" + if self.quantization_step >= len(self.layers_to_quantize): + # Maybe we just finished the last training step or we loaded + # a checkpoint for an iterative PQ model which previously + # finished training. Either way, don't quantize again. + return + + logger.info( + "quantizing model (step={}; layers_to_quantize[step]={})".format( + self.quantization_step, self.layers_to_quantize[self.quantization_step] + ) + ) + quantized_layers = pq.quantize_model_( + self.trainer.get_model(), + self.size_tracker, + self.layers_to_quantize, + self.block_sizes_config, + self.n_centroids_config, + step=self.quantization_step, + ) + logger.info("quantized layers: {}".format(quantized_layers)) + logger.info(self.size_tracker) + + self.quantization_step += 1 + + # reintialize the Trainer since model parameters have changed + self.trainer.reinitialize() + + def begin_epoch(self, epoch): + """Called at the beginning of each epoch (epochs start at 1).""" + if ( + ( + self.epoch_schedule is not None + and epoch > 0 + and (epoch - 1) % self.epoch_schedule == 0 + ) + # we always step once in the beginning, even if using + # update-based quantization + or self.quantization_step == 0 + ): + self.step() + + def step_update(self, num_updates): + """Called at the end of each step.""" + if ( + self.update_schedule is not None + and num_updates > 0 + and num_updates % self.update_schedule == 0 + ): + self.step() + + def state_dict(self): + return { + "n_centroids_config": self.n_centroids_config, + "block_sizes_config": self.block_sizes_config, + "layers_to_quantize": self.layers_to_quantize, + "epoch_schedule": self.epoch_schedule, + "update_schedule": self.update_schedule, + "quantization_step": self.quantization_step, + } + + def load_state_dict(self, state_dict): + self.n_centroids_config = state_dict["n_centroids_config"] + self.block_sizes_config = state_dict["block_sizes_config"] + self.layers_to_quantize = state_dict["layers_to_quantize"] + self.epoch_schedule = state_dict["epoch_schedule"] + self.update_schedule = state_dict["update_schedule"] + self.quantization_step = state_dict["quantization_step"] diff --git a/fairseq/fairseq/registry.py b/fairseq/fairseq/registry.py new file mode 100644 index 0000000000000000000000000000000000000000..904ffcd60253c069f466a0b7ba0aaa2136c78c82 --- /dev/null +++ b/fairseq/fairseq/registry.py @@ -0,0 +1,104 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from argparse import Namespace + +from typing import Union +from fairseq.dataclass import FairseqDataclass +from fairseq.dataclass.utils import merge_with_parent +from hydra.core.config_store import ConfigStore +from omegaconf import DictConfig + +REGISTRIES = {} + + +def setup_registry(registry_name: str, base_class=None, default=None, required=False): + assert registry_name.startswith("--") + registry_name = registry_name[2:].replace("-", "_") + + REGISTRY = {} + REGISTRY_CLASS_NAMES = set() + DATACLASS_REGISTRY = {} + + # maintain a registry of all registries + if registry_name in REGISTRIES: + return # registry already exists + REGISTRIES[registry_name] = { + "registry": REGISTRY, + "default": default, + "dataclass_registry": DATACLASS_REGISTRY, + } + + def build_x(cfg: Union[DictConfig, str, Namespace], *extra_args, **extra_kwargs): + if isinstance(cfg, DictConfig): + choice = cfg._name + + if choice and choice in DATACLASS_REGISTRY: + from_checkpoint = extra_kwargs.get("from_checkpoint", False) + dc = DATACLASS_REGISTRY[choice] + cfg = merge_with_parent(dc(), cfg, remove_missing=from_checkpoint) + elif isinstance(cfg, str): + choice = cfg + if choice in DATACLASS_REGISTRY: + cfg = DATACLASS_REGISTRY[choice]() + else: + choice = getattr(cfg, registry_name, None) + if choice in DATACLASS_REGISTRY: + cfg = DATACLASS_REGISTRY[choice].from_namespace(cfg) + + if choice is None: + if required: + raise ValueError("{} is required!".format(registry_name)) + return None + + cls = REGISTRY[choice] + if hasattr(cls, "build_" + registry_name): + builder = getattr(cls, "build_" + registry_name) + else: + builder = cls + + if "from_checkpoint" in extra_kwargs: + del extra_kwargs["from_checkpoint"] + + return builder(cfg, *extra_args, **extra_kwargs) + + def register_x(name, dataclass=None): + def register_x_cls(cls): + if name in REGISTRY: + raise ValueError( + "Cannot register duplicate {} ({})".format(registry_name, name) + ) + if cls.__name__ in REGISTRY_CLASS_NAMES: + raise ValueError( + "Cannot register {} with duplicate class name ({})".format( + registry_name, cls.__name__ + ) + ) + if base_class is not None and not issubclass(cls, base_class): + raise ValueError( + "{} must extend {}".format(cls.__name__, base_class.__name__) + ) + + if dataclass is not None and not issubclass(dataclass, FairseqDataclass): + raise ValueError( + "Dataclass {} must extend FairseqDataclass".format(dataclass) + ) + + cls.__dataclass = dataclass + if cls.__dataclass is not None: + DATACLASS_REGISTRY[name] = cls.__dataclass + + cs = ConfigStore.instance() + node = dataclass() + node._name = name + cs.store(name=name, group=registry_name, node=node, provider="fairseq") + + REGISTRY[name] = cls + + return cls + + return register_x_cls + + return build_x, register_x, REGISTRY, DATACLASS_REGISTRY diff --git a/fairseq/fairseq/search.py b/fairseq/fairseq/search.py new file mode 100644 index 0000000000000000000000000000000000000000..c7378bbb514342cd3f9f56c8514d0fa5cb351316 --- /dev/null +++ b/fairseq/fairseq/search.py @@ -0,0 +1,892 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import math + +from typing import List, Optional + +import torch +import torch.nn as nn +from fairseq.token_generation_constraints import ( + ConstraintState, + OrderedConstraintState, + UnorderedConstraintState, +) +from torch import Tensor + + +class Search(nn.Module): + def __init__(self, tgt_dict): + super().__init__() + self.pad = tgt_dict.pad() + self.unk = tgt_dict.unk() + self.eos = tgt_dict.eos() + self.vocab_size = len(tgt_dict) + self.src_lengths = torch.tensor(-1) + self.supports_constraints = False + self.stop_on_max_len = False + + def step( + self, step, lprobs, scores, prev_output_tokens=None, original_batch_idxs=None + ): + """Take a single search step. + + Args: + step: the current search step, starting at 0 + lprobs: (bsz x input_beam_size x vocab_size) + the model's log-probabilities over the vocabulary at the current step + scores: (bsz x input_beam_size x step) + the historical model scores of each hypothesis up to this point + prev_output_tokens: (bsz x step) + the previously generated oputput tokens + original_batch_idxs: (bsz) + the tensor with the batch indices, in the range [0, bsz) + this is useful in case there has been applied a re-ordering + and we need to know the orignal indices + + Return: A tuple of (scores, indices, beams) where: + scores: (bsz x output_beam_size) + the scores of the chosen elements; output_beam_size can be + larger than input_beam_size, e.g., we may return + 2*input_beam_size to account for EOS + indices: (bsz x output_beam_size) + the indices of the chosen elements + beams: (bsz x output_beam_size) + the hypothesis ids of the chosen elements, in the range [0, input_beam_size) + """ + raise NotImplementedError + + @torch.jit.export + def set_src_lengths(self, src_lengths): + self.src_lengths = src_lengths + + @torch.jit.export + def init_constraints(self, batch_constraints: Optional[Tensor], beam_size: int): + """Initialize constraint states for constrained decoding (if supported). + + Args: + batch_constraints: (torch.Tensor, optional) + the list of constraints, in packed form + beam_size: (int) + the beam size + Returns: + *encoder_out* rearranged according to *new_order* + """ + pass + + def prune_sentences(self, batch_idxs: Tensor): + """ + Removes constraint states for completed sentences (if supported). + This is called from sequence_generator._generate() when sentences are + deleted from the batch. + + Args: + batch_idxs: Indices of *sentences* whose constraint state should be *kept*. + """ + pass + + def update_constraints(self, active_hypos: Tensor): + """ + Updates the constraint states by selecting the beam items that are retained. + This is called at each time step of sequence_generator._generate() when + the set of 2 * {beam_size} candidate hypotheses are reduced to the beam size. + + Args: + active_hypos: (batch size, beam size) + list of integers denoting, for each sentence, which beam candidate items + should be kept. + """ + pass + + +class BeamSearch(Search): + def __init__(self, tgt_dict): + super().__init__(tgt_dict) + self.constraint_states = None + + @torch.jit.export + def step( + self, + step: int, + lprobs, + scores: Optional[Tensor], + prev_output_tokens: Optional[Tensor] = None, + original_batch_idxs: Optional[Tensor] = None, + candidate_multiple: int = 2, + ): + bsz, beam_size, vocab_size = lprobs.size() + + if step == 0: + # at the first step all hypotheses are equally likely, so use + # only the first beam + lprobs = lprobs[:, ::beam_size, :].contiguous() + else: + # make probs contain cumulative scores for each hypothesis + assert scores is not None + lprobs = lprobs + scores[:, :, step - 1].unsqueeze(-1) + + top_prediction = torch.topk( + lprobs.view(bsz, -1), + k=min( + # Take the best `candidate_muliple`(default 2) x beam_size predictions. We'll choose the first + # beam_size of these which don't predict eos to continue with. + candidate_multiple * beam_size, + lprobs.view(bsz, -1).size(1) - 1, # -1 so we never select pad + ), + ) + scores_buf = top_prediction[0] + indices_buf = top_prediction[1] + # Project back into relative indices and beams + beams_buf = torch.div(indices_buf, vocab_size, rounding_mode="trunc") + indices_buf = indices_buf.fmod(vocab_size) + + # At this point, beams_buf and indices_buf are single-dim and contain relative indices + return scores_buf, indices_buf, beams_buf + + +class PrefixConstrainedBeamSearch(Search): + def __init__(self, tgt_dict, prefix_allowed_tokens_fn): + super().__init__(tgt_dict) + self.prefix_allowed_tokens_fn = prefix_allowed_tokens_fn + self.stop_on_max_len = True + + @torch.jit.export + def apply_mask(self, x, prev_output_tokens, original_batch_idxs): + beam_size = x.shape[0] // original_batch_idxs.shape[0] + original_batch_idxs = ( + original_batch_idxs.unsqueeze(-1).repeat((1, beam_size)).flatten().tolist() + ) + + mask = torch.full_like(x, -math.inf) + for sent_i, (sent, batch_i) in enumerate( + zip(prev_output_tokens, original_batch_idxs) + ): + mask[sent_i, :, self.prefix_allowed_tokens_fn(batch_i, sent)] = 0 + + return mask + + @torch.jit.export + def step( + self, + step: int, + lprobs: Tensor, + scores: Tensor, + prev_output_tokens: Tensor, + original_batch_idxs: Tensor, + ): + bsz, beam_size, vocab_size = lprobs.size() + + lprobs += self.apply_mask( + lprobs.view(bsz * beam_size, 1, vocab_size), + prev_output_tokens, + original_batch_idxs, + ).view(bsz, beam_size, vocab_size) + + if step == 0: + # at the first step all hypotheses are equally likely, so use + # only the first beam + lprobs = lprobs[:, ::beam_size, :].contiguous() + else: + # make probs contain cumulative scores for each hypothesis + assert scores is not None + lprobs = lprobs + scores[:, :, step - 1].unsqueeze(-1) + + top_prediction = torch.topk( + lprobs.view(bsz, -1), + k=min( + # Take the best beam_size predictions. We'll choose the first + # beam_size of these which don't predict eos to continue with. + beam_size, + lprobs.view(bsz, -1).size(1) - 1, # -1 so we never select pad + ), + ) + scores_buf = top_prediction[0] + indices_buf = top_prediction[1] + beams_buf = indices_buf // vocab_size + indices_buf = indices_buf.fmod(vocab_size) + return scores_buf, indices_buf, beams_buf + + +class LexicallyConstrainedBeamSearch(Search): + """Implements lexically constrained beam search as described in + + Fast Lexically Constrained Decoding with Dynamic Beam + Allocation for Neural Machine Translation. Post & Vilar, + NAACL 2018. https://www.aclweb.org/anthology/N18-1119/ + + and + + Improved Lexically Constrained Decoding for Translation and + Monolingual Rewriting. Hu et al, NAACL + 2019. https://www.aclweb.org/anthology/N19-1090/ + + This is accomplished by maintaining, for each beam hypothesis, a + ConstraintState object (see constraints.py) that tracks which + constraints have been generated and using this information to + shape the beam for each input sentence. + """ + + def __init__(self, tgt_dict, representation): + super().__init__(tgt_dict) + self.representation = representation + self.vocab_size = len(tgt_dict) + self.num_cands = 0 + self.supports_constraints = True + + @torch.jit.export + def init_constraints(self, batch_constraints: Optional[Tensor], beam_size: int): + self.constraint_states = [] + for constraint_tensor in batch_constraints: + if self.representation == "ordered": + constraint_state = OrderedConstraintState.create(constraint_tensor) + elif self.representation == "unordered": + constraint_state = UnorderedConstraintState.create(constraint_tensor) + + self.constraint_states.append([constraint_state for i in range(beam_size)]) + + @torch.jit.export + def prune_sentences(self, batch_idxs: Tensor): + self.constraint_states = [ + self.constraint_states[i] for i in batch_idxs.tolist() + ] + + @torch.jit.export + def update_constraints(self, active_hypos: Tensor): + if self.constraint_states: + batch_size = active_hypos.size(0) + for sentid in range(batch_size): + self.constraint_states[sentid] = [ + self.constraint_states[sentid][i] for i in active_hypos[sentid] + ] + + @torch.jit.export + def step( + self, + step: int, + lprobs: Tensor, + scores: Optional[Tensor], + prev_output_tokens: Optional[Tensor] = None, + original_batch_idxs: Optional[Tensor] = None, + ): + """ + A constrained step builds a large candidates list from the following: + - the top 2 * {beam_size} items over the whole beam + - for each item in the beam + - the top {each_k} (default 1) + - all next constraints + We then compute the constrained state of each beam item, and assign + stripe codes: 0 to the best in each bank, 1 to the 2nd-best, and so + on. We then sort by (stripe, score), and truncate the list at + 2 * beam size. + + Args: + step: the decoder step + lprobs: (batch size, beam size, target vocab) + the target-vocab distributions for each item in the beam. + Retrun: A tuple of (scores, indices, beams, constraints) where: + scores: (batch, output beam size) + the scores of the chosen elements + indices: (batch, output beam size) + the target vocab indices of the chosen elements + beams: (batch, output beam size) + the 0-indexed hypothesis ids of the chosen elements + constraints: (batch, output beam size) + the new constraint states + """ + each_k = 1 + device = lprobs.device + + batch_size, beam_size, vocab_size = lprobs.size() + + self.num_cands = min( + # Just take the k-best. We'll get another k from the 1-best from each + # row, plus more from the constraints + beam_size * 2, + lprobs.view(batch_size, -1).size(1) - 1, # -1 so we never select pad + ) + + # STEP 0: Preliminary. Prevent EOS for unfinished hyps across all batch items + constraint_states = self.constraint_states + if constraint_states and step > 0: + not_finished_indices = [] + for sentno, sent_constraints in enumerate(constraint_states): + for beamno, state in enumerate(sent_constraints): + index = sentno * beam_size + beamno + if not state.finished: + not_finished_indices.append(index) + not_finished_indices = torch.tensor(not_finished_indices) + if not_finished_indices.numel() > 0: + lprobs.view(batch_size * beam_size, -1)[ + not_finished_indices, self.eos + ] = -math.inf + + if step == 0: + # at the first step all hypotheses are equally likely, so use + # only the first beam entry for each batch item + lprobs = lprobs[:, ::beam_size, :].contiguous() + else: + # make probs contain cumulative scores for each hypothesis + assert scores is not None + lprobs = lprobs + scores[:, :, step - 1].unsqueeze(-1) + + top_prediction = torch.topk( + lprobs.view(batch_size, -1), + self.num_cands, + ) + scores_buf, indices_buf = top_prediction + # Project back into relative indices and beams + beams_buf = indices_buf // vocab_size + indices_buf = indices_buf.fmod(vocab_size) + + # Short circuit if there are no constraints in this batch + if not constraint_states: + return scores_buf, indices_buf, beams_buf + + # STEP 1: get top-1 from each hypothesis across all sentences in the batch + if step > 0: + top_scores, top_indices = torch.topk( + lprobs.view(batch_size * beam_size, -1), + k=each_k, + dim=1, + ) + top_scores = top_scores.view(batch_size, -1) + top_indices = top_indices.view(batch_size, -1) + scores_buf = torch.cat((scores_buf, top_scores), dim=1) + indices_buf = torch.cat((indices_buf, top_indices), dim=1) + new_beams = torch.arange(0, beam_size, device=device).repeat(batch_size, 1) + beams_buf = torch.cat((beams_buf, new_beams), dim=1) + + # Now, process sentences in the batch one by one. + new_scores_buf = torch.zeros((batch_size, 2 * beam_size), device=device) + new_indices_buf = torch.zeros((batch_size, 2 * beam_size), device=device).long() + new_beams_buf = torch.zeros((batch_size, 2 * beam_size), device=device).long() + for sentno, states in enumerate(constraint_states): + scores, indices, beams, new_states = self.step_sentence( + step, + sentno, + lprobs[sentno], + constraint_states[sentno], + beams_buf[sentno].clone(), + indices_buf[sentno].clone(), + scores_buf[sentno].clone(), + ) + new_scores_buf[sentno] = scores + new_indices_buf[sentno] = indices + new_beams_buf[sentno] = beams + self.constraint_states[sentno] = new_states + + return new_scores_buf, new_indices_buf, new_beams_buf + + @torch.jit.export + def step_sentence( + self, + step: int, + sentno: int, + lprobs: Tensor, + constraint_states: List[List[ConstraintState]], + beams_buf: Tensor, + indices_buf: Tensor, + scores_buf: Tensor, + ): + """Does per-sentence processing. Adds all constraints for each + hypothesis to the list of candidates; then removes duplicates, + sorts, and dynamically stripes across the banks. All tensor inputs + are collapsed to those pertaining to a single input sentence. + """ + device = lprobs.device + + # STEP 2: Add all constraints for each beam item + for beamno, state in enumerate(constraint_states): + next_tokens = torch.tensor(list(state.next_tokens()), device=device).long() + if next_tokens.numel() != 0: + indices_buf = torch.cat((indices_buf, next_tokens)) + next_beams = ( + torch.tensor(beamno, device=device) + .repeat(next_tokens.size(0)) + .long() + ) + beams_buf = torch.cat((beams_buf, next_beams)) + next_values = lprobs[beamno].take(next_tokens.view(-1)) + scores_buf = torch.cat((scores_buf, next_values)) + + # At the 0th time step, there is just one beam item + if step == 0: + break + + # STEP 3: Compute the "bank" for each candidate. This is the + # number of constraints it's generated. We need this so that + # we can do round-robin allocation of the beam across these + # banks. If C is the number of constraints, we select the best + # item in bank C, then the best in bank C-1, etc, followed by + # the 2nd-best in bank C, the 2nd-best in bank C-1, etc, and so + # on, until the maximum beam size. We accomplish this by + # creating a sort key and striping across the banks. + + # Compute the new states for all candidates + cands_size = indices_buf.size(0) + constraint_states = [ + constraint_states[beams_buf[i]].advance(indices_buf[i]) + for i in range(cands_size) + ] + + banks = torch.tensor([state.bank for state in constraint_states], device=device) + + # STEP 4: Sort + num_constraint_tokens = len(state.tokens) + + # Sort by keys (bank, score) (i.e., sort banks together, and scores + # within banks). AFAIK pytorch doesn't support either stable sort or + # multi-key sorting, so we have to hack this. + MAX_SCORE = -100 + sort_key = (num_constraint_tokens - banks) * MAX_SCORE + scores_buf + sort_values, sort_indices = sort_key.sort(dim=0, descending=True) + scores_buf = scores_buf[sort_indices] + indices_buf = indices_buf[sort_indices] + beams_buf = beams_buf[sort_indices] + banks = banks[sort_indices] + + # Sort the constraints to follow suit + constraint_states = [constraint_states[i] for i in sort_indices] + + # STEP 5: Remove duplicates. The topk calls (overall and + # per-row) plus the per-row generation of constraints will + # produce duplicates. Here we remove them. + + def roll(t): + """Rolls a 1d tensor left by 1. + + [0, 1, 2, 3, 4] becomes [4, 0, 1, 2, 3] + """ + return torch.cat((t[-1].unsqueeze(0), t[0:-1]), dim=0) + + # We map candidates (beam, token_id) to a single dimension. + # This is then shifted by 1. We can then easily identify + # duplicates and create a mask that identifies unique + # extensions. + uniques_mask = beams_buf * (self.vocab_size + 1) + indices_buf + uniques_mask = roll(uniques_mask) != uniques_mask + + # Use the mask to pare down the data structures + scores_buf = torch.masked_select(scores_buf, uniques_mask) + indices_buf = torch.masked_select(indices_buf, uniques_mask) + beams_buf = torch.masked_select(beams_buf, uniques_mask) + banks = torch.masked_select(banks, uniques_mask) + i = 1 + for mask in uniques_mask[1:]: + if not mask: + constraint_states.pop(i) + i += mask + + # STEP 6: Assign IDs round-robin across banks, sort, and + # truncate. Now that the candidates are sorted by (bank, + # score) and uniqed, we dynamically allocate the {beam_size} + # beam by striping across the candidates. These stripes will + # be used as sort keys to do round-robin selection. This is + # accomplished in a single pass with offsets. Sorting by + # highest-banks (furthest-along hypotheses) first ensures + # progress through the constraints. + # + # e.g., BANKS: 3 3 3 2 2 2 2 1 1 1 0 0 + # OLD STRIPES: 0 1 2 0 1 2 3 0 1 2 0 1 + # NEW STRIPES: 0 1+4 2+8 0+1 1+5 2+9 3+11 0+2 1+6 2+10 0+3 1+7 + # = 0 5 10 1 6 11 13 2 7 12 3 8 + # + # Sorting by this then gives the following banks: + # + # 3 2 1 0 3 2 1 0 3 2 1 2 + # + # We'll take the top {beam_size} of these. + stripe_offsets = [offset * (len(banks) + 1) for offset in range(len(banks) + 1)] + stripes = torch.zeros_like(banks) + cur_bank_count = -1 + cur_bank = banks[0] + for i, bank in enumerate(banks): + if bank != cur_bank: + cur_bank_count = 0 + cur_bank = bank + else: + cur_bank_count += 1 + stripes[i] = num_constraint_tokens - bank + stripe_offsets[cur_bank_count] + + # STEP 7: Sort by the stripes values + sort_values, sort_indices = stripes.sort(dim=0) + scores_buf = scores_buf[sort_indices] + indices_buf = indices_buf[sort_indices] + beams_buf = beams_buf[sort_indices] + constraint_states = [constraint_states[i] for i in sort_indices] + + # STEP 8: Truncate to the candidates size! + scores_buf = scores_buf[: self.num_cands] + indices_buf = indices_buf[: self.num_cands] + beams_buf = beams_buf[: self.num_cands] + + return scores_buf, indices_buf, beams_buf, constraint_states + + +class LengthConstrainedBeamSearch(Search): + def __init__(self, tgt_dict, min_len_a, min_len_b, max_len_a, max_len_b): + super().__init__(tgt_dict) + self.min_len_a = min_len_a + self.min_len_b = min_len_b + self.max_len_a = max_len_a + self.max_len_b = max_len_b + self.beam = BeamSearch(tgt_dict) + self.needs_src_lengths = True + + def step( + self, + step: int, + lprobs, + scores, + prev_output_tokens: Optional[Tensor] = None, + original_batch_idxs: Optional[Tensor] = None, + ): + min_lens = self.min_len_a * self.src_lengths + self.min_len_b + max_lens = self.max_len_a * self.src_lengths + self.max_len_b + lprobs[step < min_lens, :, self.eos] = -math.inf + lprobs[step >= max_lens, :, self.eos] = 0 + return self.beam.step(step, lprobs, scores) + + +class DiverseBeamSearch(Search): + """Diverse Beam Search. + + See "Diverse Beam Search: Decoding Diverse Solutions from Neural Sequence + Models" for details. + + We implement cumulative diversity penalty here as default, optionally provide Hamming diversity described + in the original paper, and a way to interpolate between the two through diversity_discount. + + Take the example below for illustration of cumulative diversity implemented. + A) I like dogs. + B) I like ____. + C) There are ___. + And we are at step=2, trying to fill in the blank: + + Hamming diversity: + Penalty for B from A is 1 for "dogs" and 0 for any other words like "cats". + Penalty for C from A is 1 for "dogs" and 0 for any other words like "cats". + + Cumulative diversity (default): + Penalty for B from A is 3 for "dogs" and 0 for any other words like "cats". + Penalty for C from A is 1 for "dogs" and 0 for any other words like "cats". + B and C differ because B matches with A for "I" and "like" at respective steps incurring 2 cumulative penalty. + + Using divesrity_discount to interpolate between the two: + if diverstiy_discount = 0.5, then + Penalty for B from A is 1.75 (1 + 0.5 + 0.25) for "dogs" and 0 for any other words like "cats". + Penalty for C from A is 1 for "dogs" and 0 for any other words like "cats". + "I" and "like" matched for B and A at step 0 and 1 respectively. Since "I" is two steps away and "like" is one step away, they are discounted by (0.5)^2 and 0.5 respectively. + When diversity_discount = 0, we recover Hammning diversity and when diversity_discount = 1, we recover cumulative diversity. + + NB: During beam search for each diversity group, `candidate_mutiple` is set to 1 rather than BeamSearch default(2). + This is to ensure we have final `beam_size` candidates so that no diversity groups would be dropped during final token selection in sequence generation. + For full backwards compatibility, use diversity_discount=0 and candidate_multiple=2. + + """ + + def __init__( + self, + tgt_dict, + num_groups, + diversity_strength, + diversity_discount=1.0, + candidate_multiple=1, + ): + super().__init__(tgt_dict) + self.num_groups = num_groups + self.diversity_strength = -diversity_strength + self.beam = BeamSearch(tgt_dict) + self.diversity_discount = diversity_discount + self.candidate_multiple = candidate_multiple + + # Float tensor to keep track of overlap between groups. + # Each token shared at the same step between two groups is counted as one. + # Then token counts are discounted by `diversity_discount` for every next timestep. + # Once initialized, dimension is batch_size * num_groups * num_groups. + self.group_overlap = torch.empty(0) + + @torch.jit.export + def step( + self, + step: int, + lprobs, + scores, + prev_output_tokens: Optional[Tensor] = None, + original_batch_idxs: Optional[Tensor] = None, + ): + bsz, beam_size, vocab_size = lprobs.size() + if beam_size % self.num_groups != 0: + raise ValueError( + "DiverseBeamSearch requires --beam to be divisible by the number of groups" + ) + + # initialize diversity penalty + diversity_buf = torch.zeros(lprobs[:, 0, :].size()).to(lprobs) + + scores_G, beams_G = [], [] + + # pre-allocating tensor for indices for all groups + indices_G_stacked = torch.empty( + bsz, + int(beam_size / self.num_groups) * self.candidate_multiple, + self.num_groups, + dtype=torch.long, + device=lprobs.device, + ) + + for g in range(self.num_groups): + lprobs_g = lprobs[:, g :: self.num_groups, :] + scores_g = scores[:, g :: self.num_groups, :] if step > 0 else None + + diversity_buf.zero_() + # apply diversity penalty + if g > 0: + indices_ = indices_G_stacked[:, :, :g] + if step > 0: + penalty_val = 1 + self.group_overlap[original_batch_idxs, g, :g] + penalty_val = penalty_val.unsqueeze(1) + else: + penalty_val = torch.ones(bsz, 1, 1) + diversity_buf.scatter_add_( + 1, + indices_.reshape(bsz, -1), + penalty_val.expand(indices_.size()) + .reshape(bsz, -1) + .to(diversity_buf), + ) + + lprobs_g = torch.add( + lprobs_g, + other=diversity_buf.unsqueeze(1), + alpha=self.diversity_strength, + ) + else: + lprobs_g = lprobs_g.contiguous() + + scores_buf, indices_buf, beams_buf = self.beam.step( + step, lprobs_g, scores_g, candidate_multiple=self.candidate_multiple + ) + beams_buf.mul_(self.num_groups).add_(g) + + scores_G.append(scores_buf.clone()) + beams_G.append(beams_buf.clone()) + + indices_G_stacked[:, :, g] = indices_buf + + # interleave results from different groups + scores_buf = torch.stack(scores_G, dim=2).view(bsz, -1) + indices_buf = indices_G_stacked.view(bsz, -1) + beams_buf = torch.stack(beams_G, dim=2).view(bsz, -1) + # find num of overlapped tokens for each group pair + # then discount it for next timestamp + overlap = self.diversity_discount * torch.sum( + indices_G_stacked.unsqueeze(2).eq(indices_G_stacked.unsqueeze(3)), dim=1 + ) + if step == 0: + self.group_overlap = overlap + else: + self.group_overlap[original_batch_idxs] = ( + self.group_overlap[original_batch_idxs] * self.diversity_discount + + overlap + ) + + return scores_buf, indices_buf, beams_buf + + +class Sampling(Search): + sampling_topk: int + sampling_topp: float + + def __init__(self, tgt_dict, sampling_topk=-1, sampling_topp=-1.0): + super().__init__(tgt_dict) + self.sampling_topk = sampling_topk + self.sampling_topp = sampling_topp + + def _sample_topp(self, lprobs): + """Sample among the smallest set of elements whose cumulative probability mass exceeds p. + + See `"The Curious Case of Neural Text Degeneration" + (Holtzman et al., 2019) `_. + + Args: + lprobs: (bsz x input_beam_size x vocab_size) + the model's log-probabilities over the vocabulary at the current step + + Return: A tuple of (trimed_probs, truncated_indices) where: + trimed_probs: (bsz x input_beam_size x ?) + the model's probabilities over the elements selected to sample from. The + width of the third dimension is determined by top-P. + truncated_indices: (bsz x input_beam_size x ?) + the indices of the chosen elements. + """ + probs = lprobs.exp_() + + # sort the last dimension (vocab dimension) in descending order + sorted_probs, sorted_indices = probs.sort(descending=True) + + # compute a mask to indicate the words to be included in the top-P set. + cumsum_probs = sorted_probs.cumsum(dim=2) + mask = cumsum_probs.lt(self.sampling_topp) + + # note that mask was computed by 'lt'. One more word needs to be included + # so that the cumulative probability mass can exceed p. + cumsum_mask = mask.cumsum(dim=2) + last_included = cumsum_mask[:, :, -1:] + last_included.clamp_(0, mask.size()[2] - 1) + mask = mask.scatter_(2, last_included, 1) + + # truncate unnecessary dims. + max_dim = last_included.max() + truncated_mask = mask[:, :, : max_dim + 1] + truncated_probs = sorted_probs[:, :, : max_dim + 1] + truncated_indices = sorted_indices[:, :, : max_dim + 1] + + # trim the words that are not in top-P by setting their probabilities + # to 0, so that they would not be sampled later. + trim_mask = ~truncated_mask + trimed_probs = truncated_probs.masked_fill_(trim_mask, 0) + return trimed_probs, truncated_indices + + @torch.jit.export + def step( + self, + step: int, + lprobs, + scores, + prev_output_tokens: Optional[Tensor] = None, + original_batch_idxs: Optional[Tensor] = None, + ): + bsz, beam_size, vocab_size = lprobs.size() + + if step == 0: + # at the first step all hypotheses are equally likely, so use + # only the first beam + lprobs = lprobs[:, ::beam_size, :].contiguous() + + if self.sampling_topp > 0: + # only sample from the smallest set of words whose cumulative probability mass exceeds p + probs, top_indices = self._sample_topp(lprobs) + elif self.sampling_topk > 0: + # only sample from top-k candidates + lprobs, top_indices = lprobs.topk(self.sampling_topk) + probs = lprobs.exp_() + else: + probs = lprobs.exp_() + + # dummy data to be consistent with true branch for type check + top_indices = torch.empty(0).to(probs) + # sample + if step == 0: + indices_buf = torch.multinomial( + probs.view(bsz, -1), + beam_size, + replacement=True, + ).view(bsz, beam_size) + else: + indices_buf = torch.multinomial( + probs.view(bsz * beam_size, -1), + 1, + replacement=True, + ).view(bsz, beam_size) + + if step == 0: + # expand to beam size + probs = probs.expand(bsz, beam_size, -1) + + # gather scores + scores_buf = torch.gather(probs, dim=2, index=indices_buf.unsqueeze(-1)) + scores_buf = scores_buf.log_().view(bsz, -1) + + # remap indices if using top-k or top-P sampling + if self.sampling_topk > 0 or self.sampling_topp > 0: + indices_buf = torch.gather( + top_indices.expand(bsz, beam_size, -1), + dim=2, + index=indices_buf.unsqueeze(-1), + ).squeeze(2) + + if step == 0: + beams_buf = indices_buf.new_zeros(bsz, beam_size) + else: + beams_buf = torch.arange(0, beam_size).to(indices_buf).repeat(bsz, 1) + # make scores cumulative + scores_buf.add_( + torch.gather(scores[:, :, step - 1], dim=1, index=beams_buf) + ) + + return scores_buf, indices_buf, beams_buf + + +class DiverseSiblingsSearch(Search): + """ + Beam search with diverse siblings. + + See "A Simple, Fast Diverse Decoding Algorithm for Neural Generation" for details. + https://arxiv.org/abs/1611.08562 + + 1/ Calculate hypotheses for each beam + 2/ Intra-sibling ordering + 3/ Rewrite scores + 4/ Choose top K hypotheses + + if diversity_rate == 0 is equivalent to BeamSearch + """ + + def __init__(self, tgt_dict, diversity_rate): + super().__init__(tgt_dict) + self.diversity_rate = diversity_rate + self.beam = BeamSearch(tgt_dict) + + def step( + self, + step: int, + lprobs, + scores, + prev_output_tokens: Optional[Tensor] = None, + original_batch_idxs: Optional[Tensor] = None, + ): + bsz, beam_size, vocab_size = lprobs.size() + k = min( + # Take the best 2 x beam_size predictions. We'll choose the first + # beam_size of these which don't predict eos to continue with. + beam_size * 2, + lprobs.view(bsz, -1).size(1) - 1, # -1 so we never select pad + ) + s_list: List[Tensor] + i_list: List[Tensor] + s_list = [torch.empty(0).to(lprobs) for i in range(beam_size)] + i_list = [torch.LongTensor().to(device=lprobs.device) for i in range(beam_size)] + sibling_score = torch.arange(1, k + 1).to(lprobs) * self.diversity_rate + + if step == 0: + return self.beam.step(step, lprobs, scores) + lprobs.add_(scores[:, :, step - 1].unsqueeze(-1)) + + # 1/ Calculate hypotheses for each beam + for i in range(beam_size): + torch.topk(lprobs[:, i, :].view(bsz, -1), k, out=(s_list[i], i_list[i])) + i_list[i].fmod_(vocab_size) + + # 2/ Intra-sibling ordering by default from topk + 3/ Rewrite scores + s_list[i].sub_(sibling_score) + + # 4/ Choose top K hypotheses + indices = torch.stack(i_list, dim=1).view(bsz, -1) + + final_scores = torch.empty(0).to(lprobs) + final_indices = torch.LongTensor().to(device=lprobs.device) + final_beams = torch.LongTensor().to(device=lprobs.device) + (final_scores, final_indices) = torch.topk( + torch.stack(s_list, dim=1).view(bsz, -1), + k, + ) + + final_beams = final_indices // k + + for i in range(bsz): + final_indices[i] = indices[i][final_indices[i]] + + return final_scores, final_indices, final_beams diff --git a/fairseq/fairseq/sequence_generator.py b/fairseq/fairseq/sequence_generator.py new file mode 100644 index 0000000000000000000000000000000000000000..78db504e6ce75ac31980e71cfdbf436e07739025 --- /dev/null +++ b/fairseq/fairseq/sequence_generator.py @@ -0,0 +1,1020 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import math +import sys +from typing import Dict, List, Optional + +import torch +import torch.nn as nn +from torch import Tensor + +from fairseq import search, utils +from fairseq.data import data_utils +from fairseq.models import FairseqIncrementalDecoder +from fairseq.ngram_repeat_block import NGramRepeatBlock + + +class SequenceGenerator(nn.Module): + def __init__( + self, + models, + tgt_dict, + beam_size=1, + max_len_a=0, + max_len_b=200, + max_len=0, + min_len=1, + normalize_scores=True, + len_penalty=1.0, + unk_penalty=0.0, + temperature=1.0, + match_source_len=False, + no_repeat_ngram_size=0, + search_strategy=None, + eos=None, + symbols_to_strip_from_output=None, + lm_model=None, + lm_weight=1.0, + tokens_to_suppress=(), + ): + """Generates translations of a given source sentence. + + Args: + models (List[~fairseq.models.FairseqModel]): ensemble of models, + currently support fairseq.models.TransformerModel for scripting + beam_size (int, optional): beam width (default: 1) + max_len_a/b (int, optional): generate sequences of maximum length + ax + b, where x is the source length + max_len (int, optional): the maximum length of the generated output + (not including end-of-sentence) + min_len (int, optional): the minimum length of the generated output + (not including end-of-sentence) + normalize_scores (bool, optional): normalize scores by the length + of the output (default: True) + len_penalty (float, optional): length penalty, where <1.0 favors + shorter, >1.0 favors longer sentences (default: 1.0) + unk_penalty (float, optional): unknown word penalty, where <0 + produces more unks, >0 produces fewer (default: 0.0) + temperature (float, optional): temperature, where values + >1.0 produce more uniform samples and values <1.0 produce + sharper samples (default: 1.0) + match_source_len (bool, optional): outputs should match the source + length (default: False) + """ + super().__init__() + if isinstance(models, EnsembleModel): + self.model = models + else: + self.model = EnsembleModel(models) + self.tgt_dict = tgt_dict + self.pad = tgt_dict.pad() + self.unk = tgt_dict.unk() + self.eos = tgt_dict.eos() if eos is None else eos + self.symbols_to_strip_from_output = ( + symbols_to_strip_from_output.union({self.eos}) + if symbols_to_strip_from_output is not None + else {self.eos} + ) + + self.token_indices_to_suppress: Optional[Tensor] = None + token_indices_to_suppress = [] + for token_string in tokens_to_suppress: + token_index = tgt_dict.index(token_string) + assert token_index != self.unk + token_indices_to_suppress.append(token_index) + if len(token_indices_to_suppress) > 0: + self.token_indices_to_suppress = torch.Tensor( + token_indices_to_suppress + ).long() + + self.vocab_size = len(tgt_dict) + self.beam_size = beam_size + # the max beam size is the dictionary size - 1, since we never select pad + self.beam_size = min(beam_size, self.vocab_size - 1) + self.model.set_decoder_beam_size(self.beam_size) + self.max_len_a = max_len_a + self.max_len_b = max_len_b + self.min_len = min_len + self.max_len = max_len or self.model.max_decoder_positions() + + self.normalize_scores = normalize_scores + self.len_penalty = len_penalty + self.unk_penalty = unk_penalty + self.temperature = temperature + self.match_source_len = match_source_len + + if no_repeat_ngram_size > 0: + self.repeat_ngram_blocker = NGramRepeatBlock(no_repeat_ngram_size) + else: + self.repeat_ngram_blocker = None + + assert temperature > 0, "--temperature must be greater than 0" + + self.search = ( + search.BeamSearch(tgt_dict) if search_strategy is None else search_strategy + ) + # We only need to set src_lengths in LengthConstrainedBeamSearch. + # As a module attribute, setting it would break in multithread + # settings when the model is shared. + self.should_set_src_lengths = ( + hasattr(self.search, "needs_src_lengths") and self.search.needs_src_lengths + ) + + self.model.eval() + + self.lm_model = lm_model + self.lm_weight = lm_weight + if self.lm_model is not None: + self.lm_model.eval() + + def cuda(self): + self.model.cuda() + return self + + @torch.no_grad() + def forward( + self, + sample: Dict[str, Dict[str, Tensor]], + prefix_tokens: Optional[Tensor] = None, + bos_token: Optional[int] = None, + ): + """Generate a batch of translations. + + Args: + sample (dict): batch + prefix_tokens (torch.LongTensor, optional): force decoder to begin + with these tokens + bos_token (int, optional): beginning of sentence token + (default: self.eos) + """ + return self._generate(sample, prefix_tokens, bos_token=bos_token) + + # TODO(myleott): unused, deprecate after pytorch-translate migration + def generate_batched_itr(self, data_itr, beam_size=None, cuda=False, timer=None): + """Iterate over a batched dataset and yield individual translations. + Args: + cuda (bool, optional): use GPU for generation + timer (StopwatchMeter, optional): time generations + """ + for sample in data_itr: + s = utils.move_to_cuda(sample) if cuda else sample + if "net_input" not in s: + continue + input = s["net_input"] + # model.forward normally channels prev_output_tokens into the decoder + # separately, but SequenceGenerator directly calls model.encoder + encoder_input = { + k: v for k, v in input.items() if k != "prev_output_tokens" + } + if timer is not None: + timer.start() + with torch.no_grad(): + hypos = self.generate(encoder_input) + if timer is not None: + timer.stop(sum(len(h[0]["tokens"]) for h in hypos)) + for i, id in enumerate(s["id"].data): + # remove padding + src = utils.strip_pad(input["src_tokens"].data[i, :], self.pad) + ref = ( + utils.strip_pad(s["target"].data[i, :], self.pad) + if s["target"] is not None + else None + ) + yield id, src, ref, hypos[i] + + @torch.no_grad() + def generate( + self, models, sample: Dict[str, Dict[str, Tensor]], **kwargs + ) -> List[List[Dict[str, Tensor]]]: + """Generate translations. Match the api of other fairseq generators. + + Args: + models (List[~fairseq.models.FairseqModel]): ensemble of models + sample (dict): batch + prefix_tokens (torch.LongTensor, optional): force decoder to begin + with these tokens + constraints (torch.LongTensor, optional): force decoder to include + the list of constraints + bos_token (int, optional): beginning of sentence token + (default: self.eos) + """ + return self._generate(sample, **kwargs) + + def _generate( + self, + sample: Dict[str, Dict[str, Tensor]], + prefix_tokens: Optional[Tensor] = None, + constraints: Optional[Tensor] = None, + bos_token: Optional[int] = None, + ): + incremental_states = torch.jit.annotate( + List[Dict[str, Dict[str, Optional[Tensor]]]], + [ + torch.jit.annotate(Dict[str, Dict[str, Optional[Tensor]]], {}) + for i in range(self.model.models_size) + ], + ) + net_input = sample["net_input"] + + if "src_tokens" in net_input: + src_tokens = net_input["src_tokens"] + # length of the source text being the character length except EndOfSentence and pad + # if src_lengths exists in net_input (speech_to_text dataset case), then use it + if "src_lengths" in net_input: + src_lengths = net_input["src_lengths"] + else: + src_lengths = ( + (src_tokens.ne(self.eos) & src_tokens.ne(self.pad)) + .long() + .sum(dim=1) + ) + elif "source" in net_input: + src_tokens = net_input["source"] + src_lengths = ( + net_input["padding_mask"].size(-1) - net_input["padding_mask"].sum(-1) + if net_input["padding_mask"] is not None + else torch.tensor(src_tokens.size(-1)).to(src_tokens) + ) + elif "features" in net_input: + src_tokens = net_input["features"] + src_lengths = ( + net_input["padding_mask"].size(-1) - net_input["padding_mask"].sum(-1) + if net_input["padding_mask"] is not None + else torch.tensor(src_tokens.size(-1)).to(src_tokens) + ) + else: + raise Exception( + "expected src_tokens or source in net input. input keys: " + + str(net_input.keys()) + ) + + # bsz: total number of sentences in beam + # Note that src_tokens may have more than 2 dimensions (i.e. audio features) + bsz, src_len = src_tokens.size()[:2] + beam_size = self.beam_size + + if constraints is not None and not self.search.supports_constraints: + raise NotImplementedError( + "Target-side constraints were provided, but search method doesn't support them" + ) + + # Initialize constraints, when active + self.search.init_constraints(constraints, beam_size) + + max_len: int = -1 + if self.match_source_len: + max_len = src_lengths.max().item() + else: + max_len = min( + int(self.max_len_a * src_len + self.max_len_b), + self.max_len - 1, + ) + assert ( + self.min_len <= max_len + ), "min_len cannot be larger than max_len, please adjust these!" + # compute the encoder output for each beam + with torch.autograd.profiler.record_function("EnsembleModel: forward_encoder"): + encoder_outs = self.model.forward_encoder(net_input) + + # placeholder of indices for bsz * beam_size to hold tokens and accumulative scores + new_order = torch.arange(bsz).view(-1, 1).repeat(1, beam_size).view(-1) + new_order = new_order.to(src_tokens.device).long() + encoder_outs = self.model.reorder_encoder_out(encoder_outs, new_order) + # ensure encoder_outs is a List. + assert encoder_outs is not None + + # initialize buffers + scores = ( + torch.zeros(bsz * beam_size, max_len + 1).to(src_tokens).float() + ) # +1 for eos; pad is never chosen for scoring + tokens = ( + torch.zeros(bsz * beam_size, max_len + 2) + .to(src_tokens) + .long() + .fill_(self.pad) + ) # +2 for eos and pad + tokens[:, 0] = self.eos if bos_token is None else bos_token + attn: Optional[Tensor] = None + + # A list that indicates candidates that should be ignored. + # For example, suppose we're sampling and have already finalized 2/5 + # samples. Then cands_to_ignore would mark 2 positions as being ignored, + # so that we only finalize the remaining 3 samples. + cands_to_ignore = ( + torch.zeros(bsz, beam_size).to(src_tokens).eq(-1) + ) # forward and backward-compatible False mask + + # list of completed sentences + finalized = torch.jit.annotate( + List[List[Dict[str, Tensor]]], + [torch.jit.annotate(List[Dict[str, Tensor]], []) for i in range(bsz)], + ) # contains lists of dictionaries of infomation about the hypothesis being finalized at each step + + # a boolean array indicating if the sentence at the index is finished or not + finished = [False for i in range(bsz)] + num_remaining_sent = bsz # number of sentences remaining + + # number of candidate hypos per step + cand_size = 2 * beam_size # 2 x beam size in case half are EOS + + # offset arrays for converting between different indexing schemes + bbsz_offsets = ( + (torch.arange(0, bsz) * beam_size) + .unsqueeze(1) + .type_as(tokens) + .to(src_tokens.device) + ) + cand_offsets = torch.arange(0, cand_size).type_as(tokens).to(src_tokens.device) + + reorder_state: Optional[Tensor] = None + batch_idxs: Optional[Tensor] = None + + original_batch_idxs: Optional[Tensor] = None + if "id" in sample and isinstance(sample["id"], Tensor): + original_batch_idxs = sample["id"] + else: + original_batch_idxs = torch.arange(0, bsz).type_as(tokens) + + for step in range(max_len + 1): # one extra step for EOS marker + # reorder decoder internal states based on the prev choice of beams + if reorder_state is not None: + if batch_idxs is not None: + # update beam indices to take into account removed sentences + corr = batch_idxs - torch.arange(batch_idxs.numel()).type_as( + batch_idxs + ) + reorder_state.view(-1, beam_size).add_( + corr.unsqueeze(-1) * beam_size + ) + original_batch_idxs = original_batch_idxs[batch_idxs] + self.model.reorder_incremental_state(incremental_states, reorder_state) + encoder_outs = self.model.reorder_encoder_out( + encoder_outs, reorder_state + ) + with torch.autograd.profiler.record_function( + "EnsembleModel: forward_decoder" + ): + lprobs, avg_attn_scores = self.model.forward_decoder( + tokens[:, : step + 1], + encoder_outs, + incremental_states, + self.temperature, + ) + + if self.lm_model is not None: + lm_out = self.lm_model(tokens[:, : step + 1]) + probs = self.lm_model.get_normalized_probs( + lm_out, log_probs=True, sample=None + ) + probs = probs[:, -1, :] * self.lm_weight + lprobs += probs + + lprobs[lprobs != lprobs] = torch.tensor(-math.inf).to(lprobs) + + lprobs[:, self.pad] = -math.inf # never select pad + lprobs[:, self.unk] -= self.unk_penalty # apply unk penalty + + # handle max length constraint + if step >= max_len: + lprobs[:, : self.eos] = -math.inf + lprobs[:, self.eos + 1 :] = -math.inf + + # handle prefix tokens (possibly with different lengths) + if ( + prefix_tokens is not None + and step < prefix_tokens.size(1) + and step < max_len + ): + lprobs, tokens, scores = self._prefix_tokens( + step, lprobs, scores, tokens, prefix_tokens, beam_size + ) + else: + if step < self.min_len: + # minimum length constraint (does not apply if using prefix_tokens) + lprobs[:, self.eos] = -math.inf + + if self.token_indices_to_suppress is not None: + lprobs[:, self.token_indices_to_suppress] = -math.inf + + # Record attention scores, only support avg_attn_scores is a Tensor + if avg_attn_scores is not None: + if attn is None: + attn = torch.empty( + bsz * beam_size, avg_attn_scores.size(1), max_len + 2 + ).to(scores) + attn[:, :, step + 1].copy_(avg_attn_scores) + + scores = scores.type_as(lprobs) + eos_bbsz_idx = torch.empty(0).to( + tokens + ) # indices of hypothesis ending with eos (finished sentences) + eos_scores = torch.empty(0).to( + scores + ) # scores of hypothesis ending with eos (finished sentences) + + if self.should_set_src_lengths: + self.search.set_src_lengths(src_lengths) + + if self.repeat_ngram_blocker is not None: + lprobs = self.repeat_ngram_blocker(tokens, lprobs, bsz, beam_size, step) + + # Shape: (batch, cand_size) + cand_scores, cand_indices, cand_beams = self.search.step( + step, + lprobs.view(bsz, -1, self.vocab_size), + scores.view(bsz, beam_size, -1)[:, :, :step], + tokens[:, : step + 1], + original_batch_idxs, + ) + + # cand_bbsz_idx contains beam indices for the top candidate + # hypotheses, with a range of values: [0, bsz*beam_size), + # and dimensions: [bsz, cand_size] + cand_bbsz_idx = cand_beams.add(bbsz_offsets) + + # finalize hypotheses that end in eos + # Shape of eos_mask: (batch size, beam size) + eos_mask = cand_indices.eq(self.eos) & cand_scores.ne(-math.inf) + eos_mask[:, :beam_size][cands_to_ignore] = torch.tensor(0).to(eos_mask) + + # only consider eos when it's among the top beam_size indices + # Now we know what beam item(s) to finish + # Shape: 1d list of absolute-numbered + eos_bbsz_idx = torch.masked_select( + cand_bbsz_idx[:, :beam_size], mask=eos_mask[:, :beam_size] + ) + + finalized_sents: List[int] = [] + if eos_bbsz_idx.numel() > 0: + eos_scores = torch.masked_select( + cand_scores[:, :beam_size], mask=eos_mask[:, :beam_size] + ) + + finalized_sents = self.finalize_hypos( + step, + eos_bbsz_idx, + eos_scores, + tokens, + scores, + finalized, + finished, + beam_size, + attn, + src_lengths, + max_len, + ) + num_remaining_sent -= len(finalized_sents) + + assert num_remaining_sent >= 0 + if num_remaining_sent == 0: + break + if self.search.stop_on_max_len and step >= max_len: + break + assert step < max_len, f"{step} < {max_len}" + + # Remove finalized sentences (ones for which {beam_size} + # finished hypotheses have been generated) from the batch. + if len(finalized_sents) > 0: + new_bsz = bsz - len(finalized_sents) + + # construct batch_idxs which holds indices of batches to keep for the next pass + batch_mask = torch.ones( + bsz, dtype=torch.bool, device=cand_indices.device + ) + batch_mask[finalized_sents] = False + # TODO replace `nonzero(as_tuple=False)` after TorchScript supports it + batch_idxs = torch.arange( + bsz, device=cand_indices.device + ).masked_select(batch_mask) + + # Choose the subset of the hypothesized constraints that will continue + self.search.prune_sentences(batch_idxs) + + eos_mask = eos_mask[batch_idxs] + cand_beams = cand_beams[batch_idxs] + bbsz_offsets.resize_(new_bsz, 1) + cand_bbsz_idx = cand_beams.add(bbsz_offsets) + cand_scores = cand_scores[batch_idxs] + cand_indices = cand_indices[batch_idxs] + + if prefix_tokens is not None: + prefix_tokens = prefix_tokens[batch_idxs] + src_lengths = src_lengths[batch_idxs] + cands_to_ignore = cands_to_ignore[batch_idxs] + + scores = scores.view(bsz, -1)[batch_idxs].view(new_bsz * beam_size, -1) + tokens = tokens.view(bsz, -1)[batch_idxs].view(new_bsz * beam_size, -1) + if attn is not None: + attn = attn.view(bsz, -1)[batch_idxs].view( + new_bsz * beam_size, attn.size(1), -1 + ) + bsz = new_bsz + else: + batch_idxs = None + + # Set active_mask so that values > cand_size indicate eos hypos + # and values < cand_size indicate candidate active hypos. + # After, the min values per row are the top candidate active hypos + + # Rewrite the operator since the element wise or is not supported in torchscript. + + eos_mask[:, :beam_size] = ~((~cands_to_ignore) & (~eos_mask[:, :beam_size])) + active_mask = torch.add( + eos_mask.type_as(cand_offsets) * cand_size, + cand_offsets[: eos_mask.size(1)], + ) + + # get the top beam_size active hypotheses, which are just + # the hypos with the smallest values in active_mask. + # {active_hypos} indicates which {beam_size} hypotheses + # from the list of {2 * beam_size} candidates were + # selected. Shapes: (batch size, beam size) + new_cands_to_ignore, active_hypos = torch.topk( + active_mask, k=beam_size, dim=1, largest=False + ) + + # update cands_to_ignore to ignore any finalized hypos. + cands_to_ignore = new_cands_to_ignore.ge(cand_size)[:, :beam_size] + # Make sure there is at least one active item for each sentence in the batch. + assert (~cands_to_ignore).any(dim=1).all() + + # update cands_to_ignore to ignore any finalized hypos + + # {active_bbsz_idx} denotes which beam number is continued for each new hypothesis (a beam + # can be selected more than once). + active_bbsz_idx = torch.gather(cand_bbsz_idx, dim=1, index=active_hypos) + active_scores = torch.gather(cand_scores, dim=1, index=active_hypos) + + active_bbsz_idx = active_bbsz_idx.view(-1) + active_scores = active_scores.view(-1) + + # copy tokens and scores for active hypotheses + + # Set the tokens for each beam (can select the same row more than once) + tokens[:, : step + 1] = torch.index_select( + tokens[:, : step + 1], dim=0, index=active_bbsz_idx + ) + # Select the next token for each of them + tokens.view(bsz, beam_size, -1)[:, :, step + 1] = torch.gather( + cand_indices, dim=1, index=active_hypos + ) + if step > 0: + scores[:, :step] = torch.index_select( + scores[:, :step], dim=0, index=active_bbsz_idx + ) + scores.view(bsz, beam_size, -1)[:, :, step] = torch.gather( + cand_scores, dim=1, index=active_hypos + ) + + # Update constraints based on which candidates were selected for the next beam + self.search.update_constraints(active_hypos) + + # copy attention for active hypotheses + if attn is not None: + attn[:, :, : step + 2] = torch.index_select( + attn[:, :, : step + 2], dim=0, index=active_bbsz_idx + ) + + # reorder incremental state in decoder + reorder_state = active_bbsz_idx + + # sort by score descending + for sent in range(len(finalized)): + scores = torch.tensor( + [float(elem["score"].item()) for elem in finalized[sent]] + ) + _, sorted_scores_indices = torch.sort(scores, descending=True) + finalized[sent] = [finalized[sent][ssi] for ssi in sorted_scores_indices] + finalized[sent] = torch.jit.annotate( + List[Dict[str, Tensor]], finalized[sent] + ) + return finalized + + def _prefix_tokens( + self, step: int, lprobs, scores, tokens, prefix_tokens, beam_size: int + ): + """Handle prefix tokens""" + prefix_toks = prefix_tokens[:, step].unsqueeze(-1).repeat(1, beam_size).view(-1) + prefix_lprobs = lprobs.gather(-1, prefix_toks.unsqueeze(-1)) + prefix_mask = prefix_toks.ne(self.pad) + lprobs[prefix_mask] = torch.tensor(-math.inf).to(lprobs) + lprobs[prefix_mask] = lprobs[prefix_mask].scatter( + -1, prefix_toks[prefix_mask].unsqueeze(-1), prefix_lprobs[prefix_mask] + ) + # if prefix includes eos, then we should make sure tokens and + # scores are the same across all beams + eos_mask = prefix_toks.eq(self.eos) + if eos_mask.any(): + # validate that the first beam matches the prefix + first_beam = tokens[eos_mask].view(-1, beam_size, tokens.size(-1))[ + :, 0, 1 : step + 1 + ] + eos_mask_batch_dim = eos_mask.view(-1, beam_size)[:, 0] + target_prefix = prefix_tokens[eos_mask_batch_dim][:, :step] + assert (first_beam == target_prefix).all() + + # copy tokens, scores and lprobs from the first beam to all beams + tokens = self.replicate_first_beam(tokens, eos_mask_batch_dim, beam_size) + scores = self.replicate_first_beam(scores, eos_mask_batch_dim, beam_size) + lprobs = self.replicate_first_beam(lprobs, eos_mask_batch_dim, beam_size) + return lprobs, tokens, scores + + def replicate_first_beam(self, tensor, mask, beam_size: int): + tensor = tensor.view(-1, beam_size, tensor.size(-1)) + tensor[mask] = tensor[mask][:, :1, :] + return tensor.view(-1, tensor.size(-1)) + + def finalize_hypos( + self, + step: int, + bbsz_idx, + eos_scores, + tokens, + scores, + finalized: List[List[Dict[str, Tensor]]], + finished: List[bool], + beam_size: int, + attn: Optional[Tensor], + src_lengths, + max_len: int, + ): + """Finalize hypothesis, store finalized information in `finalized`, and change `finished` accordingly. + A sentence is finalized when {beam_size} finished items have been collected for it. + + Returns number of sentences (not beam items) being finalized. + These will be removed from the batch and not processed further. + Args: + bbsz_idx (Tensor): + """ + assert bbsz_idx.numel() == eos_scores.numel() + + # clone relevant token and attention tensors. + # tokens is (batch * beam, max_len). So the index_select + # gets the newly EOS rows, then selects cols 1..{step + 2} + tokens_clone = tokens.index_select(0, bbsz_idx)[ + :, 1 : step + 2 + ] # skip the first index, which is EOS + + tokens_clone[:, step] = self.eos + attn_clone = ( + attn.index_select(0, bbsz_idx)[:, :, 1 : step + 2] + if attn is not None + else None + ) + + # compute scores per token position + pos_scores = scores.index_select(0, bbsz_idx)[:, : step + 1] + pos_scores[:, step] = eos_scores + # convert from cumulative to per-position scores + pos_scores[:, 1:] = pos_scores[:, 1:] - pos_scores[:, :-1] + + # normalize sentence-level scores + if self.normalize_scores: + eos_scores /= (step + 1) ** self.len_penalty + + # cum_unfin records which sentences in the batch are finished. + # It helps match indexing between (a) the original sentences + # in the batch and (b) the current, possibly-reduced set of + # sentences. + cum_unfin: List[int] = [] + prev = 0 + for f in finished: + if f: + prev += 1 + else: + cum_unfin.append(prev) + cum_fin_tensor = torch.tensor(cum_unfin, dtype=torch.int).to(bbsz_idx) + + unfin_idx = torch.div(bbsz_idx, beam_size, rounding_mode="trunc") + sent = unfin_idx + torch.index_select(cum_fin_tensor, 0, unfin_idx) + + # Create a set of "{sent}{unfin_idx}", where + # "unfin_idx" is the index in the current (possibly reduced) + # list of sentences, and "sent" is the index in the original, + # unreduced batch + # For every finished beam item + # sentence index in the current (possibly reduced) batch + seen = (sent << 32) + unfin_idx + unique_seen: List[int] = torch.unique(seen).tolist() + + if self.match_source_len: + condition = step > torch.index_select(src_lengths, 0, unfin_idx) + eos_scores = torch.where(condition, torch.tensor(-math.inf), eos_scores) + sent_list: List[int] = sent.tolist() + for i in range(bbsz_idx.size()[0]): + # An input sentence (among those in a batch) is finished when + # beam_size hypotheses have been collected for it + if len(finalized[sent_list[i]]) < beam_size: + if attn_clone is not None: + # remove padding tokens from attn scores + hypo_attn = attn_clone[i] + else: + hypo_attn = torch.empty(0) + + finalized[sent_list[i]].append( + { + "tokens": tokens_clone[i], + "score": eos_scores[i], + "attention": hypo_attn, # src_len x tgt_len + "alignment": torch.empty(0), + "positional_scores": pos_scores[i], + } + ) + + newly_finished: List[int] = [] + for unique_s in unique_seen: + # check termination conditions for this sentence + unique_sent: int = unique_s >> 32 + unique_unfin_idx: int = unique_s - (unique_sent << 32) + + if not finished[unique_sent] and self.is_finished( + step, unique_unfin_idx, max_len, len(finalized[unique_sent]), beam_size + ): + finished[unique_sent] = True + newly_finished.append(unique_unfin_idx) + + return newly_finished + + def is_finished( + self, + step: int, + unfin_idx: int, + max_len: int, + finalized_sent_len: int, + beam_size: int, + ): + """ + Check whether decoding for a sentence is finished, which + occurs when the list of finalized sentences has reached the + beam size, or when we reach the maximum length. + """ + assert finalized_sent_len <= beam_size + if finalized_sent_len == beam_size or step == max_len: + return True + return False + + +class EnsembleModel(nn.Module): + """A wrapper around an ensemble of models.""" + + def __init__(self, models): + super().__init__() + self.models_size = len(models) + # method '__len__' is not supported in ModuleList for torch script + self.single_model = models[0] + self.models = nn.ModuleList(models) + + self.has_incremental: bool = False + if all( + hasattr(m, "decoder") and isinstance(m.decoder, FairseqIncrementalDecoder) + for m in models + ): + self.has_incremental = True + + def forward(self): + pass + + def has_encoder(self): + return hasattr(self.single_model, "encoder") + + def has_incremental_states(self): + return self.has_incremental + + def max_decoder_positions(self): + return min( + [ + m.max_decoder_positions() + for m in self.models + if hasattr(m, "max_decoder_positions") + ] + + [sys.maxsize] + ) + + def set_decoder_beam_size(self, beam_size): + """Set beam size for efficient beamable enc-dec attention.""" + if beam_size > 1: + for model in self.models: + if hasattr(model, "set_beam_size"): + model.set_beam_size(beam_size) + + @torch.jit.export + def forward_encoder(self, net_input: Dict[str, Tensor]): + if not self.has_encoder(): + return None + return [model.encoder.forward_torchscript(net_input) for model in self.models] + + @torch.jit.export + def forward_decoder( + self, + tokens, + encoder_outs: List[Dict[str, List[Tensor]]], + incremental_states: List[Dict[str, Dict[str, Optional[Tensor]]]], + temperature: float = 1.0, + ): + log_probs = [] + avg_attn: Optional[Tensor] = None + encoder_out: Optional[Dict[str, List[Tensor]]] = None + for i, model in enumerate(self.models): + if self.has_encoder(): + encoder_out = encoder_outs[i] + # decode each model + if self.has_incremental_states(): + decoder_out = model.decoder.forward( + tokens, + encoder_out=encoder_out, + incremental_state=incremental_states[i], + ) + else: + if hasattr(model, "decoder"): + decoder_out = model.decoder.forward(tokens, encoder_out=encoder_out) + else: + decoder_out = model.forward(tokens) + + attn: Optional[Tensor] = None + decoder_len = len(decoder_out) + if decoder_len > 1 and decoder_out[1] is not None: + if isinstance(decoder_out[1], Tensor): + attn = decoder_out[1] + else: + attn_holder = decoder_out[1]["attn"] + if isinstance(attn_holder, Tensor): + attn = attn_holder + elif attn_holder is not None: + attn = attn_holder[0] + if attn is not None: + attn = attn[:, -1, :] + + decoder_out_tuple = ( + decoder_out[0][:, -1:, :].div_(temperature), + None if decoder_len <= 1 else decoder_out[1], + ) + probs = model.get_normalized_probs( + decoder_out_tuple, log_probs=True, sample=None + ) + probs = probs[:, -1, :] + if self.models_size == 1: + return probs, attn + + log_probs.append(probs) + if attn is not None: + if avg_attn is None: + avg_attn = attn + else: + avg_attn.add_(attn) + + avg_probs = torch.logsumexp(torch.stack(log_probs, dim=0), dim=0) - math.log( + self.models_size + ) + + if avg_attn is not None: + avg_attn.div_(self.models_size) + return avg_probs, avg_attn + + @torch.jit.export + def reorder_encoder_out( + self, encoder_outs: Optional[List[Dict[str, List[Tensor]]]], new_order + ): + """ + Reorder encoder output according to *new_order*. + + Args: + encoder_out: output from the ``forward()`` method + new_order (LongTensor): desired order + + Returns: + *encoder_out* rearranged according to *new_order* + """ + new_outs: List[Dict[str, List[Tensor]]] = [] + if not self.has_encoder(): + return new_outs + for i, model in enumerate(self.models): + assert encoder_outs is not None + new_outs.append( + model.encoder.reorder_encoder_out(encoder_outs[i], new_order) + ) + return new_outs + + @torch.jit.export + def reorder_incremental_state( + self, + incremental_states: List[Dict[str, Dict[str, Optional[Tensor]]]], + new_order, + ): + if not self.has_incremental_states(): + return + for i, model in enumerate(self.models): + model.decoder.reorder_incremental_state_scripting( + incremental_states[i], new_order + ) + + +class SequenceGeneratorWithAlignment(SequenceGenerator): + def __init__( + self, models, tgt_dict, left_pad_target=False, print_alignment="hard", **kwargs + ): + """Generates translations of a given source sentence. + + Produces alignments following "Jointly Learning to Align and + Translate with Transformer Models" (Garg et al., EMNLP 2019). + + Args: + left_pad_target (bool, optional): Whether or not the + hypothesis should be left padded or not when they are + teacher forced for generating alignments. + """ + super().__init__(EnsembleModelWithAlignment(models), tgt_dict, **kwargs) + self.left_pad_target = left_pad_target + + if print_alignment == "hard": + self.extract_alignment = utils.extract_hard_alignment + elif print_alignment == "soft": + self.extract_alignment = utils.extract_soft_alignment + + @torch.no_grad() + def generate(self, models, sample, **kwargs): + finalized = super()._generate(sample, **kwargs) + + src_tokens = sample["net_input"]["src_tokens"] + bsz = src_tokens.shape[0] + beam_size = self.beam_size + ( + src_tokens, + src_lengths, + prev_output_tokens, + tgt_tokens, + ) = self._prepare_batch_for_alignment(sample, finalized) + if any(getattr(m, "full_context_alignment", False) for m in self.model.models): + attn = self.model.forward_align(src_tokens, src_lengths, prev_output_tokens) + else: + attn = [ + finalized[i // beam_size][i % beam_size]["attention"].transpose(1, 0) + for i in range(bsz * beam_size) + ] + + if src_tokens.device != "cpu": + src_tokens = src_tokens.to("cpu") + tgt_tokens = tgt_tokens.to("cpu") + attn = [i.to("cpu") for i in attn] + + # Process the attn matrix to extract hard alignments. + for i in range(bsz * beam_size): + alignment = self.extract_alignment( + attn[i], src_tokens[i], tgt_tokens[i], self.pad, self.eos + ) + finalized[i // beam_size][i % beam_size]["alignment"] = alignment + return finalized + + def _prepare_batch_for_alignment(self, sample, hypothesis): + src_tokens = sample["net_input"]["src_tokens"] + bsz = src_tokens.shape[0] + src_tokens = ( + src_tokens[:, None, :] + .expand(-1, self.beam_size, -1) + .contiguous() + .view(bsz * self.beam_size, -1) + ) + src_lengths = sample["net_input"]["src_lengths"] + src_lengths = ( + src_lengths[:, None] + .expand(-1, self.beam_size) + .contiguous() + .view(bsz * self.beam_size) + ) + prev_output_tokens = data_utils.collate_tokens( + [beam["tokens"] for example in hypothesis for beam in example], + self.pad, + self.eos, + self.left_pad_target, + move_eos_to_beginning=True, + ) + tgt_tokens = data_utils.collate_tokens( + [beam["tokens"] for example in hypothesis for beam in example], + self.pad, + self.eos, + self.left_pad_target, + move_eos_to_beginning=False, + ) + return src_tokens, src_lengths, prev_output_tokens, tgt_tokens + + +class EnsembleModelWithAlignment(EnsembleModel): + """A wrapper around an ensemble of models.""" + + def __init__(self, models): + super().__init__(models) + + def forward_align(self, src_tokens, src_lengths, prev_output_tokens): + avg_attn = None + for model in self.models: + decoder_out = model(src_tokens, src_lengths, prev_output_tokens) + attn = decoder_out[1]["attn"][0] + if avg_attn is None: + avg_attn = attn + else: + avg_attn.add_(attn) + if len(self.models) > 1: + avg_attn.div_(len(self.models)) + return avg_attn diff --git a/fairseq/fairseq/sequence_scorer.py b/fairseq/fairseq/sequence_scorer.py new file mode 100644 index 0000000000000000000000000000000000000000..411d4df4445ef8dd3f1907ad56f9de6943d1fed8 --- /dev/null +++ b/fairseq/fairseq/sequence_scorer.py @@ -0,0 +1,153 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import sys + +import torch +from fairseq import utils + + +class SequenceScorer(object): + """Scores the target for a given source sentence.""" + + def __init__( + self, + tgt_dict, + softmax_batch=None, + compute_alignment=False, + eos=None, + symbols_to_strip_from_output=None, + ): + self.pad = tgt_dict.pad() + self.eos = tgt_dict.eos() if eos is None else eos + self.softmax_batch = softmax_batch or sys.maxsize + assert self.softmax_batch > 0 + self.compute_alignment = compute_alignment + self.symbols_to_strip_from_output = ( + symbols_to_strip_from_output.union({self.eos}) + if symbols_to_strip_from_output is not None + else {self.eos} + ) + + @torch.no_grad() + def generate(self, models, sample, **kwargs): + """Score a batch of translations.""" + net_input = sample["net_input"] + + def batch_for_softmax(dec_out, target): + # assumes decoder_out[0] is the only thing needed (may not be correct for future models!) + first, rest = dec_out[0], dec_out[1:] + bsz, tsz, dim = first.shape + if bsz * tsz < self.softmax_batch: + yield dec_out, target, True + else: + flat = first.contiguous().view(1, -1, dim) + flat_tgt = target.contiguous().view(flat.shape[:-1]) + s = 0 + while s < flat.size(1): + e = s + self.softmax_batch + yield (flat[:, s:e],) + rest, flat_tgt[:, s:e], False + s = e + + def gather_target_probs(probs, target): + probs = probs.gather( + dim=2, + index=target.unsqueeze(-1), + ) + return probs + + orig_target = sample["target"] + + # compute scores for each model in the ensemble + avg_probs = None + avg_attn = None + for model in models: + model.eval() + decoder_out = model(**net_input) + attn = decoder_out[1] if len(decoder_out) > 1 else None + if type(attn) is dict: + attn = attn.get("attn", None) + + batched = batch_for_softmax(decoder_out, orig_target) + probs, idx = None, 0 + for bd, tgt, is_single in batched: + sample["target"] = tgt + curr_prob = model.get_normalized_probs( + bd, log_probs=len(models) == 1, sample=sample + ).data + if is_single: + probs = gather_target_probs(curr_prob, orig_target) + else: + if probs is None: + probs = curr_prob.new(orig_target.numel()) + step = curr_prob.size(0) * curr_prob.size(1) + end = step + idx + tgt_probs = gather_target_probs( + curr_prob.view(tgt.shape + (curr_prob.size(-1),)), tgt + ) + probs[idx:end] = tgt_probs.view(-1) + idx = end + sample["target"] = orig_target + + probs = probs.view(sample["target"].shape) + + if avg_probs is None: + avg_probs = probs + else: + avg_probs.add_(probs) + if attn is not None: + if torch.is_tensor(attn): + attn = attn.data + else: + attn = attn[0] + if avg_attn is None: + avg_attn = attn + else: + avg_attn.add_(attn) + if len(models) > 1: + avg_probs.div_(len(models)) + avg_probs.log_() + if avg_attn is not None: + avg_attn.div_(len(models)) + + bsz = avg_probs.size(0) + hypos = [] + start_idxs = sample["start_indices"] if "start_indices" in sample else [0] * bsz + for i in range(bsz): + # remove padding from ref + ref = ( + utils.strip_pad(sample["target"][i, start_idxs[i] :], self.pad) + if sample["target"] is not None + else None + ) + tgt_len = ref.numel() + avg_probs_i = avg_probs[i][start_idxs[i] : start_idxs[i] + tgt_len] + score_i = avg_probs_i.sum() / tgt_len + if avg_attn is not None: + avg_attn_i = avg_attn[i] + if self.compute_alignment: + alignment = utils.extract_hard_alignment( + avg_attn_i, + sample["net_input"]["src_tokens"][i], + sample["target"][i], + self.pad, + self.eos, + ) + else: + alignment = None + else: + avg_attn_i = alignment = None + hypos.append( + [ + { + "tokens": ref, + "score": score_i, + "attention": avg_attn_i, + "alignment": alignment, + "positional_scores": avg_probs_i, + } + ] + ) + return hypos diff --git a/fairseq/fairseq/speech_generator.py b/fairseq/fairseq/speech_generator.py new file mode 100644 index 0000000000000000000000000000000000000000..f2cc8b5e86377d74515e477313eee3864a01d812 --- /dev/null +++ b/fairseq/fairseq/speech_generator.py @@ -0,0 +1,427 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import numpy as np +import torch + +from fairseq.data.audio.speech_to_text_dataset import S2TDataConfig + + +class SpeechGenerator(object): + def __init__(self, model, vocoder, data_cfg: S2TDataConfig): + self.model = model + self.vocoder = vocoder + stats_npz_path = data_cfg.global_cmvn_stats_npz + self.gcmvn_stats = None + if stats_npz_path is not None: + self.gcmvn_stats = np.load(stats_npz_path) + + def gcmvn_denormalize(self, x): + # x: B x T x C + if self.gcmvn_stats is None: + return x + mean = torch.from_numpy(self.gcmvn_stats["mean"]).to(x) + std = torch.from_numpy(self.gcmvn_stats["std"]).to(x) + assert len(x.shape) == 3 and mean.shape[0] == std.shape[0] == x.shape[2] + x = x * std.view(1, 1, -1).expand_as(x) + return x + mean.view(1, 1, -1).expand_as(x) + + def get_waveform(self, feat): + # T x C -> T + return None if self.vocoder is None else self.vocoder(feat).squeeze(0) + + +class AutoRegressiveSpeechGenerator(SpeechGenerator): + def __init__( + self, + model, + vocoder, + data_cfg, + max_iter: int = 6000, + eos_prob_threshold: float = 0.5, + ): + super().__init__(model, vocoder, data_cfg) + self.max_iter = max_iter + self.eos_prob_threshold = eos_prob_threshold + + @torch.no_grad() + def generate(self, model, sample, has_targ=False, **kwargs): + model.eval() + + src_tokens = sample["net_input"]["src_tokens"] + src_lengths = sample["net_input"]["src_lengths"] + bsz, src_len = src_tokens.size()[:2] + n_frames_per_step = model.decoder.n_frames_per_step + out_dim = model.decoder.out_dim + raw_dim = out_dim // n_frames_per_step + + # initialize + encoder_out = model.forward_encoder( + src_tokens, src_lengths, speaker=sample["speaker"] + ) + incremental_state = {} + feat, attn, eos_prob = [], [], [] + finished = src_tokens.new_zeros((bsz,)).bool() + out_lens = src_lengths.new_zeros((bsz,)).long().fill_(self.max_iter) + + prev_feat_out = encoder_out["encoder_out"][0].new_zeros(bsz, 1, out_dim) + for step in range(self.max_iter): + cur_out_lens = out_lens.clone() + cur_out_lens.masked_fill_(cur_out_lens.eq(self.max_iter), step + 1) + _, cur_eos_out, cur_extra = model.forward_decoder( + prev_feat_out, + encoder_out=encoder_out, + incremental_state=incremental_state, + target_lengths=cur_out_lens, + speaker=sample["speaker"], + **kwargs, + ) + cur_eos_prob = torch.sigmoid(cur_eos_out).squeeze(2) + feat.append(cur_extra["feature_out"]) + attn.append(cur_extra["attn"]) + eos_prob.append(cur_eos_prob) + + cur_finished = cur_eos_prob.squeeze(1) > self.eos_prob_threshold + out_lens.masked_fill_((~finished) & cur_finished, step + 1) + finished = finished | cur_finished + if finished.sum().item() == bsz: + break + prev_feat_out = cur_extra["feature_out"] + + feat = torch.cat(feat, dim=1) + feat = model.decoder.postnet(feat) + feat + eos_prob = torch.cat(eos_prob, dim=1) + attn = torch.cat(attn, dim=2) + alignment = attn.max(dim=1)[1] + + feat = feat.reshape(bsz, -1, raw_dim) + feat = self.gcmvn_denormalize(feat) + + eos_prob = eos_prob.repeat_interleave(n_frames_per_step, dim=1) + attn = attn.repeat_interleave(n_frames_per_step, dim=2) + alignment = alignment.repeat_interleave(n_frames_per_step, dim=1) + out_lens = out_lens * n_frames_per_step + + finalized = [ + { + "feature": feat[b, :out_len], + "eos_prob": eos_prob[b, :out_len], + "attn": attn[b, :, :out_len], + "alignment": alignment[b, :out_len], + "waveform": self.get_waveform(feat[b, :out_len]), + } + for b, out_len in zip(range(bsz), out_lens) + ] + + if has_targ: + assert sample["target"].size(-1) == out_dim + tgt_feats = sample["target"].view(bsz, -1, raw_dim) + tgt_feats = self.gcmvn_denormalize(tgt_feats) + tgt_lens = sample["target_lengths"] * n_frames_per_step + for b, (f, l) in enumerate(zip(tgt_feats, tgt_lens)): + finalized[b]["targ_feature"] = f[:l] + finalized[b]["targ_waveform"] = self.get_waveform(f[:l]) + return finalized + + +class MultiDecoderSpeechGenerator(SpeechGenerator): + def __init__( + self, + models, + args, + vocoder, + data_cfg, + tgt_dict_mt, + max_iter: int = 6000, + eos_prob_threshold: float = 0.5, + eos_mt=None, + symbols_to_strip_from_output=None, + ): + super().__init__(models[0], vocoder, data_cfg) + self.max_iter = max_iter + self.eos_prob_threshold = eos_prob_threshold + + self.tgt_dict_mt = tgt_dict_mt + self.eos_mt = eos_mt + + from examples.speech_to_speech.unity.sequence_generator import SequenceGenerator + from fairseq import search + + self.text_generator = SequenceGenerator( + models, + tgt_dict_mt, + beam_size=max(1, getattr(args, "beam", 5)), + max_len_a=getattr(args, "max_len_a", 0), + max_len_b=getattr(args, "max_len_b", 200), + min_len=getattr(args, "min_len", 1), + normalize_scores=(not getattr(args, "unnormalized", False)), + len_penalty=getattr(args, "lenpen", 1), + unk_penalty=getattr(args, "unkpen", 0), + temperature=getattr(args, "temperature", 1.0), + match_source_len=getattr(args, "match_source_len", False), + no_repeat_ngram_size=getattr(args, "no_repeat_ngram_size", 0), + search_strategy=search.BeamSearch(tgt_dict_mt), + eos=eos_mt, + symbols_to_strip_from_output=symbols_to_strip_from_output, + ) + + @torch.no_grad() + def generate(self, model, sample, has_targ=False, **kwargs): + model.eval() + + src_tokens = sample["net_input"]["src_tokens"] + src_lengths = sample["net_input"]["src_lengths"] + bsz, src_len = src_tokens.size()[:2] + n_frames_per_step = model.decoder.n_frames_per_step + out_dim = model.decoder.out_dim + raw_dim = out_dim // n_frames_per_step + + # initialize + encoder_out = model.forward_encoder( + src_tokens, src_lengths, speaker=sample["speaker"] + ) + + prefix_tokens = None + constraints = None + bos_token = None + + mt_decoder = getattr(model, f"{model.mt_task_name}_decoder") + + # 1. MT decoder + finalized_mt = self.text_generator.generate_decoder( + [encoder_out], + src_tokens, + src_lengths, + sample, + prefix_tokens, + constraints, + bos_token, + aux_task_name=model.mt_task_name, + ) + + # extract decoder output corresponding to the best hypothesis + max_tgt_len = max([len(hypo[0]["tokens"]) for hypo in finalized_mt]) + prev_output_tokens_mt = ( + src_tokens.new_zeros(src_tokens.shape[0], max_tgt_len) + .fill_(mt_decoder.padding_idx) + .int() + ) # B x T + for i, hypo in enumerate(finalized_mt): + i_beam = 0 + tmp = hypo[i_beam]["tokens"].int() # hyp + eos + prev_output_tokens_mt[i, 0] = self.text_generator.eos + if tmp[-1] == self.text_generator.eos: + tmp = tmp[:-1] + prev_output_tokens_mt[i, 1 : len(tmp) + 1] = tmp + + text = "".join([self.tgt_dict_mt[c] for c in tmp]) + text = text.replace("_", " ") + text = text.replace("▁", " ") + text = text.replace("", " ") + text = text.replace("", "") + text = text.replace("", "") + if len(text) > 0 and text[0] == " ": + text = text[1:] + sample_id = sample["id"].tolist()[i] + print("{} (None-{})".format(text, sample_id)) + + mt_decoder_out = mt_decoder( + prev_output_tokens_mt, + encoder_out=encoder_out, + features_only=True, + ) + x = mt_decoder_out[0].transpose(0, 1) + + mt_decoder_padding_mask = None + if prev_output_tokens_mt.eq(mt_decoder.padding_idx).any(): + mt_decoder_padding_mask = prev_output_tokens_mt.eq(mt_decoder.padding_idx) + + # 2. TTS encoder + if getattr(model, "synthesizer_encoder", None) is not None: + synthesizer_encoder_out = model.synthesizer_encoder( + x, + mt_decoder_padding_mask, + ) + else: + synthesizer_encoder_out = { + "encoder_out": [x], # T x B x C + "encoder_padding_mask": [mt_decoder_padding_mask] + if mt_decoder_padding_mask is not None + else [], # B x T + "encoder_embedding": [], + "encoder_states": [], + "src_tokens": [], + "src_lengths": [], + } + + # 3. TTS decoder + incremental_state = {} + feat, attn, eos_prob = [], [], [] + finished = src_tokens.new_zeros((bsz,)).bool() + out_lens = src_lengths.new_zeros((bsz,)).long().fill_(self.max_iter) + + prev_feat_out = encoder_out["encoder_out"][0].new_zeros(bsz, 1, out_dim) + for step in range(self.max_iter): + cur_out_lens = out_lens.clone() + cur_out_lens.masked_fill_(cur_out_lens.eq(self.max_iter), step + 1) + _, cur_eos_out, cur_extra = model.forward_decoder( + prev_feat_out, + encoder_out=synthesizer_encoder_out, + incremental_state=incremental_state, + target_lengths=cur_out_lens, + speaker=sample["speaker"], + **kwargs, + ) + cur_eos_prob = torch.sigmoid(cur_eos_out).squeeze(2) + feat.append(cur_extra["feature_out"]) + attn.append(cur_extra["attn"]) + eos_prob.append(cur_eos_prob) + + cur_finished = cur_eos_prob.squeeze(1) > self.eos_prob_threshold + out_lens.masked_fill_((~finished) & cur_finished, step + 1) + finished = finished | cur_finished + if finished.sum().item() == bsz: + break + prev_feat_out = cur_extra["feature_out"] + + feat = torch.cat(feat, dim=1) + feat = model.decoder.postnet(feat) + feat + eos_prob = torch.cat(eos_prob, dim=1) + attn = torch.cat(attn, dim=2) + alignment = attn.max(dim=1)[1] + + feat = feat.reshape(bsz, -1, raw_dim) + feat = self.gcmvn_denormalize(feat) + + eos_prob = eos_prob.repeat_interleave(n_frames_per_step, dim=1) + attn = attn.repeat_interleave(n_frames_per_step, dim=2) + alignment = alignment.repeat_interleave(n_frames_per_step, dim=1) + out_lens = out_lens * n_frames_per_step + + finalized = [ + { + "feature": feat[b, :out_len], + "eos_prob": eos_prob[b, :out_len], + "attn": attn[b, :, :out_len], + "alignment": alignment[b, :out_len], + "waveform": self.get_waveform(feat[b, :out_len]), + } + for b, out_len in zip(range(bsz), out_lens) + ] + + if has_targ: + assert sample["target"].size(-1) == out_dim + tgt_feats = sample["target"].view(bsz, -1, raw_dim) + tgt_feats = self.gcmvn_denormalize(tgt_feats) + tgt_lens = sample["target_lengths"] * n_frames_per_step + for b, (f, l) in enumerate(zip(tgt_feats, tgt_lens)): + finalized[b]["targ_feature"] = f[:l] + finalized[b]["targ_waveform"] = self.get_waveform(f[:l]) + return finalized + + +class NonAutoregressiveSpeechGenerator(SpeechGenerator): + @torch.no_grad() + def generate(self, model, sample, has_targ=False, **kwargs): + model.eval() + + bsz, max_src_len = sample["net_input"]["src_tokens"].size() + n_frames_per_step = model.encoder.n_frames_per_step + out_dim = model.encoder.out_dim + raw_dim = out_dim // n_frames_per_step + + feat, feat_post, out_lens, log_dur_out, _, _ = model( + src_tokens=sample["net_input"]["src_tokens"], + src_lengths=sample["net_input"]["src_lengths"], + prev_output_tokens=sample["net_input"]["prev_output_tokens"], + incremental_state=None, + target_lengths=sample["target_lengths"], + speaker=sample["speaker"], + ) + if feat_post is not None: + feat = feat_post + + feat = feat.view(bsz, -1, raw_dim) + feat = self.gcmvn_denormalize(feat) + + dur_out = torch.clamp(torch.round(torch.exp(log_dur_out) - 1).long(), min=0) + + def get_dur_plot_data(d): + r = [] + for i, dd in enumerate(d): + r += [i + 1] * dd.item() + return r + + out_lens = out_lens * n_frames_per_step + finalized = [ + { + "feature": feat[b, :l] if l > 0 else feat.new_zeros([1, raw_dim]), + "waveform": self.get_waveform( + feat[b, :l] if l > 0 else feat.new_zeros([1, raw_dim]) + ), + "attn": feat.new_tensor(get_dur_plot_data(dur_out[b])), + } + for b, l in zip(range(bsz), out_lens) + ] + + if has_targ: + tgt_feats = sample["target"].view(bsz, -1, raw_dim) + tgt_feats = self.gcmvn_denormalize(tgt_feats) + tgt_lens = sample["target_lengths"] * n_frames_per_step + for b, (f, l) in enumerate(zip(tgt_feats, tgt_lens)): + finalized[b]["targ_feature"] = f[:l] + finalized[b]["targ_waveform"] = self.get_waveform(f[:l]) + return finalized + + +class TeacherForcingAutoRegressiveSpeechGenerator(AutoRegressiveSpeechGenerator): + @torch.no_grad() + def generate(self, model, sample, has_targ=False, **kwargs): + model.eval() + + src_tokens = sample["net_input"]["src_tokens"] + src_lens = sample["net_input"]["src_lengths"] + prev_out_tokens = sample["net_input"]["prev_output_tokens"] + tgt_lens = sample["target_lengths"] + n_frames_per_step = model.decoder.n_frames_per_step + raw_dim = model.decoder.out_dim // n_frames_per_step + bsz = src_tokens.shape[0] + + feat, eos_prob, extra = model( + src_tokens, + src_lens, + prev_out_tokens, + incremental_state=None, + target_lengths=tgt_lens, + speaker=sample["speaker"], + ) + + attn = extra["attn"] # B x T_s x T_t + alignment = attn.max(dim=1)[1] + feat = feat.reshape(bsz, -1, raw_dim) + feat = self.gcmvn_denormalize(feat) + eos_prob = eos_prob.repeat_interleave(n_frames_per_step, dim=1) + attn = attn.repeat_interleave(n_frames_per_step, dim=2) + alignment = alignment.repeat_interleave(n_frames_per_step, dim=1) + tgt_lens = sample["target_lengths"] * n_frames_per_step + + finalized = [ + { + "feature": feat[b, :tgt_len], + "eos_prob": eos_prob[b, :tgt_len], + "attn": attn[b, :, :tgt_len], + "alignment": alignment[b, :tgt_len], + "waveform": self.get_waveform(feat[b, :tgt_len]), + } + for b, tgt_len in zip(range(bsz), tgt_lens) + ] + + if has_targ: + tgt_feats = sample["target"].view(bsz, -1, raw_dim) + tgt_feats = self.gcmvn_denormalize(tgt_feats) + for b, (f, l) in enumerate(zip(tgt_feats, tgt_lens)): + finalized[b]["targ_feature"] = f[:l] + finalized[b]["targ_waveform"] = self.get_waveform(f[:l]) + return finalized diff --git a/fairseq/fairseq/tokenizer.py b/fairseq/fairseq/tokenizer.py new file mode 100644 index 0000000000000000000000000000000000000000..42131f7b1d334020c3b48a6e44d4139f7c62ad28 --- /dev/null +++ b/fairseq/fairseq/tokenizer.py @@ -0,0 +1,15 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import re + + +SPACE_NORMALIZER = re.compile(r"\s+") + + +def tokenize_line(line): + line = SPACE_NORMALIZER.sub(" ", line) + line = line.strip() + return line.split() diff --git a/fairseq/fairseq/trainer.py b/fairseq/fairseq/trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..16b1b9169738269147a615e1ca52036205e74421 --- /dev/null +++ b/fairseq/fairseq/trainer.py @@ -0,0 +1,1622 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +""" +Train a network across multiple GPUs. +""" + +import contextlib +import logging +import os +import sys +import time +from argparse import Namespace +from itertools import chain +from typing import Any, Dict, List + +import torch +from omegaconf import OmegaConf + +from fairseq import checkpoint_utils, models, optim, utils +from fairseq.dataclass.configs import FairseqConfig +from fairseq.dataclass.utils import convert_namespace_to_omegaconf +from fairseq.distributed import utils as distributed_utils +from fairseq.file_io import PathManager +from fairseq.logging import meters, metrics +from fairseq.models.ema import build_ema +from fairseq.nan_detector import NanDetector +from fairseq.optim import lr_scheduler +from fairseq.utils import safe_hasattr + +logger = logging.getLogger(__name__) + + +class Trainer(object): + """Main class for data parallel training. + + This class supports synchronous distributed data parallel training, + where multiple workers each have a full model replica and gradients + are accumulated across workers before each update. We use + :class:`~torch.nn.parallel.DistributedDataParallel` to handle + communication of the gradients across workers. + """ + + def __init__(self, cfg: FairseqConfig, task, model, criterion, quantizer=None): + + if isinstance(cfg, Namespace): + logger.warning( + "argparse.Namespace configuration is deprecated! Automatically converting to OmegaConf" + ) + cfg = convert_namespace_to_omegaconf(cfg) + + self.cfg = cfg + self.task = task + + # catalog shared parameters + shared_params = _catalog_shared_params(model) + self.tpu = cfg.common.tpu + self.cuda = torch.cuda.is_available() and not cfg.common.cpu and not self.tpu + if self.cuda: + self.device = torch.device("cuda") + elif self.tpu: + self.device = utils.get_tpu_device() + else: + self.device = torch.device("cpu") + + if self.is_fsdp: + import fairscale + + if self.cfg.common.bf16: + raise ValueError( + "FullyShardedDataParallel is not compatible with --bf16 or " + "--memory-efficient-bf16" + ) + if self.cfg.distributed_training.zero_sharding != "none": + raise ValueError( + "FullyShardedDataParallel is not compatible with --zero-sharding " + "option (it's already built in)" + ) + if ( + max(self.cfg.optimization.update_freq) > 1 + and fairscale.__version__ < "0.4.0" + ): + raise RuntimeError( + "Please update to fairscale 0.4.0 or newer when combining " + "--update-freq with FullyShardedDataParallel" + ) + else: + if ( + hasattr(self.cfg.distributed_training, "cpu_offload") + and self.cfg.distributed_training.cpu_offload + ): + raise ValueError("--cpu-offload requires --ddp-backend=fully_sharded") + + # copy model and criterion to current device/dtype + self._criterion = criterion + self._model = model + if not self.is_fsdp: + if cfg.common.fp16: + assert not cfg.common.amp, "Cannot use fp16 and AMP together" + self._criterion = self._criterion.half() + self._model = self._model.half() + elif cfg.common.bf16: + self._criterion = self._criterion.to(dtype=torch.bfloat16) + self._model = self._model.to(dtype=torch.bfloat16) + elif cfg.common.amp: + self._amp_retries = 0 + if ( + not cfg.distributed_training.pipeline_model_parallel + # the DistributedFairseqModel wrapper will handle moving to device, + # so only handle cases which don't use the wrapper + and not self.use_distributed_wrapper + ): + self._criterion = self._criterion.to(device=self.device) + self._model = self._model.to(device=self.device) + self.pipeline_model_parallel = cfg.distributed_training.pipeline_model_parallel + self.last_device = None + if self.cuda and self.pipeline_model_parallel: + self.last_device = torch.device( + cfg.distributed_training.pipeline_devices[-1] + ) + + # check that shared parameters are preserved after device transfer + for shared_param in shared_params: + ref = _get_module_by_path(self._model, shared_param[0]) + for path in shared_param[1:]: + logger.info( + "detected shared parameter: {} <- {}".format(shared_param[0], path) + ) + _set_module_by_path(self._model, path, ref) + + self._dummy_batch = None # indicates we don't have a dummy batch at first + self._lr_scheduler = None + self._num_updates = 0 + self._num_xla_compiles = 0 # for TPUs + self._optim_history = None + self._optimizer = None + self._warn_once = set() + self._wrapped_criterion = None + self._wrapped_model = None + self._ema = None + + # TODO(myleott): support tpu + if self.cuda and self.data_parallel_world_size > 1: + self._grad_norm_buf = torch.cuda.DoubleTensor(self.data_parallel_world_size) + else: + self._grad_norm_buf = None + + self.quantizer = quantizer + if self.quantizer is not None: + self.quantizer.set_trainer(self) + + # get detailed cuda environment + if self.cuda: + self.cuda_env = utils.CudaEnvironment() + if self.data_parallel_world_size > 1: + self.cuda_env_arr = distributed_utils.all_gather_list( + self.cuda_env, group=distributed_utils.get_global_group() + ) + else: + self.cuda_env_arr = [self.cuda_env] + if self.data_parallel_rank == 0: + utils.CudaEnvironment.pretty_print_cuda_env_list(self.cuda_env_arr) + else: + self.cuda_env = None + self.cuda_env_arr = None + + metrics.log_start_time("wall", priority=790, round=0) + + self._start_time = time.time() + self._previous_training_time = 0 + self._cumulative_training_time = None + + def reinitialize(self): + """Reinitialize the Trainer, typically after model params change.""" + self._lr_scheduler = None + self._optimizer = None + self._wrapped_criterion = None + self._wrapped_model = None + + @property + def data_parallel_world_size(self): + if self.cfg.distributed_training.distributed_world_size == 1: + return 1 + return distributed_utils.get_data_parallel_world_size() + + @property + def data_parallel_process_group(self): + return distributed_utils.get_data_parallel_group() + + @property + def data_parallel_rank(self): + if self.cfg.distributed_training.distributed_world_size == 1: + return 0 + return distributed_utils.get_data_parallel_rank() + + @property + def is_data_parallel_master(self): + # NOTE: this returns true for all model parallel replicas with data + # parallel rank 0 + return self.data_parallel_rank == 0 + + @property + def use_distributed_wrapper(self) -> bool: + return ( + self.data_parallel_world_size > 1 and not self.cfg.optimization.use_bmuf + ) or (self.is_fsdp and self.cfg.distributed_training.cpu_offload) + + @property + def should_save_checkpoint_on_current_rank(self) -> bool: + """Indicates whether to save checkpoints on the current DDP rank.""" + if ( + self.is_fsdp and self.cfg.distributed_training.use_sharded_state + ) or getattr(self.cfg.model, "base_layers", 0) > 0: + return True + else: + return self.is_data_parallel_master + + @property + def always_call_state_dict_during_save_checkpoint(self) -> bool: + if self.is_fsdp and not self.cfg.distributed_training.use_sharded_state: + # FSDP calls communication collective when consolidating checkpoints + return True + else: + return False + + @property + def checkpoint_suffix(self) -> str: + """Suffix to add to the checkpoint file name.""" + if self.is_fsdp and self.cfg.distributed_training.use_sharded_state: + return self.cfg.checkpoint.checkpoint_suffix + "-shard{0}".format( + self.data_parallel_rank + ) + else: + return self.cfg.checkpoint.checkpoint_suffix or "" + + @property + def criterion(self): + if self._wrapped_criterion is None: + if utils.has_parameters(self._criterion) and self.use_distributed_wrapper: + self._wrapped_criterion = models.DistributedFairseqModel( + self.cfg.distributed_training, + self._criterion, + process_group=self.data_parallel_process_group, + device=self.device, + ) + else: + self._wrapped_criterion = self._criterion + return self._wrapped_criterion + + @property + def model(self): + if self._wrapped_model is None: + if self.use_distributed_wrapper: + self._wrapped_model = models.DistributedFairseqModel( + self.cfg.distributed_training, + self._model, + process_group=self.data_parallel_process_group, + device=self.device, + ) + else: + self._wrapped_model = self._model + return self._wrapped_model + + @property + def ema(self): + if self._ema is None: + self._build_ema() + return self._ema + + def _build_ema(self): + if self.cfg.ema.store_ema: + self._ema = build_ema(self._model, self.cfg.ema, self.device) + logger.info("Exponential Moving Average Shadow Model is initialized.") + + @property + def optimizer(self): + if self._optimizer is None: + self._build_optimizer() + return self._optimizer + + @property + def lr_scheduler(self): + if self._lr_scheduler is None: + self._build_optimizer() # this will initialize self._lr_scheduler + return self._lr_scheduler + + def _build_optimizer(self): + + if ( + self.cfg.optimization.debug_param_names + and self.cfg.common.fp16_no_flatten_grads + ): + params = [] + self.param_names = [] + + for n, p in chain( + self.model.named_parameters(), self.criterion.named_parameters() + ): + if p.requires_grad: + params.append(p) + self.param_names.append(n) + else: + params = list( + filter( + lambda p: p.requires_grad, + chain(self.model.parameters(), self.criterion.parameters()), + ) + ) + + if self.is_fsdp and self.cfg.common.fp16: + # FullyShardedDataParallel always uses MemoryEfficientFP16 wrapper, + # mostly for the grad scaling. But if we don't have the + # --memory-efficient-fp16 flag set, then we're effectively doing + # regular --fp16 and can allow the use of optimizers that would + # otherwise be unsupported by MemoryEfficientFP16Optimizer. + allow_unsupported = not self.cfg.common.memory_efficient_fp16 + self._optimizer = optim.MemoryEfficientFP16Optimizer.build_optimizer( + self.cfg, params, allow_unsupported=allow_unsupported + ) + elif self.cfg.common.fp16 or self.cfg.common.bf16 or self.cfg.common.amp: + if self.cuda and torch.cuda.get_device_capability(0)[0] < 7: + logger.info( + "NOTE: your device does NOT support faster training with --fp16 or --amp, " + "please switch to FP32 which is likely to be faster" + ) + if ( + self.cfg.common.memory_efficient_fp16 + or self.cfg.common.memory_efficient_bf16 + ): + self._optimizer = optim.MemoryEfficientFP16Optimizer.build_optimizer( + self.cfg, params + ) + elif self.cfg.common.amp: + self._optimizer = optim.AMPOptimizer.build_optimizer(self.cfg, params) + else: + self._optimizer = optim.FP16Optimizer.build_optimizer(self.cfg, params) + else: + if self.cuda and torch.cuda.get_device_capability(0)[0] >= 7: + logger.info( + "NOTE: your device may support faster training with --fp16 or --amp" + ) + self._optimizer = optim.build_optimizer(self.cfg.optimizer, params) + + if self.is_fsdp: + assert ( + not self.cfg.optimization.use_bmuf + ), "--ddp-backend=fully_sharded is not compatible with BMUF" + assert self._optimizer.supports_flat_params, ( + "--ddp-backend=fully_sharded is only compatible with pointwise " + "optimizers (e.g., Adam, AdamW, Adadelta, Adamax, SGD, etc.). " + "However, the sharding will result in slightly different results when " + "using non-pointwise optimizers (e.g., Adagrad, Adafactor, LAMB)" + ) + + if self.cfg.optimization.use_bmuf: + self._optimizer = optim.FairseqBMUF( + self.cfg.bmuf, + self._optimizer, + ) + + if self.cfg.distributed_training.zero_sharding == "os": + if ( + self.cfg.common.fp16 + and not self.cfg.common.memory_efficient_fp16 + and not self.cfg.common.memory_efficient_bf16 + ) and not self.cfg.common.fp16_no_flatten_grads: + raise ValueError( + "ZeRO is incomptabile with fp16 and flattened grads. " + "Please use --fp16-no-flatten-grads" + ) + else: + optim.shard_(self._optimizer, self.data_parallel_process_group) + + # We should initialize the learning rate scheduler immediately after + # building the optimizer, so that the initial learning rate is set. + self._lr_scheduler = lr_scheduler.build_lr_scheduler( + self.cfg.lr_scheduler, + self.optimizer, + ) + self._lr_scheduler.step_update(0) + + @property + def is_fsdp(self): + return self.cfg.distributed_training.ddp_backend == "fully_sharded" + + def consolidate_optimizer(self): + """For OSS, we need to consolidate the state dict.""" + if self.cfg.checkpoint.no_save_optimizer_state: + return + self._gathered_optim_state = None + if hasattr(self.optimizer.optimizer, "consolidate_state_dict"): + self.optimizer.optimizer.consolidate_state_dict() + elif self.is_fsdp and not self.model.use_sharded_state: + st = self.model.gather_full_optim_state_dict( + self.optimizer + ) # only returns on rank 0 + self._gathered_optim_state = st + + def state_dict(self): + state_dict = { + "args": None, # legacy + "cfg": ( + OmegaConf.to_container(self.cfg, resolve=True, enum_to_str=True) + if OmegaConf.is_config(self.cfg) + else self.cfg + ), + "model": self.model.state_dict(), + "criterion": ( + self.criterion.state_dict() + if utils.has_parameters(self.criterion) + else None + ), + "optimizer_history": (self._optim_history or []) + + [ + { + "criterion_name": self.get_criterion().__class__.__name__, + "optimizer_name": self.optimizer.__class__.__name__, + "lr_scheduler_state": self.lr_scheduler.state_dict(), + "num_updates": self.get_num_updates(), + } + ], + "task_state": self.task.state_dict() if self.task is not None else {}, + "extra_state": { + "metrics": metrics.state_dict(), + "previous_training_time": self.cumulative_training_time(), + }, + } + if self.cfg.ema.store_ema: + # Save EMA model state as extra state + state_dict["extra_state"]["ema"] = self.ema.get_model().state_dict() + if self.cfg.ema.ema_fp32: + # Save EMA params in fp32 + state_dict["extra_state"]["ema_fp32_params"] = self.ema.fp32_params + if not self.cfg.checkpoint.no_save_optimizer_state: + if self._gathered_optim_state is not None: + state_dict["last_optimizer_state"] = self._gathered_optim_state + self._gathered_optim_state = None + else: + state_dict["last_optimizer_state"] = self.optimizer.state_dict() + if self.is_fsdp: + # save meta data for recombining checkpoint upon loading + state_dict["fsdp_metadata"] = self.model.local_metadata_dict() + return state_dict + + def save_checkpoint(self, filename, extra_state): + """Save all training state in a checkpoint file.""" + if self.should_save_checkpoint_on_current_rank: + + logger.info(f"Saving checkpoint to {os.path.abspath(filename)}") + # call state_dict on all ranks in case it needs internal communication + state_dict = utils.move_to_cpu(self.state_dict()) + state_dict["extra_state"].update(extra_state) + + checkpoint_utils.torch_persistent_save( + state_dict, + filename, + async_write=self.cfg.checkpoint.write_checkpoints_asynchronously, + ) + logger.info(f"Finished saving checkpoint to {os.path.abspath(filename)}") + return os.path.abspath(filename) + return None + + def load_checkpoint( + self, + filename, + reset_optimizer=False, + reset_lr_scheduler=False, + optimizer_overrides=None, + reset_meters=False, + ): + """ + Load all training state from a checkpoint file. + rank = 0 will load the checkpoint, and then broadcast it to all + other ranks. + """ + extra_state, self._optim_history, last_optim_state = None, [], None + + logger.info(f"Preparing to load checkpoint {filename}") + is_distributed = self.data_parallel_world_size > 1 + bexists = PathManager.isfile(filename) + if bexists: + load_on_all_ranks = ( + self.cfg.checkpoint.load_checkpoint_on_all_dp_ranks + # TPUs don't support broadcast yet, so load checkpoints + # on every worker for now + or self.tpu + # FSDP requires loading checkpoint shards on all ranks + or (self.is_fsdp and self.cfg.distributed_training.use_sharded_state) + or getattr(self.cfg.model, "base_layers", 0) > 0 + ) + + if load_on_all_ranks or self.data_parallel_rank == 0: + state = checkpoint_utils.load_checkpoint_to_cpu( + filename, load_on_all_ranks=load_on_all_ranks + ) + last_optim_state = state.get("last_optimizer_state", None) + + # If doing zero_sharding, do not broadcast global optimizer + # state. Later we will broadcast sharded states to each rank + # to avoid memory from exploding. + if ( + not load_on_all_ranks + and self.cfg.distributed_training.zero_sharding == "os" + and "last_optimizer_state" in state + and is_distributed + ): + state["last_optimizer_state"] = "SHARDED" + else: + last_optim_state = None + state = None + + if is_distributed and not load_on_all_ranks: + state = distributed_utils.broadcast_object( + state, + src_rank=0, + group=self.data_parallel_process_group, + dist_device=self.device, + ) + if self.data_parallel_rank > 0: + last_optim_state = state.get("last_optimizer_state", None) + + # load model parameters + try: + if ( + "optimizer_history" in state + and len(state["optimizer_history"]) > 0 + and "num_updates" in state["optimizer_history"][-1] + ): + self.model.set_num_updates( + state["optimizer_history"][-1]["num_updates"] + ) + + # this is the code related to AdaPrune + # In short, it removes redundant heads in multi-head attention module based on heads importance provided + # For more info, please refer to the paper: https://openreview.net/forum?id=_CMSV7FTzGI + # The idea of prune in mha can be summarized as + # Fine tune model (e.g. roberta encoder) on a certain datasets with regularization + # After the model is trained. User could use get_reserve_head_index and _adaptive_prune_heads functions to get the top X heads with most importance. + # Then user uses the rank to prune a new roberta encoder and save the pruned ckpt manually. + # User will fine tune the the new roberta encoder via the ckpt saved above + # To get rid of registering different pruned version of Roberta, I use the argument --mha-heads-to-keep to prune the Roberta model into a pruned version which matches the pruned ckpt. + if ( + safe_hasattr(self.model, "args") + and safe_hasattr(self.model.args, "mha_heads_to_keep") + and self.model.args.mha_heads_to_keep != -1 + ): + logger.info( + f"Prune model: keep {self.model.args.mha_heads_to_keep} heads for each multihead attention module" + ) + for layer in self.model.encoder.sentence_encoder.layers: + reserve_head_index = layer.self_attn._get_reserve_head_index( + num_heads_to_keep=self.model.args.mha_heads_to_keep + ) + layer.self_attn._adaptive_prune_heads( + reserve_head_index=reserve_head_index + ) + layer.self_attn._set_skip_embed_dim_check() + logger.info(self.model) + # this is the code related to AdaPrune + # In short, it removes redundant units in feedforward layer in each transformer layer based on importance + # For more info, please refer to the paper: https://openreview.net/forum?id=_CMSV7FTzGI + # The idea of prune in ffn can be summarized as + # Fine tune model (e.g. roberta encoder) on a certain datasets with regularization + # After the model is trained. User could use _get_fc_rank and _prune_fc_layer functions to get the top X units with most importance. + # Then user uses the rank to prune a new roberta encoder and save the pruned ckpt manually. + # User will fine tune the the new roberta encoder via the ckpt saved above + # To get rid of registering different pruned version of Roberta, I use the argument --ffn-blocks-to-remove to prune the Roberta model into a pruned version which matches the pruned ckpt. + if ( + safe_hasattr(self.model, "args") + and safe_hasattr(self.model.args, "ffn_blocks_to_remove") + and self.model.args.ffn_blocks_to_remove != -1 + ): + logger.info( + f"Prune model: remove {self.model.args.ffn_blocks_to_remove} ffn blocks for each transformer layer" + ) + for layer in self.model.encoder.sentence_encoder.layers: + remove_index = layer._get_fc_rank( + remove_num=self.model.args.ffn_blocks_to_remove + ) + layer._prune_fc_layer(remove_index=remove_index) + logger.info(self.model) + + self.model.load_state_dict( + state["model"], strict=True, model_cfg=self.cfg.model + ) + # save memory for later steps + del state["model"] + if utils.has_parameters(self.get_criterion()): + self.get_criterion().load_state_dict( + state["criterion"], strict=True + ) + del state["criterion"] + + except Exception: + raise Exception( + "Cannot load model parameters from checkpoint {}; " + "please ensure that the architectures match.".format(filename) + ) + extra_state = state["extra_state"] + self._optim_history = state["optimizer_history"] + + if last_optim_state is not None and not reset_optimizer: + # rebuild optimizer after loading model, since params may have changed + self._build_optimizer() + + # only reload optimizer and lr_scheduler if they match + last_optim = self._optim_history[-1] + assert ( + last_optim["criterion_name"] == self.get_criterion().__class__.__name__ + ), f"Criterion does not match; please reset the optimizer (--reset-optimizer). {last_optim['criterion_name']} vs {self.get_criterion().__class__.__name__}" + assert ( + last_optim["optimizer_name"] == self.optimizer.__class__.__name__ + ), f"Optimizer does not match; please reset the optimizer (--reset-optimizer). {last_optim['optimizer_name']} vs {self.optimizer.__class__.__name__}" + + if not reset_lr_scheduler: + self.lr_scheduler.load_state_dict(last_optim["lr_scheduler_state"]) + + if self.is_fsdp and not self.model.use_sharded_state: + # if use_sharded_state, the last_optim_state is already sharded, skip this + last_optim_state = self.model.get_shard_from_optim_state_dict( + last_optim_state + ) + elif not load_on_all_ranks and is_distributed: + last_optim_state = self.optimizer.broadcast_global_state_dict( + last_optim_state + ) + + self.optimizer.load_state_dict(last_optim_state, optimizer_overrides) + + self.set_num_updates(last_optim["num_updates"]) + + if extra_state is not None: + itr_state = extra_state["train_iterator"] + epoch = itr_state["epoch"] + + if "previous_training_time" in extra_state: + self._previous_training_time = extra_state["previous_training_time"] + self._start_time = time.time() + + self.lr_step(epoch) + + if ( + itr_state.get("version", 1) >= 2 + and itr_state["iterations_in_epoch"] == 0 + ): + # reset meters at start of epoch + reset_meters = True + + if "metrics" in extra_state and not reset_meters: + metrics.load_state_dict(extra_state["metrics"]) + + # reset TimeMeters, since their start times don't make sense anymore + for meter in metrics.get_meters("default"): + if isinstance(meter, meters.TimeMeter): + meter.reset() + + if self.cfg.ema.store_ema: + if "ema" not in extra_state: + logger.warn( + "EMA not found in checkpoint. But store_ema is True. " + "EMA is re-initialized from checkpoint." + ) + self.ema.restore( + state["model"], build_fp32_params=self.cfg.ema.ema_fp32 + ) + else: + logger.info("Loading EMA from checkpoint") + self.ema.restore(extra_state["ema"], build_fp32_params=False) + + if self.cfg.ema.ema_fp32: + if "ema_fp32_params" in extra_state: + logger.info("Loading EMA fp32 params from checkpoint") + self.ema.build_fp32_params(extra_state["ema_fp32_params"]) + else: + logger.info( + "Building EMA fp32 params from EMA model in checkpoint" + ) + self.ema.build_fp32_params() + + logger.info( + "Loaded checkpoint {} (epoch {} @ {} updates)".format( + filename, epoch, self.get_num_updates() + ) + ) + + else: + logger.info("No existing checkpoint found {}".format(filename)) + + return extra_state + + def get_train_iterator( + self, + epoch, + combine=True, + load_dataset=True, + data_selector=None, + shard_batch_itr=True, + disable_iterator_cache=False, + ): + """Return an EpochBatchIterator over the training set for a given epoch.""" + if load_dataset: + logger.info("loading train data for epoch {}".format(epoch)) + self.task.load_dataset( + self.cfg.dataset.train_subset, + epoch=epoch, + combine=combine, + data_selector=data_selector, + tpu=self.tpu, + ) + batch_iterator = self.task.get_batch_iterator( + dataset=self.task.dataset(self.cfg.dataset.train_subset), + max_tokens=self.cfg.dataset.max_tokens, + max_sentences=self.cfg.dataset.batch_size, + max_positions=utils.resolve_max_positions( + self.task.max_positions(), + self.model.max_positions(), + self.cfg.dataset.max_tokens, + ), + ignore_invalid_inputs=True, + required_batch_size_multiple=self.cfg.dataset.required_batch_size_multiple, + seed=(self.cfg.common.seed + epoch) + if self.cfg.dataset.update_ordered_indices_seed + else self.cfg.common.seed, + num_shards=self.data_parallel_world_size if shard_batch_itr else 1, + shard_id=self.data_parallel_rank if shard_batch_itr else 0, + num_workers=self.cfg.dataset.num_workers, + epoch=epoch, + data_buffer_size=self.cfg.dataset.data_buffer_size, + disable_iterator_cache=disable_iterator_cache, + skip_remainder_batch=self.cfg.optimization.skip_remainder_batch, + grouped_shuffling=self.cfg.dataset.grouped_shuffling, + update_epoch_batch_itr=self.cfg.dataset.update_epoch_batch_itr, + ) + self.reset_dummy_batch(batch_iterator.first_batch) + return batch_iterator + + def get_valid_iterator( + self, + subset, + disable_iterator_cache=False, + ): + """Return an EpochBatchIterator over given validation subset for a given epoch.""" + batch_iterator = self.task.get_batch_iterator( + dataset=self.task.dataset(subset), + max_tokens=self.cfg.dataset.max_tokens_valid, + max_sentences=self.cfg.dataset.batch_size_valid, + max_positions=utils.resolve_max_positions( + self.task.max_positions(), + self.model.max_positions(), + ), + ignore_invalid_inputs=self.cfg.dataset.skip_invalid_size_inputs_valid_test, + required_batch_size_multiple=self.cfg.dataset.required_batch_size_multiple, + seed=self.cfg.common.seed, + num_shards=self.data_parallel_world_size, + shard_id=self.data_parallel_rank, + num_workers=self.cfg.dataset.num_workers, + # always pass a fixed "epoch" to keep validation data consistent + # across training epochs + epoch=1, + data_buffer_size=self.cfg.dataset.data_buffer_size, + disable_iterator_cache=disable_iterator_cache, + skip_remainder_batch=False, + ) + self.reset_dummy_batch(batch_iterator.first_batch) + return batch_iterator + + def begin_epoch(self, epoch): + """Called at the beginning of each epoch.""" + logger.info("begin training epoch {}".format(epoch)) + + self.lr_step_begin_epoch(epoch) + + if self.quantizer is not None: + self.quantizer.begin_epoch(epoch) + + # task specific setup per epoch + self.task.begin_epoch(epoch, self.get_model()) + + if self.tpu: + import torch_xla.core.xla_model as xm + + xm.rendezvous("begin_epoch") # wait for all workers + xm.mark_step() + + def begin_valid_epoch(self, epoch): + """Called at the beginning of each validation epoch.""" + + # task specific setup per validation epoch + self.task.begin_valid_epoch(epoch, self.get_model()) + + def reset_dummy_batch(self, batch): + self._dummy_batch = batch + + @metrics.aggregate("train") + def train_step(self, samples, raise_oom=False): + """Do forward, backward and parameter update.""" + self._set_seed() + self.model.train() + self.criterion.train() + self.zero_grad() + + metrics.log_start_time("train_wall", priority=800, round=0) + + # If EMA is enabled through store_ema=True + # and task.uses_ema is True, pass the EMA model as a keyword + # argument to the task. + extra_kwargs = {} + if self.cfg.ema.store_ema and getattr(self.task, "uses_ema", False): + extra_kwargs["ema_model"] = self.ema.get_model() + + has_oom = False + + # forward and backward pass + logging_outputs, sample_size, ooms = [], 0, 0 + for i, sample in enumerate(samples): # delayed update loop + sample, is_dummy_batch = self._prepare_sample(sample) + + def maybe_no_sync(): + """ + Whenever *samples* contains more than one mini-batch, we + want to accumulate gradients locally and only call + all-reduce in the last backwards pass. + """ + if ( + self.data_parallel_world_size > 1 + and hasattr(self.model, "no_sync") + and i < len(samples) - 1 + # The no_sync context manager results in increased memory + # usage with FSDP, since full-size gradients will be + # accumulated on each GPU. It's typically a better tradeoff + # to do the extra communication with FSDP. + and not self.is_fsdp + ): + return self.model.no_sync() + else: + return contextlib.ExitStack() # dummy contextmanager + + try: + with maybe_no_sync(): + # forward and backward + loss, sample_size_i, logging_output = self.task.train_step( + sample=sample, + model=self.model, + criterion=self.criterion, + optimizer=self.optimizer, + update_num=self.get_num_updates(), + ignore_grad=is_dummy_batch, + **extra_kwargs, + ) + del loss + + logging_outputs.append(logging_output) + sample_size += sample_size_i + + # emptying the CUDA cache after the first step can + # reduce the chance of OOM + if self.cuda and self.get_num_updates() == 0: + torch.cuda.empty_cache() + except RuntimeError as e: + if "out of memory" in str(e): + self._log_oom(e) + has_oom = True + if raise_oom: + raise e + else: + raise e + except Exception: + self.consolidate_optimizer() + self.save_checkpoint( + os.path.join(self.cfg.checkpoint.save_dir, "crash.pt"), {} + ) + raise + + if has_oom: + logger.warning( + "attempting to recover from OOM in forward/backward pass" + ) + ooms += 1 + self.zero_grad() + if self.cuda: + torch.cuda.empty_cache() + + if self.cfg.distributed_training.distributed_world_size == 1: + return None + + if self.tpu and i < len(samples) - 1: + # tpu-comment: every XLA operation before marking step is + # appended to the IR graph, and processing too many batches + # before marking step can lead to OOM errors. + # To handle gradient accumulation use case, we explicitly + # mark step here for every forward pass without a backward pass + self._xla_markstep_and_send_to_cpu() + + if is_dummy_batch: + if torch.is_tensor(sample_size): + sample_size.zero_() + else: + sample_size *= 0.0 + + if torch.is_tensor(sample_size): + sample_size = sample_size.float() + else: + sample_size = float(sample_size) + + # gather logging outputs from all replicas + if self._sync_stats(): + train_time = self._local_cumulative_training_time() + ( + logging_outputs, + ( + sample_size, + ooms, + total_train_time, + ), + ) = self._aggregate_logging_outputs( + logging_outputs, sample_size, ooms, train_time, ignore=is_dummy_batch + ) + self._cumulative_training_time = ( + total_train_time / self.data_parallel_world_size + ) + + overflow = False + try: + with torch.autograd.profiler.record_function("reduce-grads"): + # reduce gradients across workers + self.optimizer.all_reduce_grads(self.model) + if utils.has_parameters(self.criterion): + self.optimizer.all_reduce_grads(self.criterion) + + with torch.autograd.profiler.record_function("multiply-grads"): + # multiply gradients by (data_parallel_size / sample_size) since + # DDP normalizes by the number of data parallel workers for + # improved fp16 precision. + # Thus we get (sum_of_gradients / sample_size) at the end. + # In case of fp16, this step also undoes loss scaling. + # (Debugging note: Some optimizers perform this scaling on the + # fly, so inspecting model.parameters() or optimizer.params may + # still show the original, unscaled gradients.) + numer = ( + self.data_parallel_world_size + if not self.cfg.optimization.use_bmuf or self._sync_stats() + else 1 + ) + self.optimizer.multiply_grads(numer / (sample_size or 1.0)) + # Note: (sample_size or 1.0) handles the case of a zero gradient, in a + # way that avoids CPU/device transfers in case sample_size is a GPU or + # TPU object. The assumption is that the gradient itself is also 0. + + with torch.autograd.profiler.record_function("clip-grads"): + # clip grads + grad_norm = self.clip_grad_norm(self.cfg.optimization.clip_norm) + + # check that grad norms are consistent across workers + # on tpu check tensor is slow + if not self.tpu: + if ( + not self.cfg.optimization.use_bmuf + and self.cfg.distributed_training.ddp_backend != "slowmo" + ): + self._check_grad_norms(grad_norm) + if not torch.isfinite(grad_norm).all(): + # in case of AMP, if gradients are Nan/Inf then + # optimizer step is still required + if self.cfg.common.amp: + overflow = True + else: + # check local gradnorm single GPU case, trigger NanDetector + raise FloatingPointError("gradients are Nan/Inf") + + with torch.autograd.profiler.record_function("optimizer"): + # take an optimization step + self.task.optimizer_step( + self.optimizer, model=self.model, update_num=self.get_num_updates() + ) + if self.cfg.common.amp and overflow: + if self._amp_retries == self.cfg.common.amp_batch_retries: + logger.info("AMP: skipping this batch.") + self._amp_retries = 0 + else: + self._amp_retries += 1 + return self.train_step( + samples, raise_oom + ) # recursion to feed in same batch + + except FloatingPointError: + + self.consolidate_optimizer() + self.save_checkpoint( + os.path.join(self.cfg.checkpoint.save_dir, "crash.pt"), {} + ) + + # re-run the forward and backward pass with hooks attached to print + # out where it fails + self.zero_grad() + with NanDetector(self.get_model()): + for _, sample in enumerate(samples): + sample, _ = self._prepare_sample(sample) + self.task.train_step( + sample, + self.model, + self.criterion, + self.optimizer, + self.get_num_updates(), + ignore_grad=False, + **extra_kwargs, + ) + raise + except OverflowError as e: + overflow = True + logger.info( + f"NOTE: gradient overflow detected, ignoring gradient, {str(e)}" + ) + + if hasattr(self, "param_names") and hasattr( + self.optimizer, "fp32_optimizer" + ): + for p, n in zip(self.optimizer.fp32_optimizer.params, self.param_names): + if torch.isinf(p.grad).any() or torch.isnan(p.grad).any(): + logger.info(f"overflow in param {n}") + + grad_norm = torch.tensor(0.0).cuda() + self.zero_grad() + except RuntimeError as e: + if "out of memory" in str(e): + self._log_oom(e) + logger.error("OOM during optimization, irrecoverable") + raise e + + # Some distributed wrappers (e.g., SlowMo) need access to the optimizer + # after the step + if hasattr(self.model, "perform_slowmo"): + self.model.perform_slowmo( + self.optimizer.optimizer, getattr(self.optimizer, "fp32_params", None) + ) + + logging_output = None + if not overflow or self.cfg.distributed_training.ddp_backend == "slowmo": + self.set_num_updates(self.get_num_updates() + 1) + + if self.cfg.ema.store_ema: + # Step EMA forward with new model. + self.ema.step( + self.get_model(), + self.get_num_updates(), + ) + metrics.log_scalar( + "ema_decay", + self.ema.get_decay(), + priority=10000, + round=5, + weight=0, + ) + + if self.tpu: + import torch_xla.core.xla_model as xm + + # mark step on TPUs + self._xla_markstep_and_send_to_cpu() + + # only log stats every log_interval steps + # this causes wps to be misreported when log_interval > 1 + logging_output = {} + if self.get_num_updates() % self.cfg.common.log_interval == 0: + # log memory usage + mem_info = xm.get_memory_info(self.device) + gb_free = mem_info["kb_free"] / 1024 / 1024 + gb_total = mem_info["kb_total"] / 1024 / 1024 + metrics.log_scalar( + "gb_free", gb_free, priority=1500, round=1, weight=0 + ) + metrics.log_scalar( + "gb_total", gb_total, priority=1600, round=1, weight=0 + ) + logging_outputs = self._xla_markstep_and_send_to_cpu( + logging_outputs + ) + logging_output = self._reduce_and_log_stats( + logging_outputs, sample_size, grad_norm + ) + + # log whenever there's an XLA compilation, since these + # slow down training and may indicate opportunities for + # optimization + self._check_xla_compilation() + else: + if self.cuda and self.cuda_env is not None: + # log minimum free memory over the iteration + gb_used = torch.cuda.max_memory_allocated() / 1024 / 1024 / 1024 + torch.cuda.reset_peak_memory_stats() + gb_free = self.cuda_env.total_memory_in_GB - gb_used + metrics.log_scalar( + "gb_free", gb_free, priority=1500, round=1, weight=0 + ) + + # log stats + logging_output = self._reduce_and_log_stats( + logging_outputs, sample_size, grad_norm + ) + + # clear CUDA cache to reduce memory fragmentation + if ( + self.cuda + and self.cfg.common.empty_cache_freq > 0 + and ( + (self.get_num_updates() + self.cfg.common.empty_cache_freq - 1) + % self.cfg.common.empty_cache_freq + ) + == 0 + ): + torch.cuda.empty_cache() + + if self.cfg.common.fp16 or self.cfg.common.amp: + metrics.log_scalar( + "loss_scale", + ( + self.optimizer.scaler.loss_scale + if self.cfg.common.fp16 + else self.optimizer.scaler.get_scale() + ), + priority=700, + round=4, + weight=0, + ) + + metrics.log_stop_time("train_wall") + return logging_output + + @metrics.aggregate("valid") + def valid_step(self, sample, raise_oom=False): + """Do forward pass in evaluation mode.""" + if self.tpu: + import torch_xla.core.xla_model as xm + + xm.rendezvous("valid_step") # wait for all workers + + # If EMA is enabled through store_ema=True + # and task.uses_ema is True, pass the EMA model as a keyword + # argument to the task. + extra_kwargs = {} + if self.cfg.ema.store_ema and getattr(self.task, "uses_ema", False): + extra_kwargs["ema_model"] = self.ema.get_model() + + with torch.no_grad(): + self.model.eval() + self.criterion.eval() + + sample, is_dummy_batch = self._prepare_sample(sample) + + try: + _loss, sample_size, logging_output = self.task.valid_step( + sample, self.model, self.criterion, **extra_kwargs + ) + except RuntimeError as e: + if "out of memory" in str(e): + self._log_oom(e) + if not raise_oom: + logger.warning( + "ran out of memory in validation step, retrying batch" + ) + for p in self.model.parameters(): + if p.grad is not None: + p.grad = None # free some memory + if self.cuda: + torch.cuda.empty_cache() + return self.valid_step(sample, raise_oom=True) + raise e + + logging_outputs = [logging_output] + if is_dummy_batch: + if torch.is_tensor(sample_size): + sample_size.zero_() + else: + sample_size *= 0.0 + + # gather logging outputs from all replicas + if self.data_parallel_world_size > 1: + logging_outputs, (sample_size,) = self._aggregate_logging_outputs( + logging_outputs, + sample_size, + ignore=is_dummy_batch, + ) + + # log validation stats + if self.tpu: + logging_outputs = self._xla_markstep_and_send_to_cpu(logging_outputs) + logging_output = self._reduce_and_log_stats(logging_outputs, sample_size) + + return logging_output + + def zero_grad(self): + self.optimizer.zero_grad() + + def lr_step_begin_epoch(self, epoch): + """Adjust the learning rate at the beginning of the epoch.""" + self.lr_scheduler.step_begin_epoch(epoch) + # prefer updating the LR based on the number of steps + return self.lr_step_update() + + def lr_step(self, epoch, val_loss=None): + """Adjust the learning rate at the end of the epoch.""" + self.lr_scheduler.step(epoch, val_loss) + # prefer updating the LR based on the number of steps + return self.lr_step_update() + + def lr_step_update(self): + """Update the learning rate after each update.""" + new_lr = self.lr_scheduler.step_update(self.get_num_updates()) + if isinstance(new_lr, dict): + for k, v in new_lr.items(): + metrics.log_scalar(f"lr_{k}", v, weight=0, priority=300) + new_lr = new_lr.get("default", next(iter(new_lr.values()))) + else: + metrics.log_scalar("lr", new_lr, weight=0, priority=300) + return new_lr + + def get_lr(self): + """Get the current learning rate.""" + return self.optimizer.get_lr() + + def get_model(self): + """Get the (non-wrapped) model instance.""" + return self._model + + def get_criterion(self): + """Get the (non-wrapped) criterion instance.""" + return self._criterion + + def get_meter(self, name): + """[deprecated] Get a specific meter by name.""" + from fairseq import meters + + if "get_meter" not in self._warn_once: + self._warn_once.add("get_meter") + utils.deprecation_warning( + "Trainer.get_meter is deprecated. Please use fairseq.metrics instead." + ) + + train_meters = metrics.get_meters("train") + if train_meters is None: + train_meters = {} + + if name == "train_loss" and "loss" in train_meters: + return train_meters["loss"] + elif name == "train_nll_loss": + # support for legacy train.py, which assumed this meter is + # always initialized + m = train_meters.get("nll_loss", None) + return m or meters.AverageMeter() + elif name == "wall": + # support for legacy train.py, which assumed this meter is + # always initialized + m = metrics.get_meter("default", "wall") + return m or meters.TimeMeter() + elif name == "wps": + m = metrics.get_meter("train", "wps") + return m or meters.TimeMeter() + elif name in {"valid_loss", "valid_nll_loss"}: + # support for legacy train.py, which assumed these meters + # are always initialized + k = name[len("valid_") :] + m = metrics.get_meter("valid", k) + return m or meters.AverageMeter() + elif name == "oom": + return meters.AverageMeter() + elif name in train_meters: + return train_meters[name] + return None + + def get_num_updates(self): + """Get the number of parameters updates.""" + return self._num_updates + + def set_num_updates(self, num_updates): + """Set the number of parameters updates.""" + self._num_updates = num_updates + self.lr_step_update() + if self.quantizer: + self.quantizer.step_update(self._num_updates) + metrics.log_scalar("num_updates", self._num_updates, weight=0, priority=200) + + def clip_grad_norm(self, clip_norm): + def agg_norm_fn(total_norm): + total_norm = total_norm.cuda().float() ** 2 + total_norm = distributed_utils.all_reduce( + total_norm, group=self.data_parallel_process_group + ) + return total_norm**0.5 + + should_agg_norm = self.is_fsdp and ( + self.data_parallel_process_group is not None + or torch.distributed.is_initialized() + ) + return self.optimizer.clip_grad_norm( + clip_norm, aggregate_norm_fn=agg_norm_fn if should_agg_norm else None + ) + + def cumulative_training_time(self): + if self._cumulative_training_time is None: + # single GPU + return self._local_cumulative_training_time() + else: + return self._cumulative_training_time + + def _local_cumulative_training_time(self): + """Aggregate training time in seconds.""" + return time.time() - self._start_time + self._previous_training_time + + def _fp_convert_sample(self, sample): + def apply_half(t): + if t.dtype is torch.float32: + return t.to(dtype=torch.half) + return t + + def apply_bfloat16(t): + if t.dtype is torch.float32: + return t.to(dtype=torch.bfloat16) + return t + + if self.cfg.common.fp16: + sample = utils.apply_to_sample(apply_half, sample) + + if self.cfg.common.bf16: + sample = utils.apply_to_sample(apply_bfloat16, sample) + + return sample + + def _prepare_sample(self, sample, is_dummy=False): + if sample == "DUMMY": + raise Exception( + "Trying to use an uninitialized 'dummy' batch. This usually indicates " + "that the total number of batches is smaller than the number of " + "participating GPUs. Try reducing the batch size or using fewer GPUs." + ) + + if sample is None or len(sample) == 0: + assert ( + self._dummy_batch is not None and len(self._dummy_batch) > 0 + ), "Invalid dummy batch: {}".format(self._dummy_batch) + sample, _ = self._prepare_sample(self._dummy_batch, is_dummy=True) + return sample, True + + # Given that PCIe/NVLink bandwidth is significantly smaller than DRAM bandwidth + # it makes sense to do the format conversion on the CPU and then transfer + # a smaller buffer to the device. This also saves GPU memory capacity. + + if self.cfg.common.on_cpu_convert_precision: + sample = self._fp_convert_sample(sample) + + if self.cuda: + if self.pipeline_model_parallel: + if "target" in sample: + sample["target"] = utils.move_to_cuda( + sample["target"], device=self.last_device + ) + else: + sample = utils.move_to_cuda(sample) + elif self.tpu and is_dummy: + # the dummy batch may not be on the appropriate device + sample = utils.move_to_cuda(sample, device=self.device) + + if not self.cfg.common.on_cpu_convert_precision: + sample = self._fp_convert_sample(sample) + + if self._dummy_batch == "DUMMY": + self._dummy_batch = sample + + return sample, False + + def _set_seed(self): + # Set seed based on args.seed and the update number so that we get + # reproducible results when resuming from checkpoints + seed = self.cfg.common.seed + self.get_num_updates() + utils.set_torch_seed(seed) + + def _sync_stats(self): + # Return True if it's using multiple GPUs and DDP or multiple GPUs with + # BMUF and it's a bmuf sync with warmup iterations completed before. + if self.data_parallel_world_size == 1: + return False + elif self.cfg.optimization.use_bmuf: + return ( + self.get_num_updates() + 1 + ) % self.cfg.bmuf.global_sync_iter == 0 and ( + self.get_num_updates() + 1 + ) > self.cfg.bmuf.warmup_iterations + else: + return True + + def _log_oom(self, exc): + msg = "OOM: Ran out of memory with exception: {}".format(exc) + logger.warning(msg) + if torch.cuda.is_available() and hasattr(torch.cuda, "memory_summary"): + for device_idx in range(torch.cuda.device_count()): + logger.warning(torch.cuda.memory_summary(device=device_idx)) + sys.stderr.flush() + + def _aggregate_logging_outputs( + self, + logging_outputs: List[Dict[str, Any]], + *extra_stats_to_sum, + ignore=False, + ): + if self.task.__class__.logging_outputs_can_be_summed(self.get_criterion()): + return self._fast_stat_sync_sum( + logging_outputs, *extra_stats_to_sum, ignore=ignore + ) + else: + return self._all_gather_list_sync( + logging_outputs, *extra_stats_to_sum, ignore=ignore + ) + + def _all_gather_list_sync( + self, + logging_outputs: List[Dict[str, Any]], + *extra_stats_to_sum, + ignore=False, + ): + """ + Sync logging outputs across workers. all_gather_list_sync is + suitable when logging outputs are complex types. + """ + if self.tpu: + raise NotImplementedError + if ignore: + logging_outputs = [] + results = list( + zip( + *distributed_utils.all_gather_list( + [logging_outputs] + list(extra_stats_to_sum), + max_size=getattr(self.cfg.common, "all_gather_list_size", 16384), + group=self.data_parallel_process_group, + ) + ) + ) + logging_outputs, extra_stats_to_sum = results[0], results[1:] + logging_outputs = list(chain.from_iterable(logging_outputs)) + extra_stats_to_sum = [sum(s) for s in extra_stats_to_sum] + return logging_outputs, extra_stats_to_sum + + def _fast_stat_sync_sum( + self, + logging_outputs: List[Dict[str, Any]], + *extra_stats_to_sum, + ignore=False, + ): + """ + Sync logging outputs across workers. fast_stat_sync_sum is + faster than all_gather_list_sync, but is only suitable when + logging outputs are scalars and can be summed. Note that + *logging_outputs* cannot contain any nested dicts/lists. + """ + data = {} + for i, stat in enumerate(extra_stats_to_sum): + data["extra_stats_" + str(i)] = stat + if len(logging_outputs) > 0: + log_keys = list(logging_outputs[0].keys()) + for k in log_keys: + if not ignore: + v = sum(log[k] for log in logging_outputs if k in log) + else: + v = logging_outputs[0][k] + v = torch.zeros_like(v) if torch.is_tensor(v) else 0 + data["logging_outputs_" + k] = v + else: + log_keys = None + + data = distributed_utils.all_reduce_dict( + data, device=self.device, group=self.data_parallel_process_group + ) + + extra_stats_to_sum = [ + data["extra_stats_" + str(i)] for i in range(len(extra_stats_to_sum)) + ] + if log_keys is not None: + logging_outputs = [{k: data["logging_outputs_" + k] for k in log_keys}] + else: + logging_outputs = [] + return logging_outputs, extra_stats_to_sum + + def _check_grad_norms(self, grad_norm): + """Check that grad norms are consistent across workers.""" + if self._grad_norm_buf is not None: + self._grad_norm_buf.zero_() + self._grad_norm_buf[self.data_parallel_rank] = grad_norm + distributed_utils.all_reduce( + self._grad_norm_buf, group=self.data_parallel_process_group + ) + + def is_consistent(tensor): + max_abs_diff = torch.max(torch.abs(tensor - tensor[0])) + return ( + ( + torch.isfinite(tensor).all() + and (max_abs_diff / (tensor[0] + 1e-6) < 1e-6).all() + ) + or (self.cfg.common.amp and not torch.isfinite(tensor).all()) + # in case of amp non-finite grads are fine + ) + + if not is_consistent(self._grad_norm_buf): + pretty_detail = "\n".join( + "rank {:3d} = {:.8f}".format(r, n) + for r, n in enumerate(self._grad_norm_buf.tolist()) + ) + error_detail = "grad_norm across the workers:\n{}\n".format( + pretty_detail + ) + # use FloatingPointError to trigger NanDetector + raise FloatingPointError( + "Fatal error: gradients are inconsistent between workers. " + "Try --ddp-backend=legacy_ddp. " + "Or are you mixing up different generation of GPUs in training?" + + "\n" + + "-" * 80 + + "\n{}\n".format(error_detail) + + "-" * 80 + ) + + def _reduce_and_log_stats(self, logging_outputs, sample_size, grad_norm=None): + if grad_norm is not None and ( + not torch.is_tensor(grad_norm) or torch.isfinite(grad_norm) + ): + metrics.log_speed("ups", 1.0, priority=100, round=2) + metrics.log_scalar("gnorm", grad_norm, priority=400, round=3) + if self.cfg.optimization.clip_norm > 0: + metrics.log_scalar( + "clip", + torch.where( + grad_norm > self.cfg.optimization.clip_norm, + grad_norm.new_tensor(100), + grad_norm.new_tensor(0), + ), + priority=500, + round=1, + ) + + with metrics.aggregate() as agg: + if logging_outputs is not None: + self.task.reduce_metrics(logging_outputs, self.get_criterion()) + del logging_outputs + + # extra warning for criterions that don't properly log a loss value + if "loss" not in agg: + if "loss" not in self._warn_once: + self._warn_once.add("loss") + logger.warning( + "Criterion.reduce_metrics did not log a 'loss' value, " + "which may break some functionality" + ) + metrics.log_scalar("loss", -1) + + # support legacy interface + if self.tpu: + logging_output = {} + else: + logging_output = agg.get_smoothed_values() + logging_output["sample_size"] = sample_size + for key_to_delete in ["ppl", "wps", "wpb", "bsz"]: + if key_to_delete in logging_output: + del logging_output[key_to_delete] + return logging_output + + def _check_xla_compilation(self): + import torch_xla.debug.metrics as met + + compile_stats = met.metric_data("CompileTime") + if compile_stats is None: + return + num_xla_compiles = compile_stats[0] + if num_xla_compiles > self._num_xla_compiles: + logger.warning( + "XLA compilation detected on device #{}; too many of these can lead " + "to slow training, but we expect a few in the beginning".format( + self.cfg.distributed_training.distributed_rank + ) + ) + self._num_xla_compiles = num_xla_compiles + + def _xla_markstep_and_send_to_cpu(self, data=None): + import torch_xla.core.xla_model as xm + + xm.mark_step() + if data is not None: + from fairseq.utils import xla_device_to_cpu + + return xla_device_to_cpu(data) + + +def _catalog_shared_params(module, memo=None, prefix=""): + if memo is None: + first_call = True + memo = {} + else: + first_call = False + for name, param in module._parameters.items(): + param_prefix = prefix + ("." if prefix else "") + name + if param not in memo: + memo[param] = [] + memo[param].append(param_prefix) + for name, m in module._modules.items(): + if m is None: + continue + submodule_prefix = prefix + ("." if prefix else "") + name + _catalog_shared_params(m, memo, submodule_prefix) + if first_call: + return [x for x in memo.values() if len(x) > 1] + + +def _get_module_by_path(module, path): + path = path.split(".") + for name in path: + module = getattr(module, name) + return module + + +def _set_module_by_path(module, path, value): + path = path.split(".") + for name in path[:-1]: + module = getattr(module, name) + setattr(module, path[-1], value) diff --git a/fairseq/fairseq/utils.py b/fairseq/fairseq/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..4d4b35052305e7b60aa958d6d9b88a7ce0201045 --- /dev/null +++ b/fairseq/fairseq/utils.py @@ -0,0 +1,951 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +import collections +import contextlib +import copy +import importlib +import logging +import os +import sys +import warnings +from itertools import accumulate +from typing import TYPE_CHECKING, Callable, Dict, List, Optional + +import torch +import torch.nn.functional as F +from torch import Tensor + +if TYPE_CHECKING: + from fairseq.modules.multihead_attention import MultiheadAttention + +try: + from amp_C import multi_tensor_l2norm + + multi_tensor_l2norm_available = True +except ImportError: + multi_tensor_l2norm_available = False + +try: + import torch_xla.core.xla_model as xm +except ImportError: + xm = None + + +logger = logging.getLogger(__name__) + + +MANIFOLD_PATH_SEP = "|" + + +class FileContentsAction(argparse.Action): + def __init__(self, option_strings, dest, nargs=None, **kwargs): + if nargs is not None: + raise ValueError("nargs not allowed") + super(FileContentsAction, self).__init__(option_strings, dest, **kwargs) + + def __call__(self, parser, namespace, values, option_string=None): + from fairseq.file_io import PathManager + + if PathManager.isfile(values): + with PathManager.open(values) as f: + argument = f.read().strip() + else: + argument = values + setattr(namespace, self.dest, argument) + + +def split_paths(paths: str, separator=os.pathsep) -> List[str]: + return ( + paths.split(separator) if "://" not in paths else paths.split(MANIFOLD_PATH_SEP) + ) + + +def load_ensemble_for_inference(filenames, task, model_arg_overrides=None): + from fairseq import checkpoint_utils + + deprecation_warning( + "utils.load_ensemble_for_inference is deprecated. " + "Please use checkpoint_utils.load_model_ensemble instead." + ) + return checkpoint_utils.load_model_ensemble( + filenames, arg_overrides=model_arg_overrides, task=task + ) + + +def apply_to_sample(f, sample): + if hasattr(sample, "__len__") and len(sample) == 0: + return {} + + def _apply(x): + if torch.is_tensor(x): + return f(x) + elif isinstance(x, collections.OrderedDict): + # OrderedDict has attributes that needs to be preserved + od = collections.OrderedDict( + (key, _apply(value)) for key, value in x.items() + ) + od.__dict__ = x.__dict__ + return od + elif isinstance(x, dict): + return {key: _apply(value) for key, value in x.items()} + elif isinstance(x, list): + return [_apply(x) for x in x] + elif isinstance(x, tuple): + return tuple(_apply(x) for x in x) + elif isinstance(x, set): + return {_apply(x) for x in x} + else: + return x + + return _apply(sample) + + +def move_to_cuda(sample, device=None): + device = device or torch.cuda.current_device() + + def _move_to_cuda(tensor): + # non_blocking is ignored if tensor is not pinned, so we can always set + # to True (see github.com/PyTorchLightning/pytorch-lightning/issues/620) + return tensor.to(device=device, non_blocking=True) + + return apply_to_sample(_move_to_cuda, sample) + + +def move_to_cpu(sample): + def _move_to_cpu(tensor): + # PyTorch has poor support for half tensors (float16) on CPU. + # Move any such tensors to float32. + if tensor.dtype in {torch.bfloat16, torch.float16}: + tensor = tensor.to(dtype=torch.float32) + return tensor.cpu() + + return apply_to_sample(_move_to_cpu, sample) + + +def move_to_tpu(sample): + + import torch_xla.core.xla_model as xm + + device = xm.xla_device() + + def _move_to_tpu(tensor): + return tensor.to(device) + + return apply_to_sample(_move_to_tpu, sample) + + +def get_incremental_state( + module: "MultiheadAttention", + incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]], + key: str, +) -> Optional[Dict[str, Optional[Tensor]]]: + """Helper for getting incremental state for an nn.Module.""" + return module.get_incremental_state(incremental_state, key) + + +def set_incremental_state( + module: "MultiheadAttention", + incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]], + key: str, + value: Dict[str, Optional[Tensor]], +) -> Optional[Dict[str, Dict[str, Optional[Tensor]]]]: + """Helper for setting incremental state for an nn.Module.""" + if incremental_state is not None: + result = module.set_incremental_state(incremental_state, key, value) + if result is not None: + incremental_state = result + return incremental_state + + +def load_align_dict(replace_unk): + if replace_unk is None: + align_dict = None + elif isinstance(replace_unk, str) and len(replace_unk) > 0: + # Load alignment dictionary for unknown word replacement if it was passed as an argument. + align_dict = {} + with open(replace_unk, "r") as f: + for line in f: + cols = line.split() + align_dict[cols[0]] = cols[1] + else: + # No alignment dictionary provided but we still want to perform unknown word replacement by copying the + # original source word. + align_dict = {} + return align_dict + + +def print_embed_overlap(embed_dict, vocab_dict): + embed_keys = set(embed_dict.keys()) + vocab_keys = set(vocab_dict.symbols) + overlap = len(embed_keys & vocab_keys) + logger.info("found {}/{} types in embedding file".format(overlap, len(vocab_dict))) + + +def parse_embedding(embed_path): + """Parse embedding text file into a dictionary of word and embedding tensors. + + The first line can have vocabulary size and dimension. The following lines + should contain word and embedding separated by spaces. + + Example: + 2 5 + the -0.0230 -0.0264 0.0287 0.0171 0.1403 + at -0.0395 -0.1286 0.0275 0.0254 -0.0932 + """ + embed_dict = {} + with open(embed_path) as f_embed: + next(f_embed) # skip header + for line in f_embed: + pieces = line.rstrip().split(" ") + embed_dict[pieces[0]] = torch.Tensor( + [float(weight) for weight in pieces[1:]] + ) + return embed_dict + + +def load_embedding(embed_dict, vocab, embedding): + for idx in range(len(vocab)): + token = vocab[idx] + if token in embed_dict: + embedding.weight.data[idx] = embed_dict[token] + return embedding + + +def replace_unk(hypo_str, src_str, alignment, align_dict, unk): + from fairseq import tokenizer + + # Tokens are strings here + hypo_tokens = tokenizer.tokenize_line(hypo_str) + # TODO: Very rare cases where the replacement is '' should be handled gracefully + src_tokens = tokenizer.tokenize_line(src_str) + [""] + for i, ht in enumerate(hypo_tokens): + if ht == unk: + src_token = src_tokens[alignment[i]] + # Either take the corresponding value in the aligned dictionary or just copy the original value. + hypo_tokens[i] = align_dict.get(src_token, src_token) + return " ".join(hypo_tokens) + + +def post_process_prediction( + hypo_tokens, + src_str, + alignment, + align_dict, + tgt_dict, + remove_bpe=None, + extra_symbols_to_ignore=None, +): + hypo_str = tgt_dict.string( + hypo_tokens, remove_bpe, extra_symbols_to_ignore=extra_symbols_to_ignore + ) + if align_dict is not None: + hypo_str = replace_unk( + hypo_str, src_str, alignment, align_dict, tgt_dict.unk_string() + ) + if align_dict is not None or remove_bpe is not None: + # Convert back to tokens for evaluating with unk replacement or without BPE + # Note that the dictionary can be modified inside the method. + hypo_tokens = tgt_dict.encode_line(hypo_str, add_if_not_exist=True) + return hypo_tokens, hypo_str, alignment + + +def make_positions(tensor, padding_idx: int, onnx_trace: bool = False): + """Replace non-padding symbols with their position numbers. + + Position numbers begin at padding_idx+1. Padding symbols are ignored. + """ + # The series of casts and type-conversions here are carefully + # balanced to both work with ONNX export and XLA. In particular XLA + # prefers ints, cumsum defaults to output longs, and ONNX doesn't know + # how to handle the dtype kwarg in cumsum. + mask = tensor.ne(padding_idx).int() + return (torch.cumsum(mask, dim=1).type_as(mask) * mask).long() + padding_idx + + +def strip_pad(tensor, pad): + return tensor[tensor.ne(pad)] + + +def buffered_arange(max, device="cpu"): + if not hasattr(buffered_arange, "buf"): + buffered_arange.buf = torch.LongTensor().to(device) + if max > buffered_arange.buf.numel(): + buffered_arange.buf.resize_(max) + torch.arange(max, out=buffered_arange.buf) + return buffered_arange.buf[:max] + + +def convert_padding_direction( + src_tokens, padding_idx, right_to_left: bool = False, left_to_right: bool = False +): + assert right_to_left ^ left_to_right + pad_mask = src_tokens.eq(padding_idx) + if not pad_mask.any(): + # no padding, return early + return src_tokens + if left_to_right and not pad_mask[:, 0].any(): + # already right padded + return src_tokens + if right_to_left and not pad_mask[:, -1].any(): + # already left padded + return src_tokens + max_len = src_tokens.size(1) + buffered = torch.empty(0).long() + if max_len > 0: + torch.arange(max_len, out=buffered) + range = buffered.type_as(src_tokens).expand_as(src_tokens) + num_pads = pad_mask.long().sum(dim=1, keepdim=True) + if right_to_left: + index = torch.remainder(range - num_pads, max_len) + else: + index = torch.remainder(range + num_pads, max_len) + return src_tokens.gather(1, index) + + +def item(tensor): + # tpu-comment: making this a no-op for xla devices. + if torch.is_tensor(tensor) and tensor.device.type == "xla": + return tensor.detach() + if hasattr(tensor, "item"): + return tensor.item() + if hasattr(tensor, "__getitem__"): + return tensor[0] + return tensor + + +def multi_tensor_total_norm(grads, chunk_size=2048 * 32) -> torch.Tensor: + per_device_grads = {} + norms = [] + for grad in grads: + device = grad.device + cur_device_grads = per_device_grads.get(device) + if cur_device_grads is None: + cur_device_grads = [] + per_device_grads[device] = cur_device_grads + cur_device_grads.append(grad) + for device in per_device_grads.keys(): + cur_device_grads = per_device_grads[device] + if device.type == "cuda": + # TODO(msb) return has_inf + has_inf = torch.zeros((1, 1), dtype=torch.int, device=device) + with torch.cuda.device(device): + norm = multi_tensor_l2norm( + chunk_size, has_inf, [cur_device_grads], False + ) + norms.append(norm[0].to(torch.cuda.current_device())) + else: + norms += [torch.norm(g, p=2, dtype=torch.float32) for g in cur_device_grads] + total_norm = torch.norm(torch.stack(norms)) + return total_norm + + +@torch.no_grad() +def clip_grad_norm_(params, max_norm, aggregate_norm_fn=None) -> torch.Tensor: + def grad_exists(p): + return p is not None and getattr(p, "grad", None) is not None + + if isinstance(params, torch.Tensor): + params = [params] + params = list(params) + grads = [ + p.grad.detach() for p in params if grad_exists(p) and not hasattr(p, "expert") + ] + expert_grads = [ + p.grad.detach() for p in params if grad_exists(p) and hasattr(p, "expert") + ] + + if len(grads) == 0: + if len(params) > 0: + return params[0].new_tensor(0.0) + else: + return torch.tensor(0.0) + + if len(grads) == 1: + total_norm = torch.norm(grads[0], p=2, dtype=torch.float32) + else: + if multi_tensor_l2norm_available: + total_norm = multi_tensor_total_norm(grads) + else: + if torch.cuda.is_available(): + warnings.warn( + "amp_C fused kernels unavailable, disabling multi_tensor_l2norm; " + "you may get better performance by installing NVIDIA's apex library" + ) + device = torch.cuda.current_device() + elif grads[0].device.type == "xla": + device = grads[0].device + else: + device = torch.device("cpu") + total_norm = torch.norm( + torch.stack( + [torch.norm(g, p=2, dtype=torch.float32).to(device) for g in grads] + ) + ) + + if aggregate_norm_fn is not None: + total_norm = aggregate_norm_fn(total_norm) + + if max_norm > 0: + max_norm = float(max_norm) + clip_coef = (max_norm / (total_norm + 1e-6)).clamp_(max=1) + torch._foreach_mul_(grads + expert_grads, clip_coef) + + return total_norm + + +def fill_with_neg_inf(t): + """FP16-compatible function that fills a tensor with -inf.""" + return t.float().fill_(float("-inf")).type_as(t) + + +def _match_types(arg1, arg2): + """Convert the numerical argument to the same type as the other argument""" + + def upgrade(arg_number, arg_structure): + if isinstance(arg_structure, tuple): + return tuple([arg_number] * len(arg_structure)) + elif isinstance(arg_structure, dict): + arg = copy.deepcopy(arg_structure) + for k in arg: + arg[k] = upgrade(arg_number, arg_structure[k]) + return arg + else: + return arg_number + + if isinstance(arg1, float) or isinstance(arg1, int): + return upgrade(arg1, arg2), arg2 + elif isinstance(arg2, float) or isinstance(arg2, int): + return arg1, upgrade(arg2, arg1) + + return arg1, arg2 + + +def resolve_max_positions(*args): + """Resolve max position constraints from multiple sources.""" + + def map_value_update(d1, d2): + updated_value = copy.deepcopy(d1) + for key in d2: + if key not in updated_value: + updated_value[key] = d2[key] + else: + updated_value[key] = min(d1[key], d2[key]) + return updated_value + + def nullsafe_min(l): + minim = None + for item in l: + if minim is None: + minim = item + elif item is not None and item < minim: + minim = item + return minim + + max_positions = None + for arg in args: + if max_positions is None: + max_positions = arg + elif arg is not None: + max_positions, arg = _match_types(max_positions, arg) + if isinstance(arg, float) or isinstance(arg, int): + max_positions = min(max_positions, arg) + elif isinstance(arg, dict): + max_positions = map_value_update(max_positions, arg) + else: + max_positions = tuple(map(nullsafe_min, zip(max_positions, arg))) + + return max_positions + + +def import_user_module(args): + module_path = getattr(args, "user_dir", None) + if module_path is not None: + module_path = os.path.abspath(args.user_dir) + if not os.path.exists(module_path) and not os.path.isfile( + os.path.dirname(module_path) + ): + fairseq_rel_path = os.path.join(os.path.dirname(__file__), args.user_dir) + if os.path.exists(fairseq_rel_path): + module_path = fairseq_rel_path + else: + fairseq_rel_path = os.path.join( + os.path.dirname(__file__), "..", args.user_dir + ) + if os.path.exists(fairseq_rel_path): + module_path = fairseq_rel_path + else: + raise FileNotFoundError(module_path) + + # ensure that user modules are only imported once + import_user_module.memo = getattr(import_user_module, "memo", set()) + if module_path not in import_user_module.memo: + import_user_module.memo.add(module_path) + + module_parent, module_name = os.path.split(module_path) + if module_name not in sys.modules: + sys.path.insert(0, module_parent) + importlib.import_module(module_name) + + tasks_path = os.path.join(module_path, "tasks") + if os.path.exists(tasks_path): + from fairseq.tasks import import_tasks + + import_tasks(tasks_path, f"{module_name}.tasks") + + models_path = os.path.join(module_path, "models") + if os.path.exists(models_path): + from fairseq.models import import_models + + import_models(models_path, f"{module_name}.models") + elif module_path in sys.modules[module_name].__path__: + logger.info(f"--user-dir={module_path} has already been imported.") + else: + raise ImportError( + "Failed to import --user-dir={} because the corresponding module name " + "({}) is not globally unique. Please rename the directory to " + "something unique and try again.".format(module_path, module_name) + ) + + +def softmax(x, dim: int, onnx_trace: bool = False): + if onnx_trace: + return F.softmax(x.float(), dim=dim) + else: + return F.softmax(x, dim=dim, dtype=torch.float32) + + +def log_softmax(x, dim: int, onnx_trace: bool = False): + if onnx_trace: + return F.log_softmax(x.float(), dim=dim) + else: + return F.log_softmax(x, dim=dim, dtype=torch.float32) + + +def get_perplexity(loss, round=2, base=2): + from fairseq.logging.meters import safe_round + + if loss is None: + return 0.0 + try: + return safe_round(base**loss, round) + except OverflowError: + return float("inf") + + +def deprecation_warning(message, stacklevel=3): + # don't use DeprecationWarning, since it's ignored by default + warnings.warn(message, stacklevel=stacklevel) + + +def relu_squared(x: torch.Tensor): + return F.relu(x).pow(2) + + +def get_activation_fn(activation: str) -> Callable: + """Returns the activation function corresponding to `activation`""" + from fairseq.modules import gelu, gelu_accurate + + if activation == "relu": + return F.relu + elif activation == "relu_squared": + return relu_squared + elif activation == "gelu": + return gelu + elif activation == "gelu_fast": + deprecation_warning( + "--activation-fn=gelu_fast has been renamed to gelu_accurate" + ) + return gelu_accurate + elif activation == "gelu_accurate": + return gelu_accurate + elif activation == "tanh": + return torch.tanh + elif activation == "linear": + return lambda x: x + elif activation == "swish": + return torch.nn.SiLU + else: + raise RuntimeError("--activation-fn {} not supported".format(activation)) + + +def get_available_activation_fns() -> List: + return [ + "relu", + "gelu", + "gelu_fast", # deprecated + "gelu_accurate", + "tanh", + "linear", + ] + + +@contextlib.contextmanager +def model_eval(model): + is_training = model.training + model.eval() + yield + model.train(is_training) + + +def has_parameters(module): + try: + next(module.parameters()) + return True + except StopIteration: + return False + + +def get_rng_state(): + state = {"torch_rng_state": torch.get_rng_state()} + if xm is not None: + state["xla_rng_state"] = xm.get_rng_state() + if torch.cuda.is_available(): + state["cuda_rng_state"] = torch.cuda.get_rng_state() + return state + + +def set_rng_state(state): + torch.set_rng_state(state["torch_rng_state"]) + if xm is not None: + xm.set_rng_state(state["xla_rng_state"]) + if torch.cuda.is_available(): + torch.cuda.set_rng_state(state["cuda_rng_state"]) + + +class set_torch_seed(object): + def __init__(self, seed): + assert isinstance(seed, int) + self.rng_state = get_rng_state() + + torch.manual_seed(seed) + if xm is not None: + xm.set_rng_state(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + + def __enter__(self): + return self + + def __exit__(self, *exc): + set_rng_state(self.rng_state) + + +def parse_alignment(line): + """ + Parses a single line from the alingment file. + + Args: + line (str): String containing the alignment of the format: + - - .. + -. All indices are 0 indexed. + + Returns: + torch.IntTensor: packed alignments of shape (2 * m). + """ + alignments = line.strip().split() + parsed_alignment = torch.IntTensor(2 * len(alignments)) + for idx, alignment in enumerate(alignments): + src_idx, tgt_idx = alignment.split("-") + parsed_alignment[2 * idx] = int(src_idx) + parsed_alignment[2 * idx + 1] = int(tgt_idx) + return parsed_alignment + + +def get_token_to_word_mapping(tokens, exclude_list): + n = len(tokens) + word_start = [int(token not in exclude_list) for token in tokens] + word_idx = list(accumulate(word_start)) + token_to_word = {i: word_idx[i] for i in range(n)} + return token_to_word + + +def extract_hard_alignment(attn, src_sent, tgt_sent, pad, eos): + tgt_valid = ( + ((tgt_sent != pad) & (tgt_sent != eos)).nonzero(as_tuple=False).squeeze(dim=-1) + ) + src_invalid = ( + ((src_sent == pad) | (src_sent == eos)).nonzero(as_tuple=False).squeeze(dim=-1) + ) + src_token_to_word = get_token_to_word_mapping(src_sent, [eos, pad]) + tgt_token_to_word = get_token_to_word_mapping(tgt_sent, [eos, pad]) + alignment = [] + if len(tgt_valid) != 0 and len(src_invalid) < len(src_sent): + attn_valid = attn[tgt_valid] + attn_valid[:, src_invalid] = float("-inf") + _, src_indices = attn_valid.max(dim=1) + for tgt_idx, src_idx in zip(tgt_valid, src_indices): + alignment.append( + ( + src_token_to_word[src_idx.item()] - 1, + tgt_token_to_word[tgt_idx.item()] - 1, + ) + ) + return alignment + + +def extract_soft_alignment(attn, src_sent, tgt_sent, pad, eos): + tgt_valid = ((tgt_sent != pad)).nonzero(as_tuple=False) + src_valid = ((src_sent != pad)).nonzero(as_tuple=False).squeeze(dim=-1) + alignment = [] + if len(tgt_valid) != 0 and len(src_valid) != 0: + attn_valid = attn[tgt_valid, src_valid] + alignment = [ + ["{:.6f}".format(p) for p in src_probs.tolist()] for src_probs in attn_valid + ] + return alignment + + +def new_arange(x, *size): + """ + Return a Tensor of `size` filled with a range function on the device of x. + If size is empty, using the size of the variable x. + """ + if len(size) == 0: + size = x.size() + return torch.arange(size[-1], device=x.device).expand(*size).contiguous() + + +def get_tpu_device(): + return xm.xla_device() + + +def tpu_data_loader(itr): + import torch_xla.core.xla_model as xm + import torch_xla.distributed.parallel_loader as pl + + from fairseq.data import iterators + + xm.rendezvous("tpu_data_loader") # wait for all workers + xm.mark_step() + device = xm.xla_device() + return iterators.CountingIterator( + pl.ParallelLoader(itr, [device]).per_device_loader(device), + start=getattr(itr, "n", 0), + total=len(itr), + ) + + +def is_xla_tensor(tensor): + return torch.is_tensor(tensor) and tensor.device.type == "xla" + + +def index_put(tensor, indices, value): + if is_xla_tensor(tensor): + for _ in range(indices.dim(), tensor.dim()): + indices = indices.unsqueeze(-1) + if indices.size(-1) < tensor.size(-1): + indices = indices.expand_as(tensor) + tensor = torch.mul(tensor, ~indices) + torch.mul(value, indices) + else: + tensor[indices] = value + return tensor + + +def xla_device_to_cpu(dat): + import torch_xla.core.xla_model as xm + + return xm._maybe_convert_to_cpu(dat) + + +class CudaEnvironment(object): + def __init__(self): + cur_device = torch.cuda.current_device() + prop = torch.cuda.get_device_properties("cuda:{}".format(cur_device)) + self.name = prop.name + self.major = prop.major + self.minor = prop.minor + self.total_memory_in_GB = prop.total_memory / 1024 / 1024 / 1024 + + @staticmethod + def pretty_print_cuda_env_list(cuda_env_list): + """ + Given a list of CudaEnviorments, pretty print them + """ + num_workers = len(cuda_env_list) + center = "CUDA enviroments for all {} workers".format(num_workers) + banner_len = 40 - len(center) // 2 + first_line = "*" * banner_len + center + "*" * banner_len + logger.info(first_line) + for r, env in enumerate(cuda_env_list): + logger.info( + "rank {:3d}: ".format(r) + + "capabilities = {:2d}.{:<2d} ; ".format(env.major, env.minor) + + "total memory = {:.3f} GB ; ".format(env.total_memory_in_GB) + + "name = {:40s}".format(env.name) + ) + logger.info(first_line) + + +def csv_str_list(x): + return x.split(",") + + +def eval_str_list(x, type=float): + if x is None: + return None + if isinstance(x, str): + x = eval(x) + try: + return list(map(type, x)) + except TypeError: + return [type(x)] + + +def eval_str_dict(x, type=dict): + if x is None: + return None + if isinstance(x, str): + x = eval(x) + return x + + +def eval_bool(x, default=False): + if x is None: + return default + try: + return bool(eval(x)) + except TypeError: + return default + + +def reset_logging(): + root = logging.getLogger() + for handler in root.handlers: + root.removeHandler(handler) + root.setLevel(os.environ.get("LOGLEVEL", "INFO").upper()) + handler = logging.StreamHandler(sys.stdout) + handler.setFormatter( + logging.Formatter( + fmt="%(asctime)s | %(levelname)s | %(name)s | %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + ) + ) + root.addHandler(handler) + + +def safe_getattr(obj, k, default=None): + """Returns obj[k] if it exists and is not None, otherwise returns default.""" + from omegaconf import OmegaConf + + if OmegaConf.is_config(obj): + return obj[k] if k in obj and obj[k] is not None else default + + return getattr(obj, k, default) + + +def safe_hasattr(obj, k): + """Returns True if the given key exists and is not None.""" + return getattr(obj, k, None) is not None + + +def hotreload_function(name=None): + """ + Decorator to function to enable hot-reload for debugging. + It allows you to debug a function without having reloading all heavy models, dataset loading and + preprocessing, allow faster debugging. + If you want to change model or dataset loading, consider relaunching your code + ----------------------------------- + This will run the decorated function func: + if func run successful: + It will pause, allow user to edit code, and prompt user to: + Press enter to re-run the function with updated code + Type "done" to finish the function, return output + Type "disable" to stop pausing this function and let code continue without pause + Ctril + C to terminal + if func raise error: + it will prompt user to + 1. Edit code, and press enter to retry + 2. Ctrl + C to terminate + 3. Type "raise" to raise that exception + * Requirements: + 0. Fairseq was installed with `pip install --editable .` + 1. pip install jurigged[develoop] + 2. set environment HOTRELOAD_PAUSE=1 CUDA_LAUNCH_BLOCKING=1 + 3. Run on only 1 GPU (no distributed) + * How to use: + 1. in python, import and decorate the top-level function to be re-run after code edits: + ```python + from fairseq.utils import hotreload_function + .... + @hotreload_function("train_step") + def train_step(self, sample ....): + .... + .... + ``` + 2. in bash run scripts: + ```bash + watch_dir=/fairseq-py/fairseq/tasks # directory to watch for file changes + export CUDA_VISIBLE_DEVICES=0 # single-gpu + HOTRELOAD_PAUSE=1 CUDA_LAUNCH_BLOCKING=1 python -m jurigged -w ${watch_dir} --poll 2 -v train.py ...... + ``` + * NOTE: + 1. -w ${watch_dir} specify all the files to be watched for changes + once functions, class, ... code are changed, all instances in the process will get updated (hot-reload) + * Limitation: + * Currently distributed debugging not working + * Need to launch train.py locally (cannot submit jobs) + """ + try: + import jurigged + except ImportError as e: + logger.warning("Please install jurigged: pip install jurigged[develoop]") + raise e + from fairseq.distributed import utils as distributed_utils + import traceback + + def hotreload_decorator(func): + assert callable(func), f"not callable: {func}" + jname = name or func.__name__ + logger.info(f"jurigged-hotreload:Apply jurigged on {jname}:{func.__name__}") + HOTRELOAD_PAUSE = bool(os.environ.get("HOTRELOAD_PAUSE", 0)) + cublk = bool(os.environ.get("CUDA_LAUNCH_BLOCKING", 0)) + prefix = f"HOTRELOAD:{jname}:[cublk={cublk}]" + hot_reload_state = {"disable": False} + + def func_wrapper(*args, **kwargs): + if not HOTRELOAD_PAUSE or hot_reload_state["disable"]: + return func(*args, **kwargs) + world_size = distributed_utils.get_global_world_size() + assert ( + world_size <= 1 + ), f"HOTRELOAD_PAUSE:{jname} currently cannot do distributed training" + success = False + while not success: + try: + output = func(*args, **kwargs) + # success = True + end_action = input( + f"{prefix}: PAUSE, you may edit code now. Enter to re-run, ctrl+C to terminate, " + f'type "done" to continue (function still being watched), or type "disable" to stop pausing this function :' + ) + if end_action.strip().lower() in ["disable", "done"]: + success = True + else: + logger.warning( + f"{prefix}: action={end_action} function will re-run now." + ) + except Exception as e: + action = input( + f"{prefix}:ERROR: \n{traceback.format_exc()}\n" + f'Edit code to try again: enter to continue, ctrl+C to terminate, or type "raise" to raise the exception: ' + ) + if action.strip().lower() == "raise": + raise e + + if end_action.strip().lower() == "disable": + logger.warning( + f"{prefix}: Stop pausing {jname}. The function is still being watched and newly editted code will take effect " + f"if the {jname} is called again later." + f' "unset HOTRELOAD_PAUSE" before relaunch to disable hotreload and' + f" remove @hotreload_function decorator in the code." + ) + hot_reload_state["disable"] = True + return output + + return func_wrapper + + return hotreload_decorator diff --git a/fairseq/fairseq/version.py b/fairseq/fairseq/version.py new file mode 100644 index 0000000000000000000000000000000000000000..76da4a9882c63454f7f915ed547854f52ae38e8f --- /dev/null +++ b/fairseq/fairseq/version.py @@ -0,0 +1 @@ +__version__ = "0.12.2" diff --git a/fairseq/fairseq/version.txt b/fairseq/fairseq/version.txt new file mode 100644 index 0000000000000000000000000000000000000000..26acbf080be051b441bc144e358859396d9133cc --- /dev/null +++ b/fairseq/fairseq/version.txt @@ -0,0 +1 @@ +0.12.2