PyTorch
ssl-aasist
custom_code
ash56 commited on
Commit
4d0021e
·
verified ·
1 Parent(s): 7b2dad9

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +1 -0
  2. fairseq/examples/textless_nlp/speech-resynth/img/fig.png +3 -0
  3. fairseq/fairseq/benchmark/__init__.py +7 -0
  4. fairseq/fairseq/benchmark/__pycache__/__init__.cpython-310.pyc +0 -0
  5. fairseq/fairseq/benchmark/__pycache__/dummy_dataset.cpython-310.pyc +0 -0
  6. fairseq/fairseq/benchmark/__pycache__/dummy_lm.cpython-310.pyc +0 -0
  7. fairseq/fairseq/benchmark/__pycache__/dummy_masked_lm.cpython-310.pyc +0 -0
  8. fairseq/fairseq/benchmark/__pycache__/dummy_model.cpython-310.pyc +0 -0
  9. fairseq/fairseq/benchmark/__pycache__/dummy_mt.cpython-310.pyc +0 -0
  10. fairseq/fairseq/benchmark/benchmark_multihead_attention.py +172 -0
  11. fairseq/fairseq/benchmark/dummy_dataset.py +36 -0
  12. fairseq/fairseq/benchmark/dummy_lm.py +83 -0
  13. fairseq/fairseq/benchmark/dummy_masked_lm.py +94 -0
  14. fairseq/fairseq/benchmark/dummy_model.py +96 -0
  15. fairseq/fairseq/benchmark/dummy_mt.py +119 -0
  16. fairseq/fairseq/clib/cuda/ngram_repeat_block_cuda.cpp +55 -0
  17. fairseq/fairseq/clib/cuda/ngram_repeat_block_cuda_kernel.cu +82 -0
  18. fairseq/fairseq/clib/libbase/balanced_assignment.cpp +109 -0
  19. fairseq/fairseq/clib/libbleu/libbleu.cpp +157 -0
  20. fairseq/fairseq/clib/libbleu/module.cpp +33 -0
  21. fairseq/fairseq/clib/libnat/edit_dist.cpp +231 -0
  22. fairseq/fairseq/clib/libnat_cuda/binding.cpp +67 -0
  23. fairseq/fairseq/clib/libnat_cuda/edit_dist.cu +344 -0
  24. fairseq/fairseq/clib/libnat_cuda/edit_dist.h +25 -0
  25. fairseq/fairseq/config/__init__.py +4 -0
  26. fairseq/fairseq/config/config.yaml +19 -0
  27. fairseq/fairseq/config/fb_run_config/slurm.yaml +29 -0
  28. fairseq/fairseq/config/model/transformer_lm/transformer_lm_baevski_gbw.yaml +36 -0
  29. fairseq/fairseq/config/model/transformer_lm/transformer_lm_baevski_wiki103.yaml +36 -0
  30. fairseq/fairseq/config/model/transformer_lm/transformer_lm_big.yaml +36 -0
  31. fairseq/fairseq/config/model/transformer_lm/transformer_lm_gbw.yaml +36 -0
  32. fairseq/fairseq/config/model/transformer_lm/transformer_lm_gpt.yaml +36 -0
  33. fairseq/fairseq/config/model/transformer_lm/transformer_lm_gpt2_big.yaml +36 -0
  34. fairseq/fairseq/config/model/transformer_lm/transformer_lm_gpt2_medium.yaml +36 -0
  35. fairseq/fairseq/config/model/transformer_lm/transformer_lm_gpt2_small.yaml +36 -0
  36. fairseq/fairseq/config/model/transformer_lm/transformer_lm_wiki103.yaml +36 -0
  37. fairseq/fairseq/config/model/wav2vec/vq_wav2vec_gumbel.yaml +5 -0
  38. fairseq/fairseq/config/model/wav2vec2/wav2vec2_base.yaml +8 -0
  39. fairseq/fairseq/config/model/wav2vec2/wav2vec2_large.yaml +20 -0
  40. fairseq/fairseq/criterions/__init__.py +36 -0
  41. fairseq/fairseq/criterions/__pycache__/fairseq_criterion.cpython-310.pyc +0 -0
  42. fairseq/fairseq/criterions/adaptive_loss.py +124 -0
  43. fairseq/fairseq/criterions/composite_loss.py +100 -0
  44. fairseq/fairseq/criterions/cross_entropy.py +91 -0
  45. fairseq/fairseq/criterions/ctc.py +325 -0
  46. fairseq/fairseq/criterions/fairseq_criterion.py +121 -0
  47. fairseq/fairseq/criterions/fastspeech2_loss.py +137 -0
  48. fairseq/fairseq/criterions/hubert_criterion.py +195 -0
  49. fairseq/fairseq/criterions/label_smoothed_cross_entropy.py +168 -0
  50. fairseq/fairseq/criterions/label_smoothed_cross_entropy_latency_augmented.py +221 -0
.gitattributes CHANGED
@@ -39,3 +39,4 @@ fairseq/alignment_train_cuda_binding.cpython-310-x86_64-linux-gnu.so filter=lfs
39
  fairseq/alignment_train_cpu_binding.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
40
  fairseq/docs/fairseq.gif filter=lfs diff=lfs merge=lfs -text
41
  fairseq/examples/hubert/tests/6313-76958-0021.flac filter=lfs diff=lfs merge=lfs -text
 
 
39
  fairseq/alignment_train_cpu_binding.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
40
  fairseq/docs/fairseq.gif filter=lfs diff=lfs merge=lfs -text
41
  fairseq/examples/hubert/tests/6313-76958-0021.flac filter=lfs diff=lfs merge=lfs -text
42
+ fairseq/examples/textless_nlp/speech-resynth/img/fig.png filter=lfs diff=lfs merge=lfs -text
fairseq/examples/textless_nlp/speech-resynth/img/fig.png ADDED

Git LFS Details

  • SHA256: c19c570f3671d88551f5d5b908e270e69bd75d304c5d358868fa19f342979c17
  • Pointer size: 131 Bytes
  • Size of remote file: 308 kB
fairseq/fairseq/benchmark/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ # import models/tasks to register them
7
+ from . import dummy_dataset, dummy_lm, dummy_masked_lm, dummy_model, dummy_mt # noqa
fairseq/fairseq/benchmark/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (335 Bytes). View file
 
fairseq/fairseq/benchmark/__pycache__/dummy_dataset.cpython-310.pyc ADDED
Binary file (1.79 kB). View file
 
fairseq/fairseq/benchmark/__pycache__/dummy_lm.cpython-310.pyc ADDED
Binary file (3.18 kB). View file
 
fairseq/fairseq/benchmark/__pycache__/dummy_masked_lm.cpython-310.pyc ADDED
Binary file (3.39 kB). View file
 
fairseq/fairseq/benchmark/__pycache__/dummy_model.cpython-310.pyc ADDED
Binary file (3.48 kB). View file
 
fairseq/fairseq/benchmark/__pycache__/dummy_mt.cpython-310.pyc ADDED
Binary file (4.56 kB). View file
 
