PyTorch
ssl-aasist
custom_code
ash56 commited on
Commit
96ffdcd
·
verified ·
1 Parent(s): 4d0021e

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +1 -0
  2. fairseq/fairseq/criterions/__pycache__/__init__.cpython-310.pyc +0 -0
  3. fairseq/fairseq/criterions/__pycache__/adaptive_loss.cpython-310.pyc +0 -0
  4. fairseq/fairseq/criterions/__pycache__/composite_loss.cpython-310.pyc +0 -0
  5. fairseq/fairseq/criterions/__pycache__/cross_entropy.cpython-310.pyc +0 -0
  6. fairseq/fairseq/criterions/__pycache__/ctc.cpython-310.pyc +0 -0
  7. fairseq/fairseq/criterions/__pycache__/fastspeech2_loss.cpython-310.pyc +0 -0
  8. fairseq/fairseq/criterions/__pycache__/hubert_criterion.cpython-310.pyc +0 -0
  9. fairseq/fairseq/criterions/__pycache__/label_smoothed_cross_entropy.cpython-310.pyc +0 -0
  10. fairseq/fairseq/criterions/__pycache__/label_smoothed_cross_entropy_latency_augmented.cpython-310.pyc +0 -0
  11. fairseq/fairseq/criterions/__pycache__/label_smoothed_cross_entropy_with_alignment.cpython-310.pyc +0 -0
  12. fairseq/fairseq/criterions/__pycache__/label_smoothed_cross_entropy_with_ctc.cpython-310.pyc +0 -0
  13. fairseq/fairseq/criterions/__pycache__/label_smoothed_cross_entropy_with_rdrop.cpython-310.pyc +0 -0
  14. fairseq/fairseq/criterions/__pycache__/legacy_masked_lm.cpython-310.pyc +0 -0
  15. fairseq/fairseq/criterions/__pycache__/masked_lm.cpython-310.pyc +0 -0
  16. fairseq/fairseq/criterions/__pycache__/model_criterion.cpython-310.pyc +0 -0
  17. fairseq/fairseq/criterions/__pycache__/nat_loss.cpython-310.pyc +0 -0
  18. fairseq/fairseq/criterions/__pycache__/sentence_prediction.cpython-310.pyc +0 -0
  19. fairseq/fairseq/criterions/__pycache__/sentence_prediction_adapters.cpython-310.pyc +0 -0
  20. fairseq/fairseq/criterions/__pycache__/sentence_ranking.cpython-310.pyc +0 -0
  21. fairseq/fairseq/criterions/__pycache__/speech_dlm_criterion.cpython-310.pyc +0 -0
  22. fairseq/fairseq/criterions/__pycache__/speech_to_speech_criterion.cpython-310.pyc +0 -0
  23. fairseq/fairseq/criterions/__pycache__/speech_ulm_criterion.cpython-310.pyc +0 -0
  24. fairseq/fairseq/criterions/__pycache__/tacotron2_loss.cpython-310.pyc +0 -0
  25. fairseq/fairseq/criterions/__pycache__/wav2vec_criterion.cpython-310.pyc +0 -0
  26. fairseq/fairseq/criterions/speech_to_speech_criterion.py +517 -0
  27. fairseq/fairseq/criterions/tacotron2_loss.py +227 -0
  28. fairseq/fairseq/criterions/wav2vec_criterion.py +231 -0
  29. fairseq/fairseq/data/__init__.py +137 -0
  30. fairseq/fairseq/data/add_class_target_dataset.py +79 -0
  31. fairseq/fairseq/data/add_target_dataset.py +83 -0
  32. fairseq/fairseq/data/append_token_dataset.py +41 -0
  33. fairseq/fairseq/data/audio/dataset_transforms/__pycache__/concataugment.cpython-310.pyc +0 -0
  34. fairseq/fairseq/data/audio/feature_transforms/__pycache__/delta_deltas.cpython-310.pyc +0 -0
  35. fairseq/fairseq/data/audio/feature_transforms/global_cmvn.py +29 -0
  36. fairseq/fairseq/data/audio/speech_to_text_dataset.py +733 -0
  37. fairseq/fairseq/data/backtranslation_dataset.py +165 -0
  38. fairseq/fairseq/data/base_wrapper_dataset.py +78 -0
  39. fairseq/fairseq/data/bucket_pad_length_dataset.py +78 -0
  40. fairseq/fairseq/data/codedataset.py +576 -0
  41. fairseq/fairseq/data/colorize_dataset.py +25 -0
  42. fairseq/fairseq/data/concat_dataset.py +124 -0
  43. fairseq/fairseq/data/concat_sentences_dataset.py +54 -0
  44. fairseq/fairseq/data/data_utils.py +1144 -0
  45. fairseq/fairseq/data/data_utils_fast.pyx +178 -0
  46. fairseq/fairseq/data/denoising_dataset.py +443 -0
  47. fairseq/fairseq/data/dictionary.py +403 -0
  48. fairseq/fairseq/data/fairseq_dataset.py +205 -0
  49. fairseq/fairseq/data/fasta_dataset.py +107 -0
  50. fairseq/fairseq/data/id_dataset.py +19 -0
.gitattributes CHANGED
@@ -40,3 +40,4 @@ fairseq/alignment_train_cpu_binding.cpython-310-x86_64-linux-gnu.so filter=lfs d
40
  fairseq/docs/fairseq.gif filter=lfs diff=lfs merge=lfs -text
41
  fairseq/examples/hubert/tests/6313-76958-0021.flac filter=lfs diff=lfs merge=lfs -text
42
  fairseq/examples/textless_nlp/speech-resynth/img/fig.png filter=lfs diff=lfs merge=lfs -text
 
 
40
  fairseq/docs/fairseq.gif filter=lfs diff=lfs merge=lfs -text
41
  fairseq/examples/hubert/tests/6313-76958-0021.flac filter=lfs diff=lfs merge=lfs -text
42
  fairseq/examples/textless_nlp/speech-resynth/img/fig.png filter=lfs diff=lfs merge=lfs -text
43
+ fairseq/fairseq/libbase.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
fairseq/fairseq/criterions/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (1.06 kB). View file
 
fairseq/fairseq/criterions/__pycache__/adaptive_loss.cpython-310.pyc ADDED
Binary file (4.76 kB). View file
 
fairseq/fairseq/criterions/__pycache__/composite_loss.cpython-310.pyc ADDED
Binary file (4.49 kB). View file
 
fairseq/fairseq/criterions/__pycache__/cross_entropy.cpython-310.pyc ADDED
Binary file (3.89 kB). View file
 
fairseq/fairseq/criterions/__pycache__/ctc.cpython-310.pyc ADDED
Binary file (8.4 kB). View file
 
fairseq/fairseq/criterions/__pycache__/fastspeech2_loss.cpython-310.pyc ADDED
Binary file (4.76 kB). View file
 
fairseq/fairseq/criterions/__pycache__/hubert_criterion.cpython-310.pyc ADDED
Binary file (6.77 kB). View file
 
fairseq/fairseq/criterions/__pycache__/label_smoothed_cross_entropy.cpython-310.pyc ADDED
Binary file (6.26 kB). View file
 
fairseq/fairseq/criterions/__pycache__/label_smoothed_cross_entropy_latency_augmented.cpython-310.pyc ADDED
Binary file (5.89 kB). View file
 
fairseq/fairseq/criterions/__pycache__/label_smoothed_cross_entropy_with_alignment.cpython-310.pyc ADDED
Binary file (4.78 kB). View file
 
fairseq/fairseq/criterions/__pycache__/label_smoothed_cross_entropy_with_ctc.cpython-310.pyc ADDED
Binary file (3.33 kB). View file
 
fairseq/fairseq/criterions/__pycache__/label_smoothed_cross_entropy_with_rdrop.cpython-310.pyc ADDED
Binary file (5.3 kB). View file
 
fairseq/fairseq/criterions/__pycache__/legacy_masked_lm.cpython-310.pyc ADDED
Binary file (5.53 kB). View file
 
fairseq/fairseq/criterions/__pycache__/masked_lm.cpython-310.pyc ADDED
Binary file (3.53 kB). View file
 
fairseq/fairseq/criterions/__pycache__/model_criterion.cpython-310.pyc ADDED
Binary file (5.71 kB). View file
 
fairseq/fairseq/criterions/__pycache__/nat_loss.cpython-310.pyc ADDED
Binary file (6.1 kB). View file
 
fairseq/fairseq/criterions/__pycache__/sentence_prediction.cpython-310.pyc ADDED
Binary file (8.21 kB). View file
 
fairseq/fairseq/criterions/__pycache__/sentence_prediction_adapters.cpython-310.pyc ADDED
Binary file (1.99 kB). View file
 
fairseq/fairseq/criterions/__pycache__/sentence_ranking.cpython-310.pyc ADDED
Binary file (4.6 kB). View file
 
fairseq/fairseq/criterions/__pycache__/speech_dlm_criterion.cpython-310.pyc ADDED
Binary file (9.03 kB). View file
 
fairseq/fairseq/criterions/__pycache__/speech_to_speech_criterion.cpython-310.pyc ADDED
Binary file (11.5 kB). View file
 
fairseq/fairseq/criterions/__pycache__/speech_ulm_criterion.cpython-310.pyc ADDED
Binary file (4.95 kB). View file
 
fairseq/fairseq/criterions/__pycache__/tacotron2_loss.cpython-310.pyc ADDED
Binary file (7.46 kB). View file
 
fairseq/fairseq/criterions/__pycache__/wav2vec_criterion.cpython-310.pyc ADDED
Binary file (6.5 kB). View file
 