fairseq/fairseq/benchmark/benchmark_multihead_attention.py ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import itertools
7
+ import random
8
+
9
+ import torch
10
+ from torch.utils import benchmark
11
+
12
+ from fairseq.modules.multihead_attention import MultiheadAttention
13
+
14
+ BATCH = [20, 41, 97]
15
+ SEQ = 64
16
+ EMB = 48
17
+ HEADS = 4
18
+ DROP = 0.1
19
+ DEVICE = torch.device("cuda")
20
+ ATTN_MASK_DTYPE = [torch.uint8, torch.bool, torch.float]
21
+ KEY_PADDING_MASK_DTYPE = [torch.uint8, torch.bool]
22
+
23
+
24
+ def _reset_seeds():
25
+ torch.manual_seed(0)
26
+ random.seed(0)
27
+
28
+
29
+ def _get_mask(to_dtype: torch.dtype, dim0: int, dim1: int):
30
+ if to_dtype == torch.float:
31
+ mask = torch.randint(0, 2, (dim0, dim1)).to(dtype=torch.bool)
32
+ return mask.to(dtype=to_dtype).masked_fill(mask, -float("inf"))
33
+ return torch.randint(0, 2, (dim0, dim1)).to(dtype=to_dtype)
34
+
35
+
36
+ def benchmark_multihead_attention(
37
+ label="",
38
+ attn_dtype=torch.uint8,
39
+ key_padding_dtype=torch.uint8,
40
+ add_bias_kv=False,
41
+ add_zero_attn=False,
42
+ static_kv=False,
43
+ batch_size=20,
44
+ embedding=EMB,
45
+ seq_len=SEQ,
46
+ num_heads=HEADS,
47
+ ):
48
+
49
+ results = []
50
+ # device = torch.device("cuda")
51
+
52
+ xformers_att_config = '{"name": "scaled_dot_product"}'
53
+
54
+ attn_mask = _get_mask(to_dtype=attn_dtype, dim0=seq_len, dim1=seq_len)
55
+ key_padding_mask = _get_mask(
56
+ to_dtype=key_padding_dtype, dim0=batch_size, dim1=seq_len
57
+ )
58
+
59
+ q = torch.rand(seq_len, batch_size, embedding, requires_grad=True)
60
+ k = torch.rand(seq_len, batch_size, embedding, requires_grad=True)
61
+ v = torch.rand(seq_len, batch_size, embedding, requires_grad=True)
62
+
63
+ _reset_seeds()
64
+
65
+ original_mha = MultiheadAttention(
66
+ embedding,
67
+ num_heads,
68
+ dropout=0.0,
69
+ xformers_att_config=None,
70
+ add_bias_kv=add_bias_kv,
71
+ add_zero_attn=add_zero_attn,
72
+ )
73
+
74
+ xformers_mha = MultiheadAttention(
75
+ embedding,
76
+ num_heads,
77
+ dropout=0.0,
78
+ xformers_att_config=xformers_att_config,
79
+ add_bias_kv=add_bias_kv,
80
+ add_zero_attn=add_zero_attn,
81
+ )
82
+
83
+ def original_bench_fw(q, k, v, key_padding_mask, attn_mask, static_kv):
84
+ original_mha(
85
+ query=q,
86
+ key=k,
87
+ value=v,
88
+ key_padding_mask=key_padding_mask,
89
+ attn_mask=attn_mask,
90
+ static_kv=static_kv,
91
+ )
92
+
93
+ def xformers_bench_fw(q, k, v, key_padding_mask, attn_mask, static_kv):
94
+ xformers_mha(
95
+ query=q,
96
+ key=k,
97
+ value=v,
98
+ key_padding_mask=key_padding_mask,
99
+ attn_mask=attn_mask,
100
+ static_kv=static_kv,
101
+ )
102
+
103
+ def original_bench_fw_bw(q, k, v, key_padding_mask, attn_mask, static_kv):
104
+ output, _ = original_mha(
105
+ query=q,
106
+ key=k,
107
+ value=v,
108
+ key_padding_mask=key_padding_mask,
109
+ attn_mask=attn_mask,
110
+ static_kv=static_kv,
111
+ )
112
+ loss = torch.norm(output)
113
+ loss.backward()
114
+
115
+ def xformers_bench_fw_bw(q, k, v, key_padding_mask, attn_mask, static_kv):
116
+ output, _ = xformers_mha(
117
+ query=q,
118
+ key=k,
119
+ value=v,
120
+ key_padding_mask=key_padding_mask,
121
+ attn_mask=attn_mask,
122
+ static_kv=static_kv,
123
+ )
124
+ loss = torch.norm(output)
125
+ loss.backward()
126
+
127
+ fns = [
128
+ original_bench_fw,
129
+ xformers_bench_fw,
130
+ original_bench_fw_bw,
131
+ xformers_bench_fw_bw,
132
+ ]
133
+
134
+ for fn in fns:
135
+ results.append(
136
+ benchmark.Timer(
137
+ stmt="fn(q, k, v, key_padding_mask, attn_mask, static_kv)",
138
+ globals={
139
+ "q": q,
140
+ "k": k,
141
+ "v": v,
142
+ "key_padding_mask": key_padding_mask,
143
+ "attn_mask": attn_mask,
144
+ "static_kv": static_kv,
145
+ "fn": fn,
146
+ },
147
+ label="multihead fw + bw",
148
+ sub_label=f"{fn.__name__}",
149
+ description=label,
150
+ ).blocked_autorange(min_run_time=1)
151
+ )
152
+
153
+ compare = benchmark.Compare(results)
154
+ compare.print()
155
+
156
+
157
+ def run_benchmarks():
158
+ for attn_dtype, key_padding_dtype, add_bias_kv, add_zero_attn in itertools.product(
159
+ ATTN_MASK_DTYPE, KEY_PADDING_MASK_DTYPE, [True, False], [True, False]
160
+ ):
161
+ label = f"attn_dtype {attn_dtype}, key_padding_dtype {key_padding_dtype}, \
162
+ add_bias_kv {add_bias_kv}, add_zero_attn {add_zero_attn}"
163
+ benchmark_multihead_attention(
164
+ label=label,
165
+ attn_dtype=attn_dtype,
166
+ key_padding_dtype=key_padding_dtype,
167
+ add_bias_kv=add_bias_kv,
168
+ add_zero_attn=add_zero_attn,
169
+ )
170
+
171
+
172
+ run_benchmarks()
fairseq/fairseq/benchmark/dummy_dataset.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from fairseq.data import FairseqDataset
3
+
4
+
5
+ class DummyDataset(FairseqDataset):
6
+ def __init__(self, batch, num_items, item_size):
7
+ super().__init__()
8
+ self.batch = batch
9
+ self.num_items = num_items
10
+ self.item_size = item_size
11
+
12
+ def __getitem__(self, index):
13
+ return index
14
+
15
+ def __len__(self):
16
+ return self.num_items
17
+
18
+ def collater(self, samples):
19
+ return self.batch
20
+
21
+ @property
22
+ def sizes(self):
23
+ return np.array([self.item_size] * self.num_items)
24
+
25
+ def num_tokens(self, index):
26
+ return self.item_size
27
+
28
+ def size(self, index):
29
+ return self.item_size
30
+
31
+ def ordered_indices(self):
32
+ return np.arange(self.num_items)
33
+
34
+ @property
35
+ def supports_prefetch(self):
36
+ return False
fairseq/fairseq/benchmark/dummy_lm.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import logging
7
+ from dataclasses import dataclass, field
8
+ from typing import Optional
9
+
10
+ import torch
11
+ from .dummy_dataset import DummyDataset
12
+ from fairseq.data import Dictionary
13
+ from fairseq.dataclass import FairseqDataclass
14
+ from fairseq.tasks import FairseqTask, register_task
15
+ from omegaconf import II
16
+
17
+
18
+ logger = logging.getLogger(__name__)
19
+
20
+
21
+ @dataclass
22
+ class DummyLMConfig(FairseqDataclass):
23
+ dict_size: int = 49996
24
+ dataset_size: int = 100000
25
+ tokens_per_sample: int = field(
26
+ default=512, metadata={"help": "max sequence length"}
27
+ )
28
+ add_bos_token: bool = False
29
+ batch_size: Optional[int] = II("dataset.batch_size")
30
+ max_tokens: Optional[int] = II("dataset.max_tokens")
31
+ max_target_positions: int = II("task.tokens_per_sample")
32
+
33
+
34
+ @register_task("dummy_lm", dataclass=DummyLMConfig)
35
+ class DummyLMTask(FairseqTask):
36
+ def __init__(self, cfg: DummyLMConfig):
37
+ super().__init__(cfg)
38
+
39
+ # load dictionary
40
+ self.dictionary = Dictionary()
41
+ for i in range(cfg.dict_size):
42
+ self.dictionary.add_symbol("word{}".format(i))
43
+ self.dictionary.pad_to_multiple_(8) # often faster if divisible by 8
44
+ logger.info("dictionary: {} types".format(len(self.dictionary)))
45
+
46
+ seq = torch.arange(cfg.tokens_per_sample + 1) + self.dictionary.pad() + 1
47
+
48
+ self.dummy_src = seq[:-1]
49
+ self.dummy_tgt = seq[1:]
50
+
51
+ def load_dataset(self, split, epoch=1, combine=False, **kwargs):
52
+ """Load a given dataset split.
53
+ Args:
54
+ split (str): name of the split (e.g., train, valid, test)
55
+ """
56
+ if self.cfg.batch_size is not None:
57
+ bsz = self.cfg.batch_size
58
+ else:
59
+ bsz = max(1, self.cfg.max_tokens // self.cfg.tokens_per_sample)
60
+ self.datasets[split] = DummyDataset(
61
+ {
62
+ "id": 1,
63
+ "net_input": {
64
+ "src_tokens": torch.stack([self.dummy_src for _ in range(bsz)]),
65
+ "src_lengths": torch.full(
66
+ (bsz,), self.cfg.tokens_per_sample, dtype=torch.long
67
+ ),
68
+ },
69
+ "target": torch.stack([self.dummy_tgt for _ in range(bsz)]),
70
+ "nsentences": bsz,
71
+ "ntokens": bsz * self.cfg.tokens_per_sample,
72
+ },
73
+ num_items=self.cfg.dataset_size,
74
+ item_size=self.cfg.tokens_per_sample,
75
+ )
76
+
77
+ @property
78
+ def source_dictionary(self):
79
+ return self.dictionary
80
+
81
+ @property
82
+ def target_dictionary(self):
83
+ return self.dictionary
fairseq/fairseq/benchmark/dummy_masked_lm.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import logging
7
+ from dataclasses import dataclass, field
8
+ from typing import Optional
9
+
10
+ import torch
11
+ from omegaconf import II
12
+
13
+ from .dummy_dataset import DummyDataset
14
+ from fairseq.data import Dictionary
15
+ from fairseq.dataclass import FairseqDataclass
16
+ from fairseq.tasks import FairseqTask, register_task
17
+
18
+ logger = logging.getLogger(__name__)
19
+
20
+
21
+ @dataclass
22
+ class DummyMaskedLMConfig(FairseqDataclass):
23
+ dict_size: int = 49996
24
+ dataset_size: int = 100000
25
+ tokens_per_sample: int = field(
26
+ default=512,
27
+ metadata={
28
+ "help": "max number of total tokens over all"
29
+ " segments per sample for BERT dataset"
30
+ },
31
+ )
32
+ batch_size: Optional[int] = II("dataset.batch_size")
33
+ max_tokens: Optional[int] = II("dataset.max_tokens")
34
+ max_target_positions: int = II("task.tokens_per_sample")
35
+
36
+
37
+ @register_task("dummy_masked_lm", dataclass=DummyMaskedLMConfig)
38
+ class DummyMaskedLMTask(FairseqTask):
39
+ def __init__(self, cfg: DummyMaskedLMConfig):
40
+ super().__init__(cfg)
41
+
42
+ self.dictionary = Dictionary()
43
+ for i in range(cfg.dict_size):
44
+ self.dictionary.add_symbol("word{}".format(i))
45
+ logger.info("dictionary: {} types".format(len(self.dictionary)))
46
+ # add mask token
47
+ self.mask_idx = self.dictionary.add_symbol("<mask>")
48
+ self.dictionary.pad_to_multiple_(8) # often faster if divisible by 8
49
+
50
+ mask_idx = 0
51
+ pad_idx = 1
52
+ seq = torch.arange(cfg.tokens_per_sample) + pad_idx + 1
53
+ mask = torch.arange(2, cfg.tokens_per_sample, 7) # ~15%
54
+ src = seq.clone()
55
+ src[mask] = mask_idx
56
+ tgt = torch.full_like(seq, pad_idx)
57
+ tgt[mask] = seq[mask]
58
+
59
+ self.dummy_src = src
60
+ self.dummy_tgt = tgt
61
+
62
+ def load_dataset(self, split, epoch=1, combine=False, **kwargs):
63
+ """Load a given dataset split.
64
+ Args:
65
+ split (str): name of the split (e.g., train, valid, test)
66
+ """
67
+ if self.cfg.batch_size is not None:
68
+ bsz = self.cfg.batch_size
69
+ else:
70
+ bsz = max(1, self.cfg.max_tokens // self.cfg.tokens_per_sample)
71
+ self.datasets[split] = DummyDataset(
72
+ {
73
+ "id": 1,
74
+ "net_input": {
75
+ "src_tokens": torch.stack([self.dummy_src for _ in range(bsz)]),
76
+ "src_lengths": torch.full(
77
+ (bsz,), self.cfg.tokens_per_sample, dtype=torch.long
78
+ ),
79
+ },
80
+ "target": torch.stack([self.dummy_tgt for _ in range(bsz)]),
81
+ "nsentences": bsz,
82
+ "ntokens": bsz * self.cfg.tokens_per_sample,
83
+ },
84
+ num_items=self.cfg.dataset_size,
85
+ item_size=self.cfg.tokens_per_sample,
86
+ )
87
+
88
+ @property
89
+ def source_dictionary(self):
90
+ return self.dictionary
91
+
92
+ @property
93
+ def target_dictionary(self):
94
+ return self.dictionary
fairseq/fairseq/benchmark/dummy_model.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ from fairseq.data import Dictionary
9
+ from fairseq.models import (
10
+ FairseqDecoder,
11
+ FairseqLanguageModel,
12
+ register_model,
13
+ register_model_architecture,
14
+ )
15
+
16
+
17
+ @register_model("dummy_model")
18
+ class DummyModel(FairseqLanguageModel):
19
+ def __init__(self, args, encoder):
20
+ super().__init__(encoder)
21
+ self.args = args
22
+
23
+ @staticmethod
24
+ def add_args(parser):
25
+ parser.add_argument("--num-layers", type=int, default=24)
26
+ parser.add_argument("--embed-dim", type=int, default=1024)
27
+
28
+ @classmethod
29
+ def build_model(cls, args, task):
30
+ encoder = DummyEncoder(
31
+ num_embed=len(task.target_dictionary),
32
+ embed_dim=args.embed_dim,
33
+ num_layers=args.num_layers,
34
+ )
35
+ return cls(args, encoder)
36
+
37
+ def forward(self, src_tokens, masked_tokens=None, **kwargs):
38
+ return self.decoder(src_tokens, masked_tokens=masked_tokens)
39
+
40
+
41
+ class DummyEncoder(FairseqDecoder):
42
+ def __init__(self, num_embed=50000, embed_dim=1024, num_layers=24):
43
+ super().__init__(Dictionary())
44
+ self.embed = nn.Embedding(
45
+ num_embeddings=num_embed, embedding_dim=embed_dim, padding_idx=0
46
+ )
47
+ self.layers_a = nn.ModuleList(
48
+ [
49
+ nn.Sequential(
50
+ nn.LayerNorm(embed_dim),
51
+ nn.Linear(embed_dim, 3 * embed_dim), # q, k, v input projection
52
+ nn.Linear(3 * embed_dim, embed_dim), # skip self-attention
53
+ nn.Linear(embed_dim, embed_dim), # output projection
54
+ nn.Dropout(),
55
+ )
56
+ for i in range(num_layers)
57
+ ]
58
+ )
59
+ self.layers_b = nn.ModuleList(
60
+ [
61
+ nn.Sequential(
62
+ nn.LayerNorm(embed_dim),
63
+ nn.Linear(embed_dim, 4 * embed_dim), # FFN
64
+ nn.ReLU(),
65
+ nn.Linear(4 * embed_dim, embed_dim), # FFN
66
+ nn.Dropout(0.1),
67
+ )
68
+ for i in range(num_layers)
69
+ ]
70
+ )
71
+ self.out_proj = nn.Linear(embed_dim, num_embed)
72
+
73
+ def forward(self, tokens, masked_tokens=None):
74
+ x = self.embed(tokens)
75
+ for layer_a, layer_b in zip(self.layers_a, self.layers_b):
76
+ x = x + layer_a(x)
77
+ x = x + layer_b(x)
78
+ x = self.out_proj(x)
79
+ if masked_tokens is not None:
80
+ x = x[masked_tokens]
81
+ return (x,)
82
+
83
+ def max_positions(self):
84
+ return 1024
85
+
86
+ def get_normalized_probs(self, net_output, log_probs, sample=None):
87
+ logits = net_output[0].float()
88
+ if log_probs:
89
+ return F.log_softmax(logits, dim=-1)
90
+ else:
91
+ return F.softmax(logits, dim=-1)
92
+
93
+
94
+ @register_model_architecture("dummy_model", "dummy_model")
95
+ def base_architecture(args):
96
+ pass
fairseq/fairseq/benchmark/dummy_mt.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import logging
7
+
8
+ import numpy as np
9
+ import torch
10
+
11
+ from fairseq.data import Dictionary, FairseqDataset
12
+ from fairseq.tasks import LegacyFairseqTask, register_task
13
+
14
+ logger = logging.getLogger(__name__)
15
+
16
+
17
+ @register_task("dummy_mt")
18
+ class DummyMTTask(LegacyFairseqTask):
19
+ @staticmethod
20
+ def add_args(parser):
21
+ """Add task-specific arguments to the parser."""
22
+ parser.add_argument("--dict-size", default=49996, type=int)
23
+ parser.add_argument("--dataset-size", default=100000, type=int)
24
+ parser.add_argument("--src-len", default=30, type=int)
25
+ parser.add_argument("--tgt-len", default=30, type=int)
26
+
27
+ def __init__(self, args, dictionary):
28
+ super().__init__(args)
29
+ self.dictionary = dictionary
30
+ self.seed = args.seed
31
+
32
+ dictionary.pad_to_multiple_(8) # often faster if divisible by 8
33
+
34
+ self.dummy_src = torch.arange(args.src_len + 1) + dictionary.pad() + 1
35
+ self.dummy_tgt = torch.arange(args.tgt_len + 1) + dictionary.pad() + 1
36
+
37
+ @classmethod
38
+ def setup_task(cls, args, **kwargs):
39
+ """Setup the task."""
40
+ dictionary = Dictionary()
41
+ for i in range(args.dict_size):
42
+ dictionary.add_symbol("word{}".format(i))
43
+ logger.info("dictionary: {} types".format(len(dictionary)))
44
+
45
+ args.max_source_positions = args.src_len + dictionary.pad() + 2
46
+ args.max_target_positions = args.tgt_len + dictionary.pad() + 2
47
+
48
+ return cls(args, dictionary)
49
+
50
+ def load_dataset(self, split, epoch=1, combine=False, **kwargs):
51
+ """Load a given dataset split.
52
+ Args:
53
+ split (str): name of the split (e.g., train, valid, test)
54
+ """
55
+ item_size = max(self.args.src_len, self.args.tgt_len)
56
+ if self.args.batch_size is not None:
57
+ bsz = self.args.batch_size
58
+ else:
59
+ bsz = max(1, self.args.max_tokens // item_size)
60
+ tgt = torch.stack([self.dummy_tgt for _ in range(bsz)])
61
+ self.datasets[split] = DummyDataset(
62
+ {
63
+ "id": 1,
64
+ "net_input": {
65
+ "src_tokens": torch.stack([self.dummy_src for _ in range(bsz)]),
66
+ "src_lengths": torch.full(
67
+ (bsz,), self.args.src_len, dtype=torch.long
68
+ ),
69
+ "prev_output_tokens": tgt.clone(),
70
+ },
71
+ "target": tgt,
72
+ "nsentences": bsz,
73
+ "ntokens": bsz * self.args.tgt_len,
74
+ },
75
+ num_items=self.args.dataset_size,
76
+ item_size=item_size,
77
+ )
78
+
79
+ @property
80
+ def source_dictionary(self):
81
+ return self.dictionary
82
+
83
+ @property
84
+ def target_dictionary(self):
85
+ return self.dictionary
86
+
87
+
88
+ class DummyDataset(FairseqDataset):
89
+ def __init__(self, batch, num_items, item_size):
90
+ super().__init__()
91
+ self.batch = batch
92
+ self.num_items = num_items
93
+ self.item_size = item_size
94
+
95
+ def __getitem__(self, index):
96
+ return index
97
+
98
+ def __len__(self):
99
+ return self.num_items
100
+
101
+ def collater(self, samples):
102
+ return self.batch
103
+
104
+ @property
105
+ def sizes(self):
106
+ return np.array([self.item_size] * self.num_items)
107
+
108
+ def num_tokens(self, index):
109
+ return self.item_size
110
+
111
+ def size(self, index):
112
+ return self.item_size
113
+
114
+ def ordered_indices(self):
115
+ return np.arange(self.num_items)
116
+
117
+ @property
118
+ def supports_prefetch(self):
119
+ return False
fairseq/fairseq/clib/cuda/ngram_repeat_block_cuda.cpp ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+ Copyright (c) Microsoft Corporation.
3
+ Licensed under the MIT License.
4
+ */
5
+
6
+ #include <torch/extension.h>
7
+ #include <vector>
8
+
9
+ /*
10
+ CPP Binding for CUDA OP
11
+ */
12
+
13
+ // CUDA forward declarations
14
+ torch::Tensor ngram_repeat_block_cuda_forward(
15
+ torch::Tensor tokens,
16
+ torch::Tensor lprobs,
17
+ int bsz,
18
+ int step,
19
+ int beam_size,
20
+ int no_repeat_ngram_size);
21
+
22
+ #define CHECK_CUDA(x) \
23
+ TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
24
+ #define CHECK_CONTIGUOUS(x) \
25
+ TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
26
+ #define CHECK_INPUT(x) \
27
+ CHECK_CUDA(x); \
28
+ CHECK_CONTIGUOUS(x)
29
+
30
+ // Input check and call to CUDA OP
31
+ // Backward method not required
32
+ torch::Tensor ngram_repeat_block_forward(
33
+ torch::Tensor tokens,
34
+ torch::Tensor lprobs,
35
+ int bsz,
36
+ int step,
37
+ int beam_size,
38
+ int no_repeat_ngram_size) {
39
+ CHECK_INPUT(tokens);
40
+ CHECK_INPUT(lprobs);
41
+ assert(bsz > 0);
42
+ assert(step >= 0);
43
+ assert(beam_size > 0);
44
+ assert(no_repeat_ngram_size > 0);
45
+
46
+ return ngram_repeat_block_cuda_forward(
47
+ tokens, lprobs, bsz, step, beam_size, no_repeat_ngram_size);
48
+ }
49
+
50
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
51
+ m.def(
52
+ "forward",
53
+ &ngram_repeat_block_forward,
54
+ "No Repeat Ngram Block forward (CUDA)");
55
+ }
fairseq/fairseq/clib/cuda/ngram_repeat_block_cuda_kernel.cu ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+ Copyright (c) Microsoft Corporation.
3
+ Licensed under the MIT License.
4
+ */
5
+
6
+ /*
7
+ Kernel implementation for blocking repeated n-grams.
8
+ */
9
+
10
+ #include <cuda.h>
11
+ #include <cuda_runtime.h>
12
+ #include <math.h>
13
+ #include <torch/extension.h>
14
+ #include <vector>
15
+
16
+ // Ban repeated ngrams of length = 'no_repeat_ngram_size'
17
+ __global__ void banRepeatedTokens(
18
+ long* __restrict__ tokens,
19
+ float* __restrict__ lprobs,
20
+ int max_predict_len,
21
+ int vocab_size,
22
+ int no_repeat_ngram_size) {
23
+ auto row = blockIdx.x;
24
+ auto col = threadIdx.x;
25
+ auto start = row * (max_predict_len) + col;
26
+ // Each thread compares ngram starting from
27
+ // thread index with final ngram starting from
28
+ // step - no_repeat_ngram_size +2
29
+ auto check_start_pos = blockDim.x;
30
+ auto lprob_start = row * vocab_size;
31
+ bool is_banned = true;
32
+ extern __shared__ long tokens_shm[];
33
+ tokens_shm[col] = tokens[start];
34
+ if (col == blockDim.x - 1) {
35
+ for (int i = 1; i < no_repeat_ngram_size; i++) {
36
+ if (col + i < max_predict_len) {
37
+ tokens_shm[col + i] = tokens[start + i];
38
+ }
39
+ }
40
+ }
41
+ __syncthreads();
42
+
43
+ for (int k = 0; k < no_repeat_ngram_size - 1; k++) {
44
+ if (tokens_shm[col + k] != tokens_shm[check_start_pos + k]) {
45
+ is_banned = false;
46
+ }
47
+ }
48
+ if (is_banned == true) {
49
+ auto token_to_be_banned = tokens_shm[col + no_repeat_ngram_size - 1];
50
+ lprobs[lprob_start + token_to_be_banned] = -INFINITY;
51
+ }
52
+ }
53
+
54
+ // Allocate blocks and threads based on
55
+ // batch size and sequence length and launch
56
+ // kernel
57
+ torch::Tensor ngram_repeat_block_cuda_forward(
58
+ const torch::Tensor tokens,
59
+ torch::Tensor lprobs,
60
+ int bsz,
61
+ int step,
62
+ int beam_size,
63
+ int no_repeat_ngram_size) {
64
+ int threads = step - no_repeat_ngram_size + 2;
65
+ if (threads <= 0)
66
+ return lprobs;
67
+ int max_predict_len = tokens.size(1);
68
+ int vocab_size = lprobs.size(1);
69
+ auto token_ptr = tokens.data_ptr<long>();
70
+ auto lprob_ptr = lprobs.data_ptr<float>();
71
+ int blocks = bsz * beam_size;
72
+ int shared_mem_size = (step + 1) * sizeof(long);
73
+
74
+ // Launching N blocks where N is number of samples in a batch (beams*bsz)
75
+ // Launching T threads where T is number of previous ngrams in a sample
76
+ // Allocating shared mem per block for fastser access of input tokens since
77
+ // each token will be accessed N times to compare with current Ngram where
78
+ // N is Ngram size.
79
+ banRepeatedTokens<<<blocks, threads, shared_mem_size>>>(
80
+ token_ptr, lprob_ptr, max_predict_len, vocab_size, no_repeat_ngram_size);
81
+ return lprobs;
82
+ }
fairseq/fairseq/clib/libbase/balanced_assignment.cpp ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /**
2
+ * Copyright 2017-present, Facebook, Inc.
3
+ * All rights reserved.
4
+ *
5
+ * This source code is licensed under the license found in the
6
+ * LICENSE file in the root directory of this source tree.
7
+ */
8
+
9
+ /*
10
+ C++ code for solving the linear assignment problem.
11
+ Based on the Auction Algorithm from
12
+ https://dspace.mit.edu/bitstream/handle/1721.1/3265/P-2108-26912652.pdf and the
13
+ implementation from: https://github.com/bkj/auction-lap Adapted to be more
14
+ efficient when each worker is looking for k jobs instead of 1.
15
+ */
16
+ #include <torch/extension.h>
17
+ #include <iostream>
18
+ using namespace torch::indexing;
19
+ torch::Tensor balanced_assignment(torch::Tensor job_and_worker_to_score) {
20
+ int max_iterations = 100;
21
+ torch::Tensor epsilon =
22
+ (job_and_worker_to_score.max() - job_and_worker_to_score.min()) / 50;
23
+ epsilon.clamp_min_(1e-04);
24
+ torch::Tensor worker_and_job_to_score =
25
+ job_and_worker_to_score.detach().transpose(0, 1).contiguous();
26
+ int num_workers = worker_and_job_to_score.size(0);
27
+ int num_jobs = worker_and_job_to_score.size(1);
28
+ auto device = worker_and_job_to_score.device();
29
+ int jobs_per_worker = num_jobs / num_workers;
30
+ torch::Tensor value = worker_and_job_to_score.clone();
31
+ int counter = 0;
32
+ torch::Tensor max_value = worker_and_job_to_score.max();
33
+
34
+ torch::Tensor bid_indices;
35
+ torch::Tensor cost = worker_and_job_to_score.new_zeros({1, num_jobs});
36
+ torch::Tensor bids =
37
+ worker_and_job_to_score.new_empty({num_workers, num_jobs});
38
+ torch::Tensor bid_increments =
39
+ worker_and_job_to_score.new_empty({num_workers, jobs_per_worker});
40
+ torch::Tensor top_values =
41
+ worker_and_job_to_score.new_empty({num_workers, jobs_per_worker + 1});
42
+ torch::Tensor high_bids = worker_and_job_to_score.new_empty({num_jobs});
43
+
44
+ torch::Tensor top_index = top_values.to(torch::kLong);
45
+ torch::Tensor high_bidders = top_index.new_empty({num_jobs});
46
+ torch::Tensor have_bids = high_bidders.to(torch::kBool);
47
+ torch::Tensor jobs_indices =
48
+ torch::arange({num_jobs}, torch::dtype(torch::kLong).device(device));
49
+ torch::Tensor true_tensor =
50
+ torch::ones({1}, torch::dtype(torch::kBool).device(device));
51
+
52
+ while (true) {
53
+ bids.zero_();
54
+ torch::topk_out(top_values, top_index, value, jobs_per_worker + 1, 1);
55
+
56
+ // Each worker bids the difference in value between that job and the k+1th
57
+ // job
58
+ torch::sub_out(
59
+ bid_increments,
60
+ top_values.index({Slice(None, None), Slice(0, jobs_per_worker)}),
61
+ top_values.index({Slice(None, None), jobs_per_worker}).unsqueeze(1));
62
+
63
+ bid_increments.add_(epsilon);
64
+ bids.scatter_(
65
+ 1,
66
+ top_index.index({Slice(None, None), Slice(0, jobs_per_worker)}),
67
+ bid_increments);
68
+
69
+ if (counter < max_iterations && counter > 0) {
70
+ // Put in a minimal bid to retain items from the last round if no-one else
71
+ // bids for them this round
72
+ bids.view(-1).index_put_({bid_indices}, epsilon);
73
+ }
74
+
75
+ // Find the highest bidding worker per job
76
+ torch::max_out(high_bids, high_bidders, bids, 0);
77
+ torch::gt_out(have_bids, high_bids, 0);
78
+
79
+ if (have_bids.all().item<bool>()) {
80
+ // All jobs were bid for
81
+ break;
82
+ }
83
+
84
+ // Make popular items more expensive
85
+ cost.add_(high_bids);
86
+ torch::sub_out(value, worker_and_job_to_score, cost);
87
+
88
+ bid_indices = ((high_bidders * num_jobs) + jobs_indices).index({have_bids});
89
+
90
+ if (counter < max_iterations) {
91
+ // Make sure that this item will be in the winning worker's top-k next
92
+ // time.
93
+ value.view(-1).index_put_({bid_indices}, max_value);
94
+ } else {
95
+ // Suboptimal approximation that converges quickly from current solution
96
+ value.view(-1).index_put_(
97
+ {bid_indices}, worker_and_job_to_score.view(-1).index({bid_indices}));
98
+ }
99
+
100
+ counter += 1;
101
+ }
102
+
103
+ return top_index.index({Slice(None, None), Slice(0, jobs_per_worker)})
104
+ .reshape(-1);
105
+ }
106
+
107
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
108
+ m.def("balanced_assignment", &balanced_assignment, "Balanced Assignment");
109
+ }
fairseq/fairseq/clib/libbleu/libbleu.cpp ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /**
2
+ * Copyright 2017-present, Facebook, Inc.
3
+ * All rights reserved.
4
+ *
5
+ * This source code is licensed under the license found in the
6
+ * LICENSE file in the root directory of this source tree.
7
+ */
8
+
9
+ #include <array>
10
+ #include <cstdio>
11
+ #include <cstring>
12
+ #include <map>
13
+
14
+ // NOLINTNEXTLINE
15
+ typedef struct {
16
+ size_t reflen;
17
+ size_t predlen;
18
+ size_t match1;
19
+ size_t count1;
20
+ size_t match2;
21
+ size_t count2;
22
+ size_t match3;
23
+ size_t count3;
24
+ size_t match4;
25
+ size_t count4;
26
+ } bleu_stat;
27
+
28
+ // left trim (remove pad)
29
+ void bleu_ltrim(size_t* len, int** sent, int pad) {
30
+ size_t start = 0;
31
+ while (start < *len) {
32
+ if (*(*sent + start) != pad) {
33
+ break;
34
+ }
35
+ start++;
36
+ }
37
+ *sent += start;
38
+ *len -= start;
39
+ }
40
+
41
+ // right trim remove (eos)
42
+ void bleu_rtrim(size_t* len, int** sent, int pad, int eos) {
43
+ size_t end = *len - 1;
44
+ while (end > 0) {
45
+ if (*(*sent + end) != eos && *(*sent + end) != pad) {
46
+ break;
47
+ }
48
+ end--;
49
+ }
50
+ *len = end + 1;
51
+ }
52
+
53
+ // left and right trim
54
+ void bleu_trim(size_t* len, int** sent, int pad, int eos) {
55
+ bleu_ltrim(len, sent, pad);
56
+ bleu_rtrim(len, sent, pad, eos);
57
+ }
58
+
59
+ size_t bleu_hash(int len, int* data) {
60
+ size_t h = 14695981039346656037ul;
61
+ size_t prime = 0x100000001b3;
62
+ char* b = (char*)data;
63
+ size_t blen = sizeof(int) * len;
64
+
65
+ while (blen-- > 0) {
66
+ h ^= *b++;
67
+ h *= prime;
68
+ }
69
+
70
+ return h;
71
+ }
72
+
73
+ void bleu_addngram(
74
+ size_t* ntotal,
75
+ size_t* nmatch,
76
+ size_t n,
77
+ size_t reflen,
78
+ int* ref,
79
+ size_t predlen,
80
+ int* pred) {
81
+ if (predlen < n) {
82
+ return;
83
+ }
84
+
85
+ predlen = predlen - n + 1;
86
+ (*ntotal) += predlen;
87
+
88
+ if (reflen < n) {
89
+ return;
90
+ }
91
+
92
+ reflen = reflen - n + 1;
93
+
94
+ std::map<size_t, size_t> count;
95
+ while (predlen > 0) {
96
+ size_t w = bleu_hash(n, pred++);
97
+ count[w]++;
98
+ predlen--;
99
+ }
100
+
101
+ while (reflen > 0) {
102
+ size_t w = bleu_hash(n, ref++);
103
+ if (count[w] > 0) {
104
+ (*nmatch)++;
105
+ count[w] -= 1;
106
+ }
107
+ reflen--;
108
+ }
109
+ }
110
+
111
+ extern "C" {
112
+
113
+ #ifdef _WIN64
114
+ __declspec(dllexport)
115
+ #endif
116
+ void bleu_zero_init(bleu_stat* stat) {
117
+ std::memset(stat, 0, sizeof(bleu_stat));
118
+ }
119
+
120
+ #ifdef _WIN64
121
+ __declspec(dllexport)
122
+ #endif
123
+ void bleu_one_init(bleu_stat* stat) {
124
+ bleu_zero_init(stat);
125
+ stat->count1 = 0;
126
+ stat->count2 = 1;
127
+ stat->count3 = 1;
128
+ stat->count4 = 1;
129
+ stat->match1 = 0;
130
+ stat->match2 = 1;
131
+ stat->match3 = 1;
132
+ stat->match4 = 1;
133
+ }
134
+
135
+ #ifdef _WIN64
136
+ __declspec(dllexport)
137
+ #endif
138
+ void bleu_add(
139
+ bleu_stat* stat,
140
+ size_t reflen,
141
+ int* ref,
142
+ size_t predlen,
143
+ int* pred,
144
+ int pad,
145
+ int eos) {
146
+
147
+ bleu_trim(&reflen, &ref, pad, eos);
148
+ bleu_trim(&predlen, &pred, pad, eos);
149
+ stat->reflen += reflen;
150
+ stat->predlen += predlen;
151
+
152
+ bleu_addngram(&stat->count1, &stat->match1, 1, reflen, ref, predlen, pred);
153
+ bleu_addngram(&stat->count2, &stat->match2, 2, reflen, ref, predlen, pred);
154
+ bleu_addngram(&stat->count3, &stat->match3, 3, reflen, ref, predlen, pred);
155
+ bleu_addngram(&stat->count4, &stat->match4, 4, reflen, ref, predlen, pred);
156
+ }
157
+ }
fairseq/fairseq/clib/libbleu/module.cpp ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /**
2
+ * Copyright 2017-present, Facebook, Inc.
3
+ * All rights reserved.
4
+ *
5
+ * This source code is licensed under the license found in the
6
+ * LICENSE file in the root directory of this source tree.
7
+ */
8
+
9
+ #include <Python.h>
10
+
11
+ static PyMethodDef method_def[] = {{NULL, NULL, 0, NULL}}; // NOLINT
12
+
13
+ static struct PyModuleDef module_def = {
14
+ PyModuleDef_HEAD_INIT,
15
+ "libbleu", /* name of module */
16
+ // NOLINTNEXTLINE
17
+ NULL, /* module documentation, may be NULL */
18
+ -1, /* size of per-interpreter state of the module,
19
+ or -1 if the module keeps state in global variables. */
20
+ method_def}; // NOLINT
21
+
22
+ #if PY_MAJOR_VERSION == 2
23
+ PyMODINIT_FUNC init_libbleu()
24
+ #else
25
+ PyMODINIT_FUNC PyInit_libbleu()
26
+ #endif
27
+ {
28
+ PyObject* m = PyModule_Create(&module_def);
29
+ if (!m) {
30
+ return NULL;
31
+ }
32
+ return m;
33
+ }
fairseq/fairseq/clib/libnat/edit_dist.cpp ADDED
@@ -0,0 +1,231 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /**
2
+ * Copyright 2017-present, Facebook, Inc.
3
+ * All rights reserved.
4
+ *
5
+ * This source code is licensed under the license found in the
6
+ * LICENSE file in the root directory of this source tree.
7
+ */
8
+
9
+ #include <pybind11/detail/common.h>
10
+ #include <pybind11/pybind11.h>
11
+ #include <torch/torch.h> // @manual=//caffe2:torch_extension
12
+ #include <algorithm>
13
+ #include <cstdint>
14
+ #include <iosfwd>
15
+ #include <memory>
16
+ #include <new>
17
+ #include <string>
18
+ #include <utility>
19
+ #include <vector>
20
+
21
+ using namespace ::std;
22
+
23
+ vector<vector<uint32_t>> edit_distance2_with_dp(
24
+ vector<uint32_t>& x,
25
+ vector<uint32_t>& y) {
26
+ uint32_t lx = x.size();
27
+ uint32_t ly = y.size();
28
+ vector<vector<uint32_t>> d(lx + 1, vector<uint32_t>(ly + 1));
29
+ for (uint32_t i = 0; i < lx + 1; i++) {
30
+ d[i][0] = i;
31
+ }
32
+ for (uint32_t j = 0; j < ly + 1; j++) {
33
+ d[0][j] = j;
34
+ }
35
+ for (uint32_t i = 1; i < lx + 1; i++) {
36
+ for (uint32_t j = 1; j < ly + 1; j++) {
37
+ d[i][j] =
38
+ min(min(d[i - 1][j], d[i][j - 1]) + 1,
39
+ d[i - 1][j - 1] + 2 * (x.at(i - 1) == y.at(j - 1) ? 0 : 1));
40
+ }
41
+ }
42
+ return d;
43
+ }
44
+
45
+ vector<vector<uint32_t>> edit_distance2_backtracking(
46
+ vector<vector<uint32_t>>& d,
47
+ vector<uint32_t>& x,
48
+ vector<uint32_t>& y,
49
+ uint32_t terminal_symbol) {
50
+ vector<uint32_t> seq;
51
+ vector<vector<uint32_t>> edit_seqs(x.size() + 2, vector<uint32_t>());
52
+ /*
53
+ edit_seqs:
54
+ 0~x.size() cell is the insertion sequences
55
+ last cell is the delete sequence
56
+ */
57
+
58
+ if (x.size() == 0) {
59
+ edit_seqs.at(0) = y;
60
+ return edit_seqs;
61
+ }
62
+
63
+ uint32_t i = d.size() - 1;
64
+ uint32_t j = d.at(0).size() - 1;
65
+
66
+ while ((i >= 0) && (j >= 0)) {
67
+ if ((i == 0) && (j == 0)) {
68
+ break;
69
+ }
70
+
71
+ if ((j > 0) && (d.at(i).at(j - 1) < d.at(i).at(j))) {
72
+ seq.push_back(1); // insert
73
+ seq.push_back(y.at(j - 1));
74
+ j--;
75
+ } else if ((i > 0) && (d.at(i - 1).at(j) < d.at(i).at(j))) {
76
+ seq.push_back(2); // delete
77
+ seq.push_back(x.at(i - 1));
78
+ i--;
79
+ } else {
80
+ seq.push_back(3); // keep
81
+ seq.push_back(x.at(i - 1));
82
+ i--;
83
+ j--;
84
+ }
85
+ }
86
+
87
+ uint32_t prev_op, op, s, word;
88
+ prev_op = 0, s = 0;
89
+ for (uint32_t k = 0; k < seq.size() / 2; k++) {
90
+ op = seq.at(seq.size() - 2 * k - 2);
91
+ word = seq.at(seq.size() - 2 * k - 1);
92
+ if (prev_op != 1) {
93
+ s++;
94
+ }
95
+ if (op == 1) // insert
96
+ {
97
+ edit_seqs.at(s - 1).push_back(word);
98
+ } else if (op == 2) // delete
99
+ {
100
+ edit_seqs.at(x.size() + 1).push_back(1);
101
+ } else {
102
+ edit_seqs.at(x.size() + 1).push_back(0);
103
+ }
104
+
105
+ prev_op = op;
106
+ }
107
+
108
+ for (uint32_t k = 0; k < edit_seqs.size(); k++) {
109
+ if (edit_seqs[k].size() == 0) {
110
+ edit_seqs[k].push_back(terminal_symbol);
111
+ }
112
+ }
113
+ return edit_seqs;
114
+ }
115
+
116
+ vector<vector<uint32_t>> edit_distance2_backtracking_with_delete(
117
+ vector<vector<uint32_t>>& d,
118
+ vector<uint32_t>& x,
119
+ vector<uint32_t>& y,
120
+ uint32_t terminal_symbol,
121
+ uint32_t deletion_symbol) {
122
+ vector<uint32_t> seq;
123
+ vector<vector<uint32_t>> edit_seqs(x.size() + 1, vector<uint32_t>());
124
+ /*
125
+ edit_seqs:
126
+ 0~x.size() cell is the insertion sequences
127
+ last cell is the delete sequence
128
+ */
129
+
130
+ if (x.size() == 0) {
131
+ edit_seqs.at(0) = y;
132
+ return edit_seqs;
133
+ }
134
+
135
+ uint32_t i = d.size() - 1;
136
+ uint32_t j = d.at(0).size() - 1;
137
+
138
+ while ((i >= 0) && (j >= 0)) {
139
+ if ((i == 0) && (j == 0)) {
140
+ break;
141
+ }
142
+
143
+ if ((j > 0) && (d.at(i).at(j - 1) < d.at(i).at(j))) {
144
+ seq.push_back(1); // insert
145
+ seq.push_back(y.at(j - 1));
146
+ j--;
147
+ } else if ((i > 0) && (d.at(i - 1).at(j) < d.at(i).at(j))) {
148
+ seq.push_back(2); // delete
149
+ seq.push_back(x.at(i - 1));
150
+ i--;
151
+ } else {
152
+ seq.push_back(3); // keep
153
+ seq.push_back(x.at(i - 1));
154
+ i--;
155
+ j--;
156
+ }
157
+ }
158
+
159
+ uint32_t prev_op, op, s, word;
160
+ prev_op = 0, s = 0;
161
+ for (uint32_t k = 0; k < seq.size() / 2; k++) {
162
+ op = seq.at(seq.size() - 2 * k - 2);
163
+ word = seq.at(seq.size() - 2 * k - 1);
164
+ if (prev_op != 1) {
165
+ s++;
166
+ }
167
+ if (op == 1) // insert
168
+ {
169
+ edit_seqs.at(s - 1).push_back(word);
170
+ } else if (op == 2) // delete
171
+ {
172
+ edit_seqs.at(s - 1).push_back(deletion_symbol);
173
+ }
174
+
175
+ prev_op = op;
176
+ }
177
+
178
+ for (uint32_t k = 0; k < edit_seqs.size(); k++) {
179
+ if (edit_seqs.at(k).size() == 0) {
180
+ edit_seqs.at(k).push_back(terminal_symbol);
181
+ }
182
+ }
183
+ return edit_seqs;
184
+ }
185
+
186
+ vector<uint32_t> compute_ed2(
187
+ vector<vector<uint32_t>>& xs,
188
+ vector<vector<uint32_t>>& ys) {
189
+ vector<uint32_t> distances(xs.size());
190
+ for (uint32_t i = 0; i < xs.size(); i++) {
191
+ vector<vector<uint32_t>> d = edit_distance2_with_dp(xs.at(i), ys.at(i));
192
+ distances.at(i) = d.at(xs.at(i).size()).at(ys.at(i).size());
193
+ }
194
+ return distances;
195
+ }
196
+
197
+ vector<vector<vector<uint32_t>>> suggested_ed2_path(
198
+ vector<vector<uint32_t>>& xs,
199
+ vector<vector<uint32_t>>& ys,
200
+ uint32_t terminal_symbol) {
201
+ vector<vector<vector<uint32_t>>> seq(xs.size());
202
+ for (uint32_t i = 0; i < xs.size(); i++) {
203
+ vector<vector<uint32_t>> d = edit_distance2_with_dp(xs.at(i), ys.at(i));
204
+ seq.at(i) =
205
+ edit_distance2_backtracking(d, xs.at(i), ys.at(i), terminal_symbol);
206
+ }
207
+ return seq;
208
+ }
209
+
210
+ vector<vector<vector<uint32_t>>> suggested_ed2_path_with_delete(
211
+ vector<vector<uint32_t>>& xs,
212
+ vector<vector<uint32_t>>& ys,
213
+ uint32_t terminal_symbol,
214
+ uint32_t deletion_symbol) {
215
+ vector<vector<vector<uint32_t>>> seq(xs.size());
216
+ for (uint32_t i = 0; i < xs.size(); i++) {
217
+ vector<vector<uint32_t>> d = edit_distance2_with_dp(xs.at(i), ys.at(i));
218
+ seq.at(i) = edit_distance2_backtracking_with_delete(
219
+ d, xs.at(i), ys.at(i), terminal_symbol, deletion_symbol);
220
+ }
221
+ return seq;
222
+ }
223
+
224
+ PYBIND11_MODULE(libnat, m) {
225
+ m.def("compute_ed2", &compute_ed2, "compute_ed2");
226
+ m.def("suggested_ed2_path", &suggested_ed2_path, "suggested_ed2_path");
227
+ m.def(
228
+ "suggested_ed2_path_with_delete",
229
+ &suggested_ed2_path_with_delete,
230
+ "suggested_ed2_path_with_delete");
231
+ }
fairseq/fairseq/clib/libnat_cuda/binding.cpp ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /**
2
+ * Copyright 2017-present, Facebook, Inc.
3
+ * All rights reserved.
4
+ *
5
+ * This source code is licensed under the license found in the
6
+ * LICENSE file in the root directory of this source tree.
7
+ */
8
+
9
+ /*
10
+ This code is partially adpoted from
11
+ https://github.com/1ytic/pytorch-edit-distance
12
+ */
13
+
14
+ #include <torch/types.h>
15
+ #include "edit_dist.h"
16
+
17
+ #ifndef TORCH_CHECK
18
+ #define TORCH_CHECK AT_CHECK
19
+ #endif
20
+
21
+ #define CHECK_CUDA(x) \
22
+ TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
23
+ #define CHECK_CONTIGUOUS(x) \
24
+ TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
25
+ #define CHECK_INPUT(x) \
26
+ CHECK_CUDA(x); \
27
+ CHECK_CONTIGUOUS(x)
28
+
29
+ torch::Tensor LevenshteinDistance(
30
+ torch::Tensor source,
31
+ torch::Tensor target,
32
+ torch::Tensor source_length,
33
+ torch::Tensor target_length) {
34
+ CHECK_INPUT(source);
35
+ CHECK_INPUT(target);
36
+ CHECK_INPUT(source_length);
37
+ CHECK_INPUT(target_length);
38
+ return LevenshteinDistanceCuda(source, target, source_length, target_length);
39
+ }
40
+
41
+ torch::Tensor GenerateDeletionLabel(
42
+ torch::Tensor source,
43
+ torch::Tensor operations) {
44
+ CHECK_INPUT(source);
45
+ CHECK_INPUT(operations);
46
+ return GenerateDeletionLabelCuda(source, operations);
47
+ }
48
+
49
+ std::pair<torch::Tensor, torch::Tensor> GenerateInsertionLabel(
50
+ torch::Tensor target,
51
+ torch::Tensor operations) {
52
+ CHECK_INPUT(target);
53
+ CHECK_INPUT(operations);
54
+ return GenerateInsertionLabelCuda(target, operations);
55
+ }
56
+
57
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
58
+ m.def("levenshtein_distance", &LevenshteinDistance, "Levenshtein distance");
59
+ m.def(
60
+ "generate_deletion_labels",
61
+ &GenerateDeletionLabel,
62
+ "Generate Deletion Label");
63
+ m.def(
64
+ "generate_insertion_labels",
65
+ &GenerateInsertionLabel,
66
+ "Generate Insertion Label");
67
+ }
fairseq/fairseq/clib/libnat_cuda/edit_dist.cu ADDED
@@ -0,0 +1,344 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /**
2
+ * Copyright 2017-present, Facebook, Inc.
3
+ * All rights reserved.
4
+ *
5
+ * This source code is licensed under the license found in the
6
+ * LICENSE file in the root directory of this source tree.
7
+ */
8
+
9
+ #include "edit_dist.h"
10
+
11
+ #include <c10/cuda/CUDAStream.h>
12
+ #include <cuda.h>
13
+ #include <cuda_runtime.h>
14
+ #include <device_launch_parameters.h>
15
+ #include <utility> // std::pair
16
+
17
+ template <typename scalar_t>
18
+ __global__ void generate_deletion_label_kernel(
19
+ const scalar_t* __restrict__ source,
20
+ const size_t source_size,
21
+ const size_t operation_size,
22
+ int* __restrict__ operations,
23
+ int* __restrict__ labels) {
24
+ const int index = blockIdx.x;
25
+ const int offset = index * operation_size;
26
+ const int offset_label = index * source_size;
27
+
28
+ for (int i = 0; i < source_size; i++) {
29
+ labels[offset_label + i] = 0;
30
+ }
31
+
32
+ int k = 0;
33
+ for (int i = 0; i < operation_size; i++) {
34
+ if (operations[offset + i] == 0) {
35
+ break;
36
+ } else if (operations[offset + i] == 1) {
37
+ continue;
38
+ } else {
39
+ labels[offset_label + k] = 3 - operations[offset + i];
40
+ k++;
41
+ }
42
+ }
43
+ }
44
+
45
+ template <typename scalar_t>
46
+ __global__ void generate_insertion_label_kernel(
47
+ const scalar_t* __restrict__ target,
48
+ const size_t target_size,
49
+ const size_t operation_size,
50
+ int* __restrict__ operations,
51
+ int* __restrict__ labels,
52
+ int* __restrict__ masks) {
53
+ const int index = blockIdx.x;
54
+ const int offset = index * operation_size;
55
+ const int offset_label = index * target_size;
56
+
57
+ int k = 0;
58
+ int u = 0;
59
+ int m = 0;
60
+
61
+ for (int i = 0; i < target_size; i++) {
62
+ labels[offset_label + i] = 0;
63
+ masks[offset_label + i] = 0;
64
+ }
65
+
66
+ for (int i = 0; i < operation_size - 1; i++) {
67
+ if (operations[offset + i] == 0) {
68
+ break;
69
+ } else if (operations[offset + i] == 2) {
70
+ continue;
71
+ } else if (operations[offset + i] == 1) {
72
+ masks[offset_label + m] = 1;
73
+ u++;
74
+ m++;
75
+ } else {
76
+ labels[offset_label + k] = u;
77
+ masks[offset_label + m] = 0;
78
+ k++;
79
+ m++;
80
+ u = 0;
81
+ }
82
+ }
83
+ }
84
+
85
+ template <typename scalar_t>
86
+ __global__ void levenshtein_distance_kernel(
87
+ const scalar_t* __restrict__ source,
88
+ const scalar_t* __restrict__ target,
89
+ const int* __restrict__ source_length,
90
+ const int* __restrict__ target_length,
91
+ const size_t source_size,
92
+ const size_t target_size,
93
+ int* __restrict__ operations,
94
+ int* __restrict__ errors_curr) {
95
+ const int index = blockIdx.x;
96
+ const int offset = index * (source_size + target_size);
97
+ const int d = index * (source_size + 1) * (target_size + 1);
98
+ const int t = target_size + 1;
99
+
100
+ auto err_idx = [d, t](int i, int j) { return d + i * t + j; };
101
+ auto opt_idx = [offset](int k) { return offset + k; };
102
+
103
+ const int hyp_len = source_length[index];
104
+ const int ref_len = target_length[index];
105
+ const scalar_t* hyp_begin = source + index * source_size;
106
+ const scalar_t* ref_begin = target + index * target_size;
107
+
108
+ // dynamic programming
109
+ for (int i = 0; i <= hyp_len; i++) {
110
+ errors_curr[err_idx(i, 0)] = i;
111
+ }
112
+ for (int j = 0; j <= ref_len; j++) {
113
+ errors_curr[err_idx(0, j)] = j;
114
+ }
115
+ for (int i = 1; i <= hyp_len; i++) {
116
+ for (int j = 1; j <= ref_len; j++) {
117
+ errors_curr[err_idx(i, j)] = min(
118
+ min(errors_curr[err_idx(i - 1, j)], errors_curr[err_idx(i, j - 1)]) +
119
+ 1,
120
+ errors_curr[err_idx(i - 1, j - 1)] +
121
+ 2 * (*(hyp_begin + i - 1) == *(ref_begin + j - 1) ? 0 : 1));
122
+ }
123
+ }
124
+
125
+ // back-tracing
126
+ int i = hyp_len;
127
+ int j = ref_len;
128
+ int o = hyp_len + ref_len;
129
+
130
+ for (int k = 0; k < source_size + target_size; k++) {
131
+ operations[opt_idx(k)] = 0;
132
+ }
133
+
134
+ while ((i >= 0) && (j >= 0)) {
135
+ if ((i == 0) && (j == 0)) {
136
+ break;
137
+ }
138
+
139
+ if ((j > 0) &&
140
+ (errors_curr[err_idx(i, j - 1)] < errors_curr[err_idx(i, j)])) {
141
+ o--;
142
+ operations[opt_idx(o)] = 1;
143
+ j--; // insertion
144
+ } else if (
145
+ (i > 0) &&
146
+ (errors_curr[err_idx(i - 1, j)] < errors_curr[err_idx(i, j)])) {
147
+ o--;
148
+ operations[opt_idx(o)] = 2;
149
+ i--; // deletion
150
+ } else {
151
+ o--;
152
+ operations[opt_idx(o)] = 3;
153
+ i--;
154
+ j--; // do nothing
155
+ }
156
+ }
157
+
158
+ // moving to the left
159
+ for (int k = 0; k < hyp_len + ref_len; k++) {
160
+ if (k + o < hyp_len + ref_len) {
161
+ operations[opt_idx(k)] = operations[opt_idx(k + o)];
162
+ } else {
163
+ operations[opt_idx(k)] = 0; // padding
164
+ }
165
+ }
166
+ }
167
+
168
+ template <typename scalar_t>
169
+ __global__ void faster_levenshtein_distance_kernel(
170
+ const scalar_t* __restrict__ source,
171
+ const scalar_t* __restrict__ target,
172
+ const int* __restrict__ source_length,
173
+ const int* __restrict__ target_length,
174
+ const size_t source_size,
175
+ const size_t target_size,
176
+ int* __restrict__ operations) {
177
+ extern __shared__ short errors[];
178
+ auto errors_curr = errors;
179
+
180
+ const int index = blockIdx.x;
181
+ const int offset = index * (source_size + target_size);
182
+ const int t = target_size + 1;
183
+
184
+ auto err_idx = [t](int i, int j) { return i * t + j; };
185
+ auto opt_idx = [offset](int k) { return offset + k; };
186
+
187
+ const int hyp_len = source_length[index];
188
+ const int ref_len = target_length[index];
189
+ const scalar_t* hyp_begin = source + index * source_size;
190
+ const scalar_t* ref_begin = target + index * target_size;
191
+
192
+ // dynamic programming
193
+ for (int i = 0; i <= hyp_len; i++) {
194
+ errors_curr[err_idx(i, 0)] = i;
195
+ }
196
+ for (int j = 0; j <= ref_len; j++) {
197
+ errors_curr[err_idx(0, j)] = j;
198
+ }
199
+ for (int i = 1; i <= hyp_len; i++) {
200
+ for (int j = 1; j <= ref_len; j++) {
201
+ errors_curr[err_idx(i, j)] = min(
202
+ min(errors_curr[err_idx(i - 1, j)], errors_curr[err_idx(i, j - 1)]) +
203
+ 1,
204
+ errors_curr[err_idx(i - 1, j - 1)] +
205
+ 2 * (*(hyp_begin + i - 1) == *(ref_begin + j - 1) ? 0 : 1));
206
+ }
207
+ }
208
+
209
+ // back-tracing
210
+ int i = hyp_len;
211
+ int j = ref_len;
212
+ int o = hyp_len + ref_len;
213
+
214
+ for (int k = 0; k < source_size + target_size; k++) {
215
+ operations[opt_idx(k)] = 0;
216
+ }
217
+
218
+ while ((i >= 0) && (j >= 0)) {
219
+ if ((i == 0) && (j == 0)) {
220
+ break;
221
+ }
222
+
223
+ if ((j > 0) &&
224
+ (errors_curr[err_idx(i, j - 1)] < errors_curr[err_idx(i, j)])) {
225
+ o--;
226
+ operations[opt_idx(o)] = 1;
227
+ j--; // insertion
228
+ } else if (
229
+ (i > 0) &&
230
+ (errors_curr[err_idx(i - 1, j)] < errors_curr[err_idx(i, j)])) {
231
+ o--;
232
+ operations[opt_idx(o)] = 2;
233
+ i--; // deletion
234
+ } else {
235
+ o--;
236
+ operations[opt_idx(o)] = 3;
237
+ i--;
238
+ j--; // do nothing
239
+ }
240
+ }
241
+
242
+ // moving to the left
243
+ for (int k = 0; k < hyp_len + ref_len; k++) {
244
+ if (k + o < hyp_len + ref_len) {
245
+ operations[opt_idx(k)] = operations[opt_idx(k + o)];
246
+ } else {
247
+ operations[opt_idx(k)] = 0; // padding
248
+ }
249
+ }
250
+ }
251
+
252
+ torch::Tensor GenerateDeletionLabelCuda(
253
+ torch::Tensor source,
254
+ torch::Tensor operations) {
255
+ const auto batch_size = source.size(0);
256
+ at::TensorOptions options(source.device());
257
+ options = options.dtype(at::ScalarType::Int);
258
+ auto labels = torch::empty({batch_size, source.size(1)}, options);
259
+ auto stream = at::cuda::getCurrentCUDAStream(source.device().index());
260
+
261
+ AT_DISPATCH_ALL_TYPES(source.scalar_type(), "generate_deletion_labels", ([&] {
262
+ generate_deletion_label_kernel<scalar_t>
263
+ <<<batch_size, 1, 0, stream>>>(
264
+ source.data_ptr<scalar_t>(),
265
+ source.size(1),
266
+ operations.size(1),
267
+ operations.data_ptr<int>(),
268
+ labels.data_ptr<int>());
269
+ }));
270
+
271
+ return labels;
272
+ }
273
+
274
+ std::pair<torch::Tensor, torch::Tensor> GenerateInsertionLabelCuda(
275
+ torch::Tensor target,
276
+ torch::Tensor operations) {
277
+ const auto batch_size = target.size(0);
278
+ at::TensorOptions options(target.device());
279
+ options = options.dtype(at::ScalarType::Int);
280
+ auto labels = torch::empty({batch_size, target.size(1)}, options);
281
+ auto masks = torch::empty({batch_size, target.size(1)}, options);
282
+ auto stream = at::cuda::getCurrentCUDAStream(target.device().index());
283
+
284
+ AT_DISPATCH_ALL_TYPES(
285
+ target.scalar_type(), "generate_insertion_labels", ([&] {
286
+ generate_insertion_label_kernel<scalar_t><<<batch_size, 1, 0, stream>>>(
287
+ target.data_ptr<scalar_t>(),
288
+ target.size(1),
289
+ operations.size(1),
290
+ operations.data_ptr<int>(),
291
+ labels.data_ptr<int>(),
292
+ masks.data_ptr<int>());
293
+ }));
294
+
295
+ return std::make_pair(labels, masks);
296
+ }
297
+
298
+ torch::Tensor LevenshteinDistanceCuda(
299
+ torch::Tensor source,
300
+ torch::Tensor target,
301
+ torch::Tensor source_length,
302
+ torch::Tensor target_length) {
303
+ const auto batch_size = source.size(0);
304
+ const auto shared_size =
305
+ (source.size(1) + 1) * (target.size(1) + 1) * sizeof(short);
306
+
307
+ at::TensorOptions options(source.device());
308
+ options = options.dtype(at::ScalarType::Int);
309
+ auto operations =
310
+ torch::empty({batch_size, source.size(1) + target.size(1)}, options);
311
+ auto stream = at::cuda::getCurrentCUDAStream(source.device().index());
312
+
313
+ if (shared_size > 40000) {
314
+ auto distances = torch::empty(
315
+ {batch_size, (source.size(1) + 1) * (target.size(1) + 1)}, options);
316
+ AT_DISPATCH_ALL_TYPES(source.scalar_type(), "levenshtein_distance", ([&] {
317
+ levenshtein_distance_kernel<scalar_t>
318
+ <<<batch_size, 1, 0, stream>>>(
319
+ source.data_ptr<scalar_t>(),
320
+ target.data_ptr<scalar_t>(),
321
+ source_length.data_ptr<int>(),
322
+ target_length.data_ptr<int>(),
323
+ source.size(1),
324
+ target.size(1),
325
+ operations.data_ptr<int>(),
326
+ distances.data_ptr<int>());
327
+ }));
328
+ } else {
329
+ AT_DISPATCH_ALL_TYPES(
330
+ source.scalar_type(), "faster_levenshtein_distance", ([&] {
331
+ faster_levenshtein_distance_kernel<scalar_t>
332
+ <<<batch_size, 1, shared_size, stream>>>(
333
+ source.data_ptr<scalar_t>(),
334
+ target.data_ptr<scalar_t>(),
335
+ source_length.data_ptr<int>(),
336
+ target_length.data_ptr<int>(),
337
+ source.size(1),
338
+ target.size(1),
339
+ operations.data_ptr<int>());
340
+ }));
341
+ }
342
+
343
+ return operations;
344
+ }
fairseq/fairseq/clib/libnat_cuda/edit_dist.h ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /**
2
+ * Copyright 2017-present, Facebook, Inc.
3
+ * All rights reserved.
4
+ *
5
+ * This source code is licensed under the license found in the
6
+ * LICENSE file in the root directory of this source tree.
7
+ */
8
+
9
+ #pragma once
10
+
11
+ #include <torch/extension.h>
12
+
13
+ torch::Tensor LevenshteinDistanceCuda(
14
+ torch::Tensor source,
15
+ torch::Tensor target,
16
+ torch::Tensor source_length,
17
+ torch::Tensor target_length);
18
+
19
+ torch::Tensor GenerateDeletionLabelCuda(
20
+ torch::Tensor source,
21
+ torch::Tensor operations);
22
+
23
+ std::pair<torch::Tensor, torch::Tensor> GenerateInsertionLabelCuda(
24
+ torch::Tensor source,
25
+ torch::Tensor operations);
fairseq/fairseq/config/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
fairseq/fairseq/config/config.yaml ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _group_
2
+
3
+ hydra:
4
+ run:
5
+ dir: .
6
+
7
+ defaults:
8
+ - _self_
9
+ - task: null
10
+ - model: null
11
+ - criterion: cross_entropy
12
+ - optimizer: null
13
+ - lr_scheduler: fixed
14
+ - bpe: null
15
+ - tokenizer: null
16
+ - scoring: null
17
+ - generation: null
18
+ - common_eval: null
19
+ - eval_lm: null
fairseq/fairseq/config/fb_run_config/slurm.yaml ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ hydra:
4
+ job:
5
+ config:
6
+ override_dirname:
7
+ kv_sep: ':'
8
+ item_sep: '__'
9
+ exclude_keys:
10
+ - fb_run_config
11
+ - distributed_training.distributed_port
12
+ sweep:
13
+ dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
14
+ launcher:
15
+ cpus_per_task: 60
16
+ gpus_per_node: ???
17
+ tasks_per_node: 1
18
+ nodes: 1
19
+ partition: learnfair
20
+ mem_gb: 400
21
+ timeout_min: 4320
22
+ max_num_timeout: 10
23
+ name: ${env:PREFIX}_${hydra.job.config_name}
24
+ submitit_folder: ${hydra.sweep.dir}
25
+
26
+ distributed_training:
27
+ ddp_backend: c10d
28
+ distributed_world_size: ???
29
+ distributed_port: ???
fairseq/fairseq/config/model/transformer_lm/transformer_lm_baevski_gbw.yaml ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _group_
2
+ activation_fn: "relu"
3
+ dropout: 0.1
4
+ attention_dropout: 0.1
5
+ activation_dropout: 0.0
6
+ relu_dropout: 0.0
7
+ decoder_embed_dim: 512
8
+ decoder_output_dim: 512
9
+ decoder_input_dim: 512
10
+ decoder_ffn_embed_dim: 4096
11
+ decoder_layers: 12
12
+ decoder_attention_heads: 16
13
+ decoder_normalize_before: true
14
+ no_decoder_final_norm: true
15
+ adaptive_softmax_cutoff: null
16
+ adaptive_softmax_dropout: 0
17
+ adaptive_softmax_factor: 4
18
+ no_token_positional_embeddings: false
19
+ share_decoder_input_output_embed: false
20
+ character_embeddings: false
21
+ character_filters: "[(1, 64), (2, 128), (3, 192), (4, 256), (5, 256), (6, 256), (7, 256)]"
22
+ character_embedding_dim: 4
23
+ char_embedder_highway_layers: 2
24
+ adaptive_input: false
25
+ adaptive_input_factor: 4
26
+ adaptive_input_cutoff: null
27
+ tie_adaptive_weights: false
28
+ tie_adaptive_proj: false
29
+ decoder_learned_pos: false
30
+ decoder_layerdrop: 0
31
+ decoder_layers_to_keep: null
32
+ layernorm_embedding: false
33
+ no_scale_embedding: false
34
+ quant_noise_pq: 0
35
+ quant_noise_pq_block_size: 8
36
+ quant_noise_scalar: 0
fairseq/fairseq/config/model/transformer_lm/transformer_lm_baevski_wiki103.yaml ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _group_
2
+ activation_fn: "relu"
3
+ dropout: 0.3
4
+ attention_dropout: 0.1
5
+ activation_dropout: 0.1
6
+ relu_dropout: 0.1
7
+ decoder_embed_dim: 1024
8
+ decoder_output_dim: 1024
9
+ decoder_input_dim: 1024
10
+ decoder_ffn_embed_dim: 4096
11
+ decoder_layers: 16
12
+ decoder_attention_heads: 8
13
+ decoder_normalize_before: true
14
+ no_decoder_final_norm: true
15
+ adaptive_softmax_cutoff: "20000,60000"
16
+ adaptive_softmax_dropout: 0.2
17
+ adaptive_softmax_factor: 4
18
+ no_token_positional_embeddings: false
19
+ share_decoder_input_output_embed: false
20
+ character_embeddings: false
21
+ character_filters: "[(1, 64), (2, 128), (3, 192), (4, 256), (5, 256), (6, 256), (7, 256)]"
22
+ character_embedding_dim: 4
23
+ char_embedder_highway_layers: 2
24
+ adaptive_input: true
25
+ adaptive_input_factor: 4
26
+ adaptive_input_cutoff: "20000,60000"
27
+ tie_adaptive_weights: true
28
+ tie_adaptive_proj: true
29
+ decoder_learned_pos: false
30
+ decoder_layerdrop: 0
31
+ decoder_layers_to_keep: null
32
+ layernorm_embedding: false
33
+ no_scale_embedding: false
34
+ quant_noise_pq: 0
35
+ quant_noise_pq_block_size: 8
36
+ quant_noise_scalar: 0
fairseq/fairseq/config/model/transformer_lm/transformer_lm_big.yaml ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _group_
2
+ activation_fn: "relu"
3
+ dropout: 0.1
4
+ attention_dropout: 0.0
5
+ activation_dropout: 0.0
6
+ relu_dropout: 0.0
7
+ decoder_embed_dim: 1024
8
+ decoder_output_dim: 1024
9
+ decoder_input_dim: 1024
10
+ decoder_ffn_embed_dim: 4096
11
+ decoder_layers: 12
12
+ decoder_attention_heads: 16
13
+ decoder_normalize_before: true
14
+ no_decoder_final_norm: false
15
+ adaptive_softmax_cutoff: null
16
+ adaptive_softmax_dropout: 0
17
+ adaptive_softmax_factor: 4
18
+ no_token_positional_embeddings: false
19
+ share_decoder_input_output_embed: false
20
+ character_embeddings: false
21
+ character_filters: "[(1, 64), (2, 128), (3, 192), (4, 256), (5, 256), (6, 256), (7, 256)]"
22
+ character_embedding_dim: 4
23
+ char_embedder_highway_layers: 2
24
+ adaptive_input: false
25
+ adaptive_input_factor: 4
26
+ adaptive_input_cutoff: null
27
+ tie_adaptive_weights: false
28
+ tie_adaptive_proj: false
29
+ decoder_learned_pos: false
30
+ decoder_layerdrop: 0
31
+ decoder_layers_to_keep: null
32
+ layernorm_embedding: false
33
+ no_scale_embedding: false
34
+ quant_noise_pq: 0
35
+ quant_noise_pq_block_size: 8
36
+ quant_noise_scalar: 0
fairseq/fairseq/config/model/transformer_lm/transformer_lm_gbw.yaml ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _group_
2
+ activation_fn: "relu"
3
+ dropout: 0.1
4
+ attention_dropout: 0.1
5
+ activation_dropout: 0.0
6
+ relu_dropout: 0.0
7
+ decoder_embed_dim: 512
8
+ decoder_output_dim: 512
9
+ decoder_input_dim: 512
10
+ decoder_ffn_embed_dim: 4096
11
+ decoder_layers: 12
12
+ decoder_attention_heads: 16
13
+ decoder_normalize_before: true
14
+ no_decoder_final_norm: true
15
+ adaptive_softmax_cutoff: null
16
+ adaptive_softmax_dropout: 0
17
+ adaptive_softmax_factor: 4
18
+ no_token_positional_embeddings: false
19
+ share_decoder_input_output_embed: false
20
+ character_embeddings: false
21
+ character_filters: "[(1, 64), (2, 128), (3, 192), (4, 256), (5, 256), (6, 256), (7, 256)]"
22
+ character_embedding_dim: 4
23
+ char_embedder_highway_layers: 2
24
+ adaptive_input: false
25
+ adaptive_input_factor: 4
26
+ adaptive_input_cutoff: null
27
+ tie_adaptive_weights: false
28
+ tie_adaptive_proj: false
29
+ decoder_learned_pos: false
30
+ decoder_layerdrop: 0
31
+ decoder_layers_to_keep: null
32
+ layernorm_embedding: false
33
+ no_scale_embedding: false
34
+ quant_noise_pq: 0
35
+ quant_noise_pq_block_size: 8
36
+ quant_noise_scalar: 0
fairseq/fairseq/config/model/transformer_lm/transformer_lm_gpt.yaml ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _group_
2
+ activation_fn: "gelu"
3
+ dropout: 0.1
4
+ attention_dropout: 0.1
5
+ activation_dropout: 0.0
6
+ relu_dropout: 0.0
7
+ decoder_embed_dim: 768
8
+ decoder_output_dim: 768
9
+ decoder_input_dim: 768
10
+ decoder_ffn_embed_dim: 3072
11
+ decoder_layers: 12
12
+ decoder_attention_heads: 12
13
+ decoder_normalize_before: true
14
+ no_decoder_final_norm: false
15
+ adaptive_softmax_cutoff: null
16
+ adaptive_softmax_dropout: 0
17
+ adaptive_softmax_factor: 4
18
+ no_token_positional_embeddings: false
19
+ share_decoder_input_output_embed: false
20
+ character_embeddings: false
21
+ character_filters: "[(1, 64), (2, 128), (3, 192), (4, 256), (5, 256), (6, 256), (7, 256)]"
22
+ character_embedding_dim: 4
23
+ char_embedder_highway_layers: 2
24
+ adaptive_input: false
25
+ adaptive_input_factor: 4
26
+ adaptive_input_cutoff: null
27
+ tie_adaptive_weights: false
28
+ tie_adaptive_proj: false
29
+ decoder_learned_pos: false
30
+ decoder_layerdrop: 0
31
+ decoder_layers_to_keep: null
32
+ layernorm_embedding: false
33
+ no_scale_embedding: false
34
+ quant_noise_pq: 0
35
+ quant_noise_pq_block_size: 8
36
+ quant_noise_scalar: 0
fairseq/fairseq/config/model/transformer_lm/transformer_lm_gpt2_big.yaml ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _group_
2
+ activation_fn: "gelu"
3
+ dropout: 0.1
4
+ attention_dropout: 0.1
5
+ activation_dropout: 0.0
6
+ relu_dropout: 0.0
7
+ decoder_embed_dim: 1600
8
+ decoder_output_dim: 1600
9
+ decoder_input_dim: 1600
10
+ decoder_ffn_embed_dim: 6400
11
+ decoder_layers: 48
12
+ decoder_attention_heads: 25
13
+ decoder_normalize_before: true
14
+ no_decoder_final_norm: false
15
+ adaptive_softmax_cutoff: null
16
+ adaptive_softmax_dropout: 0
17
+ adaptive_softmax_factor: 4
18
+ no_token_positional_embeddings: false
19
+ share_decoder_input_output_embed: false
20
+ character_embeddings: false
21
+ character_filters: "[(1, 64), (2, 128), (3, 192), (4, 256), (5, 256), (6, 256), (7, 256)]"
22
+ character_embedding_dim: 4
23
+ char_embedder_highway_layers: 2
24
+ adaptive_input: false
25
+ adaptive_input_factor: 4
26
+ adaptive_input_cutoff: null
27
+ tie_adaptive_weights: false
28
+ tie_adaptive_proj: false
29
+ decoder_learned_pos: false
30
+ decoder_layerdrop: 0
31
+ decoder_layers_to_keep: null
32
+ layernorm_embedding: false
33
+ no_scale_embedding: false
34
+ quant_noise_pq: 0
35
+ quant_noise_pq_block_size: 8
36
+ quant_noise_scalar: 0
fairseq/fairseq/config/model/transformer_lm/transformer_lm_gpt2_medium.yaml ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _group_
2
+ activation_fn: "gelu"
3
+ dropout: 0.1
4
+ attention_dropout: 0.1
5
+ activation_dropout: 0.0
6
+ relu_dropout: 0.0
7
+ decoder_embed_dim: 1280
8
+ decoder_output_dim: 1280
9
+ decoder_input_dim: 1280
10
+ decoder_ffn_embed_dim: 5120
11
+ decoder_layers: 36
12
+ decoder_attention_heads: 20
13
+ decoder_normalize_before: true
14
+ no_decoder_final_norm: false
15
+ adaptive_softmax_cutoff: null
16
+ adaptive_softmax_dropout: 0
17
+ adaptive_softmax_factor: 4
18
+ no_token_positional_embeddings: false
19
+ share_decoder_input_output_embed: false
20
+ character_embeddings: false
21
+ character_filters: "[(1, 64), (2, 128), (3, 192), (4, 256), (5, 256), (6, 256), (7, 256)]"
22
+ character_embedding_dim: 4
23
+ char_embedder_highway_layers: 2
24
+ adaptive_input: false
25
+ adaptive_input_factor: 4
26
+ adaptive_input_cutoff: null
27
+ tie_adaptive_weights: false
28
+ tie_adaptive_proj: false
29
+ decoder_learned_pos: false
30
+ decoder_layerdrop: 0
31
+ decoder_layers_to_keep: null
32
+ layernorm_embedding: false
33
+ no_scale_embedding: false
34
+ quant_noise_pq: 0
35
+ quant_noise_pq_block_size: 8
36
+ quant_noise_scalar: 0
fairseq/fairseq/config/model/transformer_lm/transformer_lm_gpt2_small.yaml ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _group_
2
+ activation_fn: "gelu"
3
+ dropout: 0.1
4
+ attention_dropout: 0.1
5
+ activation_dropout: 0.0
6
+ relu_dropout: 0.0
7
+ decoder_embed_dim: 1024
8
+ decoder_output_dim: 1024
9
+ decoder_input_dim: 1024
10
+ decoder_ffn_embed_dim: 4096
11
+ decoder_layers: 24
12
+ decoder_attention_heads: 16
13
+ decoder_normalize_before: true
14
+ no_decoder_final_norm: false
15
+ adaptive_softmax_cutoff: null
16
+ adaptive_softmax_dropout: 0
17
+ adaptive_softmax_factor: 4
18
+ no_token_positional_embeddings: false
19
+ share_decoder_input_output_embed: false
20
+ character_embeddings: false
21
+ character_filters: "[(1, 64), (2, 128), (3, 192), (4, 256), (5, 256), (6, 256), (7, 256)]"
22
+ character_embedding_dim: 4
23
+ char_embedder_highway_layers: 2
24
+ adaptive_input: false
25
+ adaptive_input_factor: 4
26
+ adaptive_input_cutoff: null
27
+ tie_adaptive_weights: false
28
+ tie_adaptive_proj: false
29
+ decoder_learned_pos: false
30
+ decoder_layerdrop: 0
31
+ decoder_layers_to_keep: null
32
+ layernorm_embedding: false
33
+ no_scale_embedding: false
34
+ quant_noise_pq: 0
35
+ quant_noise_pq_block_size: 8
36
+ quant_noise_scalar: 0
fairseq/fairseq/config/model/transformer_lm/transformer_lm_wiki103.yaml ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _group_
2
+ activation_fn: "relu"
3
+ dropout: 0.3
4
+ attention_dropout: 0.1
5
+ activation_dropout: 0.1
6
+ relu_dropout: 0.1
7
+ decoder_embed_dim: 1024
8
+ decoder_output_dim: 1024
9
+ decoder_input_dim: 1024
10
+ decoder_ffn_embed_dim: 4096
11
+ decoder_layers: 16
12
+ decoder_attention_heads: 8
13
+ decoder_normalize_before: true
14
+ no_decoder_final_norm: true
15
+ adaptive_softmax_cutoff: "20000,60000"
16
+ adaptive_softmax_dropout: 0.2
17
+ adaptive_softmax_factor: 4
18
+ no_token_positional_embeddings: false
19
+ share_decoder_input_output_embed: false
20
+ character_embeddings: false
21
+ character_filters: "[(1, 64), (2, 128), (3, 192), (4, 256), (5, 256), (6, 256), (7, 256)]"
22
+ character_embedding_dim: 4
23
+ char_embedder_highway_layers: 2
24
+ adaptive_input: true
25
+ adaptive_input_factor: 4
26
+ adaptive_input_cutoff: "20000,60000"
27
+ tie_adaptive_weights: true
28
+ tie_adaptive_proj: true
29
+ decoder_learned_pos: false
30
+ decoder_layerdrop: 0
31
+ decoder_layers_to_keep: null
32
+ layernorm_embedding: false
33
+ no_scale_embedding: false
34
+ quant_noise_pq: 0
35
+ quant_noise_pq_block_size: 8
36
+ quant_noise_scalar: 0
fairseq/fairseq/config/model/wav2vec/vq_wav2vec_gumbel.yaml ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # @package _group_
2
+ activation: gelu
3
+ vq_type: gumbel
4
+ vq_depth: 2
5
+ combine_groups: true
fairseq/fairseq/config/model/wav2vec2/wav2vec2_base.yaml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ # @package _group_
2
+
3
+ quantize_targets: true
4
+ final_dim: 256
5
+ encoder_layerdrop: 0.05
6
+ dropout_input: 0.1
7
+ dropout_features: 0.1
8
+ feature_grad_mult: 0.1
fairseq/fairseq/config/model/wav2vec2/wav2vec2_large.yaml ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _group_
2
+
3
+ quantize_targets: true
4
+ extractor_mode: layer_norm
5
+ layer_norm_first: true
6
+ final_dim: 768
7
+ latent_temp: [2.0,0.1,0.999995]
8
+ encoder_layerdrop: 0.0
9
+ dropout_input: 0.0
10
+ dropout_features: 0.0
11
+ dropout: 0.0
12
+ attention_dropout: 0.0
13
+ conv_bias: true
14
+
15
+ encoder_layers: 24
16
+ encoder_embed_dim: 1024
17
+ encoder_ffn_embed_dim: 4096
18
+ encoder_attention_heads: 16
19
+
20
+ feature_grad_mult: 1.0
fairseq/fairseq/criterions/__init__.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ """isort:skip_file"""
6
+
7
+ import importlib
8
+ import os
9
+
10
+ from fairseq import registry
11
+ from fairseq.criterions.fairseq_criterion import ( # noqa
12
+ FairseqCriterion,
13
+ LegacyFairseqCriterion,
14
+ )
15
+ from omegaconf import DictConfig
16
+
17
+
18
+ (
19
+ build_criterion_,
20
+ register_criterion,
21
+ CRITERION_REGISTRY,
22
+ CRITERION_DATACLASS_REGISTRY,
23
+ ) = registry.setup_registry(
24
+ "--criterion", base_class=FairseqCriterion, default="cross_entropy"
25
+ )
26
+
27
+
28
+ def build_criterion(cfg: DictConfig, task, from_checkpoint=False):
29
+ return build_criterion_(cfg, task, from_checkpoint=from_checkpoint)
30
+
31
+
32
+ # automatically import any Python files in the criterions/ directory
33
+ for file in sorted(os.listdir(os.path.dirname(__file__))):
34
+ if file.endswith(".py") and not file.startswith("_"):
35
+ file_name = file[: file.find(".py")]
36
+ importlib.import_module("fairseq.criterions." + file_name)
fairseq/fairseq/criterions/__pycache__/fairseq_criterion.cpython-310.pyc ADDED
Binary file (4.53 kB). View file
 
fairseq/fairseq/criterions/adaptive_loss.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import math
7
+ from dataclasses import dataclass
8
+
9
+ import torch.nn.functional as F
10
+ from fairseq import utils
11
+ from fairseq.logging import metrics
12
+ from fairseq.criterions import FairseqCriterion, register_criterion
13
+ from fairseq.dataclass import FairseqDataclass
14
+ from fairseq.dataclass.constants import DDP_BACKEND_CHOICES
15
+ from omegaconf import II
16
+
17
+
18
+ @dataclass
19
+ class AdaptiveLossConfig(FairseqDataclass):
20
+ sentence_avg: bool = II("optimization.sentence_avg")
21
+ ddp_backend: DDP_BACKEND_CHOICES = II("distributed_training.ddp_backend")
22
+
23
+
24
+ @register_criterion("adaptive_loss", dataclass=AdaptiveLossConfig)
25
+ class AdaptiveLoss(FairseqCriterion):
26
+ """This is an implementation of the loss function accompanying the adaptive softmax approximation for
27
+ graphical processing units (GPU), described in the paper "Efficient softmax approximation for GPUs"
28
+ (http://arxiv.org/abs/1609.04309)."""
29
+
30
+ def __init__(self, task, sentence_avg):
31
+ super().__init__(task)
32
+ self.sentence_avg = sentence_avg
33
+
34
+ @classmethod
35
+ def build_criterion(cls, cfg: AdaptiveLossConfig, task):
36
+ if cfg.ddp_backend in {"c10d", "pytorch_ddp"}:
37
+ raise Exception(
38
+ "AdaptiveLoss is not compatible with the PyTorch "
39
+ "version of DistributedDataParallel. Please use "
40
+ "`--ddp-backend=legacy_ddp` instead."
41
+ )
42
+ return cls(task, cfg.sentence_avg)
43
+
44
+ def forward(self, model, sample, reduce=True):
45
+ """Compute the loss for the given sample.
46
+
47
+ Returns a tuple with three elements:
48
+ 1) the loss
49
+ 2) the sample size, which is used as the denominator for the gradient
50
+ 3) logging outputs to display while training
51
+ """
52
+
53
+ assert (
54
+ hasattr(model.decoder, "adaptive_softmax")
55
+ and model.decoder.adaptive_softmax is not None
56
+ )
57
+ adaptive_softmax = model.decoder.adaptive_softmax
58
+
59
+ net_output = model(**sample["net_input"])
60
+ orig_target = model.get_targets(sample, net_output)
61
+
62
+ nsentences = orig_target.size(0)
63
+ orig_target = orig_target.view(-1)
64
+
65
+ bsz = orig_target.size(0)
66
+
67
+ logits, target = adaptive_softmax(net_output[0], orig_target)
68
+ assert len(target) == len(logits)
69
+
70
+ loss = net_output[0].new(1 if reduce else bsz).zero_()
71
+
72
+ for i in range(len(target)):
73
+ if target[i] is not None:
74
+ assert target[i].min() >= 0 and target[i].max() <= logits[i].size(1)
75
+ loss += F.cross_entropy(
76
+ logits[i],
77
+ target[i],
78
+ ignore_index=self.padding_idx,
79
+ reduction="sum" if reduce else "none",
80
+ )
81
+
82
+ orig = utils.strip_pad(orig_target, self.padding_idx)
83
+ ntokens = orig.numel()
84
+ sample_size = sample["target"].size(0) if self.sentence_avg else ntokens
85
+ logging_output = {
86
+ "loss": loss.data,
87
+ "ntokens": ntokens,
88
+ "nsentences": nsentences,
89
+ "sample_size": sample_size,
90
+ }
91
+ return loss, sample_size, logging_output
92
+
93
+ @staticmethod
94
+ def reduce_metrics(logging_outputs) -> None:
95
+ """Aggregate logging outputs from data parallel training."""
96
+ loss_sum = utils.item(sum(log.get("loss", 0) for log in logging_outputs))
97
+ ntokens = utils.item(sum(log.get("ntokens", 0) for log in logging_outputs))
98
+ sample_size = utils.item(
99
+ sum(log.get("sample_size", 0) for log in logging_outputs)
100
+ )
101
+
102
+ metrics.log_scalar(
103
+ "loss", loss_sum / sample_size / math.log(2), sample_size, round=3
104
+ )
105
+ if sample_size != ntokens:
106
+ metrics.log_scalar(
107
+ "nll_loss", loss_sum / ntokens / math.log(2), ntokens, round=3
108
+ )
109
+ metrics.log_derived(
110
+ "ppl", lambda meters: utils.get_perplexity(meters["nll_loss"].avg)
111
+ )
112
+ else:
113
+ metrics.log_derived(
114
+ "ppl", lambda meters: utils.get_perplexity(meters["loss"].avg)
115
+ )
116
+
117
+ @staticmethod
118
+ def logging_outputs_can_be_summed() -> bool:
119
+ """
120
+ Whether the logging outputs returned by `forward` can be summed
121
+ across workers prior to calling `reduce_metrics`. Setting this
122
+ to True will improves distributed training speed.
123
+ """
124
+ return True
fairseq/fairseq/criterions/composite_loss.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from fairseq import utils
7
+ from fairseq.criterions import LegacyFairseqCriterion, register_criterion
8
+ from torch import nn
9
+
10
+
11
+ @register_criterion("composite_loss")
12
+ class CompositeLoss(LegacyFairseqCriterion):
13
+ """This is a composite loss that, given a list of model outputs and a list of targets,
14
+ computes an average of losses for each output-target pair"""
15
+
16
+ def __init__(self, args, task):
17
+ super().__init__(args, task)
18
+ self.underlying_criterion = args.underlying_criterion
19
+
20
+ @staticmethod
21
+ def add_args(parser):
22
+ """Add criterion-specific arguments to the parser."""
23
+ # fmt: off
24
+ parser.add_argument('--underlying-criterion', type=str, metavar='VAL', required=True,
25
+ help='underlying criterion to use for the composite loss')
26
+ # fmt: on
27
+
28
+ @staticmethod
29
+ def build_underlying_criterion(args, task):
30
+ saved_criterion = args.criterion
31
+ args.criterion = args.underlying_criterion
32
+ assert saved_criterion != args.underlying_criterion
33
+ underlying_criterion = task.build_criterion(args)
34
+ args.criterion = saved_criterion
35
+ return underlying_criterion
36
+
37
+ @classmethod
38
+ def build_criterion(cls, args, task):
39
+ underlying_criterion = CompositeLoss.build_underlying_criterion(args, task)
40
+
41
+ class FakeModel(nn.Module):
42
+ def __init__(self, model, net_out, target):
43
+ super().__init__()
44
+ self.model = model
45
+ self.net_out = net_out
46
+ self.target = target
47
+
48
+ def forward(self, **unused):
49
+ return self.net_out
50
+
51
+ def get_normalized_probs(self, net_output, log_probs, sample=None):
52
+ return self.model.get_normalized_probs(
53
+ net_output, log_probs, sample=sample
54
+ )
55
+
56
+ def get_targets(self, *unused):
57
+ return self.target
58
+
59
+ @property
60
+ def decoder(self):
61
+ return self.model.decoder
62
+
63
+ class _CompositeLoss(LegacyFairseqCriterion):
64
+ def __init__(self, args, task, underlying_criterion):
65
+ super().__init__(args, task)
66
+ self.underlying_criterion = underlying_criterion
67
+
68
+ def forward(self, model, sample, reduce=True):
69
+ net_outputs = model(**sample["net_input"])
70
+ targets = sample["target"]
71
+
72
+ bsz = targets[0].size(0)
73
+ loss = net_outputs[0][0].new(1 if reduce else bsz).float().zero_()
74
+
75
+ sample_size = 0
76
+ logging_output = {}
77
+ for o, t in zip(net_outputs[0], targets):
78
+ m = FakeModel(model, (o, net_outputs[1]), t)
79
+ sample["target"] = t
80
+ l, ss, logging_output = self.underlying_criterion(m, sample, reduce)
81
+ loss += l
82
+ sample_size += ss
83
+
84
+ loss.div_(len(targets))
85
+ sample_size /= len(targets)
86
+
87
+ logging_output["loss"] = utils.item(loss.data) if reduce else loss.data
88
+ return loss, sample_size, logging_output
89
+
90
+ @staticmethod
91
+ def aggregate_logging_outputs(logging_outputs):
92
+ return underlying_criterion.__class__.aggregate_logging_outputs(
93
+ logging_outputs
94
+ )
95
+
96
+ @staticmethod
97
+ def reduce_metrics(logging_outputs) -> None:
98
+ underlying_criterion.__class__.reduce_metrics(logging_outputs)
99
+
100
+ return _CompositeLoss(args, task, underlying_criterion)
fairseq/fairseq/criterions/cross_entropy.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import math
7
+ from dataclasses import dataclass
8
+
9
+ import torch.nn.functional as F
10
+ from fairseq import utils
11
+ from fairseq.logging import metrics
12
+ from fairseq.criterions import FairseqCriterion, register_criterion
13
+ from fairseq.dataclass import FairseqDataclass
14
+ from omegaconf import II
15
+
16
+
17
+ @dataclass
18
+ class CrossEntropyCriterionConfig(FairseqDataclass):
19
+ sentence_avg: bool = II("optimization.sentence_avg")
20
+
21
+
22
+ @register_criterion("cross_entropy", dataclass=CrossEntropyCriterionConfig)
23
+ class CrossEntropyCriterion(FairseqCriterion):
24
+ def __init__(self, task, sentence_avg):
25
+ super().__init__(task)
26
+ self.sentence_avg = sentence_avg
27
+
28
+ def forward(self, model, sample, reduce=True):
29
+ """Compute the loss for the given sample.
30
+
31
+ Returns a tuple with three elements:
32
+ 1) the loss
33
+ 2) the sample size, which is used as the denominator for the gradient
34
+ 3) logging outputs to display while training
35
+ """
36
+ net_output = model(**sample["net_input"])
37
+ loss, _ = self.compute_loss(model, net_output, sample, reduce=reduce)
38
+ sample_size = (
39
+ sample["target"].size(0) if self.sentence_avg else sample["ntokens"]
40
+ )
41
+ logging_output = {
42
+ "loss": loss.data,
43
+ "ntokens": sample["ntokens"],
44
+ "nsentences": sample["target"].size(0),
45
+ "sample_size": sample_size,
46
+ }
47
+ return loss, sample_size, logging_output
48
+
49
+ def compute_loss(self, model, net_output, sample, reduce=True):
50
+ lprobs = model.get_normalized_probs(net_output, log_probs=True)
51
+ lprobs = lprobs.view(-1, lprobs.size(-1))
52
+ target = model.get_targets(sample, net_output).view(-1)
53
+ loss = F.nll_loss(
54
+ lprobs,
55
+ target,
56
+ ignore_index=self.padding_idx,
57
+ reduction="sum" if reduce else "none",
58
+ )
59
+ return loss, loss
60
+
61
+ @staticmethod
62
+ def reduce_metrics(logging_outputs) -> None:
63
+ """Aggregate logging outputs from data parallel training."""
64
+ loss_sum = sum(log.get("loss", 0) for log in logging_outputs)
65
+ ntokens = sum(log.get("ntokens", 0) for log in logging_outputs)
66
+ sample_size = sum(log.get("sample_size", 0) for log in logging_outputs)
67
+
68
+ # we divide by log(2) to convert the loss from base e to base 2
69
+ metrics.log_scalar(
70
+ "loss", loss_sum / sample_size / math.log(2), sample_size, round=3
71
+ )
72
+ if sample_size != ntokens:
73
+ metrics.log_scalar(
74
+ "nll_loss", loss_sum / ntokens / math.log(2), ntokens, round=3
75
+ )
76
+ metrics.log_derived(
77
+ "ppl", lambda meters: utils.get_perplexity(meters["nll_loss"].avg)
78
+ )
79
+ else:
80
+ metrics.log_derived(
81
+ "ppl", lambda meters: utils.get_perplexity(meters["loss"].avg)
82
+ )
83
+
84
+ @staticmethod
85
+ def logging_outputs_can_be_summed() -> bool:
86
+ """
87
+ Whether the logging outputs returned by `forward` can be summed
88
+ across workers prior to calling `reduce_metrics`. Setting this
89
+ to True will improves distributed training speed.
90
+ """
91
+ return True
fairseq/fairseq/criterions/ctc.py ADDED
@@ -0,0 +1,325 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # All rights reserved.
2
+ #
3
+ # This source code is licensed under the license found in the LICENSE file in
4
+ # the root directory of this source tree. An additional grant of patent rights
5
+ # can be found in the PATENTS file in the same directory.
6
+
7
+ import math
8
+ from argparse import Namespace
9
+ from dataclasses import dataclass, field
10
+ from omegaconf import II
11
+ from typing import Optional
12
+
13
+ import torch
14
+ import torch.nn.functional as F
15
+
16
+ from fairseq import utils
17
+ from fairseq.logging import metrics
18
+ from fairseq.criterions import FairseqCriterion, register_criterion
19
+ from fairseq.dataclass import FairseqDataclass
20
+ from fairseq.data.data_utils import post_process
21
+ from fairseq.tasks import FairseqTask
22
+ from fairseq.logging.meters import safe_round
23
+
24
+
25
+ @dataclass
26
+ class CtcCriterionConfig(FairseqDataclass):
27
+ zero_infinity: bool = field(
28
+ default=False,
29
+ metadata={"help": "zero inf loss when source length <= target length"},
30
+ )
31
+ sentence_avg: bool = II("optimization.sentence_avg")
32
+ post_process: str = field(
33
+ default="letter",
34
+ metadata={
35
+ "help": "how to post process predictions into words. can be letter, "
36
+ "wordpiece, BPE symbols, etc. "
37
+ "See fairseq.data.data_utils.post_process() for full list of options"
38
+ },
39
+ )
40
+ wer_kenlm_model: Optional[str] = field(
41
+ default=None,
42
+ metadata={
43
+ "help": "if this is provided, use kenlm to compute wer (along with other wer_* args)"
44
+ },
45
+ )
46
+ wer_lexicon: Optional[str] = field(
47
+ default=None,
48
+ metadata={"help": "lexicon to use with wer_kenlm_model"},
49
+ )
50
+ wer_lm_weight: float = field(
51
+ default=2.0,
52
+ metadata={"help": "lm weight to use with wer_kenlm_model"},
53
+ )
54
+ wer_word_score: float = field(
55
+ default=-1.0,
56
+ metadata={"help": "lm word score to use with wer_kenlm_model"},
57
+ )
58
+ wer_sil_weight: float = field(
59
+ default=0,
60
+ metadata={"help": "lm word score to use with wer_kenlm_model"},
61
+ )
62
+
63
+ wer_args: Optional[str] = field(
64
+ default=None,
65
+ metadata={
66
+ "help": "DEPRECATED: tuple of (wer_kenlm_model, wer_lexicon, wer_lm_weight, wer_word_score)"
67
+ },
68
+ )
69
+
70
+
71
+ @register_criterion("ctc", dataclass=CtcCriterionConfig)
72
+ class CtcCriterion(FairseqCriterion):
73
+ def __init__(
74
+ self, cfg: CtcCriterionConfig, task: FairseqTask, rdrop_alpha: int = 0.0
75
+ ):
76
+ super().__init__(task)
77
+ self.blank_idx = (
78
+ task.target_dictionary.index(task.blank_symbol)
79
+ if hasattr(task, "blank_symbol")
80
+ else 0
81
+ )
82
+ self.pad_idx = task.target_dictionary.pad()
83
+ self.eos_idx = task.target_dictionary.eos()
84
+ self.post_process = cfg.post_process
85
+
86
+ self.rdrop_alpha = rdrop_alpha
87
+
88
+ if cfg.wer_args is not None:
89
+ (
90
+ cfg.wer_kenlm_model,
91
+ cfg.wer_lexicon,
92
+ cfg.wer_lm_weight,
93
+ cfg.wer_word_score,
94
+ ) = eval(cfg.wer_args)
95
+
96
+ if cfg.wer_kenlm_model is not None and cfg.wer_kenlm_model != "":
97
+ from examples.speech_recognition.w2l_decoder import W2lKenLMDecoder
98
+
99
+ dec_args = Namespace()
100
+ dec_args.nbest = 1
101
+ dec_args.criterion = "ctc"
102
+ dec_args.kenlm_model = cfg.wer_kenlm_model
103
+ dec_args.lexicon = cfg.wer_lexicon
104
+ dec_args.beam = 50
105
+ dec_args.beam_size_token = min(50, len(task.target_dictionary))
106
+ dec_args.beam_threshold = min(50, len(task.target_dictionary))
107
+ dec_args.lm_weight = cfg.wer_lm_weight
108
+ dec_args.word_score = cfg.wer_word_score
109
+ dec_args.sil_weight = cfg.wer_sil_weight
110
+ dec_args.unk_weight = -math.inf
111
+ dec_args.sil_weight = 0
112
+
113
+ self.w2l_decoder = W2lKenLMDecoder(dec_args, task.target_dictionary)
114
+ else:
115
+ self.w2l_decoder = None
116
+
117
+ self.zero_infinity = cfg.zero_infinity
118
+ self.sentence_avg = cfg.sentence_avg
119
+
120
+ def forward(self, model, sample, reduce=True, **kwargs):
121
+ net_output = model(**sample["net_input"])
122
+ lprobs = model.get_normalized_probs(
123
+ net_output, log_probs=True
124
+ ).contiguous() # (T, B, C) from the encoder
125
+
126
+ # CTC loss is calculated over duplicated inputs
127
+ # sample is already duplicated for R-Drop
128
+ if self.rdrop_alpha > 0:
129
+ for k, v in sample.items():
130
+ if k in ["target", "target_lengths"]:
131
+ sample[k] = torch.cat([v, v.clone()], dim=0)
132
+ elif k == "net_input":
133
+ if sample[k]["src_tokens"].size(1) != sample[k]["src_lengths"].size(
134
+ 0
135
+ ):
136
+ # for decoder CTC loss
137
+ sample[k]["src_lengths"] = torch.cat(
138
+ [
139
+ sample[k]["src_lengths"],
140
+ sample[k]["src_lengths"].clone(),
141
+ ],
142
+ dim=0,
143
+ )
144
+
145
+ if "src_lengths" in sample["net_input"]:
146
+ input_lengths = sample["net_input"]["src_lengths"]
147
+ else:
148
+ if net_output["padding_mask"] is not None:
149
+ non_padding_mask = ~net_output["padding_mask"]
150
+ input_lengths = non_padding_mask.long().sum(-1)
151
+ else:
152
+ input_lengths = lprobs.new_full(
153
+ (lprobs.size(1),), lprobs.size(0), dtype=torch.long
154
+ )
155
+
156
+ pad_mask = (sample["target"] != self.pad_idx) & (
157
+ sample["target"] != self.eos_idx
158
+ )
159
+ targets_flat = sample["target"].masked_select(pad_mask)
160
+ if "target_lengths" in sample:
161
+ target_lengths = sample["target_lengths"]
162
+ else:
163
+ target_lengths = pad_mask.sum(-1)
164
+
165
+ with torch.backends.cudnn.flags(enabled=False):
166
+ loss = F.ctc_loss(
167
+ lprobs,
168
+ targets_flat,
169
+ input_lengths,
170
+ target_lengths,
171
+ blank=self.blank_idx,
172
+ reduction="sum",
173
+ zero_infinity=self.zero_infinity,
174
+ )
175
+
176
+ ntokens = (
177
+ sample["ntokens"] if "ntokens" in sample else target_lengths.sum().item()
178
+ )
179
+
180
+ sample_size = sample["target"].size(0) if self.sentence_avg else ntokens
181
+ logging_output = {
182
+ "loss": utils.item(loss.data), # * sample['ntokens'],
183
+ "ntokens": ntokens,
184
+ "nsentences": sample["id"].numel(),
185
+ "sample_size": sample_size,
186
+ }
187
+
188
+ if not model.training:
189
+ import editdistance
190
+
191
+ with torch.no_grad():
192
+ lprobs_t = lprobs.transpose(0, 1).float().contiguous().cpu()
193
+
194
+ c_err = 0
195
+ c_len = 0
196
+ w_errs = 0
197
+ w_len = 0
198
+ wv_errs = 0
199
+ for lp, t, inp_l in zip(
200
+ lprobs_t,
201
+ sample["target_label"]
202
+ if "target_label" in sample
203
+ else sample["target"],
204
+ input_lengths,
205
+ ):
206
+ lp = lp[:inp_l].unsqueeze(0)
207
+
208
+ decoded = None
209
+ if self.w2l_decoder is not None:
210
+ decoded = self.w2l_decoder.decode(lp)
211
+ if len(decoded) < 1:
212
+ decoded = None
213
+ else:
214
+ decoded = decoded[0]
215
+ if len(decoded) < 1:
216
+ decoded = None
217
+ else:
218
+ decoded = decoded[0]
219
+
220
+ p = (t != self.task.target_dictionary.pad()) & (
221
+ t != self.task.target_dictionary.eos()
222
+ )
223
+ targ = t[p]
224
+ targ_units = self.task.target_dictionary.string(targ)
225
+ targ_units_arr = targ.tolist()
226
+
227
+ toks = lp.argmax(dim=-1).unique_consecutive()
228
+ pred_units_arr = toks[toks != self.blank_idx].tolist()
229
+
230
+ c_err += editdistance.eval(pred_units_arr, targ_units_arr)
231
+ c_len += len(targ_units_arr)
232
+
233
+ targ_words = post_process(targ_units, self.post_process).split()
234
+
235
+ pred_units = self.task.target_dictionary.string(pred_units_arr)
236
+ pred_words_raw = post_process(pred_units, self.post_process).split()
237
+
238
+ if decoded is not None and "words" in decoded:
239
+ pred_words = decoded["words"]
240
+ w_errs += editdistance.eval(pred_words, targ_words)
241
+ wv_errs += editdistance.eval(pred_words_raw, targ_words)
242
+ else:
243
+ dist = editdistance.eval(pred_words_raw, targ_words)
244
+ w_errs += dist
245
+ wv_errs += dist
246
+
247
+ w_len += len(targ_words)
248
+
249
+ logging_output["wv_errors"] = wv_errs
250
+ logging_output["w_errors"] = w_errs
251
+ logging_output["w_total"] = w_len
252
+ logging_output["c_errors"] = c_err
253
+ logging_output["c_total"] = c_len
254
+
255
+ return loss, sample_size, logging_output
256
+
257
+ @staticmethod
258
+ def reduce_metrics(logging_outputs) -> None:
259
+ """Aggregate logging outputs from data parallel training."""
260
+
261
+ loss_sum = utils.item(sum(log.get("loss", 0) for log in logging_outputs))
262
+ ntokens = utils.item(sum(log.get("ntokens", 0) for log in logging_outputs))
263
+ nsentences = utils.item(
264
+ sum(log.get("nsentences", 0) for log in logging_outputs)
265
+ )
266
+ sample_size = utils.item(
267
+ sum(log.get("sample_size", 0) for log in logging_outputs)
268
+ )
269
+
270
+ metrics.log_scalar(
271
+ "loss", loss_sum / sample_size / math.log(2), sample_size, round=3
272
+ )
273
+ metrics.log_scalar("ntokens", ntokens)
274
+ metrics.log_scalar("nsentences", nsentences)
275
+ if sample_size != ntokens:
276
+ metrics.log_scalar(
277
+ "nll_loss", loss_sum / ntokens / math.log(2), ntokens, round=3
278
+ )
279
+
280
+ c_errors = sum(log.get("c_errors", 0) for log in logging_outputs)
281
+ metrics.log_scalar("_c_errors", c_errors)
282
+ c_total = sum(log.get("c_total", 0) for log in logging_outputs)
283
+ metrics.log_scalar("_c_total", c_total)
284
+ w_errors = sum(log.get("w_errors", 0) for log in logging_outputs)
285
+ metrics.log_scalar("_w_errors", w_errors)
286
+ wv_errors = sum(log.get("wv_errors", 0) for log in logging_outputs)
287
+ metrics.log_scalar("_wv_errors", wv_errors)
288
+ w_total = sum(log.get("w_total", 0) for log in logging_outputs)
289
+ metrics.log_scalar("_w_total", w_total)
290
+
291
+ if c_total > 0:
292
+ metrics.log_derived(
293
+ "uer",
294
+ lambda meters: safe_round(
295
+ meters["_c_errors"].sum * 100.0 / meters["_c_total"].sum, 3
296
+ )
297
+ if meters["_c_total"].sum > 0
298
+ else float("nan"),
299
+ )
300
+ if w_total > 0:
301
+ metrics.log_derived(
302
+ "wer",
303
+ lambda meters: safe_round(
304
+ meters["_w_errors"].sum * 100.0 / meters["_w_total"].sum, 3
305
+ )
306
+ if meters["_w_total"].sum > 0
307
+ else float("nan"),
308
+ )
309
+ metrics.log_derived(
310
+ "raw_wer",
311
+ lambda meters: safe_round(
312
+ meters["_wv_errors"].sum * 100.0 / meters["_w_total"].sum, 3
313
+ )
314
+ if meters["_w_total"].sum > 0
315
+ else float("nan"),
316
+ )
317
+
318
+ @staticmethod
319
+ def logging_outputs_can_be_summed() -> bool:
320
+ """
321
+ Whether the logging outputs returned by `forward` can be summed
322
+ across workers prior to calling `reduce_metrics`. Setting this
323
+ to True will improves distributed training speed.
324
+ """
325
+ return True
fairseq/fairseq/criterions/fairseq_criterion.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import inspect
7
+ from typing import Any, Dict, List
8
+
9
+ from fairseq import utils
10
+ from fairseq.logging import metrics
11
+ from fairseq.dataclass import FairseqDataclass
12
+ from fairseq.dataclass.utils import gen_parser_from_dataclass
13
+ from torch.nn.modules.loss import _Loss
14
+
15
+
16
+ class FairseqCriterion(_Loss):
17
+ def __init__(self, task):
18
+ super().__init__()
19
+ self.task = task
20
+ if hasattr(task, "target_dictionary"):
21
+ tgt_dict = task.target_dictionary
22
+ self.padding_idx = tgt_dict.pad() if tgt_dict is not None else -100
23
+
24
+ @classmethod
25
+ def add_args(cls, parser):
26
+ """Add criterion-specific arguments to the parser."""
27
+ dc = getattr(cls, "__dataclass", None)
28
+ if dc is not None:
29
+ gen_parser_from_dataclass(parser, dc())
30
+
31
+ @classmethod
32
+ def build_criterion(cls, cfg: FairseqDataclass, task):
33
+ """Construct a criterion from command-line args."""
34
+ # arguments in the __init__.
35
+ init_args = {}
36
+ for p in inspect.signature(cls).parameters.values():
37
+ if (
38
+ p.kind == p.POSITIONAL_ONLY
39
+ or p.kind == p.VAR_POSITIONAL
40
+ or p.kind == p.VAR_KEYWORD
41
+ ):
42
+ # we haven't implemented inference for these argument types,
43
+ # but PRs welcome :)
44
+ raise NotImplementedError("{} not supported".format(p.kind))
45
+
46
+ assert p.kind in {p.POSITIONAL_OR_KEYWORD, p.KEYWORD_ONLY}
47
+
48
+ if p.name == "task":
49
+ init_args["task"] = task
50
+ elif p.name == "cfg":
51
+ init_args["cfg"] = cfg
52
+ elif hasattr(cfg, p.name):
53
+ init_args[p.name] = getattr(cfg, p.name)
54
+ elif p.default != p.empty:
55
+ pass # we'll use the default value
56
+ else:
57
+ raise NotImplementedError(
58
+ "Unable to infer Criterion arguments, please implement "
59
+ "{}.build_criterion".format(cls.__name__)
60
+ )
61
+ return cls(**init_args)
62
+
63
+ def forward(self, model, sample, reduce=True):
64
+ """Compute the loss for the given sample.
65
+
66
+ Returns a tuple with three elements:
67
+ 1) the loss
68
+ 2) the sample size, which is used as the denominator for the gradient
69
+ 3) logging outputs to display while training
70
+ """
71
+ raise NotImplementedError
72
+
73
+ @staticmethod
74
+ def aggregate_logging_outputs(
75
+ logging_outputs: List[Dict[str, Any]]
76
+ ) -> Dict[str, Any]:
77
+ """Aggregate logging outputs from data parallel training."""
78
+ utils.deprecation_warning(
79
+ "The aggregate_logging_outputs API is deprecated. "
80
+ "Please use the reduce_metrics API instead."
81
+ )
82
+ raise NotImplementedError
83
+
84
+ @classmethod
85
+ def reduce_metrics(cls, logging_outputs: List[Dict[str, Any]]) -> None:
86
+ """Aggregate logging outputs from data parallel training."""
87
+ utils.deprecation_warning(
88
+ "Criterions should implement the reduce_metrics API. "
89
+ "Falling back to deprecated aggregate_logging_outputs API."
90
+ )
91
+ agg_logging_outputs = cls.aggregate_logging_outputs(logging_outputs)
92
+ for k, v in agg_logging_outputs.items():
93
+ if k in {"nsentences", "ntokens", "sample_size"}:
94
+ continue
95
+ metrics.log_scalar(k, v)
96
+
97
+ @staticmethod
98
+ def logging_outputs_can_be_summed() -> bool:
99
+ """
100
+ Whether the logging outputs returned by `forward` can be summed
101
+ across workers prior to calling `reduce_metrics`. Setting this
102
+ to True will improves distributed training speed.
103
+ """
104
+ return False
105
+
106
+
107
+ class LegacyFairseqCriterion(FairseqCriterion):
108
+ def __init__(self, args, task):
109
+ super().__init__(task=task)
110
+ self.args = args
111
+
112
+ utils.deprecation_warning(
113
+ "Criterions should take explicit arguments instead of an "
114
+ "argparse.Namespace object, please update your criterion by "
115
+ "extending FairseqCriterion instead of LegacyFairseqCriterion."
116
+ )
117
+
118
+ @classmethod
119
+ def build_criterion(cls, args, task):
120
+ """Construct a criterion from command-line args."""
121
+ return cls(args, task)
fairseq/fairseq/criterions/fastspeech2_loss.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2017-present, Facebook, Inc.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the LICENSE file in
5
+ # the root directory of this source tree. An additional grant of patent rights
6
+ # can be found in the PATENTS file in the same directory.
7
+
8
+ from typing import List, Dict, Any
9
+ from dataclasses import dataclass, field
10
+
11
+ import torch
12
+ import torch.nn.functional as F
13
+
14
+ from fairseq import utils
15
+ from fairseq.logging import metrics
16
+ from fairseq.criterions import FairseqCriterion, register_criterion
17
+ from fairseq.dataclass import FairseqDataclass
18
+ from fairseq.data.data_utils import lengths_to_mask
19
+ from fairseq.models.fairseq_model import FairseqEncoderModel
20
+
21
+
22
+ @dataclass
23
+ class FastSpeech2CriterionConfig(FairseqDataclass):
24
+ ctc_weight: float = field(default=0.0, metadata={"help": "weight for CTC loss"})
25
+
26
+
27
+ @register_criterion("fastspeech2", dataclass=FastSpeech2CriterionConfig)
28
+ class FastSpeech2Loss(FairseqCriterion):
29
+ def __init__(self, task, ctc_weight):
30
+ super().__init__(task)
31
+ self.ctc_weight = ctc_weight
32
+
33
+ def forward(self, model: FairseqEncoderModel, sample, reduction="mean"):
34
+ src_tokens = sample["net_input"]["src_tokens"]
35
+ src_lens = sample["net_input"]["src_lengths"]
36
+ tgt_lens = sample["target_lengths"]
37
+ _feat_out, _feat_out_post, _, log_dur_out, pitch_out, energy_out = model(
38
+ src_tokens=src_tokens,
39
+ src_lengths=src_lens,
40
+ prev_output_tokens=sample["net_input"]["prev_output_tokens"],
41
+ incremental_state=None,
42
+ target_lengths=tgt_lens,
43
+ speaker=sample["speaker"],
44
+ durations=sample["durations"],
45
+ pitches=sample["pitches"],
46
+ energies=sample["energies"],
47
+ )
48
+
49
+ src_mask = lengths_to_mask(sample["net_input"]["src_lengths"])
50
+ tgt_mask = lengths_to_mask(sample["target_lengths"])
51
+
52
+ pitches, energies = sample["pitches"], sample["energies"]
53
+ pitch_out, pitches = pitch_out[src_mask], pitches[src_mask]
54
+ energy_out, energies = energy_out[src_mask], energies[src_mask]
55
+
56
+ feat_out, feat = _feat_out[tgt_mask], sample["target"][tgt_mask]
57
+ l1_loss = F.l1_loss(feat_out, feat, reduction=reduction)
58
+ if _feat_out_post is not None:
59
+ l1_loss += F.l1_loss(_feat_out_post[tgt_mask], feat, reduction=reduction)
60
+
61
+ pitch_loss = F.mse_loss(pitch_out, pitches, reduction=reduction)
62
+ energy_loss = F.mse_loss(energy_out, energies, reduction=reduction)
63
+
64
+ log_dur_out = log_dur_out[src_mask]
65
+ dur = sample["durations"].float()
66
+ dur = dur.half() if log_dur_out.type().endswith(".HalfTensor") else dur
67
+ log_dur = torch.log(dur + 1)[src_mask]
68
+ dur_loss = F.mse_loss(log_dur_out, log_dur, reduction=reduction)
69
+
70
+ ctc_loss = torch.tensor(0.0).type_as(l1_loss)
71
+ if self.ctc_weight > 0.0:
72
+ lprobs = model.get_normalized_probs((_feat_out,), log_probs=True)
73
+ lprobs = lprobs.transpose(0, 1) # T x B x C
74
+ src_mask = lengths_to_mask(src_lens)
75
+ src_tokens_flat = src_tokens.masked_select(src_mask)
76
+ ctc_loss = (
77
+ F.ctc_loss(
78
+ lprobs,
79
+ src_tokens_flat,
80
+ tgt_lens,
81
+ src_lens,
82
+ reduction=reduction,
83
+ zero_infinity=True,
84
+ )
85
+ * self.ctc_weight
86
+ )
87
+
88
+ loss = l1_loss + dur_loss + pitch_loss + energy_loss + ctc_loss
89
+
90
+ sample_size = sample["nsentences"]
91
+ logging_output = {
92
+ "loss": utils.item(loss.data),
93
+ "ntokens": sample["ntokens"],
94
+ "nsentences": sample["nsentences"],
95
+ "sample_size": sample_size,
96
+ "l1_loss": utils.item(l1_loss.data),
97
+ "dur_loss": utils.item(dur_loss.data),
98
+ "pitch_loss": utils.item(pitch_loss.data),
99
+ "energy_loss": utils.item(energy_loss.data),
100
+ "ctc_loss": utils.item(ctc_loss.data),
101
+ }
102
+ return loss, sample_size, logging_output
103
+
104
+ @classmethod
105
+ def reduce_metrics(cls, logging_outputs: List[Dict[str, Any]]) -> None:
106
+ ns = [log.get("sample_size", 0) for log in logging_outputs]
107
+ ntot = sum(ns)
108
+ ws = [n / (ntot + 1e-8) for n in ns]
109
+ for key in [
110
+ "loss",
111
+ "l1_loss",
112
+ "dur_loss",
113
+ "pitch_loss",
114
+ "energy_loss",
115
+ "ctc_loss",
116
+ ]:
117
+ vals = [log.get(key, 0) for log in logging_outputs]
118
+ val = sum(val * w for val, w in zip(vals, ws))
119
+ metrics.log_scalar(key, val, ntot, round=3)
120
+ metrics.log_scalar("sample_size", ntot, len(logging_outputs))
121
+
122
+ # inference metrics
123
+ if "targ_frames" not in logging_outputs[0]:
124
+ return
125
+ n = sum(log.get("targ_frames", 0) for log in logging_outputs)
126
+ for key, new_key in [
127
+ ("mcd_loss", "mcd_loss"),
128
+ ("pred_frames", "pred_ratio"),
129
+ ("nins", "ins_rate"),
130
+ ("ndel", "del_rate"),
131
+ ]:
132
+ val = sum(log.get(key, 0) for log in logging_outputs)
133
+ metrics.log_scalar(new_key, val / n, n, round=3)
134
+
135
+ @staticmethod
136
+ def logging_outputs_can_be_summed() -> bool:
137
+ return False
fairseq/fairseq/criterions/hubert_criterion.py ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import math
7
+ import re
8
+ from dataclasses import dataclass, field
9
+ from typing import List, Optional
10
+
11
+ import torch
12
+ import torch.nn.functional as F
13
+ from fairseq import utils
14
+ from fairseq.logging import metrics
15
+ from fairseq.criterions import FairseqCriterion, register_criterion
16
+ from fairseq.dataclass import FairseqDataclass
17
+
18
+
19
+ @dataclass
20
+ class HubertCriterionConfig(FairseqDataclass):
21
+ pred_masked_weight: float = field(
22
+ default=1.0,
23
+ metadata={"help": "weight for predictive loss for masked frames"},
24
+ )
25
+ pred_nomask_weight: float = field(
26
+ default=0.0,
27
+ metadata={"help": "weight for predictive loss for unmasked frames"},
28
+ )
29
+ loss_weights: Optional[List[float]] = field(
30
+ default=None,
31
+ metadata={"help": "weights for additional loss terms (not first one)"},
32
+ )
33
+ log_keys: List[str] = field(
34
+ default_factory=lambda: [],
35
+ metadata={"help": "output keys to log"},
36
+ )
37
+
38
+
39
+ @register_criterion("hubert", dataclass=HubertCriterionConfig)
40
+ class HubertCriterion(FairseqCriterion):
41
+ def __init__(
42
+ self,
43
+ task,
44
+ pred_masked_weight,
45
+ pred_nomask_weight,
46
+ loss_weights=None,
47
+ log_keys=None,
48
+ ):
49
+ super().__init__(task)
50
+ self.pred_masked_weight = pred_masked_weight
51
+ self.pred_nomask_weight = pred_nomask_weight
52
+ self.loss_weights = loss_weights
53
+ self.log_keys = [] if log_keys is None else log_keys
54
+
55
+ def forward(self, model, sample, reduce=True, log_pred=False):
56
+ """Compute the loss for the given sample.
57
+ Returns a tuple with three elements:
58
+ 1) the loss
59
+ 2) the sample size, which is used as the denominator for the gradient
60
+ 3) logging outputs to display while training
61
+ """
62
+ net_output = model(target_list=sample["target_list"], **sample["net_input"])
63
+ loss = 0.0
64
+ sample_size = 0
65
+ logging_output = {}
66
+ reduction = "sum" if reduce else "none"
67
+
68
+ loss_m_list = []
69
+ logp_m_list = model.get_logits(net_output, True)
70
+ targ_m_list = model.get_targets(net_output, True)
71
+ assert self.pred_masked_weight == 0 or len(logp_m_list) > 0
72
+ for i, (logp_m, targ_m) in enumerate(zip(logp_m_list, targ_m_list)):
73
+ loss_m = F.cross_entropy(logp_m, targ_m, reduction=reduction)
74
+ loss_m_list.append(loss_m)
75
+ logging_output[f"loss_m_{i}"] = loss_m.detach().item()
76
+ if self.pred_masked_weight > 0:
77
+ loss += self.pred_masked_weight * sum(loss_m_list)
78
+ sample_size += targ_m_list[0].numel()
79
+
80
+ loss_u_list = []
81
+ logp_u_list = model.get_logits(net_output, False)
82
+ targ_u_list = model.get_targets(net_output, False)
83
+ assert self.pred_nomask_weight == 0 or len(logp_u_list) > 0
84
+ for i, (logp_u, targ_u) in enumerate(zip(logp_u_list, targ_u_list)):
85
+ loss_u = F.cross_entropy(logp_u, targ_u, reduction=reduction)
86
+ loss_u_list.append(loss_u)
87
+ logging_output[f"loss_u_{i}"] = loss_u.detach().item()
88
+ if self.pred_nomask_weight > 0:
89
+ loss += self.pred_nomask_weight * sum(loss_u_list)
90
+ sample_size += targ_u_list[0].numel()
91
+
92
+ if self.loss_weights is not None:
93
+ assert hasattr(model, "get_extra_losses")
94
+ extra_losses, names = model.get_extra_losses(net_output)
95
+ if torch.is_tensor(extra_losses):
96
+ extra_losses = [extra_losses]
97
+ names = [names]
98
+ if len(self.loss_weights) == 1 and len(extra_losses) != 1:
99
+ self.loss_weights = [self.loss_weights[0]] * len(extra_losses)
100
+ assert len(extra_losses) == len(
101
+ self.loss_weights
102
+ ), f"{len(extra_losses)}, {len(self.loss_weights)}"
103
+ for p, n, coef in zip(extra_losses, names, self.loss_weights):
104
+ if coef != 0 and p is not None:
105
+ p = coef * p.float() * sample_size
106
+ loss += p
107
+ logging_output[f"loss_{n}"] = p.item()
108
+
109
+ logging_output = {
110
+ "loss": loss.item() if reduce else loss,
111
+ "ntokens": sample_size,
112
+ "nsentences": sample["id"].numel(),
113
+ "sample_size": sample_size,
114
+ **logging_output,
115
+ }
116
+
117
+ for lk in self.log_keys:
118
+ if lk in net_output:
119
+ logging_output[lk] = float((net_output[lk]))
120
+
121
+ def compute_correct(logits):
122
+ if logits.numel() == 0:
123
+ return 0, 0
124
+ else:
125
+ assert logits.dim() > 1, logits.shape
126
+ max = logits.argmax(-1) == 0
127
+ min = logits.argmin(-1) == 0
128
+ both = max & min
129
+ corr = max.long().sum().item() - both.long().sum().item()
130
+ count = max.numel()
131
+ return corr, count
132
+
133
+ with torch.no_grad():
134
+ for i, logp_m in enumerate(logp_m_list):
135
+ corr_m, count_m = compute_correct(logp_m)
136
+ logging_output[f"correct_m_{i}"] = corr_m
137
+ logging_output[f"count_m_{i}"] = count_m
138
+
139
+ for i, logp_u in enumerate(logp_u_list):
140
+ corr_u, count_u = compute_correct(logp_u)
141
+ logging_output[f"correct_u_{i}"] = corr_u
142
+ logging_output[f"count_u_{i}"] = count_u
143
+
144
+ return loss, sample_size, logging_output
145
+
146
+ @staticmethod
147
+ def reduce_metrics(logging_outputs) -> None:
148
+ """Aggregate logging outputs from data parallel training (copied from normal cross entropy)."""
149
+ loss_sum = sum(log.get("loss", 0) for log in logging_outputs)
150
+ ntokens = sum(log.get("ntokens", 0) for log in logging_outputs)
151
+ sample_size = sum(log.get("sample_size", 0) for log in logging_outputs)
152
+
153
+ metrics.log_scalar(
154
+ "loss", loss_sum / sample_size / math.log(2), sample_size, round=3
155
+ )
156
+ if sample_size != ntokens:
157
+ metrics.log_scalar(
158
+ "nll_loss", loss_sum / ntokens / math.log(2), ntokens, round=3
159
+ )
160
+ metrics.log_derived(
161
+ "ppl", lambda meters: utils.get_perplexity(meters["nll_loss"].avg)
162
+ )
163
+ else:
164
+ metrics.log_derived(
165
+ "ppl", lambda meters: utils.get_perplexity(meters["loss"].avg)
166
+ )
167
+
168
+ counts = {}
169
+ for lk in logging_outputs[0].keys():
170
+ if lk.startswith("count_"):
171
+ val = sum(log[lk] for log in logging_outputs)
172
+ metrics.log_scalar(lk, val)
173
+ counts[lk] = val
174
+
175
+ for lk in logging_outputs[0].keys():
176
+ if lk.startswith("loss_"):
177
+ val = sum(log[lk] for log in logging_outputs)
178
+ metrics.log_scalar(lk, val / sample_size / math.log(2), round=3)
179
+ elif lk.startswith("correct_"):
180
+ val = sum(log[lk] for log in logging_outputs)
181
+ metrics.log_scalar(lk, val / counts[re.sub("correct", "count", lk)])
182
+
183
+ @staticmethod
184
+ def aggregate_logging_outputs(logging_outputs):
185
+ """Aggregate logging outputs from data parallel training."""
186
+ raise NotImplementedError()
187
+
188
+ @staticmethod
189
+ def logging_outputs_can_be_summed() -> bool:
190
+ """
191
+ Whether the logging outputs returned by `forward` can be summed
192
+ across workers prior to calling `reduce_metrics`. Setting this
193
+ to True will improves distributed training speed.
194
+ """
195
+ return False
fairseq/fairseq/criterions/label_smoothed_cross_entropy.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import math
7
+ from dataclasses import dataclass, field
8
+
9
+ import torch
10
+ from fairseq import utils
11
+ from fairseq.logging import metrics
12
+ from fairseq.criterions import FairseqCriterion, register_criterion
13
+ from fairseq.dataclass import FairseqDataclass
14
+ from omegaconf import II
15
+
16
+
17
+ @dataclass
18
+ class LabelSmoothedCrossEntropyCriterionConfig(FairseqDataclass):
19
+ label_smoothing: float = field(
20
+ default=0.0,
21
+ metadata={"help": "epsilon for label smoothing, 0 means no label smoothing"},
22
+ )
23
+ report_accuracy: bool = field(
24
+ default=False,
25
+ metadata={"help": "report accuracy metric"},
26
+ )
27
+ ignore_prefix_size: int = field(
28
+ default=0,
29
+ metadata={"help": "Ignore first N tokens"},
30
+ )
31
+ sentence_avg: bool = II("optimization.sentence_avg")
32
+
33
+
34
+ def label_smoothed_nll_loss(lprobs, target, epsilon, ignore_index=None, reduce=True):
35
+ if target.dim() == lprobs.dim() - 1:
36
+ target = target.unsqueeze(-1)
37
+ nll_loss = -lprobs.gather(dim=-1, index=target)
38
+ smooth_loss = -lprobs.sum(dim=-1, keepdim=True)
39
+ if ignore_index is not None:
40
+ pad_mask = target.eq(ignore_index)
41
+ nll_loss.masked_fill_(pad_mask, 0.0)
42
+ smooth_loss.masked_fill_(pad_mask, 0.0)
43
+ else:
44
+ nll_loss = nll_loss.squeeze(-1)
45
+ smooth_loss = smooth_loss.squeeze(-1)
46
+ if reduce:
47
+ nll_loss = nll_loss.sum()
48
+ smooth_loss = smooth_loss.sum()
49
+ eps_i = epsilon / (lprobs.size(-1) - 1)
50
+ loss = (1.0 - epsilon - eps_i) * nll_loss + eps_i * smooth_loss
51
+ return loss, nll_loss
52
+
53
+
54
+ @register_criterion(
55
+ "label_smoothed_cross_entropy", dataclass=LabelSmoothedCrossEntropyCriterionConfig
56
+ )
57
+ class LabelSmoothedCrossEntropyCriterion(FairseqCriterion):
58
+ def __init__(
59
+ self,
60
+ task,
61
+ sentence_avg,
62
+ label_smoothing,
63
+ ignore_prefix_size=0,
64
+ report_accuracy=False,
65
+ ):
66
+ super().__init__(task)
67
+ self.sentence_avg = sentence_avg
68
+ self.eps = label_smoothing
69
+ self.ignore_prefix_size = ignore_prefix_size
70
+ self.report_accuracy = report_accuracy
71
+
72
+ def forward(self, model, sample, reduce=True):
73
+ """Compute the loss for the given sample.
74
+
75
+ Returns a tuple with three elements:
76
+ 1) the loss
77
+ 2) the sample size, which is used as the denominator for the gradient
78
+ 3) logging outputs to display while training
79
+ """
80
+ net_output = model(**sample["net_input"])
81
+ loss, nll_loss = self.compute_loss(model, net_output, sample, reduce=reduce)
82
+ sample_size = (
83
+ sample["target"].size(0) if self.sentence_avg else sample["ntokens"]
84
+ )
85
+ logging_output = {
86
+ "loss": loss.data,
87
+ "nll_loss": nll_loss.data,
88
+ "ntokens": sample["ntokens"],
89
+ "nsentences": sample["target"].size(0),
90
+ "sample_size": sample_size,
91
+ }
92
+ if self.report_accuracy:
93
+ n_correct, total = self.compute_accuracy(model, net_output, sample)
94
+ logging_output["n_correct"] = utils.item(n_correct.data)
95
+ logging_output["total"] = utils.item(total.data)
96
+ return loss, sample_size, logging_output
97
+
98
+ def get_lprobs_and_target(self, model, net_output, sample):
99
+ lprobs = model.get_normalized_probs(net_output, log_probs=True)
100
+ target = model.get_targets(sample, net_output)
101
+ if self.ignore_prefix_size > 0:
102
+ # lprobs: B x T x C
103
+ lprobs = lprobs[:, self.ignore_prefix_size :, :].contiguous()
104
+ target = target[:, self.ignore_prefix_size :].contiguous()
105
+ return lprobs.view(-1, lprobs.size(-1)), target.view(-1)
106
+
107
+ def compute_loss(self, model, net_output, sample, reduce=True):
108
+ lprobs, target = self.get_lprobs_and_target(model, net_output, sample)
109
+ loss, nll_loss = label_smoothed_nll_loss(
110
+ lprobs,
111
+ target,
112
+ self.eps,
113
+ ignore_index=self.padding_idx,
114
+ reduce=reduce,
115
+ )
116
+ return loss, nll_loss
117
+
118
+ def compute_accuracy(self, model, net_output, sample):
119
+ lprobs, target = self.get_lprobs_and_target(model, net_output, sample)
120
+ mask = target.ne(self.padding_idx)
121
+ n_correct = torch.sum(
122
+ lprobs.argmax(1).masked_select(mask).eq(target.masked_select(mask))
123
+ )
124
+ total = torch.sum(mask)
125
+ return n_correct, total
126
+
127
+ @classmethod
128
+ def reduce_metrics(cls, logging_outputs) -> None:
129
+ """Aggregate logging outputs from data parallel training."""
130
+ loss_sum = sum(log.get("loss", 0) for log in logging_outputs)
131
+ nll_loss_sum = sum(log.get("nll_loss", 0) for log in logging_outputs)
132
+ ntokens = sum(log.get("ntokens", 0) for log in logging_outputs)
133
+ sample_size = sum(log.get("sample_size", 0) for log in logging_outputs)
134
+
135
+ metrics.log_scalar(
136
+ "loss", loss_sum / sample_size / math.log(2), sample_size, round=3
137
+ )
138
+ metrics.log_scalar(
139
+ "nll_loss", nll_loss_sum / ntokens / math.log(2), ntokens, round=3
140
+ )
141
+ metrics.log_derived(
142
+ "ppl", lambda meters: utils.get_perplexity(meters["nll_loss"].avg)
143
+ )
144
+
145
+ total = utils.item(sum(log.get("total", 0) for log in logging_outputs))
146
+ if total > 0:
147
+ metrics.log_scalar("total", total)
148
+ n_correct = utils.item(
149
+ sum(log.get("n_correct", 0) for log in logging_outputs)
150
+ )
151
+ metrics.log_scalar("n_correct", n_correct)
152
+ metrics.log_derived(
153
+ "accuracy",
154
+ lambda meters: round(
155
+ meters["n_correct"].sum * 100.0 / meters["total"].sum, 3
156
+ )
157
+ if meters["total"].sum > 0
158
+ else float("nan"),
159
+ )
160
+
161
+ @staticmethod
162
+ def logging_outputs_can_be_summed() -> bool:
163
+ """
164
+ Whether the logging outputs returned by `forward` can be summed
165
+ across workers prior to calling `reduce_metrics`. Setting this
166
+ to True will improves distributed training speed.
167
+ """
168
+ return True
fairseq/fairseq/criterions/label_smoothed_cross_entropy_latency_augmented.py ADDED
@@ -0,0 +1,221 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from dataclasses import dataclass, field
7
+ import torch
8
+ from fairseq import utils
9
+ from fairseq.logging import metrics
10
+ from fairseq.criterions import register_criterion
11
+ from fairseq.criterions.label_smoothed_cross_entropy import (
12
+ LabelSmoothedCrossEntropyCriterion,
13
+ LabelSmoothedCrossEntropyCriterionConfig,
14
+ )
15
+
16
+ try:
17
+ from simuleval.metrics.latency import (
18
+ AverageLagging,
19
+ AverageProportion,
20
+ DifferentiableAverageLagging,
21
+ )
22
+
23
+ LATENCY_METRICS = {
24
+ "average_lagging": AverageLagging,
25
+ "average_proportion": AverageProportion,
26
+ "differentiable_average_lagging": DifferentiableAverageLagging,
27
+ }
28
+ except ImportError:
29
+ LATENCY_METRICS = None
30
+
31
+
32
+ @dataclass
33
+ class LabelSmoothedCrossEntropyCriterionLatencyAugmentConfig(
34
+ LabelSmoothedCrossEntropyCriterionConfig
35
+ ):
36
+ latency_avg_weight: float = field(
37
+ default=0.0,
38
+ metadata={"help": "weight fot average latency loss."},
39
+ )
40
+ latency_var_weight: float = field(
41
+ default=0.0,
42
+ metadata={"help": "weight fot variance latency loss."},
43
+ )
44
+ latency_avg_type: str = field(
45
+ default="differentiable_average_lagging",
46
+ metadata={"help": "latency type for average loss"},
47
+ )
48
+ latency_var_type: str = field(
49
+ default="variance_delay",
50
+ metadata={"help": "latency typ for variance loss"},
51
+ )
52
+ latency_gather_method: str = field(
53
+ default="weighted_average",
54
+ metadata={"help": "method to gather latency loss for all heads"},
55
+ )
56
+ latency_update_after: int = field(
57
+ default=0,
58
+ metadata={"help": "Add latency loss after certain steps"},
59
+ )
60
+
61
+
62
+ @register_criterion(
63
+ "latency_augmented_label_smoothed_cross_entropy",
64
+ dataclass=LabelSmoothedCrossEntropyCriterionLatencyAugmentConfig,
65
+ )
66
+ class LatencyAugmentedLabelSmoothedCrossEntropyCriterion(
67
+ LabelSmoothedCrossEntropyCriterion
68
+ ):
69
+ def __init__(
70
+ self,
71
+ task,
72
+ sentence_avg,
73
+ label_smoothing,
74
+ ignore_prefix_size,
75
+ report_accuracy,
76
+ latency_avg_weight,
77
+ latency_var_weight,
78
+ latency_avg_type,
79
+ latency_var_type,
80
+ latency_gather_method,
81
+ latency_update_after,
82
+ ):
83
+ super().__init__(
84
+ task, sentence_avg, label_smoothing, ignore_prefix_size, report_accuracy
85
+ )
86
+ assert LATENCY_METRICS is not None, "Please make sure SimulEval is installed."
87
+
88
+ self.latency_avg_weight = latency_avg_weight
89
+ self.latency_var_weight = latency_var_weight
90
+ self.latency_avg_type = latency_avg_type
91
+ self.latency_var_type = latency_var_type
92
+ self.latency_gather_method = latency_gather_method
93
+ self.latency_update_after = latency_update_after
94
+
95
+ def forward(self, model, sample, reduce=True):
96
+ net_output = model(**sample["net_input"])
97
+ # 1. Compute cross entropy loss
98
+ loss, nll_loss = self.compute_loss(model, net_output, sample, reduce=reduce)
99
+
100
+ # 2. Compute cross latency loss
101
+ latency_loss, expected_latency, expected_delays_var = self.compute_latency_loss(
102
+ model, sample, net_output
103
+ )
104
+
105
+ if self.latency_update_after > 0:
106
+ num_updates = getattr(model.decoder, "num_updates", None)
107
+ assert (
108
+ num_updates is not None
109
+ ), "model.decoder doesn't have attribute 'num_updates'"
110
+ if num_updates <= self.latency_update_after:
111
+ latency_loss = 0
112
+
113
+ loss += latency_loss
114
+
115
+ sample_size = (
116
+ sample["target"].size(0) if self.sentence_avg else sample["ntokens"]
117
+ )
118
+
119
+ logging_output = {
120
+ "loss": loss.data,
121
+ "nll_loss": nll_loss.data,
122
+ "ntokens": sample["ntokens"],
123
+ "nsentences": sample["target"].size(0),
124
+ "sample_size": sample_size,
125
+ "latency": expected_latency,
126
+ "delays_var": expected_delays_var,
127
+ "latency_loss": latency_loss,
128
+ }
129
+
130
+ if self.report_accuracy:
131
+ n_correct, total = self.compute_accuracy(model, net_output, sample)
132
+ logging_output["n_correct"] = utils.item(n_correct.data)
133
+ logging_output["total"] = utils.item(total.data)
134
+ return loss, sample_size, logging_output
135
+
136
+ def compute_latency_loss(self, model, sample, net_output):
137
+ assert (
138
+ net_output[-1].encoder_padding_mask is None
139
+ or not net_output[-1].encoder_padding_mask[:, 0].any()
140
+ ), "Only right padding on source is supported."
141
+ # 1. Obtain the expected alignment
142
+ alpha_list = [item["alpha"] for item in net_output[1].attn_list]
143
+ num_layers = len(alpha_list)
144
+ bsz, num_heads, tgt_len, src_len = alpha_list[0].size()
145
+
146
+ # bsz * num_layers * num_heads, tgt_len, src_len
147
+ alpha_all = torch.cat(alpha_list, dim=1).view(-1, tgt_len, src_len)
148
+
149
+ # 2 compute expected delays
150
+ # bsz * num_heads * num_layers, tgt_len, src_len for MMA
151
+ steps = (
152
+ torch.arange(1, 1 + src_len)
153
+ .unsqueeze(0)
154
+ .unsqueeze(1)
155
+ .expand_as(alpha_all)
156
+ .type_as(alpha_all)
157
+ )
158
+
159
+ expected_delays = torch.sum(steps * alpha_all, dim=-1)
160
+
161
+ target_padding_mask = (
162
+ model.get_targets(sample, net_output)
163
+ .eq(self.padding_idx)
164
+ .unsqueeze(1)
165
+ .expand(bsz, num_layers * num_heads, tgt_len)
166
+ .contiguous()
167
+ .view(-1, tgt_len)
168
+ )
169
+
170
+ src_lengths = (
171
+ sample["net_input"]["src_lengths"]
172
+ .unsqueeze(1)
173
+ .expand(bsz, num_layers * num_heads)
174
+ .contiguous()
175
+ .view(-1)
176
+ )
177
+ expected_latency = LATENCY_METRICS[self.latency_avg_type](
178
+ expected_delays, src_lengths, None, target_padding_mask=target_padding_mask
179
+ )
180
+
181
+ # 2.1 average expected latency of heads
182
+ # bsz, num_layers * num_heads
183
+ expected_latency = expected_latency.view(bsz, -1)
184
+ if self.latency_gather_method == "average":
185
+ # bsz * tgt_len
186
+ expected_latency = expected_delays.mean(dim=1)
187
+ elif self.latency_gather_method == "weighted_average":
188
+ weights = torch.nn.functional.softmax(expected_latency, dim=1)
189
+ expected_latency = torch.sum(expected_latency * weights, dim=1)
190
+ elif self.latency_gather_method == "max":
191
+ expected_latency = expected_latency.max(dim=1)[0]
192
+ else:
193
+ raise NotImplementedError
194
+
195
+ expected_latency = expected_latency.sum()
196
+ avg_loss = self.latency_avg_weight * expected_latency
197
+
198
+ # 2.2 variance of expected delays
199
+ expected_delays_var = (
200
+ expected_delays.view(bsz, -1, tgt_len).var(dim=1).mean(dim=1)
201
+ )
202
+ expected_delays_var = expected_delays_var.sum()
203
+ var_loss = self.latency_avg_weight * expected_delays_var
204
+
205
+ # 3. Final loss
206
+ latency_loss = avg_loss + var_loss
207
+
208
+ return latency_loss, expected_latency, expected_delays_var
209
+
210
+ @classmethod
211
+ def reduce_metrics(cls, logging_outputs) -> None:
212
+ super().reduce_metrics(logging_outputs)
213
+ latency = sum(log.get("latency", 0) for log in logging_outputs)
214
+ delays_var = sum(log.get("delays_var", 0) for log in logging_outputs)
215
+ latency_loss = sum(log.get("latency_loss", 0) for log in logging_outputs)
216
+ nsentences = sum(log.get("nsentences", 0) for log in logging_outputs)
217
+ metrics.log_scalar("latency", latency.float() / nsentences, nsentences, round=3)
218
+ metrics.log_scalar("delays_var", delays_var / nsentences, nsentences, round=3)
219
+ metrics.log_scalar(
220
+ "latency_loss", latency_loss / nsentences, nsentences, round=3
221
+ )