fairseq/fairseq/criterions/speech_to_speech_criterion.py ADDED
@@ -0,0 +1,517 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import logging
7
+ import math
8
+ from collections import OrderedDict
9
+
10
+ import torch
11
+
12
+ from fairseq import utils
13
+ from fairseq.logging import metrics
14
+ from fairseq.criterions import register_criterion
15
+ from fairseq.criterions.ctc import CtcCriterion
16
+ from fairseq.criterions.label_smoothed_cross_entropy_with_rdrop import (
17
+ RdropLabelSmoothedCrossEntropyCriterion,
18
+ RdropLabelSmoothedCrossEntropyCriterionConfig,
19
+ duplicate_input,
20
+ )
21
+ from fairseq.criterions.tacotron2_loss import (
22
+ Tacotron2Criterion,
23
+ Tacotron2CriterionConfig,
24
+ )
25
+
26
+ logger = logging.getLogger(__name__)
27
+
28
+
29
+ class MultitaskCriterion:
30
+ def __init__(self, multitask_tasks, rdrop_alpha=0.0):
31
+ self.rdrop_alpha = rdrop_alpha
32
+ self.rdrop_alpha_mtl = rdrop_alpha
33
+
34
+ self.multitask_criterion = OrderedDict()
35
+ self.multitask_loss_weight = OrderedDict()
36
+ for task_name, task_obj in multitask_tasks.items():
37
+ if task_obj.args.get_loss_weight(0) == 0:
38
+ logger.info(f"Skip {task_name} loss criterion")
39
+ continue
40
+
41
+ rdrop_alpha_task = task_obj.args.rdrop_alpha
42
+ if rdrop_alpha_task is None:
43
+ rdrop_alpha_task = rdrop_alpha
44
+ self.rdrop_alpha_mtl = rdrop_alpha_task
45
+ logger.info(f"rdrop_alpha is set to {rdrop_alpha_task} for {task_name}")
46
+
47
+ if task_obj.args.decoder_type == "ctc":
48
+ self.multitask_criterion[task_name] = CtcCriterion(
49
+ task_obj.args.criterion_cfg,
50
+ task_obj,
51
+ rdrop_alpha=rdrop_alpha_task,
52
+ )
53
+ else:
54
+ self.multitask_criterion[
55
+ task_name
56
+ ] = RdropLabelSmoothedCrossEntropyCriterion(
57
+ task_obj,
58
+ task_obj.args.criterion_cfg.sentence_avg,
59
+ label_smoothing=task_obj.args.criterion_cfg.label_smoothing,
60
+ rdrop_alpha=rdrop_alpha_task,
61
+ )
62
+
63
+ def set_multitask_loss_weight(self, task_name, weight=0.0):
64
+ self.multitask_loss_weight[task_name] = weight
65
+
66
+ def get_multitask_loss(self, model, sample, model_out):
67
+ logging_output = {}
68
+ loss = 0.0
69
+ for task_name, task_criterion in self.multitask_criterion.items():
70
+ layer_id = task_criterion.task.args.input_layer
71
+ if isinstance(task_criterion, CtcCriterion):
72
+ if task_criterion.task.args.input_from == "encoder":
73
+ if len(model_out["encoder_padding_mask"]) > 0:
74
+ non_padding_mask = ~model_out["encoder_padding_mask"][0]
75
+ input_lengths = non_padding_mask.long().sum(-1)
76
+ else:
77
+ out = model_out["encoder_states"][layer_id]
78
+ input_lengths = out.new_full(
79
+ (out.shape[1],), out.shape[0]
80
+ ).long()
81
+
82
+ task_sample = {
83
+ "net_input": {
84
+ "src_tokens": model_out["encoder_states"][
85
+ layer_id
86
+ ], # check batch idx
87
+ "src_lengths": input_lengths,
88
+ },
89
+ "id": sample["id"],
90
+ }
91
+ else:
92
+ task_sample = {
93
+ "net_input": {
94
+ "src_tokens": model_out["inner_states"][layer_id],
95
+ "src_lengths": sample["target_lengths"],
96
+ },
97
+ "id": sample["id"],
98
+ }
99
+ else:
100
+ task_sample = {
101
+ "net_input": {
102
+ "src_tokens": sample["multitask"][task_name]["net_input"][
103
+ "prev_output_tokens"
104
+ ],
105
+ "encoder_out": {
106
+ "encoder_out": [model_out["encoder_states"][layer_id]],
107
+ "encoder_padding_mask": model_out["encoder_padding_mask"],
108
+ },
109
+ }
110
+ }
111
+
112
+ for key in ["target", "target_lengths", "ntokens"]:
113
+ task_sample[key] = sample["multitask"][task_name][key]
114
+
115
+ if task_name == getattr(model, "mt_task_name", None):
116
+ decoder_out = model_out["mt_decoder_out"]
117
+ else:
118
+ decoder_out = None
119
+ task_loss, task_sample_size, task_logging_output = task_criterion(
120
+ model.multitask_decoders[task_name], task_sample, net_output=decoder_out
121
+ )
122
+
123
+ loss = loss + self.multitask_loss_weight[task_name] * task_loss
124
+ task_logging_output["loss_weight"] = self.multitask_loss_weight[task_name]
125
+ logging_output[task_name] = task_logging_output
126
+ return loss, logging_output
127
+
128
+ @classmethod
129
+ def reduce_metrics(cls, logging_outputs) -> None:
130
+ for task_name in logging_outputs[0]["multitask"].keys():
131
+ # different criterion may return different logging
132
+ # currently only reduce on loss, the most common one
133
+ # ideally the way that losses are reduced should also depend on the task type
134
+ loss_sum = sum(
135
+ log["multitask"][task_name].get("loss", 0) for log in logging_outputs
136
+ )
137
+ sample_size = sum(
138
+ log["multitask"][task_name].get("sample_size", 0)
139
+ for log in logging_outputs
140
+ )
141
+
142
+ metrics.log_scalar(
143
+ f"multitask_{task_name}_loss",
144
+ loss_sum / sample_size / math.log(2),
145
+ sample_size,
146
+ round=3,
147
+ )
148
+
149
+ loss_weight = logging_outputs[0]["multitask"][task_name].get(
150
+ "loss_weight", 0
151
+ )
152
+ metrics.log_scalar(
153
+ f"multitask_{task_name}_loss_weight",
154
+ loss_weight,
155
+ weight=0,
156
+ priority=250,
157
+ )
158
+
159
+
160
+ @register_criterion(
161
+ "speech_to_unit", dataclass=RdropLabelSmoothedCrossEntropyCriterionConfig
162
+ )
163
+ class SpeechToUnitMultitaskTaskCriterion(
164
+ RdropLabelSmoothedCrossEntropyCriterion, MultitaskCriterion
165
+ ):
166
+ def __init__(
167
+ self,
168
+ task,
169
+ sentence_avg,
170
+ label_smoothing,
171
+ ignore_prefix_size=0,
172
+ report_accuracy=False,
173
+ rdrop_alpha=0.0,
174
+ ):
175
+ super().__init__(
176
+ task,
177
+ sentence_avg,
178
+ label_smoothing,
179
+ ignore_prefix_size,
180
+ report_accuracy,
181
+ rdrop_alpha,
182
+ )
183
+ MultitaskCriterion.__init__(self, task.multitask_tasks, rdrop_alpha)
184
+
185
+ def forward(self, model, sample, reduce=True):
186
+ net_input_concat = {
187
+ "src_tokens": sample["net_input"]["src_tokens"],
188
+ "src_lengths": sample["net_input"]["src_lengths"],
189
+ "prev_output_tokens": sample["net_input"]["prev_output_tokens"],
190
+ "tgt_speaker": sample["net_input"].get("tgt_speaker", None),
191
+ "return_all_hiddens": True,
192
+ }
193
+
194
+ if self.rdrop_alpha > 0 or self.rdrop_alpha_mtl > 0:
195
+ net_input_concat = duplicate_input(net_input_concat)
196
+
197
+ net_output, extra = model(**net_input_concat)
198
+ loss, nll_loss, rdrop_kl_loss = self.compute_loss(
199
+ model, [net_output], sample, reduce=reduce
200
+ )
201
+ sample_size = (
202
+ sample["target"].size(0) if self.sentence_avg else sample["ntokens"]
203
+ )
204
+ logging_output = {
205
+ "loss": loss.data,
206
+ "nll_loss": nll_loss.data,
207
+ "ntokens": sample["ntokens"],
208
+ "nsentences": sample["target"].size(0),
209
+ "sample_size": sample_size,
210
+ }
211
+ if self.report_accuracy:
212
+ n_correct, total = self.compute_accuracy(model, [net_output], sample)
213
+ logging_output["n_correct"] = utils.item(n_correct.data)
214
+ logging_output["total"] = utils.item(total.data)
215
+ if self.rdrop_alpha > 0:
216
+ logging_output["rdrop_kl_loss"] = utils.item(rdrop_kl_loss.data)
217
+
218
+ if len(self.multitask_criterion) == 0:
219
+ return loss, sample_size, logging_output
220
+
221
+ # multitask
222
+ multitask_loss, multitask_log = self.get_multitask_loss(model, sample, extra)
223
+ loss += multitask_loss
224
+ logging_output["multitask"] = multitask_log
225
+
226
+ return loss, sample_size, logging_output
227
+
228
+ @classmethod
229
+ def reduce_metrics(cls, logging_outputs) -> None:
230
+ super().reduce_metrics(logging_outputs)
231
+
232
+ # inference metrics
233
+ if "targ_frames" in logging_outputs[0]:
234
+ n = sum(log.get("norm_frames", 0) for log in logging_outputs)
235
+ for key, new_key in [
236
+ ("mcd_loss", "mcd_loss"),
237
+ ("pred_frames", "pred_ratio"),
238
+ ("nins", "ins_rate"),
239
+ ("ndel", "del_rate"),
240
+ ]:
241
+ val = sum(log.get(key, 0) for log in logging_outputs)
242
+ metrics.log_scalar(new_key, val / n, n, round=3)
243
+
244
+ if "multitask" not in logging_outputs[0]:
245
+ return
246
+
247
+ MultitaskCriterion.reduce_metrics(logging_outputs)
248
+
249
+ @staticmethod
250
+ def logging_outputs_can_be_summed() -> bool:
251
+ """
252
+ Whether the logging outputs returned by `forward` can be summed
253
+ across workers prior to calling `reduce_metrics`. Setting this
254
+ to True will improves distributed training speed.
255
+ """
256
+ return False
257
+
258
+
259
+ @register_criterion(
260
+ "speech_to_unit_2pass", dataclass=RdropLabelSmoothedCrossEntropyCriterionConfig
261
+ )
262
+ class SpeechToUnit2passMultitaskTaskCriterion(SpeechToUnitMultitaskTaskCriterion):
263
+ def __init__(
264
+ self,
265
+ task,
266
+ sentence_avg,
267
+ label_smoothing,
268
+ ignore_prefix_size=0,
269
+ report_accuracy=False,
270
+ rdrop_alpha=0.0,
271
+ ):
272
+ super().__init__(
273
+ task,
274
+ sentence_avg,
275
+ label_smoothing,
276
+ ignore_prefix_size,
277
+ report_accuracy,
278
+ rdrop_alpha,
279
+ )
280
+
281
+ def forward(self, model, sample, reduce=True):
282
+ net_input_concat = {
283
+ "src_tokens": sample["net_input"]["src_tokens"],
284
+ "src_lengths": sample["net_input"]["src_lengths"],
285
+ "prev_output_tokens": sample["net_input"]["prev_output_tokens"],
286
+ "prev_output_tokens_mt": sample["multitask"][model.mt_task_name][
287
+ "net_input"
288
+ ]["prev_output_tokens"],
289
+ "tgt_speaker": sample["net_input"].get("tgt_speaker", None),
290
+ "return_all_hiddens": True,
291
+ }
292
+ if getattr(model, "asr_task_name", None) is not None:
293
+ net_input_concat["prev_output_tokens_asr"] = sample["multitask"][
294
+ model.asr_task_name
295
+ ]["net_input"]["prev_output_tokens"]
296
+
297
+ if self.rdrop_alpha > 0 or self.rdrop_alpha_mtl > 0:
298
+ net_input_concat = duplicate_input(net_input_concat)
299
+
300
+ net_output, extra = model(**net_input_concat)
301
+ loss, nll_loss, rdrop_kl_loss = self.compute_loss(
302
+ model, [net_output], sample, reduce=reduce
303
+ )
304
+
305
+ sample_size = (
306
+ sample["target"].size(0) if self.sentence_avg else sample["ntokens"]
307
+ )
308
+ logging_output = {
309
+ "loss": loss.data,
310
+ "nll_loss": nll_loss.data,
311
+ "ntokens": sample["ntokens"],
312
+ "nsentences": sample["target"].size(0),
313
+ "sample_size": sample_size,
314
+ }
315
+ if self.report_accuracy:
316
+ n_correct, total = self.compute_accuracy(model, [net_output], sample)
317
+ logging_output["n_correct"] = utils.item(n_correct.data)
318
+ logging_output["total"] = utils.item(total.data)
319
+ if self.rdrop_alpha > 0:
320
+ logging_output["rdrop_kl_loss"] = utils.item(rdrop_kl_loss.data)
321
+
322
+ if len(self.multitask_criterion) == 0:
323
+ return loss, sample_size, logging_output
324
+
325
+ # multitask
326
+ multitask_loss, multitask_log = self.get_multitask_loss(model, sample, extra)
327
+ loss += multitask_loss
328
+ logging_output["multitask"] = multitask_log
329
+
330
+ return loss, sample_size, logging_output
331
+
332
+
333
+ @register_criterion("speech_to_spectrogram", dataclass=Tacotron2CriterionConfig)
334
+ class SpeechToSpectrogramMultitaskTaskCriterion(Tacotron2Criterion, MultitaskCriterion):
335
+ def __init__(
336
+ self,
337
+ task,
338
+ sentence_avg,
339
+ use_guided_attention_loss,
340
+ guided_attention_loss_sigma,
341
+ bce_pos_weight,
342
+ ctc_weight,
343
+ ):
344
+ super().__init__(
345
+ task,
346
+ sentence_avg,
347
+ use_guided_attention_loss,
348
+ guided_attention_loss_sigma,
349
+ bce_pos_weight,
350
+ ctc_weight,
351
+ )
352
+ MultitaskCriterion.__init__(self, task.multitask_tasks)
353
+
354
+ def forward(self, model, sample, reduction="mean"):
355
+ bsz, max_len, _ = sample["target"].size()
356
+ feat_tgt = sample["target"]
357
+ feat_len = sample["target_lengths"].view(bsz, 1).expand(-1, max_len)
358
+ eos_tgt = torch.arange(max_len).to(sample["target"].device)
359
+ eos_tgt = eos_tgt.view(1, max_len).expand(bsz, -1)
360
+ eos_tgt = (eos_tgt == (feat_len - 1)).float()
361
+
362
+ feat_out, eos_out, extra = model(
363
+ src_tokens=sample["net_input"]["src_tokens"],
364
+ src_lengths=sample["net_input"]["src_lengths"],
365
+ prev_output_tokens=sample["net_input"]["prev_output_tokens"],
366
+ tgt_speaker=sample["net_input"]["tgt_speaker"],
367
+ target_lengths=sample["target_lengths"],
368
+ return_all_hiddens=True,
369
+ )
370
+
371
+ l1_loss, mse_loss, eos_loss = self.compute_loss(
372
+ extra["feature_out"],
373
+ feat_out,
374
+ eos_out,
375
+ feat_tgt,
376
+ eos_tgt,
377
+ sample["target_lengths"],
378
+ reduction,
379
+ )
380
+ attn_loss = torch.tensor(0.0).type_as(l1_loss)
381
+ if self.guided_attn is not None:
382
+ attn_loss = self.guided_attn(
383
+ extra["attn"],
384
+ sample["net_input"]["src_lengths"],
385
+ sample["target_lengths"],
386
+ reduction,
387
+ )
388
+ loss = (
389
+ l1_loss + mse_loss + eos_loss + attn_loss
390
+ ) # do not include ctc loss as there's no text target
391
+
392
+ sample_size = sample["nsentences"] if self.sentence_avg else sample["ntokens"]
393
+ logging_output = {
394
+ "loss": utils.item(loss.data),
395
+ "ntokens": sample["ntokens"],
396
+ "nsentences": sample["nsentences"],
397
+ "sample_size": sample_size,
398
+ "l1_loss": utils.item(l1_loss.data),
399
+ "mse_loss": utils.item(mse_loss.data),
400
+ "eos_loss": utils.item(eos_loss.data),
401
+ "attn_loss": utils.item(attn_loss.data),
402
+ }
403
+
404
+ if len(self.multitask_criterion) == 0:
405
+ return loss, sample_size, logging_output
406
+
407
+ # multitask
408
+ multitask_loss, multitask_log = self.get_multitask_loss(model, sample, extra)
409
+ loss += multitask_loss
410
+ logging_output["multitask"] = multitask_log
411
+ return loss, sample_size, logging_output
412
+
413
+ @classmethod
414
+ def reduce_metrics(cls, logging_outputs) -> None:
415
+ super().reduce_metrics(logging_outputs)
416
+
417
+ # inference metrics
418
+ if "targ_frames" in logging_outputs[0]:
419
+ n = sum(log.get("norm_frames", 0) for log in logging_outputs)
420
+ for key, new_key in [
421
+ ("mcd_loss", "mcd_loss"),
422
+ ("pred_frames", "pred_ratio"),
423
+ ("nins", "ins_rate"),
424
+ ("ndel", "del_rate"),
425
+ ]:
426
+ val = sum(log.get(key, 0) for log in logging_outputs)
427
+ metrics.log_scalar(new_key, val / n, n, round=3)
428
+
429
+ if "multitask" not in logging_outputs[0]:
430
+ return
431
+
432
+ MultitaskCriterion.reduce_metrics(logging_outputs)
433
+
434
+
435
+ @register_criterion("speech_to_spectrogram_2pass", dataclass=Tacotron2CriterionConfig)
436
+ class SpeechToSpectrogram2passMultitaskTaskCriterion(
437
+ SpeechToSpectrogramMultitaskTaskCriterion
438
+ ):
439
+ def __init__(
440
+ self,
441
+ task,
442
+ sentence_avg,
443
+ use_guided_attention_loss,
444
+ guided_attention_loss_sigma,
445
+ bce_pos_weight,
446
+ ctc_weight,
447
+ ):
448
+ super().__init__(
449
+ task,
450
+ sentence_avg,
451
+ use_guided_attention_loss,
452
+ guided_attention_loss_sigma,
453
+ bce_pos_weight,
454
+ ctc_weight,
455
+ )
456
+
457
+ def forward(self, model, sample, reduction="mean"):
458
+ bsz, max_len, _ = sample["target"].size()
459
+ feat_tgt = sample["target"]
460
+ feat_len = sample["target_lengths"].view(bsz, 1).expand(-1, max_len)
461
+ eos_tgt = torch.arange(max_len).to(sample["target"].device)
462
+ eos_tgt = eos_tgt.view(1, max_len).expand(bsz, -1)
463
+ eos_tgt = (eos_tgt == (feat_len - 1)).float()
464
+
465
+ feat_out, eos_out, extra = model(
466
+ src_tokens=sample["net_input"]["src_tokens"],
467
+ src_lengths=sample["net_input"]["src_lengths"],
468
+ prev_output_tokens=sample["net_input"]["prev_output_tokens"],
469
+ prev_output_tokens_mt=sample["multitask"][model.mt_task_name]["net_input"][
470
+ "prev_output_tokens"
471
+ ],
472
+ tgt_speaker=sample["net_input"]["tgt_speaker"],
473
+ target_lengths=sample["target_lengths"],
474
+ return_all_hiddens=True,
475
+ )
476
+
477
+ l1_loss, mse_loss, eos_loss = self.compute_loss(
478
+ extra["feature_out"],
479
+ feat_out,
480
+ eos_out,
481
+ feat_tgt,
482
+ eos_tgt,
483
+ sample["target_lengths"],
484
+ reduction,
485
+ )
486
+ attn_loss = torch.tensor(0.0).type_as(l1_loss)
487
+ if self.guided_attn is not None:
488
+ attn_loss = self.guided_attn(
489
+ extra["attn"],
490
+ sample["net_input"]["src_lengths"],
491
+ sample["target_lengths"],
492
+ reduction,
493
+ )
494
+ loss = (
495
+ l1_loss + mse_loss + eos_loss + attn_loss
496
+ ) # do not include ctc loss as there's no text target
497
+
498
+ sample_size = sample["nsentences"] if self.sentence_avg else sample["ntokens"]
499
+ logging_output = {
500
+ "loss": utils.item(loss.data),
501
+ "ntokens": sample["ntokens"],
502
+ "nsentences": sample["nsentences"],
503
+ "sample_size": sample_size,
504
+ "l1_loss": utils.item(l1_loss.data),
505
+ "mse_loss": utils.item(mse_loss.data),
506
+ "eos_loss": utils.item(eos_loss.data),
507
+ "attn_loss": utils.item(attn_loss.data),
508
+ }
509
+
510
+ if len(self.multitask_criterion) == 0:
511
+ return loss, sample_size, logging_output
512
+
513
+ # multitask
514
+ multitask_loss, multitask_log = self.get_multitask_loss(model, sample, extra)
515
+ loss += multitask_loss
516
+ logging_output["multitask"] = multitask_log
517
+ return loss, sample_size, logging_output
fairseq/fairseq/criterions/tacotron2_loss.py ADDED
@@ -0,0 +1,227 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2017-present, Facebook, Inc.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the LICENSE file in
5
+ # the root directory of this source tree. An additional grant of patent rights
6
+ # can be found in the PATENTS file in the same directory.
7
+
8
+ import logging
9
+ from dataclasses import dataclass, field
10
+ from functools import lru_cache
11
+ from typing import Any, Dict, List
12
+
13
+ import torch
14
+ import torch.nn.functional as F
15
+ from omegaconf import II
16
+
17
+ from fairseq import utils
18
+ from fairseq.logging import metrics
19
+ from fairseq.criterions import FairseqCriterion, register_criterion
20
+ from fairseq.data.data_utils import lengths_to_mask
21
+ from fairseq.dataclass import FairseqDataclass
22
+
23
+ logger = logging.getLogger(__name__)
24
+
25
+
26
+ @dataclass
27
+ class Tacotron2CriterionConfig(FairseqDataclass):
28
+ bce_pos_weight: float = field(
29
+ default=1.0,
30
+ metadata={"help": "weight of positive examples for BCE loss"},
31
+ )
32
+ use_guided_attention_loss: bool = field(
33
+ default=False,
34
+ metadata={"help": "use guided attention loss"},
35
+ )
36
+ guided_attention_loss_sigma: float = field(
37
+ default=0.4,
38
+ metadata={"help": "weight of positive examples for BCE loss"},
39
+ )
40
+ ctc_weight: float = field(default=0.0, metadata={"help": "weight for CTC loss"})
41
+ sentence_avg: bool = II("optimization.sentence_avg")
42
+
43
+
44
+ class GuidedAttentionLoss(torch.nn.Module):
45
+ """
46
+ Efficiently Trainable Text-to-Speech System Based on Deep Convolutional
47
+ Networks with Guided Attention (https://arxiv.org/abs/1710.08969)
48
+ """
49
+
50
+ def __init__(self, sigma):
51
+ super().__init__()
52
+ self.sigma = sigma
53
+
54
+ @staticmethod
55
+ @lru_cache(maxsize=8)
56
+ def _get_weight(s_len, t_len, sigma):
57
+ grid_x, grid_y = torch.meshgrid(torch.arange(t_len), torch.arange(s_len))
58
+ grid_x = grid_x.to(s_len.device)
59
+ grid_y = grid_y.to(s_len.device)
60
+ w = (grid_y.float() / s_len - grid_x.float() / t_len) ** 2
61
+ return 1.0 - torch.exp(-w / (2 * (sigma**2)))
62
+
63
+ def _get_weights(self, src_lens, tgt_lens):
64
+ bsz, max_s_len, max_t_len = len(src_lens), max(src_lens), max(tgt_lens)
65
+ weights = torch.zeros((bsz, max_t_len, max_s_len))
66
+ for i, (s_len, t_len) in enumerate(zip(src_lens, tgt_lens)):
67
+ weights[i, :t_len, :s_len] = self._get_weight(s_len, t_len, self.sigma)
68
+ return weights
69
+
70
+ @staticmethod
71
+ def _get_masks(src_lens, tgt_lens):
72
+ in_masks = lengths_to_mask(src_lens)
73
+ out_masks = lengths_to_mask(tgt_lens)
74
+ return out_masks.unsqueeze(2) & in_masks.unsqueeze(1)
75
+
76
+ def forward(self, attn, src_lens, tgt_lens, reduction="mean"):
77
+ weights = self._get_weights(src_lens, tgt_lens).to(attn.device)
78
+ masks = self._get_masks(src_lens, tgt_lens).to(attn.device)
79
+ loss = (weights * attn.transpose(1, 2)).masked_select(masks)
80
+ loss = torch.sum(loss) if reduction == "sum" else torch.mean(loss)
81
+ return loss
82
+
83
+
84
+ @register_criterion("tacotron2", dataclass=Tacotron2CriterionConfig)
85
+ class Tacotron2Criterion(FairseqCriterion):
86
+ def __init__(
87
+ self,
88
+ task,
89
+ sentence_avg,
90
+ use_guided_attention_loss,
91
+ guided_attention_loss_sigma,
92
+ bce_pos_weight,
93
+ ctc_weight,
94
+ ):
95
+ super().__init__(task)
96
+ self.sentence_avg = sentence_avg
97
+ self.bce_pos_weight = bce_pos_weight
98
+
99
+ self.guided_attn = None
100
+ if use_guided_attention_loss:
101
+ self.guided_attn = GuidedAttentionLoss(guided_attention_loss_sigma)
102
+ self.ctc_weight = ctc_weight
103
+
104
+ def forward(self, model, sample, reduction="mean"):
105
+ bsz, max_len, _ = sample["target"].size()
106
+ feat_tgt = sample["target"]
107
+ feat_len = sample["target_lengths"].view(bsz, 1).expand(-1, max_len)
108
+ eos_tgt = torch.arange(max_len).to(sample["target"].device)
109
+ eos_tgt = eos_tgt.view(1, max_len).expand(bsz, -1)
110
+ eos_tgt = (eos_tgt == (feat_len - 1)).float()
111
+ src_tokens = sample["net_input"]["src_tokens"]
112
+ src_lens = sample["net_input"]["src_lengths"]
113
+ tgt_lens = sample["target_lengths"]
114
+
115
+ feat_out, eos_out, extra = model(
116
+ src_tokens=src_tokens,
117
+ src_lengths=src_lens,
118
+ prev_output_tokens=sample["net_input"]["prev_output_tokens"],
119
+ incremental_state=None,
120
+ target_lengths=tgt_lens,
121
+ speaker=sample["speaker"],
122
+ )
123
+
124
+ l1_loss, mse_loss, eos_loss = self.compute_loss(
125
+ extra["feature_out"],
126
+ feat_out,
127
+ eos_out,
128
+ feat_tgt,
129
+ eos_tgt,
130
+ tgt_lens,
131
+ reduction,
132
+ )
133
+ attn_loss = torch.tensor(0.0).type_as(l1_loss)
134
+ if self.guided_attn is not None:
135
+ attn_loss = self.guided_attn(extra["attn"], src_lens, tgt_lens, reduction)
136
+ ctc_loss = torch.tensor(0.0).type_as(l1_loss)
137
+ if self.ctc_weight > 0.0:
138
+ net_output = (feat_out, eos_out, extra)
139
+ lprobs = model.get_normalized_probs(net_output, log_probs=True)
140
+ lprobs = lprobs.transpose(0, 1) # T x B x C
141
+ src_mask = lengths_to_mask(src_lens)
142
+ src_tokens_flat = src_tokens.masked_select(src_mask)
143
+ ctc_loss = (
144
+ F.ctc_loss(
145
+ lprobs,
146
+ src_tokens_flat,
147
+ tgt_lens,
148
+ src_lens,
149
+ reduction=reduction,
150
+ zero_infinity=True,
151
+ )
152
+ * self.ctc_weight
153
+ )
154
+ loss = l1_loss + mse_loss + eos_loss + attn_loss + ctc_loss
155
+
156
+ sample_size = sample["nsentences"] if self.sentence_avg else sample["ntokens"]
157
+ logging_output = {
158
+ "loss": utils.item(loss.data),
159
+ "ntokens": sample["ntokens"],
160
+ "nsentences": sample["nsentences"],
161
+ "sample_size": sample_size,
162
+ "l1_loss": utils.item(l1_loss.data),
163
+ "mse_loss": utils.item(mse_loss.data),
164
+ "eos_loss": utils.item(eos_loss.data),
165
+ "attn_loss": utils.item(attn_loss.data),
166
+ "ctc_loss": utils.item(ctc_loss.data),
167
+ }
168
+ return loss, sample_size, logging_output
169
+
170
+ def compute_loss(
171
+ self,
172
+ feat_out,
173
+ feat_out_post,
174
+ eos_out,
175
+ feat_tgt,
176
+ eos_tgt,
177
+ tgt_lens,
178
+ reduction="mean",
179
+ ):
180
+ mask = lengths_to_mask(tgt_lens)
181
+ _eos_out = eos_out[mask].squeeze()
182
+ _eos_tgt = eos_tgt[mask]
183
+ _feat_tgt = feat_tgt[mask]
184
+ _feat_out = feat_out[mask]
185
+ _feat_out_post = feat_out_post[mask]
186
+
187
+ l1_loss = F.l1_loss(_feat_out, _feat_tgt, reduction=reduction) + F.l1_loss(
188
+ _feat_out_post, _feat_tgt, reduction=reduction
189
+ )
190
+ mse_loss = F.mse_loss(_feat_out, _feat_tgt, reduction=reduction) + F.mse_loss(
191
+ _feat_out_post, _feat_tgt, reduction=reduction
192
+ )
193
+ eos_loss = F.binary_cross_entropy_with_logits(
194
+ _eos_out,
195
+ _eos_tgt,
196
+ pos_weight=torch.tensor(self.bce_pos_weight),
197
+ reduction=reduction,
198
+ )
199
+ return l1_loss, mse_loss, eos_loss
200
+
201
+ @classmethod
202
+ def reduce_metrics(cls, logging_outputs: List[Dict[str, Any]]) -> None:
203
+ ns = [log.get("sample_size", 0) for log in logging_outputs]
204
+ ntot = sum(ns)
205
+ ws = [n / (ntot + 1e-8) for n in ns]
206
+ for key in ["loss", "l1_loss", "mse_loss", "eos_loss", "attn_loss", "ctc_loss"]:
207
+ vals = [log.get(key, 0) for log in logging_outputs]
208
+ val = sum(val * w for val, w in zip(vals, ws))
209
+ metrics.log_scalar(key, val, ntot, round=3)
210
+ metrics.log_scalar("sample_size", ntot, len(logging_outputs))
211
+
212
+ # inference metrics
213
+ if "targ_frames" not in logging_outputs[0]:
214
+ return
215
+ n = sum(log.get("targ_frames", 0) for log in logging_outputs)
216
+ for key, new_key in [
217
+ ("mcd_loss", "mcd_loss"),
218
+ ("pred_frames", "pred_ratio"),
219
+ ("nins", "ins_rate"),
220
+ ("ndel", "del_rate"),
221
+ ]:
222
+ val = sum(log.get(key, 0) for log in logging_outputs)
223
+ metrics.log_scalar(new_key, val / n, n, round=3)
224
+
225
+ @staticmethod
226
+ def logging_outputs_can_be_summed() -> bool:
227
+ return False
fairseq/fairseq/criterions/wav2vec_criterion.py ADDED
@@ -0,0 +1,231 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import math
7
+ from dataclasses import dataclass, field
8
+ from typing import List, Optional
9
+
10
+ import torch
11
+ import torch.nn.functional as F
12
+ from fairseq import utils
13
+ from fairseq.logging import metrics
14
+ from fairseq.criterions import FairseqCriterion, register_criterion
15
+ from fairseq.dataclass import FairseqDataclass
16
+ from fairseq.logging.meters import safe_round
17
+ from fairseq.utils import is_xla_tensor
18
+
19
+
20
+ @dataclass
21
+ class Wav2VecCriterionConfig(FairseqDataclass):
22
+ infonce: bool = field(
23
+ default=False,
24
+ metadata={
25
+ "help": "if set, uses cross entropy instead of binary cross entropy (i.e. InfoNCE loss)"
26
+ },
27
+ )
28
+ loss_weights: Optional[List[float]] = field(
29
+ default=None,
30
+ metadata={"help": "weights for additional loss terms (not first one)"},
31
+ )
32
+ log_keys: List[str] = field(
33
+ default_factory=lambda: [],
34
+ metadata={"help": "output keys to log"},
35
+ )
36
+
37
+
38
+ @register_criterion("wav2vec", dataclass=Wav2VecCriterionConfig)
39
+ class Wav2vecCriterion(FairseqCriterion):
40
+ def __init__(self, task, infonce=False, loss_weights=None, log_keys=None):
41
+ super().__init__(task)
42
+ self.infonce = infonce
43
+ self.loss_weights = loss_weights
44
+ self.log_keys = [] if log_keys is None else log_keys
45
+
46
+ def forward(self, model, sample, reduce=True):
47
+ """Compute the loss for the given sample.
48
+
49
+ Returns a tuple with three elements:
50
+ 1) the loss
51
+ 2) the sample size, which is used as the denominator for the gradient
52
+ 3) logging outputs to display while training
53
+ """
54
+ net_output = model(**sample["net_input"])
55
+ logits = model.get_logits(net_output).float()
56
+ target = model.get_targets(sample, net_output)
57
+ self.xla = is_xla_tensor(logits)
58
+
59
+ # XXX: handle weights on xla.
60
+ weights = None
61
+ if hasattr(model, "get_target_weights") and not self.infonce:
62
+ weights = model.get_target_weights(target, net_output)
63
+ if torch.is_tensor(weights):
64
+ weights = weights.float()
65
+
66
+ losses = []
67
+
68
+ reduction = "none" if ((not reduce) or self.xla) else "sum"
69
+ if self.infonce:
70
+ loss = F.cross_entropy(logits, target, reduction=reduction)
71
+ else:
72
+ loss = F.binary_cross_entropy_with_logits(
73
+ logits, target.float(), weights, reduction=reduction
74
+ )
75
+
76
+ if self.xla:
77
+ # tpu-comment: since dynamic shapes lead to recompilations on xla,
78
+ # we don't shrink tensors using mask_indices.
79
+ # Instead, we use mask indices to adjust loss.
80
+ mi = (
81
+ sample["net_input"]["mask_indices"]
82
+ .transpose(0, 1) # logits are transposed in `model.get_logits`
83
+ .reshape(logits.size(0))
84
+ )
85
+ loss = (loss * mi).sum() if reduce else (loss * mi)
86
+
87
+ if "sample_size" in sample:
88
+ sample_size = sample["sample_size"]
89
+ elif "mask_indices" in sample["net_input"]:
90
+ sample_size = sample["net_input"]["mask_indices"].sum()
91
+ else:
92
+ sample_size = target.numel() if self.infonce else target.long().sum().item()
93
+ losses.append(loss.detach().clone())
94
+
95
+ if self.loss_weights is not None:
96
+ assert hasattr(model, "get_extra_losses")
97
+ extra_losses = model.get_extra_losses(net_output)
98
+ if torch.is_tensor(extra_losses):
99
+ extra_losses = [extra_losses]
100
+ if len(self.loss_weights) == 1 and len(extra_losses) != 1:
101
+ self.loss_weights = [self.loss_weights[0]] * len(extra_losses)
102
+ assert len(extra_losses) == len(
103
+ self.loss_weights
104
+ ), f"{len(extra_losses)}, {len(self.loss_weights)}"
105
+ for p, coef in zip(extra_losses, self.loss_weights):
106
+ if coef != 0 and p is not None:
107
+ p = coef * p.float() * sample_size
108
+ loss += p
109
+ losses.append(p)
110
+
111
+ logging_output = {
112
+ "loss": loss.item() if (reduce and not self.xla) else loss.detach(),
113
+ "ntokens": sample_size,
114
+ "nsentences": sample["id"].numel(),
115
+ "sample_size": sample_size,
116
+ }
117
+
118
+ for lk in self.log_keys:
119
+ # Only store "logits" and "target" for computing MAP and MAUC
120
+ # during validation
121
+ if lk == "logits":
122
+ if not self.training:
123
+ logging_output["logits"] = logits.cpu().numpy()
124
+ elif lk == "target":
125
+ if not self.training:
126
+ # If the targets have been mixed with the predictions of
127
+ # teacher models, find the original targets
128
+ if hasattr(model, "get_original_targets"):
129
+ original_target = model.get_original_targets(sample, net_output)
130
+ else:
131
+ original_target = target
132
+ logging_output["target"] = original_target.cpu().numpy()
133
+ elif lk in net_output:
134
+ value = net_output[lk]
135
+ if not is_xla_tensor(value):
136
+ value = float(value)
137
+ logging_output[lk] = value
138
+
139
+ if len(losses) > 1:
140
+ for i, l in enumerate(losses):
141
+ logging_output[f"loss_{i}"] = l.item() if not self.xla else l.detach()
142
+
143
+ if self.infonce:
144
+ with torch.no_grad():
145
+ if logits.numel() == 0:
146
+ corr = 0
147
+ count = 0
148
+ else:
149
+ assert logits.dim() > 1, logits.shape
150
+ max = logits.argmax(-1) == 0
151
+ min = logits.argmin(-1) == 0
152
+ if is_xla_tensor(logits):
153
+ max, min = max * mi, min * mi
154
+ both = max & min
155
+ corr = max.long().sum() - both.long().sum()
156
+ count = mi.sum()
157
+ else:
158
+ both = max & min
159
+ corr = max.long().sum().item() - both.long().sum().item()
160
+ count = float(max.numel())
161
+
162
+ logging_output["correct"] = corr
163
+ logging_output["count"] = count
164
+
165
+ return loss, sample_size, logging_output
166
+
167
+ @staticmethod
168
+ def reduce_metrics(logging_outputs) -> None:
169
+ """Aggregate logging outputs from data parallel training."""
170
+ loss_sum = utils.item(sum(log.get("loss", 0) for log in logging_outputs))
171
+ ntokens = utils.item(sum(log.get("ntokens", 0) for log in logging_outputs))
172
+ nsentences = utils.item(
173
+ sum(log.get("nsentences", 0) for log in logging_outputs)
174
+ )
175
+ sample_size = utils.item(
176
+ sum(log.get("sample_size", 0) for log in logging_outputs)
177
+ )
178
+
179
+ metrics.log_scalar(
180
+ "loss", loss_sum / (sample_size or 1) / math.log(2), sample_size, round=3
181
+ )
182
+ metrics.log_scalar("ntokens", ntokens)
183
+ metrics.log_scalar("nsentences", nsentences)
184
+
185
+ correct = sum(log.get("correct", 0) for log in logging_outputs)
186
+ metrics.log_scalar("_correct", correct)
187
+
188
+ total = sum(log.get("count", 0) for log in logging_outputs)
189
+ metrics.log_scalar("_total", total)
190
+
191
+ if total > 0:
192
+ metrics.log_derived(
193
+ "accuracy",
194
+ lambda meters: safe_round(
195
+ meters["_correct"].sum / meters["_total"].sum, 5
196
+ )
197
+ if meters["_total"].sum > 0
198
+ else float("nan"),
199
+ )
200
+
201
+ builtin_keys = {
202
+ "loss",
203
+ "ntokens",
204
+ "nsentences",
205
+ "sample_size",
206
+ "correct",
207
+ "count",
208
+ }
209
+
210
+ for k in logging_outputs[0]:
211
+ if k not in builtin_keys:
212
+ val = sum(log.get(k, 0) for log in logging_outputs)
213
+ if k.startswith("loss"):
214
+ metrics.log_scalar(
215
+ k, val / (sample_size or 1) / math.log(2), sample_size, round=3
216
+ )
217
+ else:
218
+ metrics.log_scalar(k, val / len(logging_outputs), round=3)
219
+
220
+ # FIXME: revert when gather based xla reduction is implemented
221
+ # @staticmethod
222
+ # def logging_outputs_can_be_summed() -> bool:
223
+ def logging_outputs_can_be_summed(self) -> bool:
224
+ """
225
+ Whether the logging outputs returned by `forward` can be summed
226
+ across workers prior to calling `reduce_metrics`. Setting this
227
+ to True will improves distributed training speed.
228
+ """
229
+ # XXX: Gather based reduction not implemented for xla yet.
230
+ # So we fall to sum based reduction for xla.
231
+ return self.xla
fairseq/fairseq/data/__init__.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ """isort:skip_file"""
6
+
7
+ from .dictionary import Dictionary, TruncatedDictionary
8
+
9
+ from .fairseq_dataset import FairseqDataset, FairseqIterableDataset
10
+
11
+ from .base_wrapper_dataset import BaseWrapperDataset
12
+
13
+ from .add_target_dataset import AddTargetDataset
14
+ from .append_token_dataset import AppendTokenDataset
15
+ from .audio.raw_audio_dataset import BinarizedAudioDataset, FileAudioDataset
16
+ from .audio.hubert_dataset import HubertDataset
17
+ from .backtranslation_dataset import BacktranslationDataset
18
+ from .bucket_pad_length_dataset import BucketPadLengthDataset
19
+ from .colorize_dataset import ColorizeDataset
20
+ from .concat_dataset import ConcatDataset
21
+ from .concat_sentences_dataset import ConcatSentencesDataset
22
+ from .denoising_dataset import DenoisingDataset
23
+ from .id_dataset import IdDataset
24
+ from .indexed_dataset import (
25
+ IndexedCachedDataset,
26
+ IndexedDataset,
27
+ IndexedRawTextDataset,
28
+ MMapIndexedDataset,
29
+ )
30
+ from .language_pair_dataset import LanguagePairDataset
31
+ from .list_dataset import ListDataset
32
+ from .lm_context_window_dataset import LMContextWindowDataset
33
+ from .lru_cache_dataset import LRUCacheDataset
34
+ from .mask_tokens_dataset import MaskTokensDataset
35
+ from .monolingual_dataset import MonolingualDataset
36
+ from .multi_corpus_sampled_dataset import MultiCorpusSampledDataset
37
+ from .nested_dictionary_dataset import NestedDictionaryDataset
38
+ from .noising import NoisingDataset
39
+ from .numel_dataset import NumelDataset
40
+ from .num_samples_dataset import NumSamplesDataset
41
+ from .offset_tokens_dataset import OffsetTokensDataset
42
+ from .padding_mask_dataset import (
43
+ LeftPaddingMaskDataset,
44
+ PaddingMaskDataset,
45
+ RightPaddingMaskDataset,
46
+ )
47
+ from .pad_dataset import LeftPadDataset, PadDataset, RightPadDataset
48
+ from .prepend_dataset import PrependDataset
49
+ from .prepend_token_dataset import PrependTokenDataset
50
+ from .raw_label_dataset import RawLabelDataset
51
+ from .replace_dataset import ReplaceDataset
52
+ from .resampling_dataset import ResamplingDataset
53
+ from .roll_dataset import RollDataset
54
+ from .round_robin_zip_datasets import RoundRobinZipDatasets
55
+ from .sort_dataset import SortDataset
56
+ from .speech_dlm_dataset import SpeechDLMDataset
57
+ from .strip_token_dataset import StripTokenDataset
58
+ from .subsample_dataset import SubsampleDataset
59
+ from .token_block_dataset import TokenBlockDataset
60
+ from .transform_eos_dataset import TransformEosDataset
61
+ from .transform_eos_lang_pair_dataset import TransformEosLangPairDataset
62
+ from .shorten_dataset import TruncateDataset, RandomCropDataset
63
+ from .multilingual.sampled_multi_dataset import SampledMultiDataset
64
+ from .multilingual.sampled_multi_epoch_dataset import SampledMultiEpochDataset
65
+ from .fasta_dataset import FastaDataset, EncodedFastaDataset
66
+ from .transform_eos_concat_langpair_dataset import TransformEosConcatLangPairDataset
67
+
68
+ from .iterators import (
69
+ CountingIterator,
70
+ EpochBatchIterator,
71
+ GroupedIterator,
72
+ ShardedIterator,
73
+ )
74
+
75
+ __all__ = [
76
+ "AddTargetDataset",
77
+ "AppendTokenDataset",
78
+ "BacktranslationDataset",
79
+ "BaseWrapperDataset",
80
+ "BinarizedAudioDataset",
81
+ "BucketPadLengthDataset",
82
+ "ColorizeDataset",
83
+ "ConcatDataset",
84
+ "ConcatSentencesDataset",
85
+ "CountingIterator",
86
+ "DenoisingDataset",
87
+ "Dictionary",
88
+ "EncodedFastaDataset",
89
+ "EpochBatchIterator",
90
+ "FairseqDataset",
91
+ "FairseqIterableDataset",
92
+ "FastaDataset",
93
+ "FileAudioDataset",
94
+ "GroupedIterator",
95
+ "HubertDataset",
96
+ "IdDataset",
97
+ "IndexedCachedDataset",
98
+ "IndexedDataset",
99
+ "IndexedRawTextDataset",
100
+ "LanguagePairDataset",
101
+ "LeftPadDataset",
102
+ "ListDataset",
103
+ "LMContextWindowDataset",
104
+ "LRUCacheDataset",
105
+ "MaskTokensDataset",
106
+ "MMapIndexedDataset",
107
+ "MonolingualDataset",
108
+ "MultiCorpusSampledDataset",
109
+ "NestedDictionaryDataset",
110
+ "NoisingDataset",
111
+ "NumelDataset",
112
+ "NumSamplesDataset",
113
+ "OffsetTokensDataset",
114
+ "PadDataset",
115
+ "PrependDataset",
116
+ "PrependTokenDataset",
117
+ "RandomCropDataset",
118
+ "RawLabelDataset",
119
+ "ResamplingDataset",
120
+ "ReplaceDataset",
121
+ "RightPadDataset",
122
+ "RollDataset",
123
+ "RoundRobinZipDatasets",
124
+ "SampledMultiDataset",
125
+ "SampledMultiEpochDataset",
126
+ "ShardedIterator",
127
+ "SortDataset",
128
+ "SpeechDLMDataset",
129
+ "StripTokenDataset",
130
+ "SubsampleDataset",
131
+ "TokenBlockDataset",
132
+ "TransformEosDataset",
133
+ "TransformEosLangPairDataset",
134
+ "TransformEosConcatLangPairDataset",
135
+ "TruncateDataset",
136
+ "TruncatedDictionary",
137
+ ]
fairseq/fairseq/data/add_class_target_dataset.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import torch
7
+
8
+ from . import BaseWrapperDataset, data_utils
9
+ from fairseq.data.text_compressor import TextCompressor, TextCompressionLevel
10
+
11
+
12
+ class AddTargetDataset(BaseWrapperDataset):
13
+ def __init__(
14
+ self,
15
+ dataset,
16
+ labels,
17
+ pad,
18
+ eos,
19
+ batch_targets,
20
+ process_label=None,
21
+ label_len_fn=None,
22
+ add_to_input=False,
23
+ text_compression_level=TextCompressionLevel.none,
24
+ ):
25
+ super().__init__(dataset)
26
+ self.labels = labels
27
+ self.batch_targets = batch_targets
28
+ self.pad = pad
29
+ self.eos = eos
30
+ self.process_label = process_label
31
+ self.label_len_fn = label_len_fn
32
+ self.add_to_input = add_to_input
33
+ self.text_compressor = TextCompressor(level=text_compression_level)
34
+
35
+ def get_label(self, index, process_fn=None):
36
+ lbl = self.labels[index]
37
+ lbl = self.text_compressor.decompress(lbl)
38
+ return lbl if process_fn is None else process_fn(lbl)
39
+
40
+ def __getitem__(self, index):
41
+ item = self.dataset[index]
42
+ item["label"] = self.get_label(index, process_fn=self.process_label)
43
+ return item
44
+
45
+ def size(self, index):
46
+ sz = self.dataset.size(index)
47
+ own_sz = self.label_len_fn(self.get_label(index))
48
+ return sz, own_sz
49
+
50
+ def collater(self, samples):
51
+ collated = self.dataset.collater(samples)
52
+ if len(collated) == 0:
53
+ return collated
54
+ indices = set(collated["id"].tolist())
55
+ target = [s["label"] for s in samples if s["id"] in indices]
56
+
57
+ if self.batch_targets:
58
+ collated["target_lengths"] = torch.LongTensor([len(t) for t in target])
59
+ target = data_utils.collate_tokens(target, pad_idx=self.pad, left_pad=False)
60
+ collated["ntokens"] = collated["target_lengths"].sum().item()
61
+ else:
62
+ collated["ntokens"] = sum([len(t) for t in target])
63
+
64
+ collated["target"] = target
65
+
66
+ if self.add_to_input:
67
+ eos = target.new_full((target.size(0), 1), self.eos)
68
+ collated["target"] = torch.cat([target, eos], dim=-1).long()
69
+ collated["net_input"]["prev_output_tokens"] = torch.cat(
70
+ [eos, target], dim=-1
71
+ ).long()
72
+ collated["ntokens"] += target.size(0)
73
+ return collated
74
+
75
+ def filter_indices_by_size(self, indices, max_sizes):
76
+ indices, ignored = data_utils._filter_by_size_dynamic(
77
+ indices, self.size, max_sizes
78
+ )
79
+ return indices, ignored
fairseq/fairseq/data/add_target_dataset.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import torch
7
+
8
+ from . import BaseWrapperDataset, data_utils
9
+ from fairseq.data.text_compressor import TextCompressor, TextCompressionLevel
10
+
11
+
12
+ class AddTargetDataset(BaseWrapperDataset):
13
+ def __init__(
14
+ self,
15
+ dataset,
16
+ labels,
17
+ pad,
18
+ eos,
19
+ batch_targets,
20
+ process_label=None,
21
+ label_len_fn=None,
22
+ add_to_input=False,
23
+ text_compression_level=TextCompressionLevel.none,
24
+ ):
25
+ super().__init__(dataset)
26
+ self.labels = labels
27
+ self.batch_targets = batch_targets
28
+ self.pad = pad
29
+ self.eos = eos
30
+ self.process_label = process_label
31
+ self.label_len_fn = label_len_fn
32
+ self.add_to_input = add_to_input
33
+ self.text_compressor = TextCompressor(level=text_compression_level)
34
+
35
+ def get_label(self, index, process_fn=None):
36
+ lbl = self.labels[index]
37
+ lbl = self.text_compressor.decompress(lbl)
38
+ return lbl if process_fn is None else process_fn(lbl)
39
+
40
+ def __getitem__(self, index):
41
+ item = self.dataset[index]
42
+ item["label"] = self.get_label(index, process_fn=self.process_label)
43
+ return item
44
+
45
+ def size(self, index):
46
+ sz = self.dataset.size(index)
47
+ own_sz = self.label_len_fn(self.get_label(index))
48
+ return sz, own_sz
49
+
50
+ def collater(self, samples):
51
+ collated = self.dataset.collater(samples)
52
+ if len(collated) == 0:
53
+ return collated
54
+ indices = set(collated["id"].tolist())
55
+ target = [s["label"] for s in samples if s["id"] in indices]
56
+
57
+ if self.add_to_input:
58
+ eos = torch.LongTensor([self.eos])
59
+ prev_output_tokens = [torch.cat([eos, t], axis=-1) for t in target]
60
+ target = [torch.cat([t, eos], axis=-1) for t in target]
61
+ collated["net_input"]["prev_output_tokens"] = prev_output_tokens
62
+
63
+ if self.batch_targets:
64
+ collated["target_lengths"] = torch.LongTensor([len(t) for t in target])
65
+ target = data_utils.collate_tokens(target, pad_idx=self.pad, left_pad=False)
66
+ collated["ntokens"] = collated["target_lengths"].sum().item()
67
+ if getattr(collated["net_input"], "prev_output_tokens", None):
68
+ collated["net_input"]["prev_output_tokens"] = data_utils.collate_tokens(
69
+ collated["net_input"]["prev_output_tokens"],
70
+ pad_idx=self.pad,
71
+ left_pad=False,
72
+ )
73
+ else:
74
+ collated["ntokens"] = sum([len(t) for t in target])
75
+
76
+ collated["target"] = target
77
+ return collated
78
+
79
+ def filter_indices_by_size(self, indices, max_sizes):
80
+ indices, ignored = data_utils._filter_by_size_dynamic(
81
+ indices, self.size, max_sizes
82
+ )
83
+ return indices, ignored
fairseq/fairseq/data/append_token_dataset.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import numpy as np
7
+ import torch
8
+
9
+ from . import BaseWrapperDataset
10
+
11
+
12
+ class AppendTokenDataset(BaseWrapperDataset):
13
+ def __init__(self, dataset, token=None):
14
+ super().__init__(dataset)
15
+ self.token = token
16
+ if token is not None:
17
+ self._sizes = np.array(dataset.sizes) + 1
18
+ else:
19
+ self._sizes = dataset.sizes
20
+
21
+ def __getitem__(self, idx):
22
+ item = self.dataset[idx]
23
+ if self.token is not None:
24
+ item = torch.cat([item, item.new([self.token])])
25
+ return item
26
+
27
+ @property
28
+ def sizes(self):
29
+ return self._sizes
30
+
31
+ def num_tokens(self, index):
32
+ n = self.dataset.num_tokens(index)
33
+ if self.token is not None:
34
+ n += 1
35
+ return n
36
+
37
+ def size(self, index):
38
+ n = self.dataset.size(index)
39
+ if self.token is not None:
40
+ n += 1
41
+ return n
fairseq/fairseq/data/audio/dataset_transforms/__pycache__/concataugment.cpython-310.pyc ADDED
Binary file (1.85 kB). View file
 
fairseq/fairseq/data/audio/feature_transforms/__pycache__/delta_deltas.cpython-310.pyc ADDED
Binary file (1.63 kB). View file
 
fairseq/fairseq/data/audio/feature_transforms/global_cmvn.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from fairseq.data.audio.feature_transforms import (
3
+ AudioFeatureTransform,
4
+ register_audio_feature_transform,
5
+ )
6
+
7
+
8
+ @register_audio_feature_transform("global_cmvn")
9
+ class GlobalCMVN(AudioFeatureTransform):
10
+ """Global CMVN (cepstral mean and variance normalization). The global mean
11
+ and variance need to be pre-computed and stored in NumPy format (.npz)."""
12
+
13
+ @classmethod
14
+ def from_config_dict(cls, config=None):
15
+ _config = {} if config is None else config
16
+ return GlobalCMVN(_config.get("stats_npz_path"))
17
+
18
+ def __init__(self, stats_npz_path):
19
+ self.stats_npz_path = stats_npz_path
20
+ stats = np.load(stats_npz_path)
21
+ self.mean, self.std = stats["mean"], stats["std"]
22
+
23
+ def __repr__(self):
24
+ return self.__class__.__name__ + f'(stats_npz_path="{self.stats_npz_path}")'
25
+
26
+ def __call__(self, x):
27
+ x = np.subtract(x, self.mean)
28
+ x = np.divide(x, self.std)
29
+ return x
fairseq/fairseq/data/audio/speech_to_text_dataset.py ADDED
@@ -0,0 +1,733 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import csv
7
+ import logging
8
+ import re
9
+ from argparse import Namespace
10
+ from collections import defaultdict
11
+ from dataclasses import dataclass
12
+ from pathlib import Path
13
+ from typing import Dict, List, Optional, Tuple, Union
14
+
15
+ import numpy as np
16
+ import torch
17
+ import torch.nn.functional as F
18
+
19
+ from fairseq.data import ConcatDataset, Dictionary, FairseqDataset, ResamplingDataset
20
+ from fairseq.data import data_utils as fairseq_data_utils
21
+ from fairseq.data import encoders
22
+ from fairseq.data.audio.audio_utils import get_features_or_waveform
23
+ from fairseq.data.audio.data_cfg import S2TDataConfig
24
+ from fairseq.data.audio.dataset_transforms import CompositeAudioDatasetTransform
25
+ from fairseq.data.audio.dataset_transforms.concataugment import ConcatAugment
26
+ from fairseq.data.audio.dataset_transforms.noisyoverlapaugment import (
27
+ NoisyOverlapAugment,
28
+ )
29
+ from fairseq.data.audio.feature_transforms import CompositeAudioFeatureTransform
30
+ from fairseq.data.audio.waveform_transforms import CompositeAudioWaveformTransform
31
+
32
+ logger = logging.getLogger(__name__)
33
+
34
+
35
+ def _collate_frames(
36
+ frames: List[torch.Tensor], is_audio_input: bool = False
37
+ ) -> torch.Tensor:
38
+ """
39
+ Convert a list of 2D frames into a padded 3D tensor
40
+ Args:
41
+ frames (list): list of 2D frames of size L[i]*f_dim. Where L[i] is
42
+ length of i-th frame and f_dim is static dimension of features
43
+ Returns:
44
+ 3D tensor of size len(frames)*len_max*f_dim where len_max is max of L[i]
45
+ """
46
+ max_len = max(frame.size(0) for frame in frames)
47
+ if is_audio_input:
48
+ out = frames[0].new_zeros((len(frames), max_len))
49
+ else:
50
+ out = frames[0].new_zeros((len(frames), max_len, frames[0].size(1)))
51
+ for i, v in enumerate(frames):
52
+ out[i, : v.size(0)] = v
53
+ return out
54
+
55
+
56
+ def _is_int_or_np_int(n):
57
+ return isinstance(n, int) or (
58
+ isinstance(n, np.generic) and isinstance(n.item(), int)
59
+ )
60
+
61
+
62
+ @dataclass
63
+ class SpeechToTextDatasetItem(object):
64
+ index: int
65
+ source: torch.Tensor
66
+ target: Optional[torch.Tensor] = None
67
+ speaker_id: Optional[int] = None
68
+
69
+
70
+ class SpeechToTextDataset(FairseqDataset):
71
+ LANG_TAG_TEMPLATE = "<lang:{}>"
72
+
73
+ def __init__(
74
+ self,
75
+ split: str,
76
+ is_train_split: bool,
77
+ cfg: S2TDataConfig,
78
+ audio_paths: List[str],
79
+ n_frames: List[int],
80
+ src_texts: Optional[List[str]] = None,
81
+ tgt_texts: Optional[List[str]] = None,
82
+ speakers: Optional[List[str]] = None,
83
+ src_langs: Optional[List[str]] = None,
84
+ tgt_langs: Optional[List[str]] = None,
85
+ ids: Optional[List[str]] = None,
86
+ tgt_dict: Optional[Dictionary] = None,
87
+ pre_tokenizer=None,
88
+ bpe_tokenizer=None,
89
+ n_frames_per_step=1,
90
+ speaker_to_id=None,
91
+ append_eos=True,
92
+ ):
93
+ self.split, self.is_train_split = split, is_train_split
94
+ self.cfg = cfg
95
+ self.audio_paths, self.n_frames = audio_paths, n_frames
96
+ self.n_samples = len(audio_paths)
97
+ assert len(n_frames) == self.n_samples > 0
98
+ assert src_texts is None or len(src_texts) == self.n_samples
99
+ assert tgt_texts is None or len(tgt_texts) == self.n_samples
100
+ assert speakers is None or len(speakers) == self.n_samples
101
+ assert src_langs is None or len(src_langs) == self.n_samples
102
+ assert tgt_langs is None or len(tgt_langs) == self.n_samples
103
+ assert ids is None or len(ids) == self.n_samples
104
+ assert (tgt_dict is None and tgt_texts is None) or (
105
+ tgt_dict is not None and tgt_texts is not None
106
+ )
107
+ self.src_texts, self.tgt_texts = src_texts, tgt_texts
108
+ self.src_langs, self.tgt_langs = src_langs, tgt_langs
109
+ self.speakers = speakers
110
+ self.tgt_dict = tgt_dict
111
+ self.check_tgt_lang_tag()
112
+ self.ids = ids
113
+ self.shuffle = cfg.shuffle if is_train_split else False
114
+
115
+ self.feature_transforms = CompositeAudioFeatureTransform.from_config_dict(
116
+ self.cfg.get_feature_transforms(split, is_train_split)
117
+ )
118
+ self.waveform_transforms = CompositeAudioWaveformTransform.from_config_dict(
119
+ self.cfg.get_waveform_transforms(split, is_train_split)
120
+ )
121
+ # TODO: add these to data_cfg.py
122
+ self.dataset_transforms = CompositeAudioDatasetTransform.from_config_dict(
123
+ self.cfg.get_dataset_transforms(split, is_train_split)
124
+ )
125
+
126
+ # check proper usage of transforms
127
+ if self.feature_transforms and self.cfg.use_audio_input:
128
+ logger.warning(
129
+ "Feature transforms will not be applied. To use feature transforms, "
130
+ "set use_audio_input as False in config."
131
+ )
132
+
133
+ self.pre_tokenizer = pre_tokenizer
134
+ self.bpe_tokenizer = bpe_tokenizer
135
+ self.n_frames_per_step = n_frames_per_step
136
+ self.speaker_to_id = speaker_to_id
137
+
138
+ self.tgt_lens = self.get_tgt_lens_and_check_oov()
139
+ self.append_eos = append_eos
140
+
141
+ logger.info(self.__repr__())
142
+
143
+ def get_tgt_lens_and_check_oov(self):
144
+ if self.tgt_texts is None:
145
+ return [0 for _ in range(self.n_samples)]
146
+ tgt_lens = []
147
+ n_tokens, n_oov_tokens = 0, 0
148
+ for i in range(self.n_samples):
149
+ tokenized = self.get_tokenized_tgt_text(i).split(" ")
150
+ oov_tokens = [
151
+ t
152
+ for t in tokenized
153
+ if self.tgt_dict.index(t) == self.tgt_dict.unk_index
154
+ ]
155
+ n_tokens += len(tokenized)
156
+ n_oov_tokens += len(oov_tokens)
157
+ tgt_lens.append(len(tokenized))
158
+ logger.info(f"'{self.split}' has {n_oov_tokens / n_tokens * 100:.2f}% OOV")
159
+ return tgt_lens
160
+
161
+ def __repr__(self):
162
+ return (
163
+ self.__class__.__name__
164
+ + f'(split="{self.split}", n_samples={self.n_samples:_}, '
165
+ f"prepend_tgt_lang_tag={self.cfg.prepend_tgt_lang_tag}, "
166
+ f"n_frames_per_step={self.n_frames_per_step}, "
167
+ f"shuffle={self.shuffle}, "
168
+ f"feature_transforms={self.feature_transforms}, "
169
+ f"waveform_transforms={self.waveform_transforms}, "
170
+ f"dataset_transforms={self.dataset_transforms})"
171
+ )
172
+
173
+ @classmethod
174
+ def is_lang_tag(cls, token):
175
+ pattern = cls.LANG_TAG_TEMPLATE.replace("{}", "(.*)")
176
+ return re.match(pattern, token)
177
+
178
+ def check_tgt_lang_tag(self):
179
+ if self.cfg.prepend_tgt_lang_tag:
180
+ assert self.tgt_langs is not None and self.tgt_dict is not None
181
+ tgt_lang_tags = [
182
+ self.LANG_TAG_TEMPLATE.format(t) for t in set(self.tgt_langs)
183
+ ]
184
+ assert all(t in self.tgt_dict for t in tgt_lang_tags)
185
+
186
+ @classmethod
187
+ def tokenize(cls, tokenizer, text: str):
188
+ return text if tokenizer is None else tokenizer.encode(text)
189
+
190
+ def get_tokenized_tgt_text(self, index: Union[int, List[int]]):
191
+ if _is_int_or_np_int(index):
192
+ text = self.tgt_texts[index]
193
+ else:
194
+ text = " ".join([self.tgt_texts[i] for i in index])
195
+
196
+ text = self.tokenize(self.pre_tokenizer, text)
197
+ text = self.tokenize(self.bpe_tokenizer, text)
198
+ return text
199
+
200
+ def pack_frames(self, feature: torch.Tensor):
201
+ if self.n_frames_per_step == 1:
202
+ return feature
203
+ n_packed_frames = feature.shape[0] // self.n_frames_per_step
204
+ feature = feature[: self.n_frames_per_step * n_packed_frames]
205
+ return feature.reshape(n_packed_frames, -1)
206
+
207
+ @classmethod
208
+ def get_lang_tag_idx(cls, lang: str, dictionary: Dictionary):
209
+ lang_tag_idx = dictionary.index(cls.LANG_TAG_TEMPLATE.format(lang))
210
+ assert lang_tag_idx != dictionary.unk()
211
+ return lang_tag_idx
212
+
213
+ def _get_source_audio(self, index: Union[int, List[int]]) -> torch.Tensor:
214
+ """
215
+ Gives source audio for given index with any relevant transforms
216
+ applied. For ConcatAug, source audios for given indices are
217
+ concatenated in given order.
218
+ Args:
219
+ index (int or List[int]): index—or in the case of ConcatAug,
220
+ indices—to pull the source audio for
221
+ Returns:
222
+ source audios concatenated for given indices with
223
+ relevant transforms appplied
224
+ """
225
+ if _is_int_or_np_int(index):
226
+ source = get_features_or_waveform(
227
+ self.audio_paths[index],
228
+ need_waveform=self.cfg.use_audio_input,
229
+ use_sample_rate=self.cfg.use_sample_rate,
230
+ waveform_transforms=self.waveform_transforms,
231
+ )
232
+ else:
233
+ source = np.concatenate(
234
+ [
235
+ get_features_or_waveform(
236
+ self.audio_paths[i],
237
+ need_waveform=self.cfg.use_audio_input,
238
+ use_sample_rate=self.cfg.use_sample_rate,
239
+ waveform_transforms=self.waveform_transforms,
240
+ )
241
+ for i in index
242
+ ]
243
+ )
244
+ if self.cfg.use_audio_input:
245
+ source = torch.from_numpy(source).float()
246
+ if self.cfg.standardize_audio:
247
+ with torch.no_grad():
248
+ source = F.layer_norm(source, source.shape)
249
+ else:
250
+ if self.feature_transforms is not None:
251
+ source = self.feature_transforms(source)
252
+ source = torch.from_numpy(source).float()
253
+ return source
254
+
255
+ def __getitem__(self, index: int) -> SpeechToTextDatasetItem:
256
+ has_concat = self.dataset_transforms.has_transform(ConcatAugment)
257
+ if has_concat:
258
+ concat = self.dataset_transforms.get_transform(ConcatAugment)
259
+ indices = concat.find_indices(index, self.n_frames, self.n_samples)
260
+
261
+ source = self._get_source_audio(indices if has_concat else index)
262
+ source = self.pack_frames(source)
263
+
264
+ target = None
265
+ if self.tgt_texts is not None:
266
+ tokenized = self.get_tokenized_tgt_text(indices if has_concat else index)
267
+ target = self.tgt_dict.encode_line(
268
+ tokenized, add_if_not_exist=False, append_eos=self.append_eos
269
+ ).long()
270
+ if self.cfg.prepend_tgt_lang_tag:
271
+ lang_tag_idx = self.get_lang_tag_idx(
272
+ self.tgt_langs[index], self.tgt_dict
273
+ )
274
+ target = torch.cat((torch.LongTensor([lang_tag_idx]), target), 0)
275
+
276
+ if self.cfg.prepend_bos_and_append_tgt_lang_tag:
277
+ bos = torch.LongTensor([self.tgt_dict.bos()])
278
+ lang_tag_idx = self.get_lang_tag_idx(self.tgt_langs[index], self.tgt_dict)
279
+ assert lang_tag_idx != self.tgt_dict.unk()
280
+ lang_tag_idx = torch.LongTensor([lang_tag_idx])
281
+ target = torch.cat((bos, target, lang_tag_idx), 0)
282
+
283
+ speaker_id = None
284
+ if self.speaker_to_id is not None:
285
+ speaker_id = self.speaker_to_id[self.speakers[index]]
286
+ return SpeechToTextDatasetItem(
287
+ index=index, source=source, target=target, speaker_id=speaker_id
288
+ )
289
+
290
+ def __len__(self):
291
+ return self.n_samples
292
+
293
+ def collater(
294
+ self, samples: List[SpeechToTextDatasetItem], return_order: bool = False
295
+ ) -> Dict:
296
+ if len(samples) == 0:
297
+ return {}
298
+ indices = torch.tensor([x.index for x in samples], dtype=torch.long)
299
+
300
+ sources = [x.source for x in samples]
301
+ has_NOAug = self.dataset_transforms.has_transform(NoisyOverlapAugment)
302
+ if has_NOAug and self.cfg.use_audio_input:
303
+ NOAug = self.dataset_transforms.get_transform(NoisyOverlapAugment)
304
+ sources = NOAug(sources)
305
+
306
+ frames = _collate_frames(sources, self.cfg.use_audio_input)
307
+ # sort samples by descending number of frames
308
+ n_frames = torch.tensor([x.size(0) for x in sources], dtype=torch.long)
309
+ n_frames, order = n_frames.sort(descending=True)
310
+ indices = indices.index_select(0, order)
311
+ frames = frames.index_select(0, order)
312
+
313
+ target, target_lengths = None, None
314
+ prev_output_tokens = None
315
+ ntokens = None
316
+ if self.tgt_texts is not None:
317
+ target = fairseq_data_utils.collate_tokens(
318
+ [x.target for x in samples],
319
+ self.tgt_dict.pad(),
320
+ self.tgt_dict.eos(),
321
+ left_pad=False,
322
+ move_eos_to_beginning=False,
323
+ )
324
+ target = target.index_select(0, order)
325
+ target_lengths = torch.tensor(
326
+ [x.target.size(0) for x in samples], dtype=torch.long
327
+ ).index_select(0, order)
328
+ prev_output_tokens = fairseq_data_utils.collate_tokens(
329
+ [x.target for x in samples],
330
+ self.tgt_dict.pad(),
331
+ eos_idx=None,
332
+ left_pad=False,
333
+ move_eos_to_beginning=True,
334
+ )
335
+ prev_output_tokens = prev_output_tokens.index_select(0, order)
336
+ ntokens = sum(x.target.size(0) for x in samples)
337
+
338
+ speaker = None
339
+ if self.speaker_to_id is not None:
340
+ speaker = (
341
+ torch.tensor([s.speaker_id for s in samples], dtype=torch.long)
342
+ .index_select(0, order)
343
+ .view(-1, 1)
344
+ )
345
+
346
+ net_input = {
347
+ "src_tokens": frames,
348
+ "src_lengths": n_frames,
349
+ "prev_output_tokens": prev_output_tokens,
350
+ }
351
+ out = {
352
+ "id": indices,
353
+ "net_input": net_input,
354
+ "speaker": speaker,
355
+ "target": target,
356
+ "target_lengths": target_lengths,
357
+ "ntokens": ntokens,
358
+ "nsentences": len(samples),
359
+ }
360
+ if return_order:
361
+ out["order"] = order
362
+ return out
363
+
364
+ def num_tokens(self, index):
365
+ return self.n_frames[index]
366
+
367
+ def size(self, index):
368
+ return self.n_frames[index], self.tgt_lens[index]
369
+
370
+ @property
371
+ def sizes(self):
372
+ return np.array(self.n_frames)
373
+
374
+ @property
375
+ def can_reuse_epoch_itr_across_epochs(self):
376
+ return True
377
+
378
+ def ordered_indices(self):
379
+ if self.shuffle:
380
+ order = [np.random.permutation(len(self))]
381
+ else:
382
+ order = [np.arange(len(self))]
383
+ # first by descending order of # of frames then by original/random order
384
+ order.append([-n for n in self.n_frames])
385
+ return np.lexsort(order)
386
+
387
+ def prefetch(self, indices):
388
+ raise False
389
+
390
+
391
+ class TextTargetMultitaskData(object):
392
+ # mandatory columns
393
+ KEY_ID, KEY_TEXT = "id", "tgt_text"
394
+ LANG_TAG_TEMPLATE = "<lang:{}>"
395
+
396
+ def __init__(self, args, split, tgt_dict):
397
+ samples = SpeechToTextDatasetCreator._load_samples_from_tsv(args.data, split)
398
+ self.data = {s[self.KEY_ID]: s[self.KEY_TEXT] for s in samples}
399
+ self.dict = tgt_dict
400
+ self.append_eos = args.decoder_type != "ctc"
401
+ self.pre_tokenizer = self.build_tokenizer(args)
402
+ self.bpe_tokenizer = self.build_bpe(args)
403
+ self.prepend_bos_and_append_tgt_lang_tag = (
404
+ args.prepend_bos_and_append_tgt_lang_tag
405
+ )
406
+ self.eos_token = args.eos_token
407
+ self.lang_tag_mapping = args.get_lang_tag_mapping
408
+
409
+ @classmethod
410
+ def is_lang_tag(cls, token):
411
+ pattern = cls.LANG_TAG_TEMPLATE.replace("{}", "(.*)")
412
+ return re.match(pattern, token)
413
+
414
+ @classmethod
415
+ def tokenize(cls, tokenizer, text: str):
416
+ return text if tokenizer is None else tokenizer.encode(text)
417
+
418
+ def get_tokenized_tgt_text(self, index: int):
419
+ text = self.tokenize(self.pre_tokenizer, self.data[index])
420
+ text = self.tokenize(self.bpe_tokenizer, text)
421
+ return text
422
+
423
+ def get_lang_tag_idx(self, lang: str, dictionary: Dictionary):
424
+ lang_tag = self.LANG_TAG_TEMPLATE.format(lang)
425
+ lang_tag = self.lang_tag_mapping.get(lang_tag, lang_tag)
426
+ lang_tag_idx = dictionary.index(lang_tag)
427
+ assert lang_tag_idx != dictionary.unk(), (lang, lang_tag)
428
+ return lang_tag_idx
429
+
430
+ def build_tokenizer(self, args):
431
+ pre_tokenizer = args.config.get("pre_tokenizer")
432
+ if pre_tokenizer is not None:
433
+ logger.info(f"pre-tokenizer: {pre_tokenizer}")
434
+ return encoders.build_tokenizer(Namespace(**pre_tokenizer))
435
+ else:
436
+ return None
437
+
438
+ def build_bpe(self, args):
439
+ bpe_tokenizer = args.config.get("bpe_tokenizer")
440
+ if bpe_tokenizer is not None:
441
+ logger.info(f"tokenizer: {bpe_tokenizer}")
442
+ return encoders.build_bpe(Namespace(**bpe_tokenizer))
443
+ else:
444
+ return None
445
+
446
+ def get(self, sample_id, tgt_lang=None):
447
+ if sample_id in self.data:
448
+ tokenized = self.get_tokenized_tgt_text(sample_id)
449
+ target = self.dict.encode_line(
450
+ tokenized,
451
+ add_if_not_exist=False,
452
+ append_eos=self.append_eos,
453
+ )
454
+ if self.prepend_bos_and_append_tgt_lang_tag:
455
+ bos = torch.LongTensor([self.dict.bos()])
456
+ lang_tag_idx = self.get_lang_tag_idx(tgt_lang, self.dict)
457
+ assert lang_tag_idx != self.dict.unk()
458
+ lang_tag_idx = torch.LongTensor([lang_tag_idx])
459
+ target = torch.cat((bos, target, lang_tag_idx), 0)
460
+ return target
461
+ else:
462
+ logger.warning(f"no target for {sample_id}")
463
+ return torch.IntTensor([])
464
+
465
+ def collater(self, samples: List[torch.Tensor]) -> torch.Tensor:
466
+ out = fairseq_data_utils.collate_tokens(
467
+ samples,
468
+ self.dict.pad(),
469
+ eos_idx=None,
470
+ left_pad=False,
471
+ move_eos_to_beginning=False,
472
+ ).long()
473
+
474
+ prev_out = fairseq_data_utils.collate_tokens(
475
+ samples,
476
+ self.dict.pad(),
477
+ eos_idx=None,
478
+ left_pad=False,
479
+ move_eos_to_beginning=True,
480
+ ).long()
481
+
482
+ target_lengths = torch.tensor([t.size(0) for t in samples], dtype=torch.long)
483
+ ntokens = sum(t.size(0) for t in samples)
484
+
485
+ output = {
486
+ "prev_output_tokens": prev_out,
487
+ "target": out,
488
+ "target_lengths": target_lengths,
489
+ "ntokens": ntokens,
490
+ }
491
+
492
+ return output
493
+
494
+
495
+ class SpeechToTextMultitaskDataset(SpeechToTextDataset):
496
+ def __init__(self, **kwargs):
497
+ super().__init__(**kwargs)
498
+ self.multitask_data = {}
499
+
500
+ def add_multitask_dataset(self, task_name, task_data):
501
+ self.multitask_data[task_name] = task_data
502
+
503
+ def __getitem__(
504
+ self, index: int
505
+ ) -> Tuple[SpeechToTextDatasetItem, Dict[str, torch.Tensor]]:
506
+ s2t_data = super().__getitem__(index)
507
+
508
+ multitask_target = {}
509
+ sample_id = self.ids[index]
510
+ tgt_lang = self.tgt_langs[index]
511
+ for task_name, task_dataset in self.multitask_data.items():
512
+ multitask_target[task_name] = task_dataset.get(sample_id, tgt_lang)
513
+
514
+ return s2t_data, multitask_target
515
+
516
+ def collater(
517
+ self, samples: List[Tuple[SpeechToTextDatasetItem, Dict[str, torch.Tensor]]]
518
+ ) -> Dict:
519
+ if len(samples) == 0:
520
+ return {}
521
+
522
+ out = super().collater([s for s, _ in samples], return_order=True)
523
+ order = out["order"]
524
+ del out["order"]
525
+
526
+ for task_name, task_dataset in self.multitask_data.items():
527
+ if "multitask" not in out:
528
+ out["multitask"] = {}
529
+ d = [s[task_name] for _, s in samples]
530
+ task_target = task_dataset.collater(d)
531
+ out["multitask"][task_name] = {
532
+ "target": task_target["target"].index_select(0, order),
533
+ "target_lengths": task_target["target_lengths"].index_select(0, order),
534
+ "ntokens": task_target["ntokens"],
535
+ }
536
+ out["multitask"][task_name]["net_input"] = {
537
+ "prev_output_tokens": task_target["prev_output_tokens"].index_select(
538
+ 0, order
539
+ ),
540
+ }
541
+
542
+ return out
543
+
544
+
545
+ class SpeechToTextDatasetCreator(object):
546
+ # mandatory columns
547
+ KEY_ID, KEY_AUDIO, KEY_N_FRAMES = "id", "audio", "n_frames"
548
+ KEY_TGT_TEXT = "tgt_text"
549
+ # optional columns
550
+ KEY_SPEAKER, KEY_SRC_TEXT = "speaker", "src_text"
551
+ KEY_SRC_LANG, KEY_TGT_LANG = "src_lang", "tgt_lang"
552
+ # default values
553
+ DEFAULT_SPEAKER = DEFAULT_SRC_TEXT = DEFAULT_LANG = ""
554
+
555
+ @classmethod
556
+ def _from_list(
557
+ cls,
558
+ split_name: str,
559
+ is_train_split,
560
+ samples: List[Dict],
561
+ cfg: S2TDataConfig,
562
+ tgt_dict,
563
+ pre_tokenizer,
564
+ bpe_tokenizer,
565
+ n_frames_per_step,
566
+ speaker_to_id,
567
+ multitask: Optional[Dict] = None,
568
+ ) -> SpeechToTextDataset:
569
+ audio_root = Path(cfg.audio_root)
570
+ ids = [s[cls.KEY_ID] for s in samples]
571
+ audio_paths = [(audio_root / s[cls.KEY_AUDIO]).as_posix() for s in samples]
572
+ n_frames = [int(s[cls.KEY_N_FRAMES]) for s in samples]
573
+ tgt_texts = [s[cls.KEY_TGT_TEXT] for s in samples]
574
+ src_texts = [s.get(cls.KEY_SRC_TEXT, cls.DEFAULT_SRC_TEXT) for s in samples]
575
+ speakers = [s.get(cls.KEY_SPEAKER, cls.DEFAULT_SPEAKER) for s in samples]
576
+ src_langs = [s.get(cls.KEY_SRC_LANG, cls.DEFAULT_LANG) for s in samples]
577
+ tgt_langs = [s.get(cls.KEY_TGT_LANG, cls.DEFAULT_LANG) for s in samples]
578
+
579
+ has_multitask = multitask is not None and len(multitask.keys()) > 0
580
+ dataset_cls = (
581
+ SpeechToTextMultitaskDataset if has_multitask else SpeechToTextDataset
582
+ )
583
+
584
+ ds = dataset_cls(
585
+ split=split_name,
586
+ is_train_split=is_train_split,
587
+ cfg=cfg,
588
+ audio_paths=audio_paths,
589
+ n_frames=n_frames,
590
+ src_texts=src_texts,
591
+ tgt_texts=tgt_texts,
592
+ speakers=speakers,
593
+ src_langs=src_langs,
594
+ tgt_langs=tgt_langs,
595
+ ids=ids,
596
+ tgt_dict=tgt_dict,
597
+ pre_tokenizer=pre_tokenizer,
598
+ bpe_tokenizer=bpe_tokenizer,
599
+ n_frames_per_step=n_frames_per_step,
600
+ speaker_to_id=speaker_to_id,
601
+ )
602
+
603
+ if has_multitask:
604
+ for task_name, task_obj in multitask.items():
605
+ task_data = TextTargetMultitaskData(
606
+ task_obj.args, split_name, task_obj.target_dictionary
607
+ )
608
+ ds.add_multitask_dataset(task_name, task_data)
609
+ return ds
610
+
611
+ @classmethod
612
+ def get_size_ratios(
613
+ cls, datasets: List[SpeechToTextDataset], alpha: float = 1.0
614
+ ) -> List[float]:
615
+ """Size ratios for temperature-based sampling
616
+ (https://arxiv.org/abs/1907.05019)"""
617
+
618
+ id_to_lp, lp_to_sz = {}, defaultdict(int)
619
+ for ds in datasets:
620
+ lang_pairs = {f"{s}->{t}" for s, t in zip(ds.src_langs, ds.tgt_langs)}
621
+ assert len(lang_pairs) == 1
622
+ lang_pair = list(lang_pairs)[0]
623
+ id_to_lp[ds.split] = lang_pair
624
+ lp_to_sz[lang_pair] += sum(ds.n_frames)
625
+
626
+ sz_sum = sum(v for v in lp_to_sz.values())
627
+ lp_to_prob = {k: v / sz_sum for k, v in lp_to_sz.items()}
628
+ lp_to_tgt_prob = {k: v**alpha for k, v in lp_to_prob.items()}
629
+ prob_sum = sum(v for v in lp_to_tgt_prob.values())
630
+ lp_to_tgt_prob = {k: v / prob_sum for k, v in lp_to_tgt_prob.items()}
631
+ lp_to_sz_ratio = {
632
+ k: (lp_to_tgt_prob[k] * sz_sum) / v for k, v in lp_to_sz.items()
633
+ }
634
+ size_ratio = [lp_to_sz_ratio[id_to_lp[ds.split]] for ds in datasets]
635
+
636
+ p_formatted = {
637
+ k: f"{lp_to_prob[k]:.3f}->{lp_to_tgt_prob[k]:.3f}" for k in lp_to_sz
638
+ }
639
+ logger.info(f"sampling probability balancing: {p_formatted}")
640
+ sr_formatted = {ds.split: f"{r:.3f}" for ds, r in zip(datasets, size_ratio)}
641
+ logger.info(f"balanced sampling size ratio: {sr_formatted}")
642
+ return size_ratio
643
+
644
+ @classmethod
645
+ def _load_samples_from_tsv(cls, root: str, split: str):
646
+ tsv_path = Path(root) / f"{split}.tsv"
647
+ if not tsv_path.is_file():
648
+ raise FileNotFoundError(f"Dataset not found: {tsv_path}")
649
+ with open(tsv_path) as f:
650
+ reader = csv.DictReader(
651
+ f,
652
+ delimiter="\t",
653
+ quotechar=None,
654
+ doublequote=False,
655
+ lineterminator="\n",
656
+ quoting=csv.QUOTE_NONE,
657
+ )
658
+ samples = [dict(e) for e in reader]
659
+ if len(samples) == 0:
660
+ raise ValueError(f"Empty manifest: {tsv_path}")
661
+ return samples
662
+
663
+ @classmethod
664
+ def _from_tsv(
665
+ cls,
666
+ root: str,
667
+ cfg: S2TDataConfig,
668
+ split: str,
669
+ tgt_dict,
670
+ is_train_split: bool,
671
+ pre_tokenizer,
672
+ bpe_tokenizer,
673
+ n_frames_per_step,
674
+ speaker_to_id,
675
+ multitask: Optional[Dict] = None,
676
+ ) -> SpeechToTextDataset:
677
+ samples = cls._load_samples_from_tsv(root, split)
678
+ return cls._from_list(
679
+ split,
680
+ is_train_split,
681
+ samples,
682
+ cfg,
683
+ tgt_dict,
684
+ pre_tokenizer,
685
+ bpe_tokenizer,
686
+ n_frames_per_step,
687
+ speaker_to_id,
688
+ multitask,
689
+ )
690
+
691
+ @classmethod
692
+ def from_tsv(
693
+ cls,
694
+ root: str,
695
+ cfg: S2TDataConfig,
696
+ splits: str,
697
+ tgt_dict,
698
+ pre_tokenizer,
699
+ bpe_tokenizer,
700
+ is_train_split: bool,
701
+ epoch: int,
702
+ seed: int,
703
+ n_frames_per_step: int = 1,
704
+ speaker_to_id=None,
705
+ multitask: Optional[Dict] = None,
706
+ ) -> SpeechToTextDataset:
707
+ datasets = [
708
+ cls._from_tsv(
709
+ root=root,
710
+ cfg=cfg,
711
+ split=split,
712
+ tgt_dict=tgt_dict,
713
+ is_train_split=is_train_split,
714
+ pre_tokenizer=pre_tokenizer,
715
+ bpe_tokenizer=bpe_tokenizer,
716
+ n_frames_per_step=n_frames_per_step,
717
+ speaker_to_id=speaker_to_id,
718
+ multitask=multitask,
719
+ )
720
+ for split in splits.split(",")
721
+ ]
722
+
723
+ if is_train_split and len(datasets) > 1 and cfg.sampling_alpha != 1.0:
724
+ # temperature-based sampling
725
+ size_ratios = cls.get_size_ratios(datasets, alpha=cfg.sampling_alpha)
726
+ datasets = [
727
+ ResamplingDataset(
728
+ d, size_ratio=r, seed=seed, epoch=epoch, replace=(r >= 1.0)
729
+ )
730
+ for r, d in zip(size_ratios, datasets)
731
+ ]
732
+
733
+ return ConcatDataset(datasets) if len(datasets) > 1 else datasets[0]
fairseq/fairseq/data/backtranslation_dataset.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import torch
7
+ from fairseq import utils
8
+
9
+ from . import FairseqDataset
10
+
11
+
12
+ def backtranslate_samples(samples, collate_fn, generate_fn, cuda=True):
13
+ """Backtranslate a list of samples.
14
+
15
+ Given an input (*samples*) of the form:
16
+
17
+ [{'id': 1, 'source': 'hallo welt'}]
18
+
19
+ this will return:
20
+
21
+ [{'id': 1, 'source': 'hello world', 'target': 'hallo welt'}]
22
+
23
+ Args:
24
+ samples (List[dict]): samples to backtranslate. Individual samples are
25
+ expected to have a 'source' key, which will become the 'target'
26
+ after backtranslation.
27
+ collate_fn (callable): function to collate samples into a mini-batch
28
+ generate_fn (callable): function to generate backtranslations
29
+ cuda (bool): use GPU for generation (default: ``True``)
30
+
31
+ Returns:
32
+ List[dict]: an updated list of samples with a backtranslated source
33
+ """
34
+ collated_samples = collate_fn(samples)
35
+ s = utils.move_to_cuda(collated_samples) if cuda else collated_samples
36
+ generated_sources = generate_fn(s)
37
+
38
+ id_to_src = {sample["id"]: sample["source"] for sample in samples}
39
+
40
+ # Go through each tgt sentence in batch and its corresponding best
41
+ # generated hypothesis and create a backtranslation data pair
42
+ # {id: id, source: generated backtranslation, target: original tgt}
43
+ return [
44
+ {
45
+ "id": id.item(),
46
+ "target": id_to_src[id.item()],
47
+ "source": hypos[0]["tokens"].cpu(),
48
+ }
49
+ for id, hypos in zip(collated_samples["id"], generated_sources)
50
+ ]
51
+
52
+
53
+ class BacktranslationDataset(FairseqDataset):
54
+ """
55
+ Sets up a backtranslation dataset which takes a tgt batch, generates
56
+ a src using a tgt-src backtranslation function (*backtranslation_fn*),
57
+ and returns the corresponding `{generated src, input tgt}` batch.
58
+
59
+ Args:
60
+ tgt_dataset (~fairseq.data.FairseqDataset): the dataset to be
61
+ backtranslated. Only the source side of this dataset will be used.
62
+ After backtranslation, the source sentences in this dataset will be
63
+ returned as the targets.
64
+ src_dict (~fairseq.data.Dictionary): the dictionary of backtranslated
65
+ sentences.
66
+ tgt_dict (~fairseq.data.Dictionary, optional): the dictionary of
67
+ sentences to be backtranslated.
68
+ backtranslation_fn (callable, optional): function to call to generate
69
+ backtranslations. This is typically the `generate` method of a
70
+ :class:`~fairseq.sequence_generator.SequenceGenerator` object.
71
+ Pass in None when it is not available at initialization time, and
72
+ use set_backtranslation_fn function to set it when available.
73
+ output_collater (callable, optional): function to call on the
74
+ backtranslated samples to create the final batch
75
+ (default: ``tgt_dataset.collater``).
76
+ cuda: use GPU for generation
77
+ """
78
+
79
+ def __init__(
80
+ self,
81
+ tgt_dataset,
82
+ src_dict,
83
+ tgt_dict=None,
84
+ backtranslation_fn=None,
85
+ output_collater=None,
86
+ cuda=True,
87
+ **kwargs
88
+ ):
89
+ self.tgt_dataset = tgt_dataset
90
+ self.backtranslation_fn = backtranslation_fn
91
+ self.output_collater = (
92
+ output_collater if output_collater is not None else tgt_dataset.collater
93
+ )
94
+ self.cuda = cuda if torch.cuda.is_available() else False
95
+ self.src_dict = src_dict
96
+ self.tgt_dict = tgt_dict
97
+
98
+ def __getitem__(self, index):
99
+ """
100
+ Returns a single sample from *tgt_dataset*. Note that backtranslation is
101
+ not applied in this step; use :func:`collater` instead to backtranslate
102
+ a batch of samples.
103
+ """
104
+ return self.tgt_dataset[index]
105
+
106
+ def __len__(self):
107
+ return len(self.tgt_dataset)
108
+
109
+ def set_backtranslation_fn(self, backtranslation_fn):
110
+ self.backtranslation_fn = backtranslation_fn
111
+
112
+ def collater(self, samples):
113
+ """Merge and backtranslate a list of samples to form a mini-batch.
114
+
115
+ Using the samples from *tgt_dataset*, load a collated target sample to
116
+ feed to the backtranslation model. Then take the backtranslation with
117
+ the best score as the source and the original input as the target.
118
+
119
+ Note: we expect *tgt_dataset* to provide a function `collater()` that
120
+ will collate samples into the format expected by *backtranslation_fn*.
121
+ After backtranslation, we will feed the new list of samples (i.e., the
122
+ `(backtranslated source, original source)` pairs) to *output_collater*
123
+ and return the result.
124
+
125
+ Args:
126
+ samples (List[dict]): samples to backtranslate and collate
127
+
128
+ Returns:
129
+ dict: a mini-batch with keys coming from *output_collater*
130
+ """
131
+ if samples[0].get("is_dummy", False):
132
+ return samples
133
+ samples = backtranslate_samples(
134
+ samples=samples,
135
+ collate_fn=self.tgt_dataset.collater,
136
+ generate_fn=(lambda net_input: self.backtranslation_fn(net_input)),
137
+ cuda=self.cuda,
138
+ )
139
+ return self.output_collater(samples)
140
+
141
+ def num_tokens(self, index):
142
+ """Just use the tgt dataset num_tokens"""
143
+ return self.tgt_dataset.num_tokens(index)
144
+
145
+ def ordered_indices(self):
146
+ """Just use the tgt dataset ordered_indices"""
147
+ return self.tgt_dataset.ordered_indices()
148
+
149
+ def size(self, index):
150
+ """Return an example's size as a float or tuple. This value is used
151
+ when filtering a dataset with ``--max-positions``.
152
+
153
+ Note: we use *tgt_dataset* to approximate the length of the source
154
+ sentence, since we do not know the actual length until after
155
+ backtranslation.
156
+ """
157
+ tgt_size = self.tgt_dataset.size(index)[0]
158
+ return (tgt_size, tgt_size)
159
+
160
+ @property
161
+ def supports_prefetch(self):
162
+ return getattr(self.tgt_dataset, "supports_prefetch", False)
163
+
164
+ def prefetch(self, indices):
165
+ return self.tgt_dataset.prefetch(indices)
fairseq/fairseq/data/base_wrapper_dataset.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from torch.utils.data.dataloader import default_collate
7
+
8
+ from . import FairseqDataset
9
+
10
+
11
+ class BaseWrapperDataset(FairseqDataset):
12
+ def __init__(self, dataset):
13
+ super().__init__()
14
+ self.dataset = dataset
15
+
16
+ def __getitem__(self, index):
17
+ return self.dataset[index]
18
+
19
+ def __len__(self):
20
+ return len(self.dataset)
21
+
22
+ def collater(self, samples):
23
+ if hasattr(self.dataset, "collater"):
24
+ return self.dataset.collater(samples)
25
+ else:
26
+ return default_collate(samples)
27
+
28
+ @property
29
+ def sizes(self):
30
+ return self.dataset.sizes
31
+
32
+ def num_tokens(self, index):
33
+ return self.dataset.num_tokens(index)
34
+
35
+ def size(self, index):
36
+ return self.dataset.size(index)
37
+
38
+ def ordered_indices(self):
39
+ return self.dataset.ordered_indices()
40
+
41
+ @property
42
+ def supports_prefetch(self):
43
+ return getattr(self.dataset, "supports_prefetch", False)
44
+
45
+ def attr(self, attr: str, index: int):
46
+ return self.dataset.attr(attr, index)
47
+
48
+ def prefetch(self, indices):
49
+ self.dataset.prefetch(indices)
50
+
51
+ def get_batch_shapes(self):
52
+ return self.dataset.get_batch_shapes()
53
+
54
+ def batch_by_size(
55
+ self,
56
+ indices,
57
+ max_tokens=None,
58
+ max_sentences=None,
59
+ required_batch_size_multiple=1,
60
+ ):
61
+ return self.dataset.batch_by_size(
62
+ indices,
63
+ max_tokens=max_tokens,
64
+ max_sentences=max_sentences,
65
+ required_batch_size_multiple=required_batch_size_multiple,
66
+ )
67
+
68
+ def filter_indices_by_size(self, indices, max_sizes):
69
+ return self.dataset.filter_indices_by_size(indices, max_sizes)
70
+
71
+ @property
72
+ def can_reuse_epoch_itr_across_epochs(self):
73
+ return self.dataset.can_reuse_epoch_itr_across_epochs
74
+
75
+ def set_epoch(self, epoch):
76
+ super().set_epoch(epoch)
77
+ if hasattr(self.dataset, "set_epoch"):
78
+ self.dataset.set_epoch(epoch)
fairseq/fairseq/data/bucket_pad_length_dataset.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import numpy as np
7
+ import torch.nn.functional as F
8
+ from fairseq.data import BaseWrapperDataset
9
+ from fairseq.data.data_utils import get_buckets, get_bucketed_sizes
10
+
11
+
12
+ class BucketPadLengthDataset(BaseWrapperDataset):
13
+ """
14
+ Bucket and pad item lengths to the nearest bucket size. This can be used to
15
+ reduce the number of unique batch shapes, which is important on TPUs since
16
+ each new batch shape requires a recompilation.
17
+
18
+ Args:
19
+ dataset (FairseqDatset): dataset to bucket
20
+ sizes (List[int]): all item sizes
21
+ num_buckets (int): number of buckets to create
22
+ pad_idx (int): padding symbol
23
+ left_pad (bool): if True, pad on the left; otherwise right pad
24
+ """
25
+
26
+ def __init__(
27
+ self,
28
+ dataset,
29
+ sizes,
30
+ num_buckets,
31
+ pad_idx,
32
+ left_pad,
33
+ tensor_key=None,
34
+ ):
35
+ super().__init__(dataset)
36
+ self.pad_idx = pad_idx
37
+ self.left_pad = left_pad
38
+
39
+ assert num_buckets > 0
40
+ self.buckets = get_buckets(sizes, num_buckets)
41
+ self._bucketed_sizes = get_bucketed_sizes(sizes, self.buckets)
42
+ self._tensor_key = tensor_key
43
+
44
+ def _set_tensor(self, item, val):
45
+ if self._tensor_key is None:
46
+ return val
47
+ item[self._tensor_key] = val
48
+ return item
49
+
50
+ def _get_tensor(self, item):
51
+ if self._tensor_key is None:
52
+ return item
53
+ return item[self._tensor_key]
54
+
55
+ def _pad(self, tensor, bucket_size, dim=-1):
56
+ num_pad = bucket_size - tensor.size(dim)
57
+ return F.pad(
58
+ tensor,
59
+ (num_pad if self.left_pad else 0, 0 if self.left_pad else num_pad),
60
+ value=self.pad_idx,
61
+ )
62
+
63
+ def __getitem__(self, index):
64
+ item = self.dataset[index]
65
+ bucket_size = self._bucketed_sizes[index]
66
+ tensor = self._get_tensor(item)
67
+ padded = self._pad(tensor, bucket_size)
68
+ return self._set_tensor(item, padded)
69
+
70
+ @property
71
+ def sizes(self):
72
+ return self._bucketed_sizes
73
+
74
+ def num_tokens(self, index):
75
+ return self._bucketed_sizes[index]
76
+
77
+ def size(self, index):
78
+ return self._bucketed_sizes[index]
fairseq/fairseq/data/codedataset.py ADDED
@@ -0,0 +1,576 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+
7
+ import json
8
+ import logging
9
+ import os
10
+ import random
11
+ from pathlib import Path
12
+
13
+ import numpy as np
14
+ import torch
15
+ import torch.utils.data
16
+
17
+ from . import data_utils
18
+ from fairseq.data.fairseq_dataset import FairseqDataset
19
+
20
+ F0_FRAME_SPACE = 0.005 # sec
21
+
22
+
23
+ logger = logging.getLogger(__name__)
24
+
25
+
26
+ class ExpressiveCodeDataConfig(object):
27
+ def __init__(self, json_path):
28
+ with open(json_path, "r") as f:
29
+ self.config = json.load(f)
30
+ self._manifests = self.config["manifests"]
31
+
32
+ @property
33
+ def manifests(self):
34
+ return self._manifests
35
+
36
+ @property
37
+ def n_units(self):
38
+ return self.config["n_units"]
39
+
40
+ @property
41
+ def sampling_rate(self):
42
+ return self.config["sampling_rate"]
43
+
44
+ @property
45
+ def code_hop_size(self):
46
+ return self.config["code_hop_size"]
47
+
48
+ @property
49
+ def f0_stats(self):
50
+ """pre-computed f0 statistics path"""
51
+ return self.config.get("f0_stats", None)
52
+
53
+ @property
54
+ def f0_vq_type(self):
55
+ """naive or precomp"""
56
+ return self.config["f0_vq_type"]
57
+
58
+ @property
59
+ def f0_vq_name(self):
60
+ return self.config["f0_vq_name"]
61
+
62
+ def get_f0_vq_naive_quantizer(self, log, norm_mean, norm_std):
63
+ key = "log" if log else "linear"
64
+ if norm_mean and norm_std:
65
+ key += "_mean_std_norm"
66
+ elif norm_mean:
67
+ key += "_mean_norm"
68
+ else:
69
+ key += "_none_norm"
70
+ return self.config["f0_vq_naive_quantizer"][key]
71
+
72
+ @property
73
+ def f0_vq_n_units(self):
74
+ return self.config["f0_vq_n_units"]
75
+
76
+ @property
77
+ def multispkr(self):
78
+ """how to parse speaker label from audio path"""
79
+ return self.config.get("multispkr", None)
80
+
81
+
82
+ def get_f0(audio, rate=16000):
83
+ try:
84
+ import amfm_decompy.basic_tools as basic
85
+ import amfm_decompy.pYAAPT as pYAAPT
86
+ from librosa.util import normalize
87
+ except ImportError:
88
+ raise "Please install amfm_decompy (`pip install AMFM-decompy`) and librosa (`pip install librosa`)."
89
+
90
+ assert audio.ndim == 1
91
+ frame_length = 20.0 # ms
92
+ to_pad = int(frame_length / 1000 * rate) // 2
93
+
94
+ audio = normalize(audio) * 0.95
95
+ audio = np.pad(audio, (to_pad, to_pad), "constant", constant_values=0)
96
+ audio = basic.SignalObj(audio, rate)
97
+ pitch = pYAAPT.yaapt(
98
+ audio,
99
+ frame_length=frame_length,
100
+ frame_space=F0_FRAME_SPACE * 1000,
101
+ nccf_thresh1=0.25,
102
+ tda_frame_length=25.0,
103
+ )
104
+ f0 = pitch.samp_values
105
+ return f0
106
+
107
+
108
+ def interpolate_f0(f0):
109
+ try:
110
+ from scipy.interpolate import interp1d
111
+ except ImportError:
112
+ raise "Please install scipy (`pip install scipy`)"
113
+
114
+ orig_t = np.arange(f0.shape[0])
115
+ f0_interp = f0[:]
116
+ ii = f0_interp != 0
117
+ if ii.sum() > 1:
118
+ f0_interp = interp1d(
119
+ orig_t[ii], f0_interp[ii], bounds_error=False, kind="linear", fill_value=0
120
+ )(orig_t)
121
+ f0_interp = torch.Tensor(f0_interp).type_as(f0).to(f0.device)
122
+ return f0_interp
123
+
124
+
125
+ def naive_quantize(x, edges):
126
+ bin_idx = (x.view(-1, 1) > edges.view(1, -1)).long().sum(dim=1)
127
+ return bin_idx
128
+
129
+
130
+ def load_wav(full_path):
131
+ try:
132
+ import soundfile as sf
133
+ except ImportError:
134
+ raise "Please install soundfile (`pip install SoundFile`)"
135
+ data, sampling_rate = sf.read(full_path)
136
+ return data, sampling_rate
137
+
138
+
139
+ def parse_code(code_str, dictionary, append_eos):
140
+ code, duration = torch.unique_consecutive(
141
+ torch.ShortTensor(list(map(int, code_str.split()))), return_counts=True
142
+ )
143
+ code = " ".join(map(str, code.tolist()))
144
+ code = dictionary.encode_line(code, append_eos).short()
145
+
146
+ if append_eos:
147
+ duration = torch.cat((duration, duration.new_zeros((1,))), dim=0) # eos
148
+ duration = duration.short()
149
+ return code, duration
150
+
151
+
152
+ def parse_manifest(manifest, dictionary):
153
+ audio_files = []
154
+ codes = []
155
+ durations = []
156
+ speakers = []
157
+
158
+ with open(manifest) as info:
159
+ for line in info.readlines():
160
+ sample = eval(line.strip())
161
+ if "cpc_km100" in sample:
162
+ k = "cpc_km100"
163
+ elif "hubert_km100" in sample:
164
+ k = "hubert_km100"
165
+ elif "phone" in sample:
166
+ k = "phone"
167
+ else:
168
+ assert False, "unknown format"
169
+ code = sample[k]
170
+ code, duration = parse_code(code, dictionary, append_eos=True)
171
+
172
+ codes.append(code)
173
+ durations.append(duration)
174
+ audio_files.append(sample["audio"])
175
+ speakers.append(sample.get("speaker", None))
176
+
177
+ return audio_files, codes, durations, speakers
178
+
179
+
180
+ def parse_speaker(path, method):
181
+ if type(path) == str:
182
+ path = Path(path)
183
+
184
+ if method == "parent_name":
185
+ return path.parent.name
186
+ elif method == "parent_parent_name":
187
+ return path.parent.parent.name
188
+ elif method == "_":
189
+ return path.name.split("_")[0]
190
+ elif method == "single":
191
+ return "A"
192
+ elif callable(method):
193
+ return method(path)
194
+ else:
195
+ raise NotImplementedError()
196
+
197
+
198
+ def get_f0_by_filename(filename, tgt_sampling_rate):
199
+ audio, sampling_rate = load_wav(filename)
200
+ if sampling_rate != tgt_sampling_rate:
201
+ raise ValueError(
202
+ "{} SR doesn't match target {} SR".format(sampling_rate, tgt_sampling_rate)
203
+ )
204
+
205
+ # compute un-interpolated f0, and use Ann's interp in __getitem__ if set
206
+ f0 = get_f0(audio, rate=tgt_sampling_rate)
207
+ f0 = torch.from_numpy(f0.astype(np.float32))
208
+ return f0
209
+
210
+
211
+ def align_f0_to_durations(f0, durations, f0_code_ratio, tol=1):
212
+ code_len = durations.sum()
213
+ targ_len = int(f0_code_ratio * code_len)
214
+ diff = f0.size(0) - targ_len
215
+ assert abs(diff) <= tol, (
216
+ f"Cannot subsample F0: |{f0.size(0)} - {f0_code_ratio}*{code_len}|"
217
+ f" > {tol} (dur=\n{durations})"
218
+ )
219
+ if diff > 0:
220
+ f0 = f0[:targ_len]
221
+ elif diff < 0:
222
+ f0 = torch.cat((f0, f0.new_full((-diff,), f0[-1])), 0)
223
+
224
+ f0_offset = 0.0
225
+ seg_f0s = []
226
+ for dur in durations:
227
+ f0_dur = dur.item() * f0_code_ratio
228
+ seg_f0 = f0[int(f0_offset) : int(f0_offset + f0_dur)]
229
+ seg_f0 = seg_f0[seg_f0 != 0]
230
+ if len(seg_f0) == 0:
231
+ seg_f0 = torch.tensor(0).type(seg_f0.type())
232
+ else:
233
+ seg_f0 = seg_f0.mean()
234
+ seg_f0s.append(seg_f0)
235
+ f0_offset += f0_dur
236
+
237
+ assert int(f0_offset) == f0.size(0), f"{f0_offset} {f0.size()} {durations.sum()}"
238
+ return torch.tensor(seg_f0s)
239
+
240
+
241
+ class Paddings(object):
242
+ def __init__(self, code_val, dur_val=0, f0_val=-2.0):
243
+ self.code = code_val
244
+ self.dur = dur_val
245
+ self.f0 = f0_val
246
+
247
+
248
+ class Shifts(object):
249
+ def __init__(self, shifts_str, pads):
250
+ self._shifts = list(map(int, shifts_str.split(",")))
251
+ assert len(self._shifts) == 2, self._shifts
252
+ assert all(s >= 0 for s in self._shifts)
253
+ self.extra_length = max(s for s in self._shifts)
254
+ self.pads = pads
255
+
256
+ @property
257
+ def dur(self):
258
+ return self._shifts[0]
259
+
260
+ @property
261
+ def f0(self):
262
+ return self._shifts[1]
263
+
264
+ @staticmethod
265
+ def shift_one(seq, left_pad_num, right_pad_num, pad):
266
+ assert seq.ndim == 1
267
+ bos = seq.new_full((left_pad_num,), pad)
268
+ eos = seq.new_full((right_pad_num,), pad)
269
+ seq = torch.cat([bos, seq, eos])
270
+ mask = torch.ones_like(seq).bool()
271
+ mask[left_pad_num : len(seq) - right_pad_num] = 0
272
+ return seq, mask
273
+
274
+ def __call__(self, code, dur, f0):
275
+ if self.extra_length == 0:
276
+ code_mask = torch.zeros_like(code).bool()
277
+ dur_mask = torch.zeros_like(dur).bool()
278
+ f0_mask = torch.zeros_like(f0).bool()
279
+ return code, code_mask, dur, dur_mask, f0, f0_mask
280
+
281
+ code, code_mask = self.shift_one(code, 0, self.extra_length, self.pads.code)
282
+ dur, dur_mask = self.shift_one(
283
+ dur, self.dur, self.extra_length - self.dur, self.pads.dur
284
+ )
285
+ f0, f0_mask = self.shift_one(
286
+ f0, self.f0, self.extra_length - self.f0, self.pads.f0
287
+ )
288
+ return code, code_mask, dur, dur_mask, f0, f0_mask
289
+
290
+
291
+ class CodeDataset(FairseqDataset):
292
+ def __init__(
293
+ self,
294
+ manifest,
295
+ dictionary,
296
+ dur_dictionary,
297
+ f0_dictionary,
298
+ config,
299
+ discrete_dur,
300
+ discrete_f0,
301
+ log_f0,
302
+ normalize_f0_mean,
303
+ normalize_f0_std,
304
+ interpolate_f0,
305
+ return_filename=False,
306
+ strip_filename=True,
307
+ shifts="0,0",
308
+ return_continuous_f0=False,
309
+ ):
310
+ random.seed(1234)
311
+ self.dictionary = dictionary
312
+ self.dur_dictionary = dur_dictionary
313
+ self.f0_dictionary = f0_dictionary
314
+ self.config = config
315
+
316
+ # duration config
317
+ self.discrete_dur = discrete_dur
318
+
319
+ # pitch config
320
+ self.discrete_f0 = discrete_f0
321
+ self.log_f0 = log_f0
322
+ self.normalize_f0_mean = normalize_f0_mean
323
+ self.normalize_f0_std = normalize_f0_std
324
+ self.interpolate_f0 = interpolate_f0
325
+
326
+ self.return_filename = return_filename
327
+ self.strip_filename = strip_filename
328
+ self.f0_code_ratio = config.code_hop_size / (
329
+ config.sampling_rate * F0_FRAME_SPACE
330
+ )
331
+
332
+ # use lazy loading to avoid sharing file handlers across workers
333
+ self.manifest = manifest
334
+ self._codes = None
335
+ self._durs = None
336
+ self._f0s = None
337
+ with open(f"{manifest}.leng.txt", "r") as f:
338
+ lengs = [int(line.rstrip()) for line in f]
339
+ edges = np.cumsum([0] + lengs)
340
+ self.starts, self.ends = edges[:-1], edges[1:]
341
+ with open(f"{manifest}.path.txt", "r") as f:
342
+ self.file_names = [line.rstrip() for line in f]
343
+ logger.info(f"num entries: {len(self.starts)}")
344
+
345
+ if os.path.exists(f"{manifest}.f0_stat.pt"):
346
+ self.f0_stats = torch.load(f"{manifest}.f0_stat.pt")
347
+ elif config.f0_stats:
348
+ self.f0_stats = torch.load(config.f0_stats)
349
+
350
+ self.multispkr = config.multispkr
351
+ if config.multispkr:
352
+ with open(f"{manifest}.speaker.txt", "r") as f:
353
+ self.spkrs = [line.rstrip() for line in f]
354
+ self.id_to_spkr = sorted(self.spkrs)
355
+ self.spkr_to_id = {k: v for v, k in enumerate(self.id_to_spkr)}
356
+
357
+ self.pads = Paddings(
358
+ dictionary.pad(),
359
+ 0, # use 0 for duration padding
360
+ f0_dictionary.pad() if discrete_f0 else -5.0,
361
+ )
362
+ self.shifts = Shifts(shifts, pads=self.pads)
363
+ self.return_continuous_f0 = return_continuous_f0
364
+
365
+ def get_data_handlers(self):
366
+ logging.info(f"loading data for {self.manifest}")
367
+ self._codes = np.load(f"{self.manifest}.code.npy", mmap_mode="r")
368
+ self._durs = np.load(f"{self.manifest}.dur.npy", mmap_mode="r")
369
+
370
+ if self.discrete_f0:
371
+ if self.config.f0_vq_type == "precomp":
372
+ self._f0s = np.load(
373
+ f"{self.manifest}.{self.config.f0_vq_name}.npy", mmap_mode="r"
374
+ )
375
+ elif self.config.f0_vq_type == "naive":
376
+ self._f0s = np.load(f"{self.manifest}.f0.npy", mmap_mode="r")
377
+ quantizers_path = self.config.get_f0_vq_naive_quantizer(
378
+ self.log_f0, self.normalize_f0_mean, self.normalize_f0_std
379
+ )
380
+ quantizers = torch.load(quantizers_path)
381
+ n_units = self.config.f0_vq_n_units
382
+ self._f0_quantizer = torch.from_numpy(quantizers[n_units])
383
+ else:
384
+ raise ValueError(f"f0_vq_type {self.config.f0_vq_type} not supported")
385
+ else:
386
+ self._f0s = np.load(f"{self.manifest}.f0.npy", mmap_mode="r")
387
+
388
+ def preprocess_f0(self, f0, stats):
389
+ """
390
+ 1. interpolate
391
+ 2. log transform (keep unvoiced frame 0)
392
+ """
393
+ # TODO: change this to be dependent on config for naive quantizer
394
+ f0 = f0.clone()
395
+ if self.interpolate_f0:
396
+ f0 = interpolate_f0(f0)
397
+
398
+ mask = f0 != 0 # only process voiced frames
399
+ if self.log_f0:
400
+ f0[mask] = f0[mask].log()
401
+ if self.normalize_f0_mean:
402
+ mean = stats["logf0_mean"] if self.log_f0 else stats["f0_mean"]
403
+ f0[mask] = f0[mask] - mean
404
+ if self.normalize_f0_std:
405
+ std = stats["logf0_std"] if self.log_f0 else stats["f0_std"]
406
+ f0[mask] = f0[mask] / std
407
+ return f0
408
+
409
+ def _get_raw_item(self, index):
410
+ start, end = self.starts[index], self.ends[index]
411
+ if self._codes is None:
412
+ self.get_data_handlers()
413
+ code = torch.from_numpy(np.array(self._codes[start:end])).long()
414
+ dur = torch.from_numpy(np.array(self._durs[start:end]))
415
+ f0 = torch.from_numpy(np.array(self._f0s[start:end]))
416
+ return code, dur, f0
417
+
418
+ def __getitem__(self, index):
419
+ code, dur, f0 = self._get_raw_item(index)
420
+ code = torch.cat([code.new([self.dictionary.bos()]), code])
421
+
422
+ # use 0 for eos and bos
423
+ dur = torch.cat([dur.new([0]), dur])
424
+ if self.discrete_dur:
425
+ dur = self.dur_dictionary.encode_line(
426
+ " ".join(map(str, dur.tolist())), append_eos=False
427
+ ).long()
428
+ else:
429
+ dur = dur.float()
430
+
431
+ # TODO: find a more elegant approach
432
+ raw_f0 = None
433
+ if self.discrete_f0:
434
+ if self.config.f0_vq_type == "precomp":
435
+ f0 = self.f0_dictionary.encode_line(
436
+ " ".join(map(str, f0.tolist())), append_eos=False
437
+ ).long()
438
+ else:
439
+ f0 = f0.float()
440
+ f0 = self.preprocess_f0(f0, self.f0_stats[self.spkrs[index]])
441
+ if self.return_continuous_f0:
442
+ raw_f0 = f0
443
+ raw_f0 = torch.cat([raw_f0.new([self.f0_dictionary.bos()]), raw_f0])
444
+ f0 = naive_quantize(f0, self._f0_quantizer)
445
+ f0 = torch.cat([f0.new([self.f0_dictionary.bos()]), f0])
446
+ else:
447
+ f0 = f0.float()
448
+ if self.multispkr:
449
+ f0 = self.preprocess_f0(f0, self.f0_stats[self.spkrs[index]])
450
+ else:
451
+ f0 = self.preprocess_f0(f0, self.f0_stats)
452
+ f0 = torch.cat([f0.new([0]), f0])
453
+
454
+ if raw_f0 is not None:
455
+ *_, raw_f0, raw_f0_mask = self.shifts(code, dur, raw_f0)
456
+ else:
457
+ raw_f0_mask = None
458
+
459
+ code, code_mask, dur, dur_mask, f0, f0_mask = self.shifts(code, dur, f0)
460
+ if raw_f0_mask is not None:
461
+ assert (raw_f0_mask == f0_mask).all()
462
+
463
+ # is a padded frame if either input or output is padded
464
+ feats = {
465
+ "source": code[:-1],
466
+ "target": code[1:],
467
+ "mask": code_mask[1:].logical_or(code_mask[:-1]),
468
+ "dur_source": dur[:-1],
469
+ "dur_target": dur[1:],
470
+ "dur_mask": dur_mask[1:].logical_or(dur_mask[:-1]),
471
+ "f0_source": f0[:-1],
472
+ "f0_target": f0[1:],
473
+ "f0_mask": f0_mask[1:].logical_or(f0_mask[:-1]),
474
+ }
475
+
476
+ if raw_f0 is not None:
477
+ feats["raw_f0"] = raw_f0[1:]
478
+
479
+ if self.return_filename:
480
+ fname = self.file_names[index]
481
+ feats["filename"] = (
482
+ fname if not self.strip_filename else Path(fname).with_suffix("").name
483
+ )
484
+ return feats
485
+
486
+ def __len__(self):
487
+ return len(self.starts)
488
+
489
+ def size(self, index):
490
+ return self.ends[index] - self.starts[index] + self.shifts.extra_length
491
+
492
+ def num_tokens(self, index):
493
+ return self.size(index)
494
+
495
+ def collater(self, samples):
496
+ pad_idx, eos_idx = self.dictionary.pad(), self.dictionary.eos()
497
+ if len(samples) == 0:
498
+ return {}
499
+
500
+ src_tokens = data_utils.collate_tokens(
501
+ [s["source"] for s in samples], pad_idx, eos_idx, left_pad=False
502
+ )
503
+
504
+ tgt_tokens = data_utils.collate_tokens(
505
+ [s["target"] for s in samples],
506
+ pad_idx=pad_idx,
507
+ eos_idx=pad_idx, # appending padding, eos is there already
508
+ left_pad=False,
509
+ )
510
+
511
+ src_durs, tgt_durs = [
512
+ data_utils.collate_tokens(
513
+ [s[k] for s in samples],
514
+ pad_idx=self.pads.dur,
515
+ eos_idx=self.pads.dur,
516
+ left_pad=False,
517
+ )
518
+ for k in ["dur_source", "dur_target"]
519
+ ]
520
+
521
+ src_f0s, tgt_f0s = [
522
+ data_utils.collate_tokens(
523
+ [s[k] for s in samples],
524
+ pad_idx=self.pads.f0,
525
+ eos_idx=self.pads.f0,
526
+ left_pad=False,
527
+ )
528
+ for k in ["f0_source", "f0_target"]
529
+ ]
530
+
531
+ mask, dur_mask, f0_mask = [
532
+ data_utils.collate_tokens(
533
+ [s[k] for s in samples],
534
+ pad_idx=1,
535
+ eos_idx=1,
536
+ left_pad=False,
537
+ )
538
+ for k in ["mask", "dur_mask", "f0_mask"]
539
+ ]
540
+
541
+ src_lengths = torch.LongTensor([s["source"].numel() for s in samples])
542
+ n_tokens = sum(len(s["source"]) for s in samples)
543
+
544
+ result = {
545
+ "nsentences": len(samples),
546
+ "ntokens": n_tokens,
547
+ "net_input": {
548
+ "src_tokens": src_tokens,
549
+ "src_lengths": src_lengths,
550
+ "dur_src": src_durs,
551
+ "f0_src": src_f0s,
552
+ },
553
+ "target": tgt_tokens,
554
+ "dur_target": tgt_durs,
555
+ "f0_target": tgt_f0s,
556
+ "mask": mask,
557
+ "dur_mask": dur_mask,
558
+ "f0_mask": f0_mask,
559
+ }
560
+
561
+ if "filename" in samples[0]:
562
+ result["filename"] = [s["filename"] for s in samples]
563
+
564
+ # TODO: remove this hack into the inference dataset
565
+ if "prefix" in samples[0]:
566
+ result["prefix"] = [s["prefix"] for s in samples]
567
+
568
+ if "raw_f0" in samples[0]:
569
+ raw_f0s = data_utils.collate_tokens(
570
+ [s["raw_f0"] for s in samples],
571
+ pad_idx=self.pads.f0,
572
+ eos_idx=self.pads.f0,
573
+ left_pad=False,
574
+ )
575
+ result["raw_f0"] = raw_f0s
576
+ return result
fairseq/fairseq/data/colorize_dataset.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import torch
7
+
8
+ from . import BaseWrapperDataset
9
+
10
+
11
+ class ColorizeDataset(BaseWrapperDataset):
12
+ """Adds 'colors' property to net input that is obtained from the provided color getter for use by models"""
13
+
14
+ def __init__(self, dataset, color_getter):
15
+ super().__init__(dataset)
16
+ self.color_getter = color_getter
17
+
18
+ def collater(self, samples):
19
+ base_collate = super().collater(samples)
20
+ if len(base_collate) > 0:
21
+ base_collate["net_input"]["colors"] = torch.tensor(
22
+ list(self.color_getter(self.dataset, s["id"]) for s in samples),
23
+ dtype=torch.long,
24
+ )
25
+ return base_collate
fairseq/fairseq/data/concat_dataset.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import bisect
7
+
8
+ import numpy as np
9
+ from torch.utils.data.dataloader import default_collate
10
+
11
+ from . import FairseqDataset
12
+
13
+
14
+ class ConcatDataset(FairseqDataset):
15
+ @staticmethod
16
+ def cumsum(sequence, sample_ratios):
17
+ r, s = [], 0
18
+ for e, ratio in zip(sequence, sample_ratios):
19
+ curr_len = int(ratio * len(e))
20
+ r.append(curr_len + s)
21
+ s += curr_len
22
+ return r
23
+
24
+ def __init__(self, datasets, sample_ratios=1):
25
+ super(ConcatDataset, self).__init__()
26
+ assert len(datasets) > 0, "datasets should not be an empty iterable"
27
+ self.datasets = list(datasets)
28
+ if isinstance(sample_ratios, int):
29
+ sample_ratios = [sample_ratios] * len(self.datasets)
30
+ self.sample_ratios = sample_ratios
31
+ self.cumulative_sizes = self.cumsum(self.datasets, sample_ratios)
32
+ self.real_sizes = [len(d) for d in self.datasets]
33
+
34
+ def __len__(self):
35
+ return self.cumulative_sizes[-1]
36
+
37
+ def __getitem__(self, idx):
38
+ dataset_idx, sample_idx = self._get_dataset_and_sample_index(idx)
39
+ return self.datasets[dataset_idx][sample_idx]
40
+
41
+ def _get_dataset_and_sample_index(self, idx: int):
42
+ dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
43
+ if dataset_idx == 0:
44
+ sample_idx = idx
45
+ else:
46
+ sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
47
+ sample_idx = sample_idx % self.real_sizes[dataset_idx]
48
+ return dataset_idx, sample_idx
49
+
50
+ def collater(self, samples, **extra_args):
51
+ # For now only supports datasets with same underlying collater implementations
52
+ if hasattr(self.datasets[0], "collater"):
53
+ return self.datasets[0].collater(samples, **extra_args)
54
+ else:
55
+ return default_collate(samples, **extra_args)
56
+
57
+ def size(self, idx: int):
58
+ """
59
+ Return an example's size as a float or tuple.
60
+ """
61
+ dataset_idx, sample_idx = self._get_dataset_and_sample_index(idx)
62
+ return self.datasets[dataset_idx].size(sample_idx)
63
+
64
+ def num_tokens(self, index: int):
65
+ return np.max(self.size(index))
66
+
67
+ def attr(self, attr: str, index: int):
68
+ dataset_idx = bisect.bisect_right(self.cumulative_sizes, index)
69
+ return getattr(self.datasets[dataset_idx], attr, None)
70
+
71
+ @property
72
+ def sizes(self):
73
+ _dataset_sizes = []
74
+ for ds, sr in zip(self.datasets, self.sample_ratios):
75
+ if isinstance(ds.sizes, np.ndarray):
76
+ _dataset_sizes.append(np.tile(ds.sizes, sr))
77
+ else:
78
+ # Only support underlying dataset with single size array.
79
+ assert isinstance(ds.sizes, list)
80
+ _dataset_sizes.append(np.tile(ds.sizes[0], sr))
81
+ return np.concatenate(_dataset_sizes)
82
+
83
+ @property
84
+ def supports_prefetch(self):
85
+ return all(d.supports_prefetch for d in self.datasets)
86
+
87
+ def ordered_indices(self):
88
+ """
89
+ Returns indices sorted by length. So less padding is needed.
90
+ """
91
+ if isinstance(self.sizes, np.ndarray) and len(self.sizes.shape) > 1:
92
+ # special handling for concatenating lang_pair_datasets
93
+ indices = np.arange(len(self))
94
+ sizes = self.sizes
95
+ tgt_sizes = (
96
+ sizes[:, 1] if len(sizes.shape) > 0 and sizes.shape[1] > 1 else None
97
+ )
98
+ src_sizes = (
99
+ sizes[:, 0] if len(sizes.shape) > 0 and sizes.shape[1] > 1 else sizes
100
+ )
101
+ # sort by target length, then source length
102
+ if tgt_sizes is not None:
103
+ indices = indices[np.argsort(tgt_sizes[indices], kind="mergesort")]
104
+ return indices[np.argsort(src_sizes[indices], kind="mergesort")]
105
+ else:
106
+ return np.argsort(self.sizes)
107
+
108
+ def prefetch(self, indices):
109
+ frm = 0
110
+ for to, ds in zip(self.cumulative_sizes, self.datasets):
111
+ real_size = len(ds)
112
+ if getattr(ds, "supports_prefetch", False):
113
+ ds.prefetch([(i - frm) % real_size for i in indices if frm <= i < to])
114
+ frm = to
115
+
116
+ @property
117
+ def can_reuse_epoch_itr_across_epochs(self):
118
+ return all(d.can_reuse_epoch_itr_across_epochs for d in self.datasets)
119
+
120
+ def set_epoch(self, epoch):
121
+ super().set_epoch(epoch)
122
+ for ds in self.datasets:
123
+ if hasattr(ds, "set_epoch"):
124
+ ds.set_epoch(epoch)
fairseq/fairseq/data/concat_sentences_dataset.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import torch
7
+
8
+ from . import FairseqDataset
9
+
10
+
11
+ class ConcatSentencesDataset(FairseqDataset):
12
+ def __init__(self, *datasets):
13
+ super().__init__()
14
+ self.datasets = datasets
15
+ assert all(
16
+ len(ds) == len(datasets[0]) for ds in datasets
17
+ ), "datasets must have the same length"
18
+
19
+ def __getitem__(self, index):
20
+ return torch.cat([ds[index] for ds in self.datasets])
21
+
22
+ def __len__(self):
23
+ return len(self.datasets[0])
24
+
25
+ def collater(self, samples):
26
+ return self.datasets[0].collater(samples)
27
+
28
+ @property
29
+ def sizes(self):
30
+ return sum(ds.sizes for ds in self.datasets)
31
+
32
+ def num_tokens(self, index):
33
+ return sum(ds.num_tokens(index) for ds in self.datasets)
34
+
35
+ def size(self, index):
36
+ return sum(ds.size(index) for ds in self.datasets)
37
+
38
+ def ordered_indices(self):
39
+ return self.datasets[0].ordered_indices()
40
+
41
+ @property
42
+ def supports_prefetch(self):
43
+ return any(getattr(ds, "supports_prefetch", False) for ds in self.datasets)
44
+
45
+ def prefetch(self, indices):
46
+ for ds in self.datasets:
47
+ if getattr(ds, "supports_prefetch", False):
48
+ ds.prefetch(indices)
49
+
50
+ def set_epoch(self, epoch):
51
+ super().set_epoch(epoch)
52
+ for ds in self.datasets:
53
+ if hasattr(ds, "set_epoch"):
54
+ ds.set_epoch(epoch)
fairseq/fairseq/data/data_utils.py ADDED
@@ -0,0 +1,1144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ try:
7
+ from collections.abc import Iterable
8
+ except ImportError:
9
+ from collections import Iterable
10
+ import contextlib
11
+ import itertools
12
+ import logging
13
+ import re
14
+ import warnings
15
+ from typing import Optional, Tuple
16
+
17
+ import math
18
+ import numpy as np
19
+ import torch
20
+
21
+ from fairseq.file_io import PathManager
22
+ from fairseq import utils
23
+ import os
24
+
25
+ logger = logging.getLogger(__name__)
26
+
27
+
28
+ def infer_language_pair(path):
29
+ """Infer language pair from filename: <split>.<lang1>-<lang2>.(...).idx"""
30
+ src, dst = None, None
31
+ for filename in PathManager.ls(path):
32
+ parts = filename.split(".")
33
+ if len(parts) >= 3 and len(parts[1].split("-")) == 2:
34
+ return parts[1].split("-")
35
+ return src, dst
36
+
37
+
38
+ def collate_tokens(
39
+ values,
40
+ pad_idx,
41
+ eos_idx=None,
42
+ left_pad=False,
43
+ move_eos_to_beginning=False,
44
+ pad_to_length=None,
45
+ pad_to_multiple=1,
46
+ pad_to_bsz=None,
47
+ ):
48
+ """Convert a list of 1d tensors into a padded 2d tensor."""
49
+ size = max(v.size(0) for v in values)
50
+ size = size if pad_to_length is None else max(size, pad_to_length)
51
+ if pad_to_multiple != 1 and size % pad_to_multiple != 0:
52
+ size = int(((size - 0.1) // pad_to_multiple + 1) * pad_to_multiple)
53
+
54
+ batch_size = len(values) if pad_to_bsz is None else max(len(values), pad_to_bsz)
55
+ res = values[0].new(batch_size, size).fill_(pad_idx)
56
+
57
+ def copy_tensor(src, dst):
58
+ assert dst.numel() == src.numel()
59
+ if move_eos_to_beginning:
60
+ if eos_idx is None:
61
+ # if no eos_idx is specified, then use the last token in src
62
+ dst[0] = src[-1]
63
+ else:
64
+ dst[0] = eos_idx
65
+ dst[1:] = src[:-1]
66
+ else:
67
+ dst.copy_(src)
68
+
69
+ for i, v in enumerate(values):
70
+ copy_tensor(v, res[i][size - len(v) :] if left_pad else res[i][: len(v)])
71
+ return res
72
+
73
+
74
+ def load_indexed_dataset(
75
+ path, dictionary=None, dataset_impl=None, combine=False, default="cached"
76
+ ):
77
+ """A helper function for loading indexed datasets.
78
+
79
+ Args:
80
+ path (str): path to indexed dataset (e.g., 'data-bin/train')
81
+ dictionary (~fairseq.data.Dictionary): data dictionary
82
+ dataset_impl (str, optional): which dataset implementation to use. If
83
+ not provided, it will be inferred automatically. For legacy indexed
84
+ data we use the 'cached' implementation by default.
85
+ combine (bool, optional): automatically load and combine multiple
86
+ datasets. For example, if *path* is 'data-bin/train', then we will
87
+ combine 'data-bin/train', 'data-bin/train1', ... and return a
88
+ single ConcatDataset instance.
89
+ """
90
+ import fairseq.data.indexed_dataset as indexed_dataset
91
+ from fairseq.data.concat_dataset import ConcatDataset
92
+
93
+ datasets = []
94
+ for k in itertools.count():
95
+ path_k = path + (str(k) if k > 0 else "")
96
+ try:
97
+ path_k = indexed_dataset.get_indexed_dataset_to_local(path_k)
98
+ except Exception as e:
99
+ if "StorageException: [404] Path not found" in str(e):
100
+ logger.warning(f"path_k: {e} not found")
101
+ else:
102
+ raise e
103
+
104
+ dataset_impl_k = dataset_impl
105
+ if dataset_impl_k is None:
106
+ dataset_impl_k = indexed_dataset.infer_dataset_impl(path_k)
107
+ dataset = indexed_dataset.make_dataset(
108
+ path_k,
109
+ impl=dataset_impl_k or default,
110
+ fix_lua_indexing=True,
111
+ dictionary=dictionary,
112
+ )
113
+ if dataset is None:
114
+ break
115
+ logger.info("loaded {:,} examples from: {}".format(len(dataset), path_k))
116
+ datasets.append(dataset)
117
+ if not combine:
118
+ break
119
+ if len(datasets) == 0:
120
+ return None
121
+ elif len(datasets) == 1:
122
+ return datasets[0]
123
+ else:
124
+ return ConcatDataset(datasets)
125
+
126
+
127
+ @contextlib.contextmanager
128
+ def numpy_seed(seed, *addl_seeds):
129
+ """Context manager which seeds the NumPy PRNG with the specified seed and
130
+ restores the state afterward"""
131
+ if seed is None:
132
+ yield
133
+ return
134
+ if len(addl_seeds) > 0:
135
+ seed = int(hash((seed, *addl_seeds)) % 1e6)
136
+ state = np.random.get_state()
137
+ np.random.seed(seed)
138
+ try:
139
+ yield
140
+ finally:
141
+ np.random.set_state(state)
142
+
143
+
144
+ def collect_filtered(function, iterable, filtered):
145
+ """
146
+ Similar to :func:`filter` but collects filtered elements in ``filtered``.
147
+
148
+ Args:
149
+ function (callable): function that returns ``False`` for elements that
150
+ should be filtered
151
+ iterable (iterable): iterable to filter
152
+ filtered (list): list to store filtered elements
153
+ """
154
+ for el in iterable:
155
+ if function(el):
156
+ yield el
157
+ else:
158
+ filtered.append(el)
159
+
160
+
161
+ def _filter_by_size_dynamic(indices, size_fn, max_positions, raise_exception=False):
162
+ def compare_leq(a, b):
163
+ return a <= b if not isinstance(a, tuple) else max(a) <= b
164
+
165
+ def check_size(idx):
166
+ if isinstance(max_positions, float) or isinstance(max_positions, int):
167
+ return size_fn(idx) <= max_positions
168
+ elif isinstance(max_positions, dict):
169
+ idx_size = size_fn(idx)
170
+ assert isinstance(idx_size, dict)
171
+ intersect_keys = set(max_positions.keys()) & set(idx_size.keys())
172
+ return all(
173
+ all(
174
+ a is None or b is None or a <= b
175
+ for a, b in zip(idx_size[key], max_positions[key])
176
+ )
177
+ for key in intersect_keys
178
+ )
179
+ else:
180
+ # For MultiCorpusSampledDataset, will generalize it later
181
+ if not isinstance(size_fn(idx), Iterable):
182
+ return all(size_fn(idx) <= b for b in max_positions)
183
+ return all(
184
+ a is None or b is None or a <= b
185
+ for a, b in zip(size_fn(idx), max_positions)
186
+ )
187
+
188
+ ignored = []
189
+ itr = collect_filtered(check_size, indices, ignored)
190
+ indices = np.fromiter(itr, dtype=np.int64, count=-1)
191
+ return indices, ignored
192
+
193
+
194
+ def filter_by_size(indices, dataset, max_positions, raise_exception=False):
195
+ """
196
+ [deprecated] Filter indices based on their size.
197
+ Use `FairseqDataset::filter_indices_by_size` instead.
198
+
199
+ Args:
200
+ indices (List[int]): ordered list of dataset indices
201
+ dataset (FairseqDataset): fairseq dataset instance
202
+ max_positions (tuple): filter elements larger than this size.
203
+ Comparisons are done component-wise.
204
+ raise_exception (bool, optional): if ``True``, raise an exception if
205
+ any elements are filtered (default: False).
206
+ """
207
+ warnings.warn(
208
+ "data_utils.filter_by_size is deprecated. "
209
+ "Use `FairseqDataset::filter_indices_by_size` instead.",
210
+ stacklevel=2,
211
+ )
212
+ if isinstance(max_positions, float) or isinstance(max_positions, int):
213
+ if hasattr(dataset, "sizes") and isinstance(dataset.sizes, np.ndarray):
214
+ ignored = indices[dataset.sizes[indices] > max_positions].tolist()
215
+ indices = indices[dataset.sizes[indices] <= max_positions]
216
+ elif (
217
+ hasattr(dataset, "sizes")
218
+ and isinstance(dataset.sizes, list)
219
+ and len(dataset.sizes) == 1
220
+ ):
221
+ ignored = indices[dataset.sizes[0][indices] > max_positions].tolist()
222
+ indices = indices[dataset.sizes[0][indices] <= max_positions]
223
+ else:
224
+ indices, ignored = _filter_by_size_dynamic(
225
+ indices, dataset.size, max_positions
226
+ )
227
+ else:
228
+ indices, ignored = _filter_by_size_dynamic(indices, dataset.size, max_positions)
229
+
230
+ if len(ignored) > 0 and raise_exception:
231
+ raise Exception(
232
+ (
233
+ "Size of sample #{} is invalid (={}) since max_positions={}, "
234
+ "skip this example with --skip-invalid-size-inputs-valid-test"
235
+ ).format(ignored[0], dataset.size(ignored[0]), max_positions)
236
+ )
237
+ if len(ignored) > 0:
238
+ logger.warning(
239
+ (
240
+ "{} samples have invalid sizes and will be skipped, "
241
+ "max_positions={}, first few sample ids={}"
242
+ ).format(len(ignored), max_positions, ignored[:10])
243
+ )
244
+ return indices
245
+
246
+
247
+ def filter_paired_dataset_indices_by_size(src_sizes, tgt_sizes, indices, max_sizes):
248
+ """Filter a list of sample indices. Remove those that are longer
249
+ than specified in max_sizes.
250
+
251
+ Args:
252
+ indices (np.array): original array of sample indices
253
+ max_sizes (int or list[int] or tuple[int]): max sample size,
254
+ can be defined separately for src and tgt (then list or tuple)
255
+
256
+ Returns:
257
+ np.array: filtered sample array
258
+ list: list of removed indices
259
+ """
260
+ if max_sizes is None:
261
+ return indices, []
262
+ if type(max_sizes) in (int, float):
263
+ max_src_size, max_tgt_size = max_sizes, max_sizes
264
+ else:
265
+ max_src_size, max_tgt_size = max_sizes
266
+ if tgt_sizes is None:
267
+ ignored = indices[src_sizes[indices] > max_src_size]
268
+ else:
269
+ ignored = indices[
270
+ (src_sizes[indices] > max_src_size) | (tgt_sizes[indices] > max_tgt_size)
271
+ ]
272
+ if len(ignored) > 0:
273
+ if tgt_sizes is None:
274
+ indices = indices[src_sizes[indices] <= max_src_size]
275
+ else:
276
+ indices = indices[
277
+ (src_sizes[indices] <= max_src_size)
278
+ & (tgt_sizes[indices] <= max_tgt_size)
279
+ ]
280
+ return indices, ignored.tolist()
281
+
282
+
283
+ def batch_by_size(
284
+ indices,
285
+ num_tokens_fn,
286
+ num_tokens_vec=None,
287
+ max_tokens=None,
288
+ max_sentences=None,
289
+ required_batch_size_multiple=1,
290
+ fixed_shapes=None,
291
+ ):
292
+ """
293
+ Yield mini-batches of indices bucketed by size. Batches may contain
294
+ sequences of different lengths.
295
+
296
+ Args:
297
+ indices (List[int]): ordered list of dataset indices
298
+ num_tokens_fn (callable): function that returns the number of tokens at
299
+ a given index
300
+ num_tokens_vec (List[int], optional): precomputed vector of the number
301
+ of tokens for each index in indices (to enable faster batch generation)
302
+ max_tokens (int, optional): max number of tokens in each batch
303
+ (default: None).
304
+ max_sentences (int, optional): max number of sentences in each
305
+ batch (default: None).
306
+ required_batch_size_multiple (int, optional): require batch size to
307
+ be less than N or a multiple of N (default: 1).
308
+ fixed_shapes (List[Tuple[int, int]], optional): if given, batches will
309
+ only be created with the given shapes. *max_sentences* and
310
+ *required_batch_size_multiple* will be ignored (default: None).
311
+ """
312
+ try:
313
+ from fairseq.data.data_utils_fast import (
314
+ batch_by_size_fn,
315
+ batch_by_size_vec,
316
+ batch_fixed_shapes_fast,
317
+ )
318
+ except ImportError:
319
+ raise ImportError(
320
+ "Please build Cython components with: "
321
+ "`python setup.py build_ext --inplace`"
322
+ )
323
+ except ValueError:
324
+ raise ValueError(
325
+ "Please build (or rebuild) Cython components with `python setup.py build_ext --inplace`."
326
+ )
327
+
328
+ # added int() to avoid TypeError: an integer is required
329
+ max_tokens = int(max_tokens) if max_tokens is not None else -1
330
+ max_sentences = max_sentences if max_sentences is not None else -1
331
+ bsz_mult = required_batch_size_multiple
332
+
333
+ if not isinstance(indices, np.ndarray):
334
+ indices = np.fromiter(indices, dtype=np.int64, count=-1)
335
+
336
+ if num_tokens_vec is not None and not isinstance(num_tokens_vec, np.ndarray):
337
+ num_tokens_vec = np.fromiter(num_tokens_vec, dtype=np.int64, count=-1)
338
+
339
+ if fixed_shapes is None:
340
+ if num_tokens_vec is None:
341
+ b = batch_by_size_fn(
342
+ indices,
343
+ num_tokens_fn,
344
+ max_tokens,
345
+ max_sentences,
346
+ bsz_mult,
347
+ )
348
+ else:
349
+ b = batch_by_size_vec(
350
+ indices,
351
+ num_tokens_vec,
352
+ max_tokens,
353
+ max_sentences,
354
+ bsz_mult,
355
+ )
356
+
357
+ if bsz_mult > 1 and len(b[-1]) % bsz_mult != 0:
358
+ b = b[:-1]
359
+
360
+ return b
361
+
362
+ else:
363
+ fixed_shapes = np.array(fixed_shapes, dtype=np.int64)
364
+ sort_order = np.lexsort(
365
+ [
366
+ fixed_shapes[:, 1].argsort(), # length
367
+ fixed_shapes[:, 0].argsort(), # bsz
368
+ ]
369
+ )
370
+ fixed_shapes_sorted = fixed_shapes[sort_order]
371
+ return batch_fixed_shapes_fast(indices, num_tokens_fn, fixed_shapes_sorted)
372
+
373
+
374
+ def post_process(sentence: str, symbol: str):
375
+ if symbol == "sentencepiece":
376
+ sentence = sentence.replace(" ", "").replace("\u2581", " ").strip()
377
+ elif symbol == "wordpiece":
378
+ sentence = sentence.replace(" ", "").replace("_", " ").strip()
379
+ elif symbol == "letter":
380
+ sentence = sentence.replace(" ", "").replace("|", " ").strip()
381
+ elif symbol == "silence":
382
+ import re
383
+
384
+ sentence = sentence.replace("<SIL>", "")
385
+ sentence = re.sub(" +", " ", sentence).strip()
386
+ elif symbol == "_EOW":
387
+ sentence = sentence.replace(" ", "").replace("_EOW", " ").strip()
388
+ elif symbol in {"subword_nmt", "@@ ", "@@"}:
389
+ if symbol == "subword_nmt":
390
+ symbol = "@@ "
391
+ sentence = (sentence + " ").replace(symbol, "").rstrip()
392
+ elif symbol == "none":
393
+ pass
394
+ elif symbol is not None:
395
+ raise NotImplementedError(f"Unknown post_process option: {symbol}")
396
+ return sentence
397
+
398
+
399
+ def compute_mask_indices(
400
+ shape: Tuple[int, int],
401
+ padding_mask: Optional[torch.Tensor],
402
+ mask_prob: float,
403
+ mask_length: int,
404
+ mask_type: str = "static",
405
+ mask_other: float = 0.0,
406
+ min_masks: int = 0,
407
+ no_overlap: bool = False,
408
+ min_space: int = 0,
409
+ require_same_masks: bool = True,
410
+ mask_dropout: float = 0.0,
411
+ add_masks: bool = False,
412
+ seed: Optional[int] = None,
413
+ epoch: Optional[int] = None,
414
+ indices: Optional[torch.Tensor] = None,
415
+ idc_select_ver: int = 1, # 2 to reproduce mask_tokens_dataset
416
+ num_mask_ver: int = 2, # 2 to reproduce mask_tokens_dataset
417
+ ) -> np.ndarray:
418
+ """
419
+ Computes random mask spans for a given shape
420
+
421
+ Args:
422
+ shape: the the shape for which to compute masks.
423
+ should be of size 2 where first element is batch size and 2nd is timesteps
424
+ padding_mask: optional padding mask of the same size as shape, which will prevent masking padded elements
425
+ mask_prob: probability for each token to be chosen as start of the span to be masked. this will be multiplied by
426
+ number of timesteps divided by length of mask span to mask approximately this percentage of all elements.
427
+ however due to overlaps, the actual number will be smaller (unless no_overlap is True)
428
+ mask_type: how to compute mask lengths
429
+ static = fixed size
430
+ uniform = sample from uniform distribution [mask_other, mask_length*2]
431
+ normal = sample from normal distribution with mean mask_length and stdev mask_other. mask is min 1 element
432
+ poisson = sample from possion distribution with lambda = mask length
433
+ min_masks: minimum number of masked spans
434
+ no_overlap: if false, will switch to an alternative recursive algorithm that prevents spans from overlapping
435
+ min_space: only used if no_overlap is True, this is how many elements to keep unmasked between spans
436
+ require_same_masks: if true, will randomly drop out masks until same amount of masks remains in each sample
437
+ mask_dropout: randomly dropout this percentage of masks in each example
438
+ """
439
+
440
+ bsz, all_sz = shape
441
+ mask = np.full((bsz, all_sz), False)
442
+
443
+ if num_mask_ver == 1:
444
+ all_num_mask = int(
445
+ # add a random number for probabilistic rounding
446
+ mask_prob * all_sz / float(mask_length)
447
+ + np.random.rand()
448
+ )
449
+ all_num_mask = max(min_masks, all_num_mask)
450
+
451
+ mask_idcs = []
452
+ for i in range(bsz):
453
+ if seed is not None and epoch is not None and indices is not None:
454
+ seed_i = int(hash((seed, epoch, indices[i].item())) % 1e6)
455
+ else:
456
+ seed_i = None
457
+
458
+ rng = np.random.default_rng(seed_i)
459
+
460
+ if padding_mask is not None:
461
+ sz = all_sz - padding_mask[i].long().sum().item()
462
+ assert sz >= 0, sz
463
+ else:
464
+ sz = all_sz
465
+
466
+ if num_mask_ver == 1:
467
+ if padding_mask is not None:
468
+ num_mask = int(
469
+ # add a random number for probabilistic rounding
470
+ mask_prob * sz / float(mask_length)
471
+ + np.random.rand()
472
+ )
473
+ num_mask = max(min_masks, num_mask)
474
+ else:
475
+ num_mask = all_num_mask
476
+ elif num_mask_ver == 2:
477
+ num_mask = int(
478
+ # add a random number for probabilistic rounding
479
+ mask_prob * sz / float(mask_length)
480
+ + rng.random()
481
+ )
482
+ num_mask = max(min_masks, num_mask)
483
+ else:
484
+ raise ValueError()
485
+
486
+ if mask_type == "static":
487
+ lengths = np.full(num_mask, mask_length)
488
+ elif mask_type == "uniform":
489
+ lengths = rng.randint(mask_other, mask_length * 2 + 1, size=num_mask)
490
+ elif mask_type == "normal":
491
+ lengths = rng.normal(mask_length, mask_other, size=num_mask)
492
+ lengths = [max(1, int(round(x))) for x in lengths]
493
+ elif mask_type == "poisson":
494
+ lengths = rng.poisson(mask_length, size=num_mask)
495
+ lengths = [int(round(x)) for x in lengths]
496
+ else:
497
+ raise Exception("unknown mask selection " + mask_type)
498
+
499
+ if sum(lengths) == 0:
500
+ if mask_type == "static":
501
+ raise ValueError(f"this should never happens")
502
+ else:
503
+ lengths = [min(mask_length, sz - 1)]
504
+
505
+ if no_overlap:
506
+ mask_idc = []
507
+
508
+ def arrange(s, e, length, keep_length):
509
+ span_start = rng.randint(s, e - length)
510
+ mask_idc.extend(span_start + i for i in range(length))
511
+
512
+ new_parts = []
513
+ if span_start - s - min_space >= keep_length:
514
+ new_parts.append((s, span_start - min_space + 1))
515
+ if e - span_start - length - min_space > keep_length:
516
+ new_parts.append((span_start + length + min_space, e))
517
+ return new_parts
518
+
519
+ parts = [(0, sz)]
520
+ min_length = min(lengths)
521
+ for length in sorted(lengths, reverse=True):
522
+ lens = np.fromiter(
523
+ (e - s if e - s >= length + min_space else 0 for s, e in parts),
524
+ np.int,
525
+ )
526
+ l_sum = np.sum(lens)
527
+ if l_sum == 0:
528
+ break
529
+ probs = lens / np.sum(lens)
530
+ c = rng.choice(len(parts), p=probs)
531
+ s, e = parts.pop(c)
532
+ parts.extend(arrange(s, e, length, min_length))
533
+ mask_idc = np.asarray(mask_idc)
534
+ else:
535
+ if idc_select_ver == 1:
536
+ min_len = min(lengths)
537
+ if sz - min_len <= num_mask:
538
+ min_len = sz - num_mask - 1
539
+ mask_idc = rng.choice(sz - min_len, num_mask, replace=False)
540
+ elif idc_select_ver == 2:
541
+ mask_idc = rng.choice(sz, num_mask, replace=False)
542
+ else:
543
+ raise ValueError()
544
+
545
+ mask_idc = np.asarray(
546
+ [
547
+ mask_idc[j] + offset
548
+ for j in range(len(mask_idc))
549
+ for offset in range(lengths[j])
550
+ ]
551
+ )
552
+
553
+ mask_idc = np.unique(mask_idc[mask_idc < sz])
554
+ if len(mask_idc) >= sz:
555
+ raise ValueError(
556
+ (
557
+ f"the entire sequence is masked. "
558
+ f"sz={sz}; mask_idc[mask_idc]; "
559
+ f"index={indices[i] if indices is not None else None}"
560
+ )
561
+ )
562
+ mask_idcs.append(mask_idc)
563
+
564
+ target_len = None
565
+ if require_same_masks:
566
+ if add_masks:
567
+ target_len = max([len(m) for m in mask_idcs])
568
+ else:
569
+ target_len = min([len(m) for m in mask_idcs])
570
+
571
+ for i, mask_idc in enumerate(mask_idcs):
572
+ if target_len is not None and len(mask_idc) > target_len:
573
+ mask_idc = rng.choice(mask_idc, target_len, replace=False)
574
+
575
+ mask[i, mask_idc] = True
576
+
577
+ if target_len is not None and len(mask_idc) < target_len:
578
+ unmasked = np.flatnonzero(~mask[i])
579
+ to_mask = rng.choice(unmasked, target_len - len(mask_idc), replace=False)
580
+ mask[i, to_mask] = True
581
+
582
+ if mask_dropout > 0:
583
+ masked = np.flatnonzero(mask[i])
584
+ num_holes = np.rint(len(masked) * mask_dropout).astype(int)
585
+ to_drop = rng.choice(masked, num_holes, replace=False)
586
+ mask[i, to_drop] = False
587
+
588
+ return mask
589
+
590
+
591
+ def compute_block_mask_2d(
592
+ shape: Tuple[int, int],
593
+ mask_prob: float,
594
+ mask_length: int,
595
+ mask_prob_adjust: float = 0,
596
+ inverse_mask: bool = False,
597
+ require_same_masks: bool = True,
598
+ expand_adjcent: bool = False,
599
+ mask_dropout: float = 0,
600
+ non_overlapping: bool = False,
601
+ ) -> torch.Tensor:
602
+
603
+ assert mask_length > 1
604
+
605
+ B, L = shape
606
+
607
+ d = int(L**0.5)
608
+
609
+ if inverse_mask:
610
+ mask_prob = 1 - mask_prob
611
+
612
+ if non_overlapping:
613
+ sz = math.ceil(d / mask_length)
614
+ inp_len = sz * sz
615
+
616
+ inp = torch.zeros((B, 1, sz, sz))
617
+ w = torch.ones((1, 1, mask_length, mask_length))
618
+
619
+ mask_inds = torch.multinomial(
620
+ 1 - inp.view(B, -1),
621
+ int(inp_len * (mask_prob + mask_prob_adjust) * (1 + mask_dropout)),
622
+ replacement=False,
623
+ )
624
+ inp.view(B, -1).scatter_(1, mask_inds, 1)
625
+
626
+ mask = torch.nn.functional.conv_transpose2d(inp, w, stride=mask_length).squeeze(
627
+ 1
628
+ )
629
+ if mask.size(-1) > d:
630
+ mask = mask[..., :d, :d]
631
+ else:
632
+ mask = torch.zeros((B, d, d))
633
+ mask_inds = torch.randint(
634
+ 0,
635
+ L,
636
+ size=(
637
+ B,
638
+ int(
639
+ L
640
+ * ((mask_prob + mask_prob_adjust) / mask_length**2)
641
+ * (1 + mask_dropout)
642
+ ),
643
+ ),
644
+ )
645
+ mask.view(B, -1).scatter_(1, mask_inds, 1)
646
+ centers = mask.nonzero(as_tuple=True)
647
+
648
+ inds = ([], [], [])
649
+
650
+ offset = mask_length // 2
651
+ for i in range(mask_length):
652
+ for j in range(mask_length):
653
+ k1 = i - offset
654
+ k2 = j - offset
655
+ inds[0].append(centers[0])
656
+ inds[1].append(centers[1] + k1)
657
+ inds[2].append(centers[2] + k2)
658
+
659
+ i0 = torch.cat(inds[0])
660
+ i1 = torch.cat(inds[1]).clamp_(min=0, max=d - 1)
661
+ i2 = torch.cat(inds[2]).clamp_(min=0, max=d - 1)
662
+
663
+ mask[(i0, i1, i2)] = 1
664
+
665
+ def get_nbs(b, m, w):
666
+ all_nbs = torch.nn.functional.conv2d(m.unsqueeze(1), w, padding="same")
667
+ all_nbs = all_nbs.clamp_max_(1).view(b, -1)
668
+ return all_nbs
669
+
670
+ if require_same_masks and expand_adjcent:
671
+ w = torch.zeros((1, 1, 3, 3))
672
+ w[..., 0, 1] = 1
673
+ w[..., 2, 1] = 1
674
+ w[..., 1, 0] = 1
675
+ w[..., 1, 2] = 1
676
+
677
+ all_nbs = get_nbs(B, mask, w)
678
+
679
+ mask = mask.reshape(B, -1)
680
+
681
+ if require_same_masks:
682
+ n_masks = mask.sum(dim=-1)
683
+ final_target_len = int(L * (mask_prob))
684
+ target_len = int(final_target_len * (1 + mask_dropout))
685
+
686
+ for i in range(len(mask)):
687
+ n = n_masks[i]
688
+ m = mask[i]
689
+ r = 0
690
+ while expand_adjcent and n < target_len:
691
+ if r == 0:
692
+ nbs = all_nbs[i]
693
+ else:
694
+ nbs = get_nbs(1, m.view(1, d, d), w).flatten()
695
+
696
+ cands = (1 - m + nbs) > 1
697
+ cand_sz = int(cands.sum().item())
698
+
699
+ assert cand_sz > 0, f"{nbs} {cand_sz}"
700
+
701
+ to_mask = torch.multinomial(
702
+ cands.float(), min(cand_sz, int(target_len - n)), replacement=False
703
+ )
704
+ m[to_mask] = 1
705
+ assert to_mask.numel() > 0
706
+ n += to_mask.numel()
707
+ r += 1
708
+
709
+ if n > final_target_len:
710
+ to_unmask = torch.multinomial(
711
+ m, int(n - final_target_len), replacement=False
712
+ )
713
+ m[to_unmask] = 0
714
+ elif n < final_target_len:
715
+ to_mask = torch.multinomial(
716
+ (1 - m), int(final_target_len - n), replacement=False
717
+ )
718
+ m[to_mask] = 1
719
+
720
+ if inverse_mask:
721
+ mask = 1 - mask
722
+
723
+ return mask
724
+
725
+
726
+ def compute_block_mask_1d(
727
+ shape: Tuple[int, int],
728
+ mask_prob: float,
729
+ mask_length: int,
730
+ mask_prob_adjust: float = 0,
731
+ inverse_mask: bool = False,
732
+ require_same_masks: bool = True,
733
+ expand_adjcent: bool = False,
734
+ mask_dropout: float = 0,
735
+ non_overlapping: bool = False,
736
+ ) -> torch.Tensor:
737
+
738
+ B, L = shape
739
+
740
+ if inverse_mask:
741
+ mask_prob = 1 - mask_prob
742
+
743
+ if non_overlapping:
744
+ sz = math.ceil(L / mask_length)
745
+
746
+ inp = torch.zeros((B, 1, sz))
747
+ w = torch.ones((1, 1, mask_length))
748
+
749
+ mask_inds = torch.multinomial(
750
+ 1 - inp.view(B, -1),
751
+ int(sz * (mask_prob + mask_prob_adjust) * (1 + mask_dropout)),
752
+ replacement=False,
753
+ )
754
+ inp.view(B, -1).scatter_(1, mask_inds, 1)
755
+
756
+ mask = torch.nn.functional.conv_transpose1d(inp, w, stride=mask_length).squeeze(
757
+ 1
758
+ )
759
+ if mask.size(-1) > L:
760
+ mask = mask[..., :L]
761
+
762
+ else:
763
+ mask = torch.zeros((B, L))
764
+ mask_inds = torch.randint(
765
+ 0,
766
+ L,
767
+ size=(
768
+ B,
769
+ int(
770
+ L
771
+ * ((mask_prob + mask_prob_adjust) / mask_length)
772
+ * (1 + mask_dropout)
773
+ ),
774
+ ),
775
+ )
776
+
777
+ mask.view(B, -1).scatter_(1, mask_inds, 1)
778
+ centers = mask.nonzero(as_tuple=True)
779
+
780
+ inds = ([], [])
781
+
782
+ offset = mask_length // 2
783
+ for i in range(mask_length):
784
+ k1 = i - offset
785
+ inds[0].append(centers[0])
786
+ inds[1].append(centers[1] + k1)
787
+
788
+ i0 = torch.cat(inds[0])
789
+ i1 = torch.cat(inds[1]).clamp_(min=0, max=L - 1)
790
+
791
+ mask[(i0, i1)] = 1
792
+
793
+ def get_nbs(b, m, w):
794
+ all_nbs = torch.nn.functional.conv1d(m.unsqueeze(1), w, padding="same")
795
+ all_nbs = all_nbs.clamp_max_(1).view(b, -1)
796
+ return all_nbs
797
+
798
+ if require_same_masks and expand_adjcent:
799
+ w = torch.ones((1, 1, 3))
800
+ w[..., 1] = 0
801
+ all_nbs = get_nbs(B, mask, w)
802
+
803
+ mask = mask.view(B, -1)
804
+
805
+ if require_same_masks:
806
+ n_masks = mask.sum(dim=-1)
807
+ final_target_len = int(L * (mask_prob))
808
+ target_len = int(final_target_len * (1 + mask_dropout))
809
+
810
+ for i in range(len(mask)):
811
+ n = n_masks[i]
812
+ m = mask[i]
813
+ r = 0
814
+ while expand_adjcent and n < target_len:
815
+ if r == 0:
816
+ nbs = all_nbs[i]
817
+ else:
818
+ nbs = get_nbs(1, m.unsqueeze(0), w).squeeze(0)
819
+
820
+ cands = (1 - m + nbs) > 1
821
+ cand_sz = int(cands.sum().item())
822
+
823
+ assert cand_sz > 0, f"{nbs} {cand_sz}"
824
+
825
+ to_mask = torch.multinomial(
826
+ cands.float(), min(cand_sz, int(target_len - n)), replacement=False
827
+ )
828
+ m[to_mask] = 1
829
+ assert to_mask.numel() > 0
830
+ n += to_mask.numel()
831
+ r += 1
832
+
833
+ if n > final_target_len:
834
+ to_unmask = torch.multinomial(
835
+ m, int(n - final_target_len), replacement=False
836
+ )
837
+ m[to_unmask] = 0
838
+ elif n < final_target_len:
839
+ to_mask = torch.multinomial(
840
+ (1 - m), int(final_target_len - n), replacement=False
841
+ )
842
+ m[to_mask] = 1
843
+
844
+ if inverse_mask:
845
+ mask = 1 - mask
846
+
847
+ return mask
848
+
849
+
850
+ def get_mem_usage():
851
+ try:
852
+ import psutil
853
+
854
+ mb = 1024 * 1024
855
+ return f"used={psutil.virtual_memory().used / mb}Mb; avail={psutil.virtual_memory().available / mb}Mb"
856
+ except ImportError:
857
+ return "N/A"
858
+
859
+
860
+ # lens: torch.LongTensor
861
+ # returns: torch.BoolTensor
862
+ def lengths_to_padding_mask(lens):
863
+ bsz, max_lens = lens.size(0), torch.max(lens).item()
864
+ mask = torch.arange(max_lens).to(lens.device).view(1, max_lens)
865
+ mask = mask.expand(bsz, -1) >= lens.view(bsz, 1).expand(-1, max_lens)
866
+ return mask
867
+
868
+
869
+ # lens: torch.LongTensor
870
+ # returns: torch.BoolTensor
871
+ def lengths_to_mask(lens):
872
+ return ~lengths_to_padding_mask(lens)
873
+
874
+
875
+ def get_buckets(sizes, num_buckets):
876
+ buckets = np.unique(
877
+ np.percentile(
878
+ sizes,
879
+ np.linspace(0, 100, num_buckets + 1),
880
+ interpolation="lower",
881
+ )[1:]
882
+ )
883
+ return buckets
884
+
885
+
886
+ def get_bucketed_sizes(orig_sizes, buckets):
887
+ sizes = np.copy(orig_sizes)
888
+ assert np.min(sizes) >= 0
889
+ start_val = -1
890
+ for end_val in buckets:
891
+ mask = (sizes > start_val) & (sizes <= end_val)
892
+ sizes[mask] = end_val
893
+ start_val = end_val
894
+ return sizes
895
+
896
+
897
+ def _find_extra_valid_paths(dataset_path: str) -> set:
898
+ paths = utils.split_paths(dataset_path)
899
+ all_valid_paths = set()
900
+ for sub_dir in paths:
901
+ contents = PathManager.ls(sub_dir)
902
+ valid_paths = [c for c in contents if re.match("valid*[0-9].*", c) is not None]
903
+ all_valid_paths |= {os.path.basename(p) for p in valid_paths}
904
+ # Remove .bin, .idx etc
905
+ roots = {os.path.splitext(p)[0] for p in all_valid_paths}
906
+ return roots
907
+
908
+
909
+ def raise_if_valid_subsets_unintentionally_ignored(train_cfg) -> None:
910
+ """Raises if there are paths matching 'valid*[0-9].*' which are not combined or ignored."""
911
+ if (
912
+ train_cfg.dataset.ignore_unused_valid_subsets
913
+ or train_cfg.dataset.combine_valid_subsets
914
+ or train_cfg.dataset.disable_validation
915
+ or not hasattr(train_cfg.task, "data")
916
+ ):
917
+ return
918
+ other_paths = _find_extra_valid_paths(train_cfg.task.data)
919
+ specified_subsets = train_cfg.dataset.valid_subset.split(",")
920
+ ignored_paths = [p for p in other_paths if p not in specified_subsets]
921
+ if ignored_paths:
922
+ advice = "Set --combine-val to combine them or --ignore-unused-valid-subsets to ignore them."
923
+ msg = f"Valid paths {ignored_paths} will be ignored. {advice}"
924
+ raise ValueError(msg)
925
+
926
+
927
+ def compute_mask_indices_for_one(
928
+ sz,
929
+ mask_prob: float,
930
+ mask_length: int,
931
+ seed=None,
932
+ epoch=None,
933
+ index=None,
934
+ min_masks=0,
935
+ ):
936
+ """
937
+ set seed, epoch, index for deterministic masking
938
+ """
939
+ seed = int(hash((seed, epoch, index)) % 1e6) if seed else None
940
+ rng = np.random.default_rng(seed)
941
+
942
+ # decide elements to mask
943
+ mask = np.full(sz, False)
944
+ num_mask = int(
945
+ # add a random number for probabilistic rounding
946
+ mask_prob * sz / float(mask_length)
947
+ + rng.random()
948
+ )
949
+ num_mask = max(min_masks, num_mask)
950
+
951
+ # multiple masking as described in the vq-wav2vec paper (https://arxiv.org/abs/1910.05453)
952
+ mask_idc = rng.choice(sz, num_mask, replace=False)
953
+ mask_idc = np.concatenate([mask_idc + i for i in range(mask_length)])
954
+ mask_idc = mask_idc[mask_idc < len(mask)]
955
+ try:
956
+ mask[mask_idc] = True
957
+ except: # something wrong
958
+ print(f"Assigning mask indexes {mask_idc} to mask {mask} failed!")
959
+ raise
960
+
961
+ return mask
962
+
963
+
964
+ def compute_mask_indices_v2(
965
+ shape: Tuple[int, int],
966
+ padding_mask: Optional[torch.Tensor],
967
+ mask_prob: float,
968
+ mask_length: int,
969
+ min_masks: int = 0,
970
+ require_same_masks: bool = True,
971
+ seed: Optional[int] = None,
972
+ epoch: Optional[int] = None,
973
+ indices: Optional[torch.Tensor] = None,
974
+ ) -> np.ndarray:
975
+ bsz, all_sz = shape
976
+ mask = np.full((bsz, all_sz), False)
977
+ for i in range(bsz):
978
+ if padding_mask is not None:
979
+ sz = all_sz - padding_mask[i].long().sum().item()
980
+ else:
981
+ sz = all_sz
982
+ index = indices[i].item() if indices is not None else None
983
+ mask_for_one = compute_mask_indices_for_one(
984
+ sz, mask_prob, mask_length, seed, epoch, index, min_masks
985
+ )
986
+ mask[i, :sz] = mask_for_one
987
+
988
+ if require_same_masks:
989
+ index_sum = indices.sum().item() if indices is not None else None
990
+ seed = int(hash((seed, epoch, index_sum)) % 1e6) if seed else None
991
+ rng = np.random.default_rng(seed)
992
+
993
+ num_mask = mask.sum(-1).min()
994
+ for i in range(bsz):
995
+ extra = mask[i].sum() - num_mask
996
+ if extra > 0:
997
+ to_unmask = rng.choice(np.nonzero(mask[i])[0], extra, replace=False)
998
+ mask[i, to_unmask] = False
999
+
1000
+ return mask
1001
+
1002
+
1003
+ # TODO: a copy of the original compute_mask_indices
1004
+ def compute_mask_indices_v3(
1005
+ shape: Tuple[int, int],
1006
+ padding_mask: Optional[torch.Tensor],
1007
+ mask_prob: float,
1008
+ mask_length: int,
1009
+ mask_type: str = "static",
1010
+ mask_other: float = 0.0,
1011
+ min_masks: int = 0,
1012
+ no_overlap: bool = False,
1013
+ min_space: int = 0,
1014
+ require_same_masks: bool = True,
1015
+ mask_dropout: float = 0.0,
1016
+ seed: Optional[int] = None,
1017
+ epoch: Optional[int] = None,
1018
+ indices: Optional[torch.Tensor] = None,
1019
+ ) -> np.ndarray:
1020
+ """
1021
+ Computes random mask spans for a given shape
1022
+
1023
+ Args:
1024
+ shape: the the shape for which to compute masks.
1025
+ should be of size 2 where first element is batch size and 2nd is timesteps
1026
+ padding_mask: optional padding mask of the same size as shape, which will prevent masking padded elements
1027
+ mask_prob: probability for each token to be chosen as start of the span to be masked. this will be multiplied by
1028
+ number of timesteps divided by length of mask span to mask approximately this percentage of all elements.
1029
+ however due to overlaps, the actual number will be smaller (unless no_overlap is True)
1030
+ mask_type: how to compute mask lengths
1031
+ static = fixed size
1032
+ uniform = sample from uniform distribution [mask_other, mask_length*2]
1033
+ normal = sample from normal distribution with mean mask_length and stdev mask_other. mask is min 1 element
1034
+ poisson = sample from possion distribution with lambda = mask length
1035
+ min_masks: minimum number of masked spans
1036
+ no_overlap: if false, will switch to an alternative recursive algorithm that prevents spans from overlapping
1037
+ min_space: only used if no_overlap is True, this is how many elements to keep unmasked between spans
1038
+ require_same_masks: if true, will randomly drop out masks until same amount of masks remains in each sample
1039
+ mask_dropout: randomly dropout this percentage of masks in each example
1040
+ """
1041
+ bsz, all_sz = shape
1042
+ mask = np.full((bsz, all_sz), False)
1043
+
1044
+ all_num_mask = int(
1045
+ # add a random number for probabilistic rounding
1046
+ mask_prob * all_sz / float(mask_length)
1047
+ + np.random.rand()
1048
+ )
1049
+
1050
+ all_num_mask = max(min_masks, all_num_mask)
1051
+
1052
+ mask_idcs = []
1053
+ for i in range(bsz):
1054
+ if seed is not None and epoch is not None and indices is not None:
1055
+ seed_i = int(hash((seed, epoch, indices[i].item())) % 1e6)
1056
+ else:
1057
+ seed_i = None
1058
+ rng = np.random.default_rng(seed_i)
1059
+
1060
+ if padding_mask is not None:
1061
+ sz = all_sz - padding_mask[i].long().sum().item()
1062
+ num_mask = int(
1063
+ # add a random number for probabilistic rounding
1064
+ mask_prob * sz / float(mask_length)
1065
+ + rng.random()
1066
+ )
1067
+ num_mask = max(min_masks, num_mask)
1068
+ else:
1069
+ sz = all_sz
1070
+ num_mask = all_num_mask
1071
+
1072
+ if mask_type == "static":
1073
+ lengths = np.full(num_mask, mask_length)
1074
+ elif mask_type == "uniform":
1075
+ lengths = rng.randint(mask_other, mask_length * 2 + 1, size=num_mask)
1076
+ elif mask_type == "normal":
1077
+ lengths = rng.normal(mask_length, mask_other, size=num_mask)
1078
+ lengths = [max(1, int(round(x))) for x in lengths]
1079
+ elif mask_type == "poisson":
1080
+ lengths = rng.poisson(mask_length, size=num_mask)
1081
+ lengths = [int(round(x)) for x in lengths]
1082
+ else:
1083
+ raise Exception("unknown mask selection " + mask_type)
1084
+
1085
+ if sum(lengths) == 0:
1086
+ lengths[0] = min(mask_length, sz - 1)
1087
+
1088
+ if no_overlap:
1089
+ mask_idc = []
1090
+
1091
+ def arrange(s, e, length, keep_length):
1092
+ span_start = rng.randint(s, e - length)
1093
+ mask_idc.extend(span_start + i for i in range(length))
1094
+
1095
+ new_parts = []
1096
+ if span_start - s - min_space >= keep_length:
1097
+ new_parts.append((s, span_start - min_space + 1))
1098
+ if e - span_start - length - min_space > keep_length:
1099
+ new_parts.append((span_start + length + min_space, e))
1100
+ return new_parts
1101
+
1102
+ parts = [(0, sz)]
1103
+ min_length = min(lengths)
1104
+ for length in sorted(lengths, reverse=True):
1105
+ lens = np.fromiter(
1106
+ (e - s if e - s >= length + min_space else 0 for s, e in parts),
1107
+ np.int,
1108
+ )
1109
+ l_sum = np.sum(lens)
1110
+ if l_sum == 0:
1111
+ break
1112
+ probs = lens / np.sum(lens)
1113
+ c = rng.choice(len(parts), p=probs)
1114
+ s, e = parts.pop(c)
1115
+ parts.extend(arrange(s, e, length, min_length))
1116
+ mask_idc = np.asarray(mask_idc)
1117
+ else:
1118
+ min_len = min(lengths)
1119
+ if sz - min_len <= num_mask:
1120
+ min_len = sz - num_mask - 1
1121
+
1122
+ mask_idc = rng.choice(sz - min_len, num_mask, replace=False)
1123
+
1124
+ mask_idc = np.asarray(
1125
+ [
1126
+ mask_idc[j] + offset
1127
+ for j in range(len(mask_idc))
1128
+ for offset in range(lengths[j])
1129
+ ]
1130
+ )
1131
+
1132
+ mask_idcs.append(np.unique(mask_idc[mask_idc < sz]))
1133
+
1134
+ min_len = min([len(m) for m in mask_idcs])
1135
+ for i, mask_idc in enumerate(mask_idcs):
1136
+ if len(mask_idc) > min_len and require_same_masks:
1137
+ mask_idc = rng.choice(mask_idc, min_len, replace=False)
1138
+ if mask_dropout > 0:
1139
+ num_holes = np.rint(len(mask_idc) * mask_dropout).astype(int)
1140
+ mask_idc = rng.choice(mask_idc, len(mask_idc) - num_holes, replace=False)
1141
+
1142
+ mask[i, mask_idc] = True
1143
+
1144
+ return mask
fairseq/fairseq/data/data_utils_fast.pyx ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # cython: language_level=3
2
+ # Copyright (c) Facebook, Inc. and its affiliates.
3
+ #
4
+ # This source code is licensed under the MIT license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import numpy as np
8
+
9
+ cimport cython
10
+ cimport numpy as np
11
+
12
+ from libc.stdint cimport int32_t, int64_t
13
+ from libcpp cimport bool as bool_t
14
+
15
+ ctypedef int64_t DTYPE_t
16
+
17
+ @cython.cdivision(True)
18
+ @cython.boundscheck(False)
19
+ @cython.wraparound(False)
20
+ cpdef list batch_by_size_vec(
21
+ np.ndarray[int64_t, ndim=1] indices,
22
+ np.ndarray[int64_t, ndim=1] num_tokens_vec,
23
+ int64_t max_tokens,
24
+ int64_t max_sentences,
25
+ int32_t bsz_mult,
26
+ ):
27
+ if indices.shape[0] == 0:
28
+ return []
29
+
30
+ assert max_tokens <= 0 or np.max(num_tokens_vec) <= max_tokens, (
31
+ f"Sentences lengths should not exceed max_tokens={max_tokens}"
32
+ )
33
+
34
+ cdef int32_t indices_len = indices.shape[0]
35
+ cdef np.ndarray[int32_t, ndim=1] batches_ends = \
36
+ np.zeros(indices_len, dtype=np.int32)
37
+ cdef int32_t[:] batches_ends_view = batches_ends
38
+ cdef int64_t[:] num_tokens_view = num_tokens_vec
39
+
40
+ cdef int32_t pos = 0
41
+ cdef int32_t new_batch_end = 0
42
+
43
+ cdef int64_t new_batch_max_tokens = 0
44
+ cdef int32_t new_batch_sentences = 0
45
+ cdef int64_t new_batch_num_tokens = 0
46
+
47
+ cdef bool_t overflow = False
48
+ cdef bool_t size_matches_with_bsz_mult = False
49
+
50
+ cdef int32_t batches_count = 0
51
+ cdef int32_t batch_start = 0
52
+ cdef int64_t tail_max_tokens = 0
53
+ cdef int64_t batch_max_tokens = 0
54
+
55
+ for pos in range(indices_len):
56
+ # At every pos we keep stats about the last complete batch [batch_start:batch_end),
57
+ # and tail [batch_end:pos].
58
+ # 1) Every time when (batch + tail) forms a valid batch
59
+ # (according to max_tokens, max_sentences and bsz_mult) we append tail to batch.
60
+ # 2) When (batch+tail) violates max_tokens or max_sentences constraints
61
+ # we finalize running batch, and tail becomes a new batch.
62
+ # 3) There is a corner case when tail also violates constraints.
63
+ # In that situation [batch_end:pos-1] (tail without the current pos)
64
+ # gets added to the finalized batches, while [pos:pos] becomes a new tail.
65
+ #
66
+ # Important: For the sake of performance try to avoid using function calls within this loop.
67
+
68
+ tail_max_tokens = tail_max_tokens \
69
+ if tail_max_tokens > num_tokens_view[pos] \
70
+ else num_tokens_view[pos]
71
+ new_batch_end = pos + 1
72
+ new_batch_max_tokens = batch_max_tokens \
73
+ if batch_max_tokens > tail_max_tokens \
74
+ else tail_max_tokens
75
+ new_batch_sentences = new_batch_end - batch_start
76
+ new_batch_num_tokens = new_batch_sentences * new_batch_max_tokens
77
+
78
+ overflow = (new_batch_sentences > max_sentences > 0 or
79
+ new_batch_num_tokens > max_tokens > 0)
80
+ size_matches_with_bsz_mult = (new_batch_sentences < bsz_mult or
81
+ new_batch_sentences % bsz_mult == 0)
82
+
83
+ if overflow:
84
+ tail_num_tokens = tail_max_tokens * \
85
+ (new_batch_end - batches_ends_view[batches_count])
86
+ tail_overflow = tail_num_tokens > max_tokens > 0
87
+ # In case of a tail overflow finalize two batches
88
+ if tail_overflow:
89
+ batches_count += 1
90
+ batches_ends_view[batches_count] = pos
91
+ tail_max_tokens = num_tokens_view[pos]
92
+ batch_start = batches_ends_view[batches_count]
93
+ batches_count += 1
94
+ new_batch_max_tokens = tail_max_tokens
95
+
96
+ if overflow or size_matches_with_bsz_mult:
97
+ batches_ends_view[batches_count] = new_batch_end
98
+ batch_max_tokens = new_batch_max_tokens
99
+ tail_max_tokens = 0
100
+ if batches_ends_view[batches_count] != indices_len:
101
+ batches_count += 1
102
+ # Memory and time-efficient split
103
+ return np.split(indices, batches_ends[:batches_count])
104
+
105
+
106
+ @cython.boundscheck(False)
107
+ @cython.wraparound(False)
108
+ cpdef list batch_by_size_fn(
109
+ np.ndarray[DTYPE_t, ndim=1] indices,
110
+ num_tokens_fn,
111
+ int64_t max_tokens,
112
+ int64_t max_sentences,
113
+ int32_t bsz_mult,
114
+ ):
115
+ cdef int32_t indices_len = indices.shape[0]
116
+ cdef np.ndarray[int64_t, ndim=1] num_tokens_vec = np.zeros(indices_len,
117
+ dtype=np.int64)
118
+ cdef DTYPE_t[:] indices_view = indices
119
+ cdef DTYPE_t[:] num_tokens_vec_view = num_tokens_vec
120
+ cdef int64_t pos
121
+ for pos in range(indices_len):
122
+ num_tokens_vec[pos] = num_tokens_fn(indices_view[pos])
123
+ return batch_by_size_vec(indices, num_tokens_vec, max_tokens,
124
+ max_sentences, bsz_mult,)
125
+
126
+
127
+ cdef _find_valid_shape(
128
+ DTYPE_t[:, :] shapes_view,
129
+ int64_t num_sentences,
130
+ int64_t num_tokens,
131
+ ):
132
+ """Return index of first valid shape of -1 if none is found."""
133
+ for i in range(shapes_view.shape[0]):
134
+ if num_sentences <= shapes_view[i][0] and num_tokens <= shapes_view[i][1]:
135
+ return i
136
+ return -1
137
+
138
+
139
+ @cython.cdivision(True)
140
+ cpdef list batch_fixed_shapes_fast(
141
+ np.ndarray[DTYPE_t, ndim=1] indices,
142
+ num_tokens_fn,
143
+ np.ndarray[DTYPE_t, ndim=2] fixed_shapes_sorted,
144
+ ):
145
+ cdef int64_t sample_len = 0
146
+ cdef list sample_lens = []
147
+ cdef list batch = []
148
+ cdef list batches = []
149
+ cdef int64_t mod_len
150
+ cdef int64_t i
151
+ cdef int64_t idx
152
+ cdef int64_t num_tokens
153
+ cdef DTYPE_t[:] indices_view = indices
154
+ cdef DTYPE_t[:, :] shapes_view = fixed_shapes_sorted
155
+
156
+ for i in range(len(indices_view)):
157
+ idx = indices_view[i]
158
+ num_tokens = num_tokens_fn(idx)
159
+ sample_lens.append(num_tokens)
160
+ sample_len = max(sample_len, num_tokens)
161
+
162
+ shape_idx = _find_valid_shape(shapes_view, len(batch) + 1, sample_len)
163
+ if shape_idx == -1:
164
+ batches.append(batch)
165
+ batch = []
166
+ sample_lens = []
167
+ sample_len = 0
168
+ shapes_view = fixed_shapes_sorted
169
+ elif shape_idx > 0:
170
+ # small optimization for the next call to _find_valid_shape
171
+ shapes_view = shapes_view[shape_idx:]
172
+
173
+ batch.append(idx)
174
+
175
+ if len(batch) > 0:
176
+ batches.append(batch)
177
+
178
+ return batches
fairseq/fairseq/data/denoising_dataset.py ADDED
@@ -0,0 +1,443 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import math
7
+
8
+ import numpy as np
9
+ import torch
10
+
11
+ from . import FairseqDataset, data_utils
12
+
13
+
14
+ def collate(
15
+ samples,
16
+ pad_idx,
17
+ eos_idx,
18
+ vocab,
19
+ left_pad_source=False,
20
+ left_pad_target=False,
21
+ input_feeding=True,
22
+ pad_to_length=None,
23
+ ):
24
+ assert input_feeding
25
+ if len(samples) == 0:
26
+ return {}
27
+
28
+ def merge(key, left_pad, move_eos_to_beginning=False, pad_to_length=None):
29
+ return data_utils.collate_tokens(
30
+ [s[key] for s in samples],
31
+ pad_idx,
32
+ eos_idx=None, # use eos_idx of each sample instead of vocab.eos()
33
+ left_pad=left_pad,
34
+ move_eos_to_beginning=move_eos_to_beginning,
35
+ pad_to_length=pad_to_length,
36
+ )
37
+
38
+ id = torch.LongTensor([s["id"] for s in samples])
39
+ src_tokens = merge(
40
+ "source",
41
+ left_pad=left_pad_source,
42
+ pad_to_length=pad_to_length["source"] if pad_to_length is not None else None,
43
+ )
44
+ # sort by descending source length
45
+ src_lengths = torch.LongTensor([s["source"].numel() for s in samples])
46
+ src_lengths, sort_order = src_lengths.sort(descending=True)
47
+ id = id.index_select(0, sort_order)
48
+ src_tokens = src_tokens.index_select(0, sort_order)
49
+
50
+ prev_output_tokens = None
51
+ target = None
52
+ if samples[0].get("target", None) is not None:
53
+ target = merge(
54
+ "target",
55
+ left_pad=left_pad_target,
56
+ pad_to_length=pad_to_length["target"]
57
+ if pad_to_length is not None
58
+ else None,
59
+ )
60
+ target = target.index_select(0, sort_order)
61
+ ntokens = sum(len(s["target"]) for s in samples)
62
+
63
+ if input_feeding:
64
+ # we create a shifted version of targets for feeding the
65
+ # previous output token(s) into the next decoder step
66
+ prev_output_tokens = merge(
67
+ "target",
68
+ left_pad=left_pad_target,
69
+ move_eos_to_beginning=True,
70
+ pad_to_length=pad_to_length["target"]
71
+ if pad_to_length is not None
72
+ else None,
73
+ )
74
+ prev_output_tokens = prev_output_tokens.index_select(0, sort_order)
75
+ else:
76
+ ntokens = sum(len(s["source"]) for s in samples)
77
+
78
+ batch = {
79
+ "id": id,
80
+ "ntokens": ntokens,
81
+ "net_input": {
82
+ "src_tokens": src_tokens,
83
+ "src_lengths": src_lengths,
84
+ },
85
+ "target": target,
86
+ "nsentences": samples[0]["source"].size(0),
87
+ "sort_order": sort_order,
88
+ }
89
+ if prev_output_tokens is not None:
90
+ batch["net_input"]["prev_output_tokens"] = prev_output_tokens
91
+
92
+ return batch
93
+
94
+
95
+ class DenoisingDataset(FairseqDataset):
96
+ """
97
+ A wrapper around TokenBlockDataset for BART dataset.
98
+
99
+ Args:
100
+ dataset (TokenBlockDataset): dataset to wrap
101
+ sizes (List[int]): sentence lengths
102
+ vocab (~fairseq.data.Dictionary): vocabulary
103
+ mask_idx (int): dictionary index used for masked token
104
+ mask_whole_words: only mask whole words. This should be a byte mask
105
+ over vocab indices, indicating whether it is the beginning of a
106
+ word. We will extend any mask to encompass the whole word.
107
+ shuffle (bool, optional): shuffle the elements before batching.
108
+ Default: ``True``
109
+ seed: Seed for random number generator for reproducibility.
110
+ """
111
+
112
+ def __init__(
113
+ self,
114
+ dataset,
115
+ sizes,
116
+ vocab,
117
+ mask_idx,
118
+ mask_whole_words,
119
+ shuffle,
120
+ seed,
121
+ mask,
122
+ mask_random,
123
+ insert,
124
+ rotate,
125
+ permute_sentences,
126
+ bpe,
127
+ replace_length,
128
+ mask_length,
129
+ poisson_lambda,
130
+ eos=None,
131
+ item_transform_func=None,
132
+ ):
133
+ self.dataset = dataset
134
+
135
+ self.sizes = sizes
136
+
137
+ self.vocab = vocab
138
+ self.shuffle = shuffle
139
+ self.seed = seed
140
+ self.mask_idx = mask_idx
141
+ self.mask_whole_word = mask_whole_words
142
+ self.mask_ratio = mask
143
+ self.random_ratio = mask_random
144
+ self.insert_ratio = insert
145
+ self.rotate_ratio = rotate
146
+ self.permute_sentence_ratio = permute_sentences
147
+ self.eos = eos if eos is not None else vocab.eos()
148
+ self.item_transform_func = item_transform_func
149
+
150
+ if bpe != "gpt2":
151
+ self.full_stop_index = self.vocab.eos()
152
+ else:
153
+ assert bpe == "gpt2"
154
+ self.full_stop_index = self.vocab.index("13")
155
+
156
+ self.replace_length = replace_length
157
+ if self.replace_length not in [-1, 0, 1]:
158
+ raise ValueError(f"invalid arg: replace_length={self.replace_length}")
159
+ if mask_length not in ["subword", "word", "span-poisson"]:
160
+ raise ValueError(f"invalid arg: mask-length={mask_length}")
161
+ if mask_length == "subword" and replace_length not in [0, 1]:
162
+ raise ValueError(f"if using subwords, use replace-length=1 or 0")
163
+
164
+ self.mask_span_distribution = None
165
+ if mask_length == "span-poisson":
166
+ _lambda = poisson_lambda
167
+
168
+ lambda_to_the_k = 1
169
+ e_to_the_minus_lambda = math.exp(-_lambda)
170
+ k_factorial = 1
171
+ ps = []
172
+ for k in range(0, 128):
173
+ ps.append(e_to_the_minus_lambda * lambda_to_the_k / k_factorial)
174
+ lambda_to_the_k *= _lambda
175
+ k_factorial *= k + 1
176
+ if ps[-1] < 0.0000001:
177
+ break
178
+ ps = torch.FloatTensor(ps)
179
+ self.mask_span_distribution = torch.distributions.Categorical(ps)
180
+
181
+ self.epoch = 0
182
+
183
+ @property
184
+ def can_reuse_epoch_itr_across_epochs(self):
185
+ return True # only the noise changes, not item sizes
186
+
187
+ def set_epoch(self, epoch, **unused):
188
+ self.epoch = epoch
189
+
190
+ def __getitem__(self, index):
191
+ with data_utils.numpy_seed(self.seed, self.epoch, index):
192
+ tokens = self.dataset[index]
193
+ assert tokens[-1] == self.eos
194
+ source, target = tokens, tokens.clone()
195
+
196
+ if self.permute_sentence_ratio > 0.0:
197
+ source = self.permute_sentences(source, self.permute_sentence_ratio)
198
+
199
+ if self.mask_ratio > 0:
200
+ source = self.add_whole_word_mask(source, self.mask_ratio)
201
+
202
+ if self.insert_ratio > 0:
203
+ source = self.add_insertion_noise(source, self.insert_ratio)
204
+
205
+ if self.rotate_ratio > 0.0 and np.random.random() < self.rotate_ratio:
206
+ source = self.add_rolling_noise(source)
207
+ # there can additional changes to make:
208
+ if self.item_transform_func is not None:
209
+ source, target = self.item_transform_func(source, target)
210
+
211
+ assert (source >= 0).all()
212
+ assert (source[1:-1] >= 1).all()
213
+ assert (source <= len(self.vocab)).all()
214
+ assert source[0] == self.vocab.bos()
215
+ assert source[-1] == self.eos
216
+ return {
217
+ "id": index,
218
+ "source": source,
219
+ "target": target,
220
+ }
221
+
222
+ def __len__(self):
223
+ return len(self.dataset)
224
+
225
+ def permute_sentences(self, source, p=1.0):
226
+ full_stops = source == self.full_stop_index
227
+ # Pretend it ends with a full stop so last span is a sentence
228
+ full_stops[-2] = 1
229
+
230
+ # Tokens that are full stops, where the previous token is not
231
+ sentence_ends = (full_stops[1:] * ~full_stops[:-1]).nonzero(as_tuple=False) + 2
232
+ result = source.clone()
233
+
234
+ num_sentences = sentence_ends.size(0)
235
+ num_to_permute = math.ceil((num_sentences * 2 * p) / 2.0)
236
+ substitutions = torch.randperm(num_sentences)[:num_to_permute]
237
+ ordering = torch.arange(0, num_sentences)
238
+ ordering[substitutions] = substitutions[torch.randperm(num_to_permute)]
239
+
240
+ # Ignore <bos> at start
241
+ index = 1
242
+ for i in ordering:
243
+ sentence = source[(sentence_ends[i - 1] if i > 0 else 1) : sentence_ends[i]]
244
+ result[index : index + sentence.size(0)] = sentence
245
+ index += sentence.size(0)
246
+ return result
247
+
248
+ def word_starts(self, source):
249
+ if self.mask_whole_word is not None:
250
+ is_word_start = self.mask_whole_word.gather(0, source)
251
+ else:
252
+ is_word_start = torch.ones(source.size())
253
+ is_word_start[0] = 0
254
+ is_word_start[-1] = 0
255
+ return is_word_start
256
+
257
+ def add_whole_word_mask(self, source, p):
258
+ is_word_start = self.word_starts(source)
259
+ num_to_mask = int(math.ceil(is_word_start.float().sum() * p))
260
+ num_inserts = 0
261
+ if num_to_mask == 0:
262
+ return source
263
+
264
+ if self.mask_span_distribution is not None:
265
+ lengths = self.mask_span_distribution.sample(sample_shape=(num_to_mask,))
266
+
267
+ # Make sure we have enough to mask
268
+ cum_length = torch.cumsum(lengths, 0)
269
+ while cum_length[-1] < num_to_mask:
270
+ lengths = torch.cat(
271
+ [
272
+ lengths,
273
+ self.mask_span_distribution.sample(sample_shape=(num_to_mask,)),
274
+ ],
275
+ dim=0,
276
+ )
277
+ cum_length = torch.cumsum(lengths, 0)
278
+
279
+ # Trim to masking budget
280
+ i = 0
281
+ while cum_length[i] < num_to_mask:
282
+ i += 1
283
+ lengths[i] = num_to_mask - (0 if i == 0 else cum_length[i - 1])
284
+ num_to_mask = i + 1
285
+ lengths = lengths[:num_to_mask]
286
+
287
+ # Handle 0-length mask (inserts) separately
288
+ lengths = lengths[lengths > 0]
289
+ num_inserts = num_to_mask - lengths.size(0)
290
+ num_to_mask -= num_inserts
291
+ if num_to_mask == 0:
292
+ return self.add_insertion_noise(source, num_inserts / source.size(0))
293
+
294
+ assert (lengths > 0).all()
295
+ else:
296
+ lengths = torch.ones((num_to_mask,)).long()
297
+ assert is_word_start[-1] == 0
298
+ word_starts = is_word_start.nonzero(as_tuple=False)
299
+ indices = word_starts[
300
+ torch.randperm(word_starts.size(0))[:num_to_mask]
301
+ ].squeeze(1)
302
+ mask_random = torch.FloatTensor(num_to_mask).uniform_() < self.random_ratio
303
+
304
+ source_length = source.size(0)
305
+ assert source_length - 1 not in indices
306
+ to_keep = torch.ones(source_length, dtype=torch.bool)
307
+ is_word_start[
308
+ -1
309
+ ] = 255 # acts as a long length, so spans don't go over the end of doc
310
+ if self.replace_length == 0:
311
+ to_keep[indices] = 0
312
+ else:
313
+ # keep index, but replace it with [MASK]
314
+ source[indices] = self.mask_idx
315
+ source[indices[mask_random]] = torch.randint(
316
+ 1, len(self.vocab), size=(mask_random.sum(),)
317
+ )
318
+
319
+ if self.mask_span_distribution is not None:
320
+ assert len(lengths.size()) == 1
321
+ assert lengths.size() == indices.size()
322
+ lengths -= 1
323
+ while indices.size(0) > 0:
324
+ assert lengths.size() == indices.size()
325
+ lengths -= is_word_start[indices + 1].long()
326
+ uncompleted = lengths >= 0
327
+ indices = indices[uncompleted] + 1
328
+ mask_random = mask_random[uncompleted]
329
+ lengths = lengths[uncompleted]
330
+ if self.replace_length != -1:
331
+ # delete token
332
+ to_keep[indices] = 0
333
+ else:
334
+ # keep index, but replace it with [MASK]
335
+ source[indices] = self.mask_idx
336
+ source[indices[mask_random]] = torch.randint(
337
+ 1, len(self.vocab), size=(mask_random.sum(),)
338
+ )
339
+ else:
340
+ # A bit faster when all lengths are 1
341
+ while indices.size(0) > 0:
342
+ uncompleted = is_word_start[indices + 1] == 0
343
+ indices = indices[uncompleted] + 1
344
+ mask_random = mask_random[uncompleted]
345
+ if self.replace_length != -1:
346
+ # delete token
347
+ to_keep[indices] = 0
348
+ else:
349
+ # keep index, but replace it with [MASK]
350
+ source[indices] = self.mask_idx
351
+ source[indices[mask_random]] = torch.randint(
352
+ 1, len(self.vocab), size=(mask_random.sum(),)
353
+ )
354
+
355
+ assert source_length - 1 not in indices
356
+
357
+ source = source[to_keep]
358
+
359
+ if num_inserts > 0:
360
+ source = self.add_insertion_noise(source, num_inserts / source.size(0))
361
+
362
+ return source
363
+
364
+ def add_permuted_noise(self, tokens, p):
365
+ num_words = len(tokens)
366
+ num_to_permute = math.ceil(((num_words * 2) * p) / 2.0)
367
+ substitutions = torch.randperm(num_words - 2)[:num_to_permute] + 1
368
+ tokens[substitutions] = tokens[substitutions[torch.randperm(num_to_permute)]]
369
+ return tokens
370
+
371
+ def add_rolling_noise(self, tokens):
372
+ offset = np.random.randint(1, max(1, tokens.size(-1) - 1) + 1)
373
+ tokens = torch.cat(
374
+ (tokens[0:1], tokens[offset:-1], tokens[1:offset], tokens[-1:]),
375
+ dim=0,
376
+ )
377
+ return tokens
378
+
379
+ def add_insertion_noise(self, tokens, p):
380
+ if p == 0.0:
381
+ return tokens
382
+
383
+ num_tokens = len(tokens)
384
+ n = int(math.ceil(num_tokens * p))
385
+
386
+ noise_indices = torch.randperm(num_tokens + n - 2)[:n] + 1
387
+ noise_mask = torch.zeros(size=(num_tokens + n,), dtype=torch.bool)
388
+ noise_mask[noise_indices] = 1
389
+ result = torch.LongTensor(n + len(tokens)).fill_(-1)
390
+
391
+ num_random = int(math.ceil(n * self.random_ratio))
392
+ result[noise_indices[num_random:]] = self.mask_idx
393
+ result[noise_indices[:num_random]] = torch.randint(
394
+ low=1, high=len(self.vocab), size=(num_random,)
395
+ )
396
+
397
+ result[~noise_mask] = tokens
398
+
399
+ assert (result >= 0).all()
400
+ return result
401
+
402
+ def collater(self, samples, pad_to_length=None):
403
+ """Merge a list of samples to form a mini-batch.
404
+ Args:
405
+ samples (List[dict]): samples to collate
406
+ Returns:
407
+ dict: a mini-batch of data
408
+ """
409
+ return collate(
410
+ samples, self.vocab.pad(), self.eos, self.vocab, pad_to_length=pad_to_length
411
+ )
412
+
413
+ def num_tokens(self, index):
414
+ """Return the number of tokens in a sample. This value is used to
415
+ enforce ``--max-tokens`` during batching."""
416
+ return self.sizes[index]
417
+
418
+ def size(self, index):
419
+ """Return an example's size as a float or tuple. This value is used when
420
+ filtering a dataset with ``--max-positions``."""
421
+ return self.sizes[index]
422
+
423
+ def ordered_indices(self):
424
+ """Return an ordered list of indices. Batches will be constructed based
425
+ on this order."""
426
+ if self.shuffle:
427
+ indices = np.random.permutation(len(self))
428
+ else:
429
+ indices = np.arange(len(self))
430
+ return indices[np.argsort(self.sizes[indices], kind="mergesort")]
431
+
432
+ def prefetch(self, indices):
433
+ self.src.prefetch(indices)
434
+ self.tgt.prefetch(indices)
435
+
436
+ @property
437
+ def supports_prefetch(self):
438
+ return (
439
+ hasattr(self.src, "supports_prefetch")
440
+ and self.src.supports_prefetch
441
+ and hasattr(self.tgt, "supports_prefetch")
442
+ and self.tgt.supports_prefetch
443
+ )
fairseq/fairseq/data/dictionary.py ADDED
@@ -0,0 +1,403 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import os
7
+ from collections import Counter
8
+ from multiprocessing import Pool
9
+
10
+ import torch
11
+ from fairseq import utils
12
+ from fairseq.data import data_utils
13
+ from fairseq.file_chunker_utils import Chunker, find_offsets
14
+ from fairseq.file_io import PathManager
15
+ from fairseq.tokenizer import tokenize_line
16
+
17
+
18
+ class Dictionary:
19
+ """A mapping from symbols to consecutive integers"""
20
+
21
+ def __init__(
22
+ self,
23
+ *, # begin keyword-only arguments
24
+ bos="<s>",
25
+ pad="<pad>",
26
+ eos="</s>",
27
+ unk="<unk>",
28
+ extra_special_symbols=None,
29
+ add_special_symbols=True,
30
+ ):
31
+ self.bos_word, self.unk_word, self.pad_word, self.eos_word = bos, unk, pad, eos
32
+ self.symbols = []
33
+ self.count = []
34
+ self.indices = {}
35
+ if add_special_symbols:
36
+ self.bos_index = self.add_symbol(bos)
37
+ self.pad_index = self.add_symbol(pad)
38
+ self.eos_index = self.add_symbol(eos)
39
+ self.unk_index = self.add_symbol(unk)
40
+ if extra_special_symbols:
41
+ for s in extra_special_symbols:
42
+ self.add_symbol(s)
43
+ self.nspecial = len(self.symbols)
44
+
45
+ def __eq__(self, other):
46
+ return self.indices == other.indices
47
+
48
+ def __getitem__(self, idx):
49
+ if idx < len(self.symbols):
50
+ return self.symbols[idx]
51
+ return self.unk_word
52
+
53
+ def get_count(self, idx):
54
+ return self.count[idx]
55
+
56
+ def __len__(self):
57
+ """Returns the number of symbols in the dictionary"""
58
+ return len(self.symbols)
59
+
60
+ def __contains__(self, sym):
61
+ return sym in self.indices
62
+
63
+ def index(self, sym):
64
+ """Returns the index of the specified symbol"""
65
+ assert isinstance(sym, str)
66
+ if sym in self.indices:
67
+ return self.indices[sym]
68
+ return self.unk_index
69
+
70
+ def string(
71
+ self,
72
+ tensor,
73
+ bpe_symbol=None,
74
+ escape_unk=False,
75
+ extra_symbols_to_ignore=None,
76
+ unk_string=None,
77
+ include_eos=False,
78
+ separator=" ",
79
+ ):
80
+ """Helper for converting a tensor of token indices to a string.
81
+
82
+ Can optionally remove BPE symbols or escape <unk> words.
83
+ """
84
+ if torch.is_tensor(tensor) and tensor.dim() == 2:
85
+ return "\n".join(
86
+ self.string(
87
+ t,
88
+ bpe_symbol,
89
+ escape_unk,
90
+ extra_symbols_to_ignore,
91
+ include_eos=include_eos,
92
+ )
93
+ for t in tensor
94
+ )
95
+
96
+ extra_symbols_to_ignore = set(extra_symbols_to_ignore or [])
97
+ if not include_eos:
98
+ extra_symbols_to_ignore.add(self.eos())
99
+
100
+ def token_string(i):
101
+ if i == self.unk():
102
+ if unk_string is not None:
103
+ return unk_string
104
+ else:
105
+ return self.unk_string(escape_unk)
106
+ else:
107
+ return self[i]
108
+
109
+ if hasattr(self, "bos_index"):
110
+ extra_symbols_to_ignore.add(self.bos())
111
+
112
+ sent = separator.join(
113
+ token_string(i)
114
+ for i in tensor
115
+ if utils.item(i) not in extra_symbols_to_ignore
116
+ )
117
+
118
+ return data_utils.post_process(sent, bpe_symbol)
119
+
120
+ def unk_string(self, escape=False):
121
+ """Return unknown string, optionally escaped as: <<unk>>"""
122
+ if escape:
123
+ return "<{}>".format(self.unk_word)
124
+ else:
125
+ return self.unk_word
126
+
127
+ def add_symbol(self, word, n=1, overwrite=False):
128
+ """Adds a word to the dictionary"""
129
+ if word in self.indices and not overwrite:
130
+ idx = self.indices[word]
131
+ self.count[idx] = self.count[idx] + n
132
+ return idx
133
+ else:
134
+ idx = len(self.symbols)
135
+ self.indices[word] = idx
136
+ self.symbols.append(word)
137
+ self.count.append(n)
138
+ return idx
139
+
140
+ def update(self, new_dict):
141
+ """Updates counts from new dictionary."""
142
+ for word in new_dict.symbols:
143
+ idx2 = new_dict.indices[word]
144
+ if word in self.indices:
145
+ idx = self.indices[word]
146
+ self.count[idx] = self.count[idx] + new_dict.count[idx2]
147
+ else:
148
+ idx = len(self.symbols)
149
+ self.indices[word] = idx
150
+ self.symbols.append(word)
151
+ self.count.append(new_dict.count[idx2])
152
+
153
+ def finalize(self, threshold=-1, nwords=-1, padding_factor=8):
154
+ """Sort symbols by frequency in descending order, ignoring special ones.
155
+
156
+ Args:
157
+ - threshold defines the minimum word count
158
+ - nwords defines the total number of words in the final dictionary,
159
+ including special symbols
160
+ - padding_factor can be used to pad the dictionary size to be a
161
+ multiple of 8, which is important on some hardware (e.g., Nvidia
162
+ Tensor Cores).
163
+ """
164
+ if nwords <= 0:
165
+ nwords = len(self)
166
+
167
+ new_indices = dict(zip(self.symbols[: self.nspecial], range(self.nspecial)))
168
+ new_symbols = self.symbols[: self.nspecial]
169
+ new_count = self.count[: self.nspecial]
170
+
171
+ c = Counter(
172
+ dict(
173
+ sorted(zip(self.symbols[self.nspecial :], self.count[self.nspecial :]))
174
+ )
175
+ )
176
+ for symbol, count in c.most_common(nwords - self.nspecial):
177
+ if count >= threshold:
178
+ new_indices[symbol] = len(new_symbols)
179
+ new_symbols.append(symbol)
180
+ new_count.append(count)
181
+ else:
182
+ break
183
+
184
+ assert len(new_symbols) == len(new_indices)
185
+
186
+ self.count = list(new_count)
187
+ self.symbols = list(new_symbols)
188
+ self.indices = new_indices
189
+
190
+ self.pad_to_multiple_(padding_factor)
191
+
192
+ def pad_to_multiple_(self, padding_factor):
193
+ """Pad Dictionary size to be a multiple of *padding_factor*."""
194
+ if padding_factor > 1:
195
+ i = 0
196
+ while len(self) % padding_factor != 0:
197
+ symbol = "madeupword{:04d}".format(i)
198
+ self.add_symbol(symbol, n=0)
199
+ i += 1
200
+
201
+ def bos(self):
202
+ """Helper to get index of beginning-of-sentence symbol"""
203
+ return self.bos_index
204
+
205
+ def pad(self):
206
+ """Helper to get index of pad symbol"""
207
+ return self.pad_index
208
+
209
+ def eos(self):
210
+ """Helper to get index of end-of-sentence symbol"""
211
+ return self.eos_index
212
+
213
+ def unk(self):
214
+ """Helper to get index of unk symbol"""
215
+ return self.unk_index
216
+
217
+ @classmethod
218
+ def load(cls, f, add_special_symbols=True):
219
+ """Loads the dictionary from a text file with the format:
220
+
221
+ ```
222
+ <symbol0> <count0>
223
+ <symbol1> <count1>
224
+ ...
225
+ ```
226
+ """
227
+ d = cls(add_special_symbols=add_special_symbols)
228
+ d.add_from_file(f)
229
+ return d
230
+
231
+ def add_from_file(self, f):
232
+ """
233
+ Loads a pre-existing dictionary from a text file and adds its symbols
234
+ to this instance.
235
+ """
236
+ if isinstance(f, str):
237
+ try:
238
+ with open(PathManager.get_local_path(f), "r", encoding="utf-8") as fd:
239
+ self.add_from_file(fd)
240
+ except FileNotFoundError as fnfe:
241
+ raise fnfe
242
+ except UnicodeError:
243
+ raise Exception(
244
+ "Incorrect encoding detected in {}, please "
245
+ "rebuild the dataset".format(f)
246
+ )
247
+ return
248
+
249
+ lines = f.readlines()
250
+ indices_start_line = self._load_meta(lines)
251
+
252
+ for line in lines[indices_start_line:]:
253
+ try:
254
+ line, field = line.rstrip().rsplit(" ", 1)
255
+ if field == "#fairseq:overwrite":
256
+ overwrite = True
257
+ line, field = line.rsplit(" ", 1)
258
+ else:
259
+ overwrite = False
260
+ count = int(field)
261
+ word = line
262
+ if word in self and not overwrite:
263
+ raise RuntimeError(
264
+ "Duplicate word found when loading Dictionary: '{}'. "
265
+ "Duplicate words can overwrite earlier ones by adding the "
266
+ "#fairseq:overwrite flag at the end of the corresponding row "
267
+ "in the dictionary file. If using the Camembert model, please "
268
+ "download an updated copy of the model file.".format(word)
269
+ )
270
+ self.add_symbol(word, n=count, overwrite=overwrite)
271
+ except ValueError:
272
+ raise ValueError(
273
+ f"Incorrect dictionary format, expected '<token> <cnt> [flags]': \"{line}\""
274
+ )
275
+
276
+ def _save(self, f, kv_iterator):
277
+ if isinstance(f, str):
278
+ PathManager.mkdirs(os.path.dirname(f))
279
+ with PathManager.open(f, "w", encoding="utf-8") as fd:
280
+ return self.save(fd)
281
+ for k, v in kv_iterator:
282
+ print("{} {}".format(k, v), file=f)
283
+
284
+ def _get_meta(self):
285
+ return [], []
286
+
287
+ def _load_meta(self, lines):
288
+ return 0
289
+
290
+ def save(self, f):
291
+ """Stores dictionary into a text file"""
292
+ ex_keys, ex_vals = self._get_meta()
293
+ self._save(
294
+ f,
295
+ zip(
296
+ ex_keys + self.symbols[self.nspecial :],
297
+ ex_vals + self.count[self.nspecial :],
298
+ ),
299
+ )
300
+
301
+ def dummy_sentence(self, length):
302
+ t = torch.Tensor(length).uniform_(self.nspecial + 1, len(self)).long()
303
+ t[-1] = self.eos()
304
+ return t
305
+
306
+ def encode_line(
307
+ self,
308
+ line,
309
+ line_tokenizer=tokenize_line,
310
+ add_if_not_exist=True,
311
+ consumer=None,
312
+ append_eos=True,
313
+ reverse_order=False,
314
+ ) -> torch.IntTensor:
315
+ words = line_tokenizer(line)
316
+ if reverse_order:
317
+ words = list(reversed(words))
318
+ nwords = len(words)
319
+ ids = torch.IntTensor(nwords + 1 if append_eos else nwords)
320
+
321
+ for i, word in enumerate(words):
322
+ if add_if_not_exist:
323
+ idx = self.add_symbol(word)
324
+ else:
325
+ idx = self.index(word)
326
+ if consumer is not None:
327
+ consumer(word, idx)
328
+ ids[i] = idx
329
+ if append_eos:
330
+ ids[nwords] = self.eos_index
331
+ return ids
332
+
333
+ @staticmethod
334
+ def _add_file_to_dictionary_single_worker(
335
+ filename,
336
+ tokenize,
337
+ eos_word,
338
+ start_offset,
339
+ end_offset,
340
+ ):
341
+ counter = Counter()
342
+ with Chunker(filename, start_offset, end_offset) as line_iterator:
343
+ for line in line_iterator:
344
+ for word in tokenize(line):
345
+ counter.update([word])
346
+ counter.update([eos_word])
347
+ return counter
348
+
349
+ @staticmethod
350
+ def add_file_to_dictionary(filename, dict, tokenize, num_workers):
351
+ def merge_result(counter):
352
+ for w, c in sorted(counter.items()):
353
+ dict.add_symbol(w, c)
354
+
355
+ local_file = PathManager.get_local_path(filename)
356
+ offsets = find_offsets(local_file, num_workers)
357
+ if num_workers > 1:
358
+ chunks = zip(offsets, offsets[1:])
359
+ pool = Pool(processes=num_workers)
360
+ results = []
361
+ for (start_offset, end_offset) in chunks:
362
+ results.append(
363
+ pool.apply_async(
364
+ Dictionary._add_file_to_dictionary_single_worker,
365
+ (
366
+ local_file,
367
+ tokenize,
368
+ dict.eos_word,
369
+ start_offset,
370
+ end_offset,
371
+ ),
372
+ )
373
+ )
374
+ pool.close()
375
+ pool.join()
376
+ for r in results:
377
+ merge_result(r.get())
378
+ else:
379
+ merge_result(
380
+ Dictionary._add_file_to_dictionary_single_worker(
381
+ local_file, tokenize, dict.eos_word, offsets[0], offsets[1]
382
+ )
383
+ )
384
+
385
+
386
+ class TruncatedDictionary(object):
387
+ def __init__(self, wrapped_dict, length):
388
+ self.__class__ = type(
389
+ wrapped_dict.__class__.__name__,
390
+ (self.__class__, wrapped_dict.__class__),
391
+ {},
392
+ )
393
+ self.__dict__ = wrapped_dict.__dict__
394
+ self.wrapped_dict = wrapped_dict
395
+ self.length = min(len(self.wrapped_dict), length)
396
+
397
+ def __len__(self):
398
+ return self.length
399
+
400
+ def __getitem__(self, i):
401
+ if i < self.length:
402
+ return self.wrapped_dict[i]
403
+ return self.wrapped_dict.unk()
fairseq/fairseq/data/fairseq_dataset.py ADDED
@@ -0,0 +1,205 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import logging
7
+ import numpy as np
8
+ import torch.utils.data
9
+ from fairseq.data import data_utils
10
+
11
+ logger = logging.getLogger(__name__)
12
+
13
+
14
+ class EpochListening:
15
+ """Mixin for receiving updates whenever the epoch increments."""
16
+
17
+ @property
18
+ def can_reuse_epoch_itr_across_epochs(self):
19
+ """
20
+ Whether we can reuse the :class:`fairseq.data.EpochBatchIterator` for
21
+ this dataset across epochs.
22
+
23
+ This needs to return ``False`` if the sample sizes can change across
24
+ epochs, in which case we may need to regenerate batches at each epoch.
25
+ If your dataset relies in ``set_epoch`` then you should consider setting
26
+ this to ``False``.
27
+ """
28
+ return True
29
+
30
+ def set_epoch(self, epoch):
31
+ """Will receive the updated epoch number at the beginning of the epoch."""
32
+ pass
33
+
34
+
35
+ class FairseqDataset(torch.utils.data.Dataset, EpochListening):
36
+ """A dataset that provides helpers for batching."""
37
+
38
+ def __getitem__(self, index):
39
+ raise NotImplementedError
40
+
41
+ def __len__(self):
42
+ raise NotImplementedError
43
+
44
+ def collater(self, samples):
45
+ """Merge a list of samples to form a mini-batch.
46
+
47
+ Args:
48
+ samples (List[dict]): samples to collate
49
+
50
+ Returns:
51
+ dict: a mini-batch suitable for forwarding with a Model
52
+ """
53
+ raise NotImplementedError
54
+
55
+ def num_tokens(self, index):
56
+ """Return the number of tokens in a sample. This value is used to
57
+ enforce ``--max-tokens`` during batching."""
58
+ raise NotImplementedError
59
+
60
+ def num_tokens_vec(self, indices):
61
+ """Return the number of tokens for a set of positions defined by indices.
62
+ This value is used to enforce ``--max-tokens`` during batching."""
63
+ raise NotImplementedError
64
+
65
+ def size(self, index):
66
+ """Return an example's size as a float or tuple. This value is used when
67
+ filtering a dataset with ``--max-positions``."""
68
+ raise NotImplementedError
69
+
70
+ def ordered_indices(self):
71
+ """Return an ordered list of indices. Batches will be constructed based
72
+ on this order."""
73
+ return np.arange(len(self), dtype=np.int64)
74
+
75
+ @property
76
+ def supports_prefetch(self):
77
+ """Whether this dataset supports prefetching."""
78
+ return False
79
+
80
+ def attr(self, attr: str, index: int):
81
+ return getattr(self, attr, None)
82
+
83
+ def prefetch(self, indices):
84
+ """Prefetch the data required for this epoch."""
85
+ raise NotImplementedError
86
+
87
+ def get_batch_shapes(self):
88
+ """
89
+ Return a list of valid batch shapes, for example::
90
+
91
+ [(8, 512), (16, 256), (32, 128)]
92
+
93
+ The first dimension of each tuple is the batch size and can be ``None``
94
+ to automatically infer the max batch size based on ``--max-tokens``.
95
+ The second dimension of each tuple is the max supported length as given
96
+ by :func:`fairseq.data.FairseqDataset.num_tokens`.
97
+
98
+ This will be used by :func:`fairseq.data.FairseqDataset.batch_by_size`
99
+ to restrict batch shapes. This is useful on TPUs to avoid too many
100
+ dynamic shapes (and recompilations).
101
+ """
102
+ return None
103
+
104
+ def batch_by_size(
105
+ self,
106
+ indices,
107
+ max_tokens=None,
108
+ max_sentences=None,
109
+ required_batch_size_multiple=1,
110
+ ):
111
+ """
112
+ Given an ordered set of indices, return batches according to
113
+ *max_tokens*, *max_sentences* and *required_batch_size_multiple*.
114
+ """
115
+ from fairseq.data import data_utils
116
+
117
+ fixed_shapes = self.get_batch_shapes()
118
+ if fixed_shapes is not None:
119
+
120
+ def adjust_bsz(bsz, num_tokens):
121
+ if bsz is None:
122
+ assert max_tokens is not None, "Must specify --max-tokens"
123
+ bsz = max_tokens // num_tokens
124
+ if max_sentences is not None:
125
+ bsz = min(bsz, max_sentences)
126
+ elif (
127
+ bsz >= required_batch_size_multiple
128
+ and bsz % required_batch_size_multiple != 0
129
+ ):
130
+ bsz -= bsz % required_batch_size_multiple
131
+ return bsz
132
+
133
+ fixed_shapes = np.array(
134
+ [
135
+ [adjust_bsz(bsz, num_tokens), num_tokens]
136
+ for (bsz, num_tokens) in fixed_shapes
137
+ ]
138
+ )
139
+
140
+ try:
141
+ num_tokens_vec = self.num_tokens_vec(indices).astype("int64")
142
+ except NotImplementedError:
143
+ num_tokens_vec = None
144
+
145
+ return data_utils.batch_by_size(
146
+ indices,
147
+ num_tokens_fn=self.num_tokens,
148
+ num_tokens_vec=num_tokens_vec,
149
+ max_tokens=max_tokens,
150
+ max_sentences=max_sentences,
151
+ required_batch_size_multiple=required_batch_size_multiple,
152
+ fixed_shapes=fixed_shapes,
153
+ )
154
+
155
+ def filter_indices_by_size(self, indices, max_sizes):
156
+ """
157
+ Filter a list of sample indices. Remove those that are longer than
158
+ specified in *max_sizes*.
159
+
160
+ WARNING: don't update, override method in child classes
161
+
162
+ Args:
163
+ indices (np.array): original array of sample indices
164
+ max_sizes (int or list[int] or tuple[int]): max sample size,
165
+ can be defined separately for src and tgt (then list or tuple)
166
+
167
+ Returns:
168
+ np.array: filtered sample array
169
+ list: list of removed indices
170
+ """
171
+ if isinstance(max_sizes, float) or isinstance(max_sizes, int):
172
+ if hasattr(self, "sizes") and isinstance(self.sizes, np.ndarray):
173
+ ignored = indices[self.sizes[indices] > max_sizes].tolist()
174
+ indices = indices[self.sizes[indices] <= max_sizes]
175
+ elif (
176
+ hasattr(self, "sizes")
177
+ and isinstance(self.sizes, list)
178
+ and len(self.sizes) == 1
179
+ ):
180
+ ignored = indices[self.sizes[0][indices] > max_sizes].tolist()
181
+ indices = indices[self.sizes[0][indices] <= max_sizes]
182
+ else:
183
+ indices, ignored = data_utils._filter_by_size_dynamic(
184
+ indices, self.size, max_sizes
185
+ )
186
+ else:
187
+ indices, ignored = data_utils._filter_by_size_dynamic(
188
+ indices, self.size, max_sizes
189
+ )
190
+ return indices, ignored
191
+
192
+ @property
193
+ def supports_fetch_outside_dataloader(self):
194
+ """Whether this dataset supports fetching outside the workers of the dataloader."""
195
+ return True
196
+
197
+
198
+ class FairseqIterableDataset(torch.utils.data.IterableDataset, EpochListening):
199
+ """
200
+ For datasets that need to be read sequentially, usually because the data is
201
+ being streamed or otherwise can't be manipulated on a single machine.
202
+ """
203
+
204
+ def __iter__(self):
205
+ raise NotImplementedError
fairseq/fairseq/data/fasta_dataset.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import os
7
+ import subprocess
8
+ import threading
9
+ from pathlib import Path
10
+
11
+ import numpy as np
12
+ import torch
13
+
14
+
15
+ def fasta_file_path(prefix_path):
16
+ return prefix_path + ".fasta"
17
+
18
+
19
+ class FastaDataset(torch.utils.data.Dataset):
20
+ """
21
+ For loading protein sequence datasets in the common FASTA data format
22
+ """
23
+
24
+ def __init__(self, path: str, cache_indices=False):
25
+ self.fn = fasta_file_path(path)
26
+ self.threadlocal = threading.local()
27
+ self.cache = Path(f"{path}.fasta.idx.npy")
28
+ if cache_indices:
29
+ if self.cache.exists():
30
+ self.offsets, self.sizes = np.load(self.cache)
31
+ else:
32
+ self.offsets, self.sizes = self._build_index(path)
33
+ np.save(self.cache, np.stack([self.offsets, self.sizes]))
34
+ else:
35
+ self.offsets, self.sizes = self._build_index(path)
36
+
37
+ def _get_file(self):
38
+ if not hasattr(self.threadlocal, "f"):
39
+ self.threadlocal.f = open(self.fn, "r")
40
+ return self.threadlocal.f
41
+
42
+ def __getitem__(self, idx):
43
+ f = self._get_file()
44
+ f.seek(self.offsets[idx])
45
+ desc = f.readline().strip()
46
+ line = f.readline()
47
+ seq = ""
48
+ while line != "" and line[0] != ">":
49
+ seq += line.strip()
50
+ line = f.readline()
51
+ return desc, seq
52
+
53
+ def __len__(self):
54
+ return self.offsets.size
55
+
56
+ def _build_index(self, path: str):
57
+ # Use grep and awk to get 100M/s on local SSD.
58
+ # Should process your enormous 100G fasta in ~10 min single core...
59
+ path = fasta_file_path(path)
60
+ bytes_offsets = subprocess.check_output(
61
+ f"cat {path} | tqdm --bytes --total $(wc -c < {path})"
62
+ "| grep --byte-offset '^>' -o | cut -d: -f1",
63
+ shell=True,
64
+ )
65
+ fasta_lengths = subprocess.check_output(
66
+ f"cat {path} | tqdm --bytes --total $(wc -c < {path})"
67
+ "| awk '/^>/ {print \"\";next;} { printf(\"%s\",$0);}' | tail -n+2 | awk '{print length($1)}'",
68
+ shell=True,
69
+ )
70
+ bytes_np = np.fromstring(bytes_offsets, dtype=np.int64, sep=" ")
71
+ sizes_np = np.fromstring(fasta_lengths, dtype=np.int64, sep=" ")
72
+ return bytes_np, sizes_np
73
+
74
+ def __setstate__(self, state):
75
+ self.__dict__ = state
76
+ self.threadlocal = threading.local()
77
+
78
+ def __getstate__(self):
79
+ d = {}
80
+ for i, v in self.__dict__.items():
81
+ if i != "threadlocal":
82
+ d[i] = v
83
+ return d
84
+
85
+ def __del__(self):
86
+ if hasattr(self.threadlocal, "f"):
87
+ self.threadlocal.f.close()
88
+ del self.threadlocal.f
89
+
90
+ @staticmethod
91
+ def exists(path):
92
+ return os.path.exists(fasta_file_path(path))
93
+
94
+
95
+ class EncodedFastaDataset(FastaDataset):
96
+ """
97
+ The FastaDataset returns raw sequences - this allows us to return
98
+ indices with a dictionary instead.
99
+ """
100
+
101
+ def __init__(self, path, dictionary):
102
+ super().__init__(path, cache_indices=True)
103
+ self.dictionary = dictionary
104
+
105
+ def __getitem__(self, idx):
106
+ desc, seq = super().__getitem__(idx)
107
+ return self.dictionary.encode_line(seq, line_tokenizer=list).long()
fairseq/fairseq/data/id_dataset.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import torch
7
+
8
+ from . import FairseqDataset
9
+
10
+
11
+ class IdDataset(FairseqDataset):
12
+ def __getitem__(self, index):
13
+ return index
14
+
15
+ def __len__(self):
16
+ return 0
17
+
18
+ def collater(self, samples):
19
+ return torch.tensor(samples)