Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +1 -0
- fairseq/fairseq/criterions/__pycache__/__init__.cpython-310.pyc +0 -0
- fairseq/fairseq/criterions/__pycache__/adaptive_loss.cpython-310.pyc +0 -0
- fairseq/fairseq/criterions/__pycache__/composite_loss.cpython-310.pyc +0 -0
- fairseq/fairseq/criterions/__pycache__/cross_entropy.cpython-310.pyc +0 -0
- fairseq/fairseq/criterions/__pycache__/ctc.cpython-310.pyc +0 -0
- fairseq/fairseq/criterions/__pycache__/fastspeech2_loss.cpython-310.pyc +0 -0
- fairseq/fairseq/criterions/__pycache__/hubert_criterion.cpython-310.pyc +0 -0
- fairseq/fairseq/criterions/__pycache__/label_smoothed_cross_entropy.cpython-310.pyc +0 -0
- fairseq/fairseq/criterions/__pycache__/label_smoothed_cross_entropy_latency_augmented.cpython-310.pyc +0 -0
- fairseq/fairseq/criterions/__pycache__/label_smoothed_cross_entropy_with_alignment.cpython-310.pyc +0 -0
- fairseq/fairseq/criterions/__pycache__/label_smoothed_cross_entropy_with_ctc.cpython-310.pyc +0 -0
- fairseq/fairseq/criterions/__pycache__/label_smoothed_cross_entropy_with_rdrop.cpython-310.pyc +0 -0
- fairseq/fairseq/criterions/__pycache__/legacy_masked_lm.cpython-310.pyc +0 -0
- fairseq/fairseq/criterions/__pycache__/masked_lm.cpython-310.pyc +0 -0
- fairseq/fairseq/criterions/__pycache__/model_criterion.cpython-310.pyc +0 -0
- fairseq/fairseq/criterions/__pycache__/nat_loss.cpython-310.pyc +0 -0
- fairseq/fairseq/criterions/__pycache__/sentence_prediction.cpython-310.pyc +0 -0
- fairseq/fairseq/criterions/__pycache__/sentence_prediction_adapters.cpython-310.pyc +0 -0
- fairseq/fairseq/criterions/__pycache__/sentence_ranking.cpython-310.pyc +0 -0
- fairseq/fairseq/criterions/__pycache__/speech_dlm_criterion.cpython-310.pyc +0 -0
- fairseq/fairseq/criterions/__pycache__/speech_to_speech_criterion.cpython-310.pyc +0 -0
- fairseq/fairseq/criterions/__pycache__/speech_ulm_criterion.cpython-310.pyc +0 -0
- fairseq/fairseq/criterions/__pycache__/tacotron2_loss.cpython-310.pyc +0 -0
- fairseq/fairseq/criterions/__pycache__/wav2vec_criterion.cpython-310.pyc +0 -0
- fairseq/fairseq/criterions/speech_to_speech_criterion.py +517 -0
- fairseq/fairseq/criterions/tacotron2_loss.py +227 -0
- fairseq/fairseq/criterions/wav2vec_criterion.py +231 -0
- fairseq/fairseq/data/__init__.py +137 -0
- fairseq/fairseq/data/add_class_target_dataset.py +79 -0
- fairseq/fairseq/data/add_target_dataset.py +83 -0
- fairseq/fairseq/data/append_token_dataset.py +41 -0
- fairseq/fairseq/data/audio/dataset_transforms/__pycache__/concataugment.cpython-310.pyc +0 -0
- fairseq/fairseq/data/audio/feature_transforms/__pycache__/delta_deltas.cpython-310.pyc +0 -0
- fairseq/fairseq/data/audio/feature_transforms/global_cmvn.py +29 -0
- fairseq/fairseq/data/audio/speech_to_text_dataset.py +733 -0
- fairseq/fairseq/data/backtranslation_dataset.py +165 -0
- fairseq/fairseq/data/base_wrapper_dataset.py +78 -0
- fairseq/fairseq/data/bucket_pad_length_dataset.py +78 -0
- fairseq/fairseq/data/codedataset.py +576 -0
- fairseq/fairseq/data/colorize_dataset.py +25 -0
- fairseq/fairseq/data/concat_dataset.py +124 -0
- fairseq/fairseq/data/concat_sentences_dataset.py +54 -0
- fairseq/fairseq/data/data_utils.py +1144 -0
- fairseq/fairseq/data/data_utils_fast.pyx +178 -0
- fairseq/fairseq/data/denoising_dataset.py +443 -0
- fairseq/fairseq/data/dictionary.py +403 -0
- fairseq/fairseq/data/fairseq_dataset.py +205 -0
- fairseq/fairseq/data/fasta_dataset.py +107 -0
- fairseq/fairseq/data/id_dataset.py +19 -0
.gitattributes
CHANGED
@@ -40,3 +40,4 @@ fairseq/alignment_train_cpu_binding.cpython-310-x86_64-linux-gnu.so filter=lfs d
|
|
40 |
fairseq/docs/fairseq.gif filter=lfs diff=lfs merge=lfs -text
|
41 |
fairseq/examples/hubert/tests/6313-76958-0021.flac filter=lfs diff=lfs merge=lfs -text
|
42 |
fairseq/examples/textless_nlp/speech-resynth/img/fig.png filter=lfs diff=lfs merge=lfs -text
|
|
|
|
40 |
fairseq/docs/fairseq.gif filter=lfs diff=lfs merge=lfs -text
|
41 |
fairseq/examples/hubert/tests/6313-76958-0021.flac filter=lfs diff=lfs merge=lfs -text
|
42 |
fairseq/examples/textless_nlp/speech-resynth/img/fig.png filter=lfs diff=lfs merge=lfs -text
|
43 |
+
fairseq/fairseq/libbase.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
|
fairseq/fairseq/criterions/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (1.06 kB). View file
|
|
fairseq/fairseq/criterions/__pycache__/adaptive_loss.cpython-310.pyc
ADDED
Binary file (4.76 kB). View file
|
|
fairseq/fairseq/criterions/__pycache__/composite_loss.cpython-310.pyc
ADDED
Binary file (4.49 kB). View file
|
|
fairseq/fairseq/criterions/__pycache__/cross_entropy.cpython-310.pyc
ADDED
Binary file (3.89 kB). View file
|
|
fairseq/fairseq/criterions/__pycache__/ctc.cpython-310.pyc
ADDED
Binary file (8.4 kB). View file
|
|
fairseq/fairseq/criterions/__pycache__/fastspeech2_loss.cpython-310.pyc
ADDED
Binary file (4.76 kB). View file
|
|
fairseq/fairseq/criterions/__pycache__/hubert_criterion.cpython-310.pyc
ADDED
Binary file (6.77 kB). View file
|
|
fairseq/fairseq/criterions/__pycache__/label_smoothed_cross_entropy.cpython-310.pyc
ADDED
Binary file (6.26 kB). View file
|
|
fairseq/fairseq/criterions/__pycache__/label_smoothed_cross_entropy_latency_augmented.cpython-310.pyc
ADDED
Binary file (5.89 kB). View file
|
|
fairseq/fairseq/criterions/__pycache__/label_smoothed_cross_entropy_with_alignment.cpython-310.pyc
ADDED
Binary file (4.78 kB). View file
|
|
fairseq/fairseq/criterions/__pycache__/label_smoothed_cross_entropy_with_ctc.cpython-310.pyc
ADDED
Binary file (3.33 kB). View file
|
|
fairseq/fairseq/criterions/__pycache__/label_smoothed_cross_entropy_with_rdrop.cpython-310.pyc
ADDED
Binary file (5.3 kB). View file
|
|
fairseq/fairseq/criterions/__pycache__/legacy_masked_lm.cpython-310.pyc
ADDED
Binary file (5.53 kB). View file
|
|
fairseq/fairseq/criterions/__pycache__/masked_lm.cpython-310.pyc
ADDED
Binary file (3.53 kB). View file
|
|
fairseq/fairseq/criterions/__pycache__/model_criterion.cpython-310.pyc
ADDED
Binary file (5.71 kB). View file
|
|
fairseq/fairseq/criterions/__pycache__/nat_loss.cpython-310.pyc
ADDED
Binary file (6.1 kB). View file
|
|
fairseq/fairseq/criterions/__pycache__/sentence_prediction.cpython-310.pyc
ADDED
Binary file (8.21 kB). View file
|
|
fairseq/fairseq/criterions/__pycache__/sentence_prediction_adapters.cpython-310.pyc
ADDED
Binary file (1.99 kB). View file
|
|
fairseq/fairseq/criterions/__pycache__/sentence_ranking.cpython-310.pyc
ADDED
Binary file (4.6 kB). View file
|
|
fairseq/fairseq/criterions/__pycache__/speech_dlm_criterion.cpython-310.pyc
ADDED
Binary file (9.03 kB). View file
|
|
fairseq/fairseq/criterions/__pycache__/speech_to_speech_criterion.cpython-310.pyc
ADDED
Binary file (11.5 kB). View file
|
|
fairseq/fairseq/criterions/__pycache__/speech_ulm_criterion.cpython-310.pyc
ADDED
Binary file (4.95 kB). View file
|
|
fairseq/fairseq/criterions/__pycache__/tacotron2_loss.cpython-310.pyc
ADDED
Binary file (7.46 kB). View file
|
|
fairseq/fairseq/criterions/__pycache__/wav2vec_criterion.cpython-310.pyc
ADDED
Binary file (6.5 kB). View file
|
|
fairseq/fairseq/criterions/speech_to_speech_criterion.py
ADDED
@@ -0,0 +1,517 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import logging
|
7 |
+
import math
|
8 |
+
from collections import OrderedDict
|
9 |
+
|
10 |
+
import torch
|
11 |
+
|
12 |
+
from fairseq import utils
|
13 |
+
from fairseq.logging import metrics
|
14 |
+
from fairseq.criterions import register_criterion
|
15 |
+
from fairseq.criterions.ctc import CtcCriterion
|
16 |
+
from fairseq.criterions.label_smoothed_cross_entropy_with_rdrop import (
|
17 |
+
RdropLabelSmoothedCrossEntropyCriterion,
|
18 |
+
RdropLabelSmoothedCrossEntropyCriterionConfig,
|
19 |
+
duplicate_input,
|
20 |
+
)
|
21 |
+
from fairseq.criterions.tacotron2_loss import (
|
22 |
+
Tacotron2Criterion,
|
23 |
+
Tacotron2CriterionConfig,
|
24 |
+
)
|
25 |
+
|
26 |
+
logger = logging.getLogger(__name__)
|
27 |
+
|
28 |
+
|
29 |
+
class MultitaskCriterion:
|
30 |
+
def __init__(self, multitask_tasks, rdrop_alpha=0.0):
|
31 |
+
self.rdrop_alpha = rdrop_alpha
|
32 |
+
self.rdrop_alpha_mtl = rdrop_alpha
|
33 |
+
|
34 |
+
self.multitask_criterion = OrderedDict()
|
35 |
+
self.multitask_loss_weight = OrderedDict()
|
36 |
+
for task_name, task_obj in multitask_tasks.items():
|
37 |
+
if task_obj.args.get_loss_weight(0) == 0:
|
38 |
+
logger.info(f"Skip {task_name} loss criterion")
|
39 |
+
continue
|
40 |
+
|
41 |
+
rdrop_alpha_task = task_obj.args.rdrop_alpha
|
42 |
+
if rdrop_alpha_task is None:
|
43 |
+
rdrop_alpha_task = rdrop_alpha
|
44 |
+
self.rdrop_alpha_mtl = rdrop_alpha_task
|
45 |
+
logger.info(f"rdrop_alpha is set to {rdrop_alpha_task} for {task_name}")
|
46 |
+
|
47 |
+
if task_obj.args.decoder_type == "ctc":
|
48 |
+
self.multitask_criterion[task_name] = CtcCriterion(
|
49 |
+
task_obj.args.criterion_cfg,
|
50 |
+
task_obj,
|
51 |
+
rdrop_alpha=rdrop_alpha_task,
|
52 |
+
)
|
53 |
+
else:
|
54 |
+
self.multitask_criterion[
|
55 |
+
task_name
|
56 |
+
] = RdropLabelSmoothedCrossEntropyCriterion(
|
57 |
+
task_obj,
|
58 |
+
task_obj.args.criterion_cfg.sentence_avg,
|
59 |
+
label_smoothing=task_obj.args.criterion_cfg.label_smoothing,
|
60 |
+
rdrop_alpha=rdrop_alpha_task,
|
61 |
+
)
|
62 |
+
|
63 |
+
def set_multitask_loss_weight(self, task_name, weight=0.0):
|
64 |
+
self.multitask_loss_weight[task_name] = weight
|
65 |
+
|
66 |
+
def get_multitask_loss(self, model, sample, model_out):
|
67 |
+
logging_output = {}
|
68 |
+
loss = 0.0
|
69 |
+
for task_name, task_criterion in self.multitask_criterion.items():
|
70 |
+
layer_id = task_criterion.task.args.input_layer
|
71 |
+
if isinstance(task_criterion, CtcCriterion):
|
72 |
+
if task_criterion.task.args.input_from == "encoder":
|
73 |
+
if len(model_out["encoder_padding_mask"]) > 0:
|
74 |
+
non_padding_mask = ~model_out["encoder_padding_mask"][0]
|
75 |
+
input_lengths = non_padding_mask.long().sum(-1)
|
76 |
+
else:
|
77 |
+
out = model_out["encoder_states"][layer_id]
|
78 |
+
input_lengths = out.new_full(
|
79 |
+
(out.shape[1],), out.shape[0]
|
80 |
+
).long()
|
81 |
+
|
82 |
+
task_sample = {
|
83 |
+
"net_input": {
|
84 |
+
"src_tokens": model_out["encoder_states"][
|
85 |
+
layer_id
|
86 |
+
], # check batch idx
|
87 |
+
"src_lengths": input_lengths,
|
88 |
+
},
|
89 |
+
"id": sample["id"],
|
90 |
+
}
|
91 |
+
else:
|
92 |
+
task_sample = {
|
93 |
+
"net_input": {
|
94 |
+
"src_tokens": model_out["inner_states"][layer_id],
|
95 |
+
"src_lengths": sample["target_lengths"],
|
96 |
+
},
|
97 |
+
"id": sample["id"],
|
98 |
+
}
|
99 |
+
else:
|
100 |
+
task_sample = {
|
101 |
+
"net_input": {
|
102 |
+
"src_tokens": sample["multitask"][task_name]["net_input"][
|
103 |
+
"prev_output_tokens"
|
104 |
+
],
|
105 |
+
"encoder_out": {
|
106 |
+
"encoder_out": [model_out["encoder_states"][layer_id]],
|
107 |
+
"encoder_padding_mask": model_out["encoder_padding_mask"],
|
108 |
+
},
|
109 |
+
}
|
110 |
+
}
|
111 |
+
|
112 |
+
for key in ["target", "target_lengths", "ntokens"]:
|
113 |
+
task_sample[key] = sample["multitask"][task_name][key]
|
114 |
+
|
115 |
+
if task_name == getattr(model, "mt_task_name", None):
|
116 |
+
decoder_out = model_out["mt_decoder_out"]
|
117 |
+
else:
|
118 |
+
decoder_out = None
|
119 |
+
task_loss, task_sample_size, task_logging_output = task_criterion(
|
120 |
+
model.multitask_decoders[task_name], task_sample, net_output=decoder_out
|
121 |
+
)
|
122 |
+
|
123 |
+
loss = loss + self.multitask_loss_weight[task_name] * task_loss
|
124 |
+
task_logging_output["loss_weight"] = self.multitask_loss_weight[task_name]
|
125 |
+
logging_output[task_name] = task_logging_output
|
126 |
+
return loss, logging_output
|
127 |
+
|
128 |
+
@classmethod
|
129 |
+
def reduce_metrics(cls, logging_outputs) -> None:
|
130 |
+
for task_name in logging_outputs[0]["multitask"].keys():
|
131 |
+
# different criterion may return different logging
|
132 |
+
# currently only reduce on loss, the most common one
|
133 |
+
# ideally the way that losses are reduced should also depend on the task type
|
134 |
+
loss_sum = sum(
|
135 |
+
log["multitask"][task_name].get("loss", 0) for log in logging_outputs
|
136 |
+
)
|
137 |
+
sample_size = sum(
|
138 |
+
log["multitask"][task_name].get("sample_size", 0)
|
139 |
+
for log in logging_outputs
|
140 |
+
)
|
141 |
+
|
142 |
+
metrics.log_scalar(
|
143 |
+
f"multitask_{task_name}_loss",
|
144 |
+
loss_sum / sample_size / math.log(2),
|
145 |
+
sample_size,
|
146 |
+
round=3,
|
147 |
+
)
|
148 |
+
|
149 |
+
loss_weight = logging_outputs[0]["multitask"][task_name].get(
|
150 |
+
"loss_weight", 0
|
151 |
+
)
|
152 |
+
metrics.log_scalar(
|
153 |
+
f"multitask_{task_name}_loss_weight",
|
154 |
+
loss_weight,
|
155 |
+
weight=0,
|
156 |
+
priority=250,
|
157 |
+
)
|
158 |
+
|
159 |
+
|
160 |
+
@register_criterion(
|
161 |
+
"speech_to_unit", dataclass=RdropLabelSmoothedCrossEntropyCriterionConfig
|
162 |
+
)
|
163 |
+
class SpeechToUnitMultitaskTaskCriterion(
|
164 |
+
RdropLabelSmoothedCrossEntropyCriterion, MultitaskCriterion
|
165 |
+
):
|
166 |
+
def __init__(
|
167 |
+
self,
|
168 |
+
task,
|
169 |
+
sentence_avg,
|
170 |
+
label_smoothing,
|
171 |
+
ignore_prefix_size=0,
|
172 |
+
report_accuracy=False,
|
173 |
+
rdrop_alpha=0.0,
|
174 |
+
):
|
175 |
+
super().__init__(
|
176 |
+
task,
|
177 |
+
sentence_avg,
|
178 |
+
label_smoothing,
|
179 |
+
ignore_prefix_size,
|
180 |
+
report_accuracy,
|
181 |
+
rdrop_alpha,
|
182 |
+
)
|
183 |
+
MultitaskCriterion.__init__(self, task.multitask_tasks, rdrop_alpha)
|
184 |
+
|
185 |
+
def forward(self, model, sample, reduce=True):
|
186 |
+
net_input_concat = {
|
187 |
+
"src_tokens": sample["net_input"]["src_tokens"],
|
188 |
+
"src_lengths": sample["net_input"]["src_lengths"],
|
189 |
+
"prev_output_tokens": sample["net_input"]["prev_output_tokens"],
|
190 |
+
"tgt_speaker": sample["net_input"].get("tgt_speaker", None),
|
191 |
+
"return_all_hiddens": True,
|
192 |
+
}
|
193 |
+
|
194 |
+
if self.rdrop_alpha > 0 or self.rdrop_alpha_mtl > 0:
|
195 |
+
net_input_concat = duplicate_input(net_input_concat)
|
196 |
+
|
197 |
+
net_output, extra = model(**net_input_concat)
|
198 |
+
loss, nll_loss, rdrop_kl_loss = self.compute_loss(
|
199 |
+
model, [net_output], sample, reduce=reduce
|
200 |
+
)
|
201 |
+
sample_size = (
|
202 |
+
sample["target"].size(0) if self.sentence_avg else sample["ntokens"]
|
203 |
+
)
|
204 |
+
logging_output = {
|
205 |
+
"loss": loss.data,
|
206 |
+
"nll_loss": nll_loss.data,
|
207 |
+
"ntokens": sample["ntokens"],
|
208 |
+
"nsentences": sample["target"].size(0),
|
209 |
+
"sample_size": sample_size,
|
210 |
+
}
|
211 |
+
if self.report_accuracy:
|
212 |
+
n_correct, total = self.compute_accuracy(model, [net_output], sample)
|
213 |
+
logging_output["n_correct"] = utils.item(n_correct.data)
|
214 |
+
logging_output["total"] = utils.item(total.data)
|
215 |
+
if self.rdrop_alpha > 0:
|
216 |
+
logging_output["rdrop_kl_loss"] = utils.item(rdrop_kl_loss.data)
|
217 |
+
|
218 |
+
if len(self.multitask_criterion) == 0:
|
219 |
+
return loss, sample_size, logging_output
|
220 |
+
|
221 |
+
# multitask
|
222 |
+
multitask_loss, multitask_log = self.get_multitask_loss(model, sample, extra)
|
223 |
+
loss += multitask_loss
|
224 |
+
logging_output["multitask"] = multitask_log
|
225 |
+
|
226 |
+
return loss, sample_size, logging_output
|
227 |
+
|
228 |
+
@classmethod
|
229 |
+
def reduce_metrics(cls, logging_outputs) -> None:
|
230 |
+
super().reduce_metrics(logging_outputs)
|
231 |
+
|
232 |
+
# inference metrics
|
233 |
+
if "targ_frames" in logging_outputs[0]:
|
234 |
+
n = sum(log.get("norm_frames", 0) for log in logging_outputs)
|
235 |
+
for key, new_key in [
|
236 |
+
("mcd_loss", "mcd_loss"),
|
237 |
+
("pred_frames", "pred_ratio"),
|
238 |
+
("nins", "ins_rate"),
|
239 |
+
("ndel", "del_rate"),
|
240 |
+
]:
|
241 |
+
val = sum(log.get(key, 0) for log in logging_outputs)
|
242 |
+
metrics.log_scalar(new_key, val / n, n, round=3)
|
243 |
+
|
244 |
+
if "multitask" not in logging_outputs[0]:
|
245 |
+
return
|
246 |
+
|
247 |
+
MultitaskCriterion.reduce_metrics(logging_outputs)
|
248 |
+
|
249 |
+
@staticmethod
|
250 |
+
def logging_outputs_can_be_summed() -> bool:
|
251 |
+
"""
|
252 |
+
Whether the logging outputs returned by `forward` can be summed
|
253 |
+
across workers prior to calling `reduce_metrics`. Setting this
|
254 |
+
to True will improves distributed training speed.
|
255 |
+
"""
|
256 |
+
return False
|
257 |
+
|
258 |
+
|
259 |
+
@register_criterion(
|
260 |
+
"speech_to_unit_2pass", dataclass=RdropLabelSmoothedCrossEntropyCriterionConfig
|
261 |
+
)
|
262 |
+
class SpeechToUnit2passMultitaskTaskCriterion(SpeechToUnitMultitaskTaskCriterion):
|
263 |
+
def __init__(
|
264 |
+
self,
|
265 |
+
task,
|
266 |
+
sentence_avg,
|
267 |
+
label_smoothing,
|
268 |
+
ignore_prefix_size=0,
|
269 |
+
report_accuracy=False,
|
270 |
+
rdrop_alpha=0.0,
|
271 |
+
):
|
272 |
+
super().__init__(
|
273 |
+
task,
|
274 |
+
sentence_avg,
|
275 |
+
label_smoothing,
|
276 |
+
ignore_prefix_size,
|
277 |
+
report_accuracy,
|
278 |
+
rdrop_alpha,
|
279 |
+
)
|
280 |
+
|
281 |
+
def forward(self, model, sample, reduce=True):
|
282 |
+
net_input_concat = {
|
283 |
+
"src_tokens": sample["net_input"]["src_tokens"],
|
284 |
+
"src_lengths": sample["net_input"]["src_lengths"],
|
285 |
+
"prev_output_tokens": sample["net_input"]["prev_output_tokens"],
|
286 |
+
"prev_output_tokens_mt": sample["multitask"][model.mt_task_name][
|
287 |
+
"net_input"
|
288 |
+
]["prev_output_tokens"],
|
289 |
+
"tgt_speaker": sample["net_input"].get("tgt_speaker", None),
|
290 |
+
"return_all_hiddens": True,
|
291 |
+
}
|
292 |
+
if getattr(model, "asr_task_name", None) is not None:
|
293 |
+
net_input_concat["prev_output_tokens_asr"] = sample["multitask"][
|
294 |
+
model.asr_task_name
|
295 |
+
]["net_input"]["prev_output_tokens"]
|
296 |
+
|
297 |
+
if self.rdrop_alpha > 0 or self.rdrop_alpha_mtl > 0:
|
298 |
+
net_input_concat = duplicate_input(net_input_concat)
|
299 |
+
|
300 |
+
net_output, extra = model(**net_input_concat)
|
301 |
+
loss, nll_loss, rdrop_kl_loss = self.compute_loss(
|
302 |
+
model, [net_output], sample, reduce=reduce
|
303 |
+
)
|
304 |
+
|
305 |
+
sample_size = (
|
306 |
+
sample["target"].size(0) if self.sentence_avg else sample["ntokens"]
|
307 |
+
)
|
308 |
+
logging_output = {
|
309 |
+
"loss": loss.data,
|
310 |
+
"nll_loss": nll_loss.data,
|
311 |
+
"ntokens": sample["ntokens"],
|
312 |
+
"nsentences": sample["target"].size(0),
|
313 |
+
"sample_size": sample_size,
|
314 |
+
}
|
315 |
+
if self.report_accuracy:
|
316 |
+
n_correct, total = self.compute_accuracy(model, [net_output], sample)
|
317 |
+
logging_output["n_correct"] = utils.item(n_correct.data)
|
318 |
+
logging_output["total"] = utils.item(total.data)
|
319 |
+
if self.rdrop_alpha > 0:
|
320 |
+
logging_output["rdrop_kl_loss"] = utils.item(rdrop_kl_loss.data)
|
321 |
+
|
322 |
+
if len(self.multitask_criterion) == 0:
|
323 |
+
return loss, sample_size, logging_output
|
324 |
+
|
325 |
+
# multitask
|
326 |
+
multitask_loss, multitask_log = self.get_multitask_loss(model, sample, extra)
|
327 |
+
loss += multitask_loss
|
328 |
+
logging_output["multitask"] = multitask_log
|
329 |
+
|
330 |
+
return loss, sample_size, logging_output
|
331 |
+
|
332 |
+
|
333 |
+
@register_criterion("speech_to_spectrogram", dataclass=Tacotron2CriterionConfig)
|
334 |
+
class SpeechToSpectrogramMultitaskTaskCriterion(Tacotron2Criterion, MultitaskCriterion):
|
335 |
+
def __init__(
|
336 |
+
self,
|
337 |
+
task,
|
338 |
+
sentence_avg,
|
339 |
+
use_guided_attention_loss,
|
340 |
+
guided_attention_loss_sigma,
|
341 |
+
bce_pos_weight,
|
342 |
+
ctc_weight,
|
343 |
+
):
|
344 |
+
super().__init__(
|
345 |
+
task,
|
346 |
+
sentence_avg,
|
347 |
+
use_guided_attention_loss,
|
348 |
+
guided_attention_loss_sigma,
|
349 |
+
bce_pos_weight,
|
350 |
+
ctc_weight,
|
351 |
+
)
|
352 |
+
MultitaskCriterion.__init__(self, task.multitask_tasks)
|
353 |
+
|
354 |
+
def forward(self, model, sample, reduction="mean"):
|
355 |
+
bsz, max_len, _ = sample["target"].size()
|
356 |
+
feat_tgt = sample["target"]
|
357 |
+
feat_len = sample["target_lengths"].view(bsz, 1).expand(-1, max_len)
|
358 |
+
eos_tgt = torch.arange(max_len).to(sample["target"].device)
|
359 |
+
eos_tgt = eos_tgt.view(1, max_len).expand(bsz, -1)
|
360 |
+
eos_tgt = (eos_tgt == (feat_len - 1)).float()
|
361 |
+
|
362 |
+
feat_out, eos_out, extra = model(
|
363 |
+
src_tokens=sample["net_input"]["src_tokens"],
|
364 |
+
src_lengths=sample["net_input"]["src_lengths"],
|
365 |
+
prev_output_tokens=sample["net_input"]["prev_output_tokens"],
|
366 |
+
tgt_speaker=sample["net_input"]["tgt_speaker"],
|
367 |
+
target_lengths=sample["target_lengths"],
|
368 |
+
return_all_hiddens=True,
|
369 |
+
)
|
370 |
+
|
371 |
+
l1_loss, mse_loss, eos_loss = self.compute_loss(
|
372 |
+
extra["feature_out"],
|
373 |
+
feat_out,
|
374 |
+
eos_out,
|
375 |
+
feat_tgt,
|
376 |
+
eos_tgt,
|
377 |
+
sample["target_lengths"],
|
378 |
+
reduction,
|
379 |
+
)
|
380 |
+
attn_loss = torch.tensor(0.0).type_as(l1_loss)
|
381 |
+
if self.guided_attn is not None:
|
382 |
+
attn_loss = self.guided_attn(
|
383 |
+
extra["attn"],
|
384 |
+
sample["net_input"]["src_lengths"],
|
385 |
+
sample["target_lengths"],
|
386 |
+
reduction,
|
387 |
+
)
|
388 |
+
loss = (
|
389 |
+
l1_loss + mse_loss + eos_loss + attn_loss
|
390 |
+
) # do not include ctc loss as there's no text target
|
391 |
+
|
392 |
+
sample_size = sample["nsentences"] if self.sentence_avg else sample["ntokens"]
|
393 |
+
logging_output = {
|
394 |
+
"loss": utils.item(loss.data),
|
395 |
+
"ntokens": sample["ntokens"],
|
396 |
+
"nsentences": sample["nsentences"],
|
397 |
+
"sample_size": sample_size,
|
398 |
+
"l1_loss": utils.item(l1_loss.data),
|
399 |
+
"mse_loss": utils.item(mse_loss.data),
|
400 |
+
"eos_loss": utils.item(eos_loss.data),
|
401 |
+
"attn_loss": utils.item(attn_loss.data),
|
402 |
+
}
|
403 |
+
|
404 |
+
if len(self.multitask_criterion) == 0:
|
405 |
+
return loss, sample_size, logging_output
|
406 |
+
|
407 |
+
# multitask
|
408 |
+
multitask_loss, multitask_log = self.get_multitask_loss(model, sample, extra)
|
409 |
+
loss += multitask_loss
|
410 |
+
logging_output["multitask"] = multitask_log
|
411 |
+
return loss, sample_size, logging_output
|
412 |
+
|
413 |
+
@classmethod
|
414 |
+
def reduce_metrics(cls, logging_outputs) -> None:
|
415 |
+
super().reduce_metrics(logging_outputs)
|
416 |
+
|
417 |
+
# inference metrics
|
418 |
+
if "targ_frames" in logging_outputs[0]:
|
419 |
+
n = sum(log.get("norm_frames", 0) for log in logging_outputs)
|
420 |
+
for key, new_key in [
|
421 |
+
("mcd_loss", "mcd_loss"),
|
422 |
+
("pred_frames", "pred_ratio"),
|
423 |
+
("nins", "ins_rate"),
|
424 |
+
("ndel", "del_rate"),
|
425 |
+
]:
|
426 |
+
val = sum(log.get(key, 0) for log in logging_outputs)
|
427 |
+
metrics.log_scalar(new_key, val / n, n, round=3)
|
428 |
+
|
429 |
+
if "multitask" not in logging_outputs[0]:
|
430 |
+
return
|
431 |
+
|
432 |
+
MultitaskCriterion.reduce_metrics(logging_outputs)
|
433 |
+
|
434 |
+
|
435 |
+
@register_criterion("speech_to_spectrogram_2pass", dataclass=Tacotron2CriterionConfig)
|
436 |
+
class SpeechToSpectrogram2passMultitaskTaskCriterion(
|
437 |
+
SpeechToSpectrogramMultitaskTaskCriterion
|
438 |
+
):
|
439 |
+
def __init__(
|
440 |
+
self,
|
441 |
+
task,
|
442 |
+
sentence_avg,
|
443 |
+
use_guided_attention_loss,
|
444 |
+
guided_attention_loss_sigma,
|
445 |
+
bce_pos_weight,
|
446 |
+
ctc_weight,
|
447 |
+
):
|
448 |
+
super().__init__(
|
449 |
+
task,
|
450 |
+
sentence_avg,
|
451 |
+
use_guided_attention_loss,
|
452 |
+
guided_attention_loss_sigma,
|
453 |
+
bce_pos_weight,
|
454 |
+
ctc_weight,
|
455 |
+
)
|
456 |
+
|
457 |
+
def forward(self, model, sample, reduction="mean"):
|
458 |
+
bsz, max_len, _ = sample["target"].size()
|
459 |
+
feat_tgt = sample["target"]
|
460 |
+
feat_len = sample["target_lengths"].view(bsz, 1).expand(-1, max_len)
|
461 |
+
eos_tgt = torch.arange(max_len).to(sample["target"].device)
|
462 |
+
eos_tgt = eos_tgt.view(1, max_len).expand(bsz, -1)
|
463 |
+
eos_tgt = (eos_tgt == (feat_len - 1)).float()
|
464 |
+
|
465 |
+
feat_out, eos_out, extra = model(
|
466 |
+
src_tokens=sample["net_input"]["src_tokens"],
|
467 |
+
src_lengths=sample["net_input"]["src_lengths"],
|
468 |
+
prev_output_tokens=sample["net_input"]["prev_output_tokens"],
|
469 |
+
prev_output_tokens_mt=sample["multitask"][model.mt_task_name]["net_input"][
|
470 |
+
"prev_output_tokens"
|
471 |
+
],
|
472 |
+
tgt_speaker=sample["net_input"]["tgt_speaker"],
|
473 |
+
target_lengths=sample["target_lengths"],
|
474 |
+
return_all_hiddens=True,
|
475 |
+
)
|
476 |
+
|
477 |
+
l1_loss, mse_loss, eos_loss = self.compute_loss(
|
478 |
+
extra["feature_out"],
|
479 |
+
feat_out,
|
480 |
+
eos_out,
|
481 |
+
feat_tgt,
|
482 |
+
eos_tgt,
|
483 |
+
sample["target_lengths"],
|
484 |
+
reduction,
|
485 |
+
)
|
486 |
+
attn_loss = torch.tensor(0.0).type_as(l1_loss)
|
487 |
+
if self.guided_attn is not None:
|
488 |
+
attn_loss = self.guided_attn(
|
489 |
+
extra["attn"],
|
490 |
+
sample["net_input"]["src_lengths"],
|
491 |
+
sample["target_lengths"],
|
492 |
+
reduction,
|
493 |
+
)
|
494 |
+
loss = (
|
495 |
+
l1_loss + mse_loss + eos_loss + attn_loss
|
496 |
+
) # do not include ctc loss as there's no text target
|
497 |
+
|
498 |
+
sample_size = sample["nsentences"] if self.sentence_avg else sample["ntokens"]
|
499 |
+
logging_output = {
|
500 |
+
"loss": utils.item(loss.data),
|
501 |
+
"ntokens": sample["ntokens"],
|
502 |
+
"nsentences": sample["nsentences"],
|
503 |
+
"sample_size": sample_size,
|
504 |
+
"l1_loss": utils.item(l1_loss.data),
|
505 |
+
"mse_loss": utils.item(mse_loss.data),
|
506 |
+
"eos_loss": utils.item(eos_loss.data),
|
507 |
+
"attn_loss": utils.item(attn_loss.data),
|
508 |
+
}
|
509 |
+
|
510 |
+
if len(self.multitask_criterion) == 0:
|
511 |
+
return loss, sample_size, logging_output
|
512 |
+
|
513 |
+
# multitask
|
514 |
+
multitask_loss, multitask_log = self.get_multitask_loss(model, sample, extra)
|
515 |
+
loss += multitask_loss
|
516 |
+
logging_output["multitask"] = multitask_log
|
517 |
+
return loss, sample_size, logging_output
|
fairseq/fairseq/criterions/tacotron2_loss.py
ADDED
@@ -0,0 +1,227 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2017-present, Facebook, Inc.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the LICENSE file in
|
5 |
+
# the root directory of this source tree. An additional grant of patent rights
|
6 |
+
# can be found in the PATENTS file in the same directory.
|
7 |
+
|
8 |
+
import logging
|
9 |
+
from dataclasses import dataclass, field
|
10 |
+
from functools import lru_cache
|
11 |
+
from typing import Any, Dict, List
|
12 |
+
|
13 |
+
import torch
|
14 |
+
import torch.nn.functional as F
|
15 |
+
from omegaconf import II
|
16 |
+
|
17 |
+
from fairseq import utils
|
18 |
+
from fairseq.logging import metrics
|
19 |
+
from fairseq.criterions import FairseqCriterion, register_criterion
|
20 |
+
from fairseq.data.data_utils import lengths_to_mask
|
21 |
+
from fairseq.dataclass import FairseqDataclass
|
22 |
+
|
23 |
+
logger = logging.getLogger(__name__)
|
24 |
+
|
25 |
+
|
26 |
+
@dataclass
|
27 |
+
class Tacotron2CriterionConfig(FairseqDataclass):
|
28 |
+
bce_pos_weight: float = field(
|
29 |
+
default=1.0,
|
30 |
+
metadata={"help": "weight of positive examples for BCE loss"},
|
31 |
+
)
|
32 |
+
use_guided_attention_loss: bool = field(
|
33 |
+
default=False,
|
34 |
+
metadata={"help": "use guided attention loss"},
|
35 |
+
)
|
36 |
+
guided_attention_loss_sigma: float = field(
|
37 |
+
default=0.4,
|
38 |
+
metadata={"help": "weight of positive examples for BCE loss"},
|
39 |
+
)
|
40 |
+
ctc_weight: float = field(default=0.0, metadata={"help": "weight for CTC loss"})
|
41 |
+
sentence_avg: bool = II("optimization.sentence_avg")
|
42 |
+
|
43 |
+
|
44 |
+
class GuidedAttentionLoss(torch.nn.Module):
|
45 |
+
"""
|
46 |
+
Efficiently Trainable Text-to-Speech System Based on Deep Convolutional
|
47 |
+
Networks with Guided Attention (https://arxiv.org/abs/1710.08969)
|
48 |
+
"""
|
49 |
+
|
50 |
+
def __init__(self, sigma):
|
51 |
+
super().__init__()
|
52 |
+
self.sigma = sigma
|
53 |
+
|
54 |
+
@staticmethod
|
55 |
+
@lru_cache(maxsize=8)
|
56 |
+
def _get_weight(s_len, t_len, sigma):
|
57 |
+
grid_x, grid_y = torch.meshgrid(torch.arange(t_len), torch.arange(s_len))
|
58 |
+
grid_x = grid_x.to(s_len.device)
|
59 |
+
grid_y = grid_y.to(s_len.device)
|
60 |
+
w = (grid_y.float() / s_len - grid_x.float() / t_len) ** 2
|
61 |
+
return 1.0 - torch.exp(-w / (2 * (sigma**2)))
|
62 |
+
|
63 |
+
def _get_weights(self, src_lens, tgt_lens):
|
64 |
+
bsz, max_s_len, max_t_len = len(src_lens), max(src_lens), max(tgt_lens)
|
65 |
+
weights = torch.zeros((bsz, max_t_len, max_s_len))
|
66 |
+
for i, (s_len, t_len) in enumerate(zip(src_lens, tgt_lens)):
|
67 |
+
weights[i, :t_len, :s_len] = self._get_weight(s_len, t_len, self.sigma)
|
68 |
+
return weights
|
69 |
+
|
70 |
+
@staticmethod
|
71 |
+
def _get_masks(src_lens, tgt_lens):
|
72 |
+
in_masks = lengths_to_mask(src_lens)
|
73 |
+
out_masks = lengths_to_mask(tgt_lens)
|
74 |
+
return out_masks.unsqueeze(2) & in_masks.unsqueeze(1)
|
75 |
+
|
76 |
+
def forward(self, attn, src_lens, tgt_lens, reduction="mean"):
|
77 |
+
weights = self._get_weights(src_lens, tgt_lens).to(attn.device)
|
78 |
+
masks = self._get_masks(src_lens, tgt_lens).to(attn.device)
|
79 |
+
loss = (weights * attn.transpose(1, 2)).masked_select(masks)
|
80 |
+
loss = torch.sum(loss) if reduction == "sum" else torch.mean(loss)
|
81 |
+
return loss
|
82 |
+
|
83 |
+
|
84 |
+
@register_criterion("tacotron2", dataclass=Tacotron2CriterionConfig)
|
85 |
+
class Tacotron2Criterion(FairseqCriterion):
|
86 |
+
def __init__(
|
87 |
+
self,
|
88 |
+
task,
|
89 |
+
sentence_avg,
|
90 |
+
use_guided_attention_loss,
|
91 |
+
guided_attention_loss_sigma,
|
92 |
+
bce_pos_weight,
|
93 |
+
ctc_weight,
|
94 |
+
):
|
95 |
+
super().__init__(task)
|
96 |
+
self.sentence_avg = sentence_avg
|
97 |
+
self.bce_pos_weight = bce_pos_weight
|
98 |
+
|
99 |
+
self.guided_attn = None
|
100 |
+
if use_guided_attention_loss:
|
101 |
+
self.guided_attn = GuidedAttentionLoss(guided_attention_loss_sigma)
|
102 |
+
self.ctc_weight = ctc_weight
|
103 |
+
|
104 |
+
def forward(self, model, sample, reduction="mean"):
|
105 |
+
bsz, max_len, _ = sample["target"].size()
|
106 |
+
feat_tgt = sample["target"]
|
107 |
+
feat_len = sample["target_lengths"].view(bsz, 1).expand(-1, max_len)
|
108 |
+
eos_tgt = torch.arange(max_len).to(sample["target"].device)
|
109 |
+
eos_tgt = eos_tgt.view(1, max_len).expand(bsz, -1)
|
110 |
+
eos_tgt = (eos_tgt == (feat_len - 1)).float()
|
111 |
+
src_tokens = sample["net_input"]["src_tokens"]
|
112 |
+
src_lens = sample["net_input"]["src_lengths"]
|
113 |
+
tgt_lens = sample["target_lengths"]
|
114 |
+
|
115 |
+
feat_out, eos_out, extra = model(
|
116 |
+
src_tokens=src_tokens,
|
117 |
+
src_lengths=src_lens,
|
118 |
+
prev_output_tokens=sample["net_input"]["prev_output_tokens"],
|
119 |
+
incremental_state=None,
|
120 |
+
target_lengths=tgt_lens,
|
121 |
+
speaker=sample["speaker"],
|
122 |
+
)
|
123 |
+
|
124 |
+
l1_loss, mse_loss, eos_loss = self.compute_loss(
|
125 |
+
extra["feature_out"],
|
126 |
+
feat_out,
|
127 |
+
eos_out,
|
128 |
+
feat_tgt,
|
129 |
+
eos_tgt,
|
130 |
+
tgt_lens,
|
131 |
+
reduction,
|
132 |
+
)
|
133 |
+
attn_loss = torch.tensor(0.0).type_as(l1_loss)
|
134 |
+
if self.guided_attn is not None:
|
135 |
+
attn_loss = self.guided_attn(extra["attn"], src_lens, tgt_lens, reduction)
|
136 |
+
ctc_loss = torch.tensor(0.0).type_as(l1_loss)
|
137 |
+
if self.ctc_weight > 0.0:
|
138 |
+
net_output = (feat_out, eos_out, extra)
|
139 |
+
lprobs = model.get_normalized_probs(net_output, log_probs=True)
|
140 |
+
lprobs = lprobs.transpose(0, 1) # T x B x C
|
141 |
+
src_mask = lengths_to_mask(src_lens)
|
142 |
+
src_tokens_flat = src_tokens.masked_select(src_mask)
|
143 |
+
ctc_loss = (
|
144 |
+
F.ctc_loss(
|
145 |
+
lprobs,
|
146 |
+
src_tokens_flat,
|
147 |
+
tgt_lens,
|
148 |
+
src_lens,
|
149 |
+
reduction=reduction,
|
150 |
+
zero_infinity=True,
|
151 |
+
)
|
152 |
+
* self.ctc_weight
|
153 |
+
)
|
154 |
+
loss = l1_loss + mse_loss + eos_loss + attn_loss + ctc_loss
|
155 |
+
|
156 |
+
sample_size = sample["nsentences"] if self.sentence_avg else sample["ntokens"]
|
157 |
+
logging_output = {
|
158 |
+
"loss": utils.item(loss.data),
|
159 |
+
"ntokens": sample["ntokens"],
|
160 |
+
"nsentences": sample["nsentences"],
|
161 |
+
"sample_size": sample_size,
|
162 |
+
"l1_loss": utils.item(l1_loss.data),
|
163 |
+
"mse_loss": utils.item(mse_loss.data),
|
164 |
+
"eos_loss": utils.item(eos_loss.data),
|
165 |
+
"attn_loss": utils.item(attn_loss.data),
|
166 |
+
"ctc_loss": utils.item(ctc_loss.data),
|
167 |
+
}
|
168 |
+
return loss, sample_size, logging_output
|
169 |
+
|
170 |
+
def compute_loss(
|
171 |
+
self,
|
172 |
+
feat_out,
|
173 |
+
feat_out_post,
|
174 |
+
eos_out,
|
175 |
+
feat_tgt,
|
176 |
+
eos_tgt,
|
177 |
+
tgt_lens,
|
178 |
+
reduction="mean",
|
179 |
+
):
|
180 |
+
mask = lengths_to_mask(tgt_lens)
|
181 |
+
_eos_out = eos_out[mask].squeeze()
|
182 |
+
_eos_tgt = eos_tgt[mask]
|
183 |
+
_feat_tgt = feat_tgt[mask]
|
184 |
+
_feat_out = feat_out[mask]
|
185 |
+
_feat_out_post = feat_out_post[mask]
|
186 |
+
|
187 |
+
l1_loss = F.l1_loss(_feat_out, _feat_tgt, reduction=reduction) + F.l1_loss(
|
188 |
+
_feat_out_post, _feat_tgt, reduction=reduction
|
189 |
+
)
|
190 |
+
mse_loss = F.mse_loss(_feat_out, _feat_tgt, reduction=reduction) + F.mse_loss(
|
191 |
+
_feat_out_post, _feat_tgt, reduction=reduction
|
192 |
+
)
|
193 |
+
eos_loss = F.binary_cross_entropy_with_logits(
|
194 |
+
_eos_out,
|
195 |
+
_eos_tgt,
|
196 |
+
pos_weight=torch.tensor(self.bce_pos_weight),
|
197 |
+
reduction=reduction,
|
198 |
+
)
|
199 |
+
return l1_loss, mse_loss, eos_loss
|
200 |
+
|
201 |
+
@classmethod
|
202 |
+
def reduce_metrics(cls, logging_outputs: List[Dict[str, Any]]) -> None:
|
203 |
+
ns = [log.get("sample_size", 0) for log in logging_outputs]
|
204 |
+
ntot = sum(ns)
|
205 |
+
ws = [n / (ntot + 1e-8) for n in ns]
|
206 |
+
for key in ["loss", "l1_loss", "mse_loss", "eos_loss", "attn_loss", "ctc_loss"]:
|
207 |
+
vals = [log.get(key, 0) for log in logging_outputs]
|
208 |
+
val = sum(val * w for val, w in zip(vals, ws))
|
209 |
+
metrics.log_scalar(key, val, ntot, round=3)
|
210 |
+
metrics.log_scalar("sample_size", ntot, len(logging_outputs))
|
211 |
+
|
212 |
+
# inference metrics
|
213 |
+
if "targ_frames" not in logging_outputs[0]:
|
214 |
+
return
|
215 |
+
n = sum(log.get("targ_frames", 0) for log in logging_outputs)
|
216 |
+
for key, new_key in [
|
217 |
+
("mcd_loss", "mcd_loss"),
|
218 |
+
("pred_frames", "pred_ratio"),
|
219 |
+
("nins", "ins_rate"),
|
220 |
+
("ndel", "del_rate"),
|
221 |
+
]:
|
222 |
+
val = sum(log.get(key, 0) for log in logging_outputs)
|
223 |
+
metrics.log_scalar(new_key, val / n, n, round=3)
|
224 |
+
|
225 |
+
@staticmethod
|
226 |
+
def logging_outputs_can_be_summed() -> bool:
|
227 |
+
return False
|
fairseq/fairseq/criterions/wav2vec_criterion.py
ADDED
@@ -0,0 +1,231 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import math
|
7 |
+
from dataclasses import dataclass, field
|
8 |
+
from typing import List, Optional
|
9 |
+
|
10 |
+
import torch
|
11 |
+
import torch.nn.functional as F
|
12 |
+
from fairseq import utils
|
13 |
+
from fairseq.logging import metrics
|
14 |
+
from fairseq.criterions import FairseqCriterion, register_criterion
|
15 |
+
from fairseq.dataclass import FairseqDataclass
|
16 |
+
from fairseq.logging.meters import safe_round
|
17 |
+
from fairseq.utils import is_xla_tensor
|
18 |
+
|
19 |
+
|
20 |
+
@dataclass
|
21 |
+
class Wav2VecCriterionConfig(FairseqDataclass):
|
22 |
+
infonce: bool = field(
|
23 |
+
default=False,
|
24 |
+
metadata={
|
25 |
+
"help": "if set, uses cross entropy instead of binary cross entropy (i.e. InfoNCE loss)"
|
26 |
+
},
|
27 |
+
)
|
28 |
+
loss_weights: Optional[List[float]] = field(
|
29 |
+
default=None,
|
30 |
+
metadata={"help": "weights for additional loss terms (not first one)"},
|
31 |
+
)
|
32 |
+
log_keys: List[str] = field(
|
33 |
+
default_factory=lambda: [],
|
34 |
+
metadata={"help": "output keys to log"},
|
35 |
+
)
|
36 |
+
|
37 |
+
|
38 |
+
@register_criterion("wav2vec", dataclass=Wav2VecCriterionConfig)
|
39 |
+
class Wav2vecCriterion(FairseqCriterion):
|
40 |
+
def __init__(self, task, infonce=False, loss_weights=None, log_keys=None):
|
41 |
+
super().__init__(task)
|
42 |
+
self.infonce = infonce
|
43 |
+
self.loss_weights = loss_weights
|
44 |
+
self.log_keys = [] if log_keys is None else log_keys
|
45 |
+
|
46 |
+
def forward(self, model, sample, reduce=True):
|
47 |
+
"""Compute the loss for the given sample.
|
48 |
+
|
49 |
+
Returns a tuple with three elements:
|
50 |
+
1) the loss
|
51 |
+
2) the sample size, which is used as the denominator for the gradient
|
52 |
+
3) logging outputs to display while training
|
53 |
+
"""
|
54 |
+
net_output = model(**sample["net_input"])
|
55 |
+
logits = model.get_logits(net_output).float()
|
56 |
+
target = model.get_targets(sample, net_output)
|
57 |
+
self.xla = is_xla_tensor(logits)
|
58 |
+
|
59 |
+
# XXX: handle weights on xla.
|
60 |
+
weights = None
|
61 |
+
if hasattr(model, "get_target_weights") and not self.infonce:
|
62 |
+
weights = model.get_target_weights(target, net_output)
|
63 |
+
if torch.is_tensor(weights):
|
64 |
+
weights = weights.float()
|
65 |
+
|
66 |
+
losses = []
|
67 |
+
|
68 |
+
reduction = "none" if ((not reduce) or self.xla) else "sum"
|
69 |
+
if self.infonce:
|
70 |
+
loss = F.cross_entropy(logits, target, reduction=reduction)
|
71 |
+
else:
|
72 |
+
loss = F.binary_cross_entropy_with_logits(
|
73 |
+
logits, target.float(), weights, reduction=reduction
|
74 |
+
)
|
75 |
+
|
76 |
+
if self.xla:
|
77 |
+
# tpu-comment: since dynamic shapes lead to recompilations on xla,
|
78 |
+
# we don't shrink tensors using mask_indices.
|
79 |
+
# Instead, we use mask indices to adjust loss.
|
80 |
+
mi = (
|
81 |
+
sample["net_input"]["mask_indices"]
|
82 |
+
.transpose(0, 1) # logits are transposed in `model.get_logits`
|
83 |
+
.reshape(logits.size(0))
|
84 |
+
)
|
85 |
+
loss = (loss * mi).sum() if reduce else (loss * mi)
|
86 |
+
|
87 |
+
if "sample_size" in sample:
|
88 |
+
sample_size = sample["sample_size"]
|
89 |
+
elif "mask_indices" in sample["net_input"]:
|
90 |
+
sample_size = sample["net_input"]["mask_indices"].sum()
|
91 |
+
else:
|
92 |
+
sample_size = target.numel() if self.infonce else target.long().sum().item()
|
93 |
+
losses.append(loss.detach().clone())
|
94 |
+
|
95 |
+
if self.loss_weights is not None:
|
96 |
+
assert hasattr(model, "get_extra_losses")
|
97 |
+
extra_losses = model.get_extra_losses(net_output)
|
98 |
+
if torch.is_tensor(extra_losses):
|
99 |
+
extra_losses = [extra_losses]
|
100 |
+
if len(self.loss_weights) == 1 and len(extra_losses) != 1:
|
101 |
+
self.loss_weights = [self.loss_weights[0]] * len(extra_losses)
|
102 |
+
assert len(extra_losses) == len(
|
103 |
+
self.loss_weights
|
104 |
+
), f"{len(extra_losses)}, {len(self.loss_weights)}"
|
105 |
+
for p, coef in zip(extra_losses, self.loss_weights):
|
106 |
+
if coef != 0 and p is not None:
|
107 |
+
p = coef * p.float() * sample_size
|
108 |
+
loss += p
|
109 |
+
losses.append(p)
|
110 |
+
|
111 |
+
logging_output = {
|
112 |
+
"loss": loss.item() if (reduce and not self.xla) else loss.detach(),
|
113 |
+
"ntokens": sample_size,
|
114 |
+
"nsentences": sample["id"].numel(),
|
115 |
+
"sample_size": sample_size,
|
116 |
+
}
|
117 |
+
|
118 |
+
for lk in self.log_keys:
|
119 |
+
# Only store "logits" and "target" for computing MAP and MAUC
|
120 |
+
# during validation
|
121 |
+
if lk == "logits":
|
122 |
+
if not self.training:
|
123 |
+
logging_output["logits"] = logits.cpu().numpy()
|
124 |
+
elif lk == "target":
|
125 |
+
if not self.training:
|
126 |
+
# If the targets have been mixed with the predictions of
|
127 |
+
# teacher models, find the original targets
|
128 |
+
if hasattr(model, "get_original_targets"):
|
129 |
+
original_target = model.get_original_targets(sample, net_output)
|
130 |
+
else:
|
131 |
+
original_target = target
|
132 |
+
logging_output["target"] = original_target.cpu().numpy()
|
133 |
+
elif lk in net_output:
|
134 |
+
value = net_output[lk]
|
135 |
+
if not is_xla_tensor(value):
|
136 |
+
value = float(value)
|
137 |
+
logging_output[lk] = value
|
138 |
+
|
139 |
+
if len(losses) > 1:
|
140 |
+
for i, l in enumerate(losses):
|
141 |
+
logging_output[f"loss_{i}"] = l.item() if not self.xla else l.detach()
|
142 |
+
|
143 |
+
if self.infonce:
|
144 |
+
with torch.no_grad():
|
145 |
+
if logits.numel() == 0:
|
146 |
+
corr = 0
|
147 |
+
count = 0
|
148 |
+
else:
|
149 |
+
assert logits.dim() > 1, logits.shape
|
150 |
+
max = logits.argmax(-1) == 0
|
151 |
+
min = logits.argmin(-1) == 0
|
152 |
+
if is_xla_tensor(logits):
|
153 |
+
max, min = max * mi, min * mi
|
154 |
+
both = max & min
|
155 |
+
corr = max.long().sum() - both.long().sum()
|
156 |
+
count = mi.sum()
|
157 |
+
else:
|
158 |
+
both = max & min
|
159 |
+
corr = max.long().sum().item() - both.long().sum().item()
|
160 |
+
count = float(max.numel())
|
161 |
+
|
162 |
+
logging_output["correct"] = corr
|
163 |
+
logging_output["count"] = count
|
164 |
+
|
165 |
+
return loss, sample_size, logging_output
|
166 |
+
|
167 |
+
@staticmethod
|
168 |
+
def reduce_metrics(logging_outputs) -> None:
|
169 |
+
"""Aggregate logging outputs from data parallel training."""
|
170 |
+
loss_sum = utils.item(sum(log.get("loss", 0) for log in logging_outputs))
|
171 |
+
ntokens = utils.item(sum(log.get("ntokens", 0) for log in logging_outputs))
|
172 |
+
nsentences = utils.item(
|
173 |
+
sum(log.get("nsentences", 0) for log in logging_outputs)
|
174 |
+
)
|
175 |
+
sample_size = utils.item(
|
176 |
+
sum(log.get("sample_size", 0) for log in logging_outputs)
|
177 |
+
)
|
178 |
+
|
179 |
+
metrics.log_scalar(
|
180 |
+
"loss", loss_sum / (sample_size or 1) / math.log(2), sample_size, round=3
|
181 |
+
)
|
182 |
+
metrics.log_scalar("ntokens", ntokens)
|
183 |
+
metrics.log_scalar("nsentences", nsentences)
|
184 |
+
|
185 |
+
correct = sum(log.get("correct", 0) for log in logging_outputs)
|
186 |
+
metrics.log_scalar("_correct", correct)
|
187 |
+
|
188 |
+
total = sum(log.get("count", 0) for log in logging_outputs)
|
189 |
+
metrics.log_scalar("_total", total)
|
190 |
+
|
191 |
+
if total > 0:
|
192 |
+
metrics.log_derived(
|
193 |
+
"accuracy",
|
194 |
+
lambda meters: safe_round(
|
195 |
+
meters["_correct"].sum / meters["_total"].sum, 5
|
196 |
+
)
|
197 |
+
if meters["_total"].sum > 0
|
198 |
+
else float("nan"),
|
199 |
+
)
|
200 |
+
|
201 |
+
builtin_keys = {
|
202 |
+
"loss",
|
203 |
+
"ntokens",
|
204 |
+
"nsentences",
|
205 |
+
"sample_size",
|
206 |
+
"correct",
|
207 |
+
"count",
|
208 |
+
}
|
209 |
+
|
210 |
+
for k in logging_outputs[0]:
|
211 |
+
if k not in builtin_keys:
|
212 |
+
val = sum(log.get(k, 0) for log in logging_outputs)
|
213 |
+
if k.startswith("loss"):
|
214 |
+
metrics.log_scalar(
|
215 |
+
k, val / (sample_size or 1) / math.log(2), sample_size, round=3
|
216 |
+
)
|
217 |
+
else:
|
218 |
+
metrics.log_scalar(k, val / len(logging_outputs), round=3)
|
219 |
+
|
220 |
+
# FIXME: revert when gather based xla reduction is implemented
|
221 |
+
# @staticmethod
|
222 |
+
# def logging_outputs_can_be_summed() -> bool:
|
223 |
+
def logging_outputs_can_be_summed(self) -> bool:
|
224 |
+
"""
|
225 |
+
Whether the logging outputs returned by `forward` can be summed
|
226 |
+
across workers prior to calling `reduce_metrics`. Setting this
|
227 |
+
to True will improves distributed training speed.
|
228 |
+
"""
|
229 |
+
# XXX: Gather based reduction not implemented for xla yet.
|
230 |
+
# So we fall to sum based reduction for xla.
|
231 |
+
return self.xla
|
fairseq/fairseq/data/__init__.py
ADDED
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
"""isort:skip_file"""
|
6 |
+
|
7 |
+
from .dictionary import Dictionary, TruncatedDictionary
|
8 |
+
|
9 |
+
from .fairseq_dataset import FairseqDataset, FairseqIterableDataset
|
10 |
+
|
11 |
+
from .base_wrapper_dataset import BaseWrapperDataset
|
12 |
+
|
13 |
+
from .add_target_dataset import AddTargetDataset
|
14 |
+
from .append_token_dataset import AppendTokenDataset
|
15 |
+
from .audio.raw_audio_dataset import BinarizedAudioDataset, FileAudioDataset
|
16 |
+
from .audio.hubert_dataset import HubertDataset
|
17 |
+
from .backtranslation_dataset import BacktranslationDataset
|
18 |
+
from .bucket_pad_length_dataset import BucketPadLengthDataset
|
19 |
+
from .colorize_dataset import ColorizeDataset
|
20 |
+
from .concat_dataset import ConcatDataset
|
21 |
+
from .concat_sentences_dataset import ConcatSentencesDataset
|
22 |
+
from .denoising_dataset import DenoisingDataset
|
23 |
+
from .id_dataset import IdDataset
|
24 |
+
from .indexed_dataset import (
|
25 |
+
IndexedCachedDataset,
|
26 |
+
IndexedDataset,
|
27 |
+
IndexedRawTextDataset,
|
28 |
+
MMapIndexedDataset,
|
29 |
+
)
|
30 |
+
from .language_pair_dataset import LanguagePairDataset
|
31 |
+
from .list_dataset import ListDataset
|
32 |
+
from .lm_context_window_dataset import LMContextWindowDataset
|
33 |
+
from .lru_cache_dataset import LRUCacheDataset
|
34 |
+
from .mask_tokens_dataset import MaskTokensDataset
|
35 |
+
from .monolingual_dataset import MonolingualDataset
|
36 |
+
from .multi_corpus_sampled_dataset import MultiCorpusSampledDataset
|
37 |
+
from .nested_dictionary_dataset import NestedDictionaryDataset
|
38 |
+
from .noising import NoisingDataset
|
39 |
+
from .numel_dataset import NumelDataset
|
40 |
+
from .num_samples_dataset import NumSamplesDataset
|
41 |
+
from .offset_tokens_dataset import OffsetTokensDataset
|
42 |
+
from .padding_mask_dataset import (
|
43 |
+
LeftPaddingMaskDataset,
|
44 |
+
PaddingMaskDataset,
|
45 |
+
RightPaddingMaskDataset,
|
46 |
+
)
|
47 |
+
from .pad_dataset import LeftPadDataset, PadDataset, RightPadDataset
|
48 |
+
from .prepend_dataset import PrependDataset
|
49 |
+
from .prepend_token_dataset import PrependTokenDataset
|
50 |
+
from .raw_label_dataset import RawLabelDataset
|
51 |
+
from .replace_dataset import ReplaceDataset
|
52 |
+
from .resampling_dataset import ResamplingDataset
|
53 |
+
from .roll_dataset import RollDataset
|
54 |
+
from .round_robin_zip_datasets import RoundRobinZipDatasets
|
55 |
+
from .sort_dataset import SortDataset
|
56 |
+
from .speech_dlm_dataset import SpeechDLMDataset
|
57 |
+
from .strip_token_dataset import StripTokenDataset
|
58 |
+
from .subsample_dataset import SubsampleDataset
|
59 |
+
from .token_block_dataset import TokenBlockDataset
|
60 |
+
from .transform_eos_dataset import TransformEosDataset
|
61 |
+
from .transform_eos_lang_pair_dataset import TransformEosLangPairDataset
|
62 |
+
from .shorten_dataset import TruncateDataset, RandomCropDataset
|
63 |
+
from .multilingual.sampled_multi_dataset import SampledMultiDataset
|
64 |
+
from .multilingual.sampled_multi_epoch_dataset import SampledMultiEpochDataset
|
65 |
+
from .fasta_dataset import FastaDataset, EncodedFastaDataset
|
66 |
+
from .transform_eos_concat_langpair_dataset import TransformEosConcatLangPairDataset
|
67 |
+
|
68 |
+
from .iterators import (
|
69 |
+
CountingIterator,
|
70 |
+
EpochBatchIterator,
|
71 |
+
GroupedIterator,
|
72 |
+
ShardedIterator,
|
73 |
+
)
|
74 |
+
|
75 |
+
__all__ = [
|
76 |
+
"AddTargetDataset",
|
77 |
+
"AppendTokenDataset",
|
78 |
+
"BacktranslationDataset",
|
79 |
+
"BaseWrapperDataset",
|
80 |
+
"BinarizedAudioDataset",
|
81 |
+
"BucketPadLengthDataset",
|
82 |
+
"ColorizeDataset",
|
83 |
+
"ConcatDataset",
|
84 |
+
"ConcatSentencesDataset",
|
85 |
+
"CountingIterator",
|
86 |
+
"DenoisingDataset",
|
87 |
+
"Dictionary",
|
88 |
+
"EncodedFastaDataset",
|
89 |
+
"EpochBatchIterator",
|
90 |
+
"FairseqDataset",
|
91 |
+
"FairseqIterableDataset",
|
92 |
+
"FastaDataset",
|
93 |
+
"FileAudioDataset",
|
94 |
+
"GroupedIterator",
|
95 |
+
"HubertDataset",
|
96 |
+
"IdDataset",
|
97 |
+
"IndexedCachedDataset",
|
98 |
+
"IndexedDataset",
|
99 |
+
"IndexedRawTextDataset",
|
100 |
+
"LanguagePairDataset",
|
101 |
+
"LeftPadDataset",
|
102 |
+
"ListDataset",
|
103 |
+
"LMContextWindowDataset",
|
104 |
+
"LRUCacheDataset",
|
105 |
+
"MaskTokensDataset",
|
106 |
+
"MMapIndexedDataset",
|
107 |
+
"MonolingualDataset",
|
108 |
+
"MultiCorpusSampledDataset",
|
109 |
+
"NestedDictionaryDataset",
|
110 |
+
"NoisingDataset",
|
111 |
+
"NumelDataset",
|
112 |
+
"NumSamplesDataset",
|
113 |
+
"OffsetTokensDataset",
|
114 |
+
"PadDataset",
|
115 |
+
"PrependDataset",
|
116 |
+
"PrependTokenDataset",
|
117 |
+
"RandomCropDataset",
|
118 |
+
"RawLabelDataset",
|
119 |
+
"ResamplingDataset",
|
120 |
+
"ReplaceDataset",
|
121 |
+
"RightPadDataset",
|
122 |
+
"RollDataset",
|
123 |
+
"RoundRobinZipDatasets",
|
124 |
+
"SampledMultiDataset",
|
125 |
+
"SampledMultiEpochDataset",
|
126 |
+
"ShardedIterator",
|
127 |
+
"SortDataset",
|
128 |
+
"SpeechDLMDataset",
|
129 |
+
"StripTokenDataset",
|
130 |
+
"SubsampleDataset",
|
131 |
+
"TokenBlockDataset",
|
132 |
+
"TransformEosDataset",
|
133 |
+
"TransformEosLangPairDataset",
|
134 |
+
"TransformEosConcatLangPairDataset",
|
135 |
+
"TruncateDataset",
|
136 |
+
"TruncatedDictionary",
|
137 |
+
]
|
fairseq/fairseq/data/add_class_target_dataset.py
ADDED
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import torch
|
7 |
+
|
8 |
+
from . import BaseWrapperDataset, data_utils
|
9 |
+
from fairseq.data.text_compressor import TextCompressor, TextCompressionLevel
|
10 |
+
|
11 |
+
|
12 |
+
class AddTargetDataset(BaseWrapperDataset):
|
13 |
+
def __init__(
|
14 |
+
self,
|
15 |
+
dataset,
|
16 |
+
labels,
|
17 |
+
pad,
|
18 |
+
eos,
|
19 |
+
batch_targets,
|
20 |
+
process_label=None,
|
21 |
+
label_len_fn=None,
|
22 |
+
add_to_input=False,
|
23 |
+
text_compression_level=TextCompressionLevel.none,
|
24 |
+
):
|
25 |
+
super().__init__(dataset)
|
26 |
+
self.labels = labels
|
27 |
+
self.batch_targets = batch_targets
|
28 |
+
self.pad = pad
|
29 |
+
self.eos = eos
|
30 |
+
self.process_label = process_label
|
31 |
+
self.label_len_fn = label_len_fn
|
32 |
+
self.add_to_input = add_to_input
|
33 |
+
self.text_compressor = TextCompressor(level=text_compression_level)
|
34 |
+
|
35 |
+
def get_label(self, index, process_fn=None):
|
36 |
+
lbl = self.labels[index]
|
37 |
+
lbl = self.text_compressor.decompress(lbl)
|
38 |
+
return lbl if process_fn is None else process_fn(lbl)
|
39 |
+
|
40 |
+
def __getitem__(self, index):
|
41 |
+
item = self.dataset[index]
|
42 |
+
item["label"] = self.get_label(index, process_fn=self.process_label)
|
43 |
+
return item
|
44 |
+
|
45 |
+
def size(self, index):
|
46 |
+
sz = self.dataset.size(index)
|
47 |
+
own_sz = self.label_len_fn(self.get_label(index))
|
48 |
+
return sz, own_sz
|
49 |
+
|
50 |
+
def collater(self, samples):
|
51 |
+
collated = self.dataset.collater(samples)
|
52 |
+
if len(collated) == 0:
|
53 |
+
return collated
|
54 |
+
indices = set(collated["id"].tolist())
|
55 |
+
target = [s["label"] for s in samples if s["id"] in indices]
|
56 |
+
|
57 |
+
if self.batch_targets:
|
58 |
+
collated["target_lengths"] = torch.LongTensor([len(t) for t in target])
|
59 |
+
target = data_utils.collate_tokens(target, pad_idx=self.pad, left_pad=False)
|
60 |
+
collated["ntokens"] = collated["target_lengths"].sum().item()
|
61 |
+
else:
|
62 |
+
collated["ntokens"] = sum([len(t) for t in target])
|
63 |
+
|
64 |
+
collated["target"] = target
|
65 |
+
|
66 |
+
if self.add_to_input:
|
67 |
+
eos = target.new_full((target.size(0), 1), self.eos)
|
68 |
+
collated["target"] = torch.cat([target, eos], dim=-1).long()
|
69 |
+
collated["net_input"]["prev_output_tokens"] = torch.cat(
|
70 |
+
[eos, target], dim=-1
|
71 |
+
).long()
|
72 |
+
collated["ntokens"] += target.size(0)
|
73 |
+
return collated
|
74 |
+
|
75 |
+
def filter_indices_by_size(self, indices, max_sizes):
|
76 |
+
indices, ignored = data_utils._filter_by_size_dynamic(
|
77 |
+
indices, self.size, max_sizes
|
78 |
+
)
|
79 |
+
return indices, ignored
|
fairseq/fairseq/data/add_target_dataset.py
ADDED
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import torch
|
7 |
+
|
8 |
+
from . import BaseWrapperDataset, data_utils
|
9 |
+
from fairseq.data.text_compressor import TextCompressor, TextCompressionLevel
|
10 |
+
|
11 |
+
|
12 |
+
class AddTargetDataset(BaseWrapperDataset):
|
13 |
+
def __init__(
|
14 |
+
self,
|
15 |
+
dataset,
|
16 |
+
labels,
|
17 |
+
pad,
|
18 |
+
eos,
|
19 |
+
batch_targets,
|
20 |
+
process_label=None,
|
21 |
+
label_len_fn=None,
|
22 |
+
add_to_input=False,
|
23 |
+
text_compression_level=TextCompressionLevel.none,
|
24 |
+
):
|
25 |
+
super().__init__(dataset)
|
26 |
+
self.labels = labels
|
27 |
+
self.batch_targets = batch_targets
|
28 |
+
self.pad = pad
|
29 |
+
self.eos = eos
|
30 |
+
self.process_label = process_label
|
31 |
+
self.label_len_fn = label_len_fn
|
32 |
+
self.add_to_input = add_to_input
|
33 |
+
self.text_compressor = TextCompressor(level=text_compression_level)
|
34 |
+
|
35 |
+
def get_label(self, index, process_fn=None):
|
36 |
+
lbl = self.labels[index]
|
37 |
+
lbl = self.text_compressor.decompress(lbl)
|
38 |
+
return lbl if process_fn is None else process_fn(lbl)
|
39 |
+
|
40 |
+
def __getitem__(self, index):
|
41 |
+
item = self.dataset[index]
|
42 |
+
item["label"] = self.get_label(index, process_fn=self.process_label)
|
43 |
+
return item
|
44 |
+
|
45 |
+
def size(self, index):
|
46 |
+
sz = self.dataset.size(index)
|
47 |
+
own_sz = self.label_len_fn(self.get_label(index))
|
48 |
+
return sz, own_sz
|
49 |
+
|
50 |
+
def collater(self, samples):
|
51 |
+
collated = self.dataset.collater(samples)
|
52 |
+
if len(collated) == 0:
|
53 |
+
return collated
|
54 |
+
indices = set(collated["id"].tolist())
|
55 |
+
target = [s["label"] for s in samples if s["id"] in indices]
|
56 |
+
|
57 |
+
if self.add_to_input:
|
58 |
+
eos = torch.LongTensor([self.eos])
|
59 |
+
prev_output_tokens = [torch.cat([eos, t], axis=-1) for t in target]
|
60 |
+
target = [torch.cat([t, eos], axis=-1) for t in target]
|
61 |
+
collated["net_input"]["prev_output_tokens"] = prev_output_tokens
|
62 |
+
|
63 |
+
if self.batch_targets:
|
64 |
+
collated["target_lengths"] = torch.LongTensor([len(t) for t in target])
|
65 |
+
target = data_utils.collate_tokens(target, pad_idx=self.pad, left_pad=False)
|
66 |
+
collated["ntokens"] = collated["target_lengths"].sum().item()
|
67 |
+
if getattr(collated["net_input"], "prev_output_tokens", None):
|
68 |
+
collated["net_input"]["prev_output_tokens"] = data_utils.collate_tokens(
|
69 |
+
collated["net_input"]["prev_output_tokens"],
|
70 |
+
pad_idx=self.pad,
|
71 |
+
left_pad=False,
|
72 |
+
)
|
73 |
+
else:
|
74 |
+
collated["ntokens"] = sum([len(t) for t in target])
|
75 |
+
|
76 |
+
collated["target"] = target
|
77 |
+
return collated
|
78 |
+
|
79 |
+
def filter_indices_by_size(self, indices, max_sizes):
|
80 |
+
indices, ignored = data_utils._filter_by_size_dynamic(
|
81 |
+
indices, self.size, max_sizes
|
82 |
+
)
|
83 |
+
return indices, ignored
|
fairseq/fairseq/data/append_token_dataset.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
import torch
|
8 |
+
|
9 |
+
from . import BaseWrapperDataset
|
10 |
+
|
11 |
+
|
12 |
+
class AppendTokenDataset(BaseWrapperDataset):
|
13 |
+
def __init__(self, dataset, token=None):
|
14 |
+
super().__init__(dataset)
|
15 |
+
self.token = token
|
16 |
+
if token is not None:
|
17 |
+
self._sizes = np.array(dataset.sizes) + 1
|
18 |
+
else:
|
19 |
+
self._sizes = dataset.sizes
|
20 |
+
|
21 |
+
def __getitem__(self, idx):
|
22 |
+
item = self.dataset[idx]
|
23 |
+
if self.token is not None:
|
24 |
+
item = torch.cat([item, item.new([self.token])])
|
25 |
+
return item
|
26 |
+
|
27 |
+
@property
|
28 |
+
def sizes(self):
|
29 |
+
return self._sizes
|
30 |
+
|
31 |
+
def num_tokens(self, index):
|
32 |
+
n = self.dataset.num_tokens(index)
|
33 |
+
if self.token is not None:
|
34 |
+
n += 1
|
35 |
+
return n
|
36 |
+
|
37 |
+
def size(self, index):
|
38 |
+
n = self.dataset.size(index)
|
39 |
+
if self.token is not None:
|
40 |
+
n += 1
|
41 |
+
return n
|
fairseq/fairseq/data/audio/dataset_transforms/__pycache__/concataugment.cpython-310.pyc
ADDED
Binary file (1.85 kB). View file
|
|
fairseq/fairseq/data/audio/feature_transforms/__pycache__/delta_deltas.cpython-310.pyc
ADDED
Binary file (1.63 kB). View file
|
|
fairseq/fairseq/data/audio/feature_transforms/global_cmvn.py
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
from fairseq.data.audio.feature_transforms import (
|
3 |
+
AudioFeatureTransform,
|
4 |
+
register_audio_feature_transform,
|
5 |
+
)
|
6 |
+
|
7 |
+
|
8 |
+
@register_audio_feature_transform("global_cmvn")
|
9 |
+
class GlobalCMVN(AudioFeatureTransform):
|
10 |
+
"""Global CMVN (cepstral mean and variance normalization). The global mean
|
11 |
+
and variance need to be pre-computed and stored in NumPy format (.npz)."""
|
12 |
+
|
13 |
+
@classmethod
|
14 |
+
def from_config_dict(cls, config=None):
|
15 |
+
_config = {} if config is None else config
|
16 |
+
return GlobalCMVN(_config.get("stats_npz_path"))
|
17 |
+
|
18 |
+
def __init__(self, stats_npz_path):
|
19 |
+
self.stats_npz_path = stats_npz_path
|
20 |
+
stats = np.load(stats_npz_path)
|
21 |
+
self.mean, self.std = stats["mean"], stats["std"]
|
22 |
+
|
23 |
+
def __repr__(self):
|
24 |
+
return self.__class__.__name__ + f'(stats_npz_path="{self.stats_npz_path}")'
|
25 |
+
|
26 |
+
def __call__(self, x):
|
27 |
+
x = np.subtract(x, self.mean)
|
28 |
+
x = np.divide(x, self.std)
|
29 |
+
return x
|
fairseq/fairseq/data/audio/speech_to_text_dataset.py
ADDED
@@ -0,0 +1,733 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import csv
|
7 |
+
import logging
|
8 |
+
import re
|
9 |
+
from argparse import Namespace
|
10 |
+
from collections import defaultdict
|
11 |
+
from dataclasses import dataclass
|
12 |
+
from pathlib import Path
|
13 |
+
from typing import Dict, List, Optional, Tuple, Union
|
14 |
+
|
15 |
+
import numpy as np
|
16 |
+
import torch
|
17 |
+
import torch.nn.functional as F
|
18 |
+
|
19 |
+
from fairseq.data import ConcatDataset, Dictionary, FairseqDataset, ResamplingDataset
|
20 |
+
from fairseq.data import data_utils as fairseq_data_utils
|
21 |
+
from fairseq.data import encoders
|
22 |
+
from fairseq.data.audio.audio_utils import get_features_or_waveform
|
23 |
+
from fairseq.data.audio.data_cfg import S2TDataConfig
|
24 |
+
from fairseq.data.audio.dataset_transforms import CompositeAudioDatasetTransform
|
25 |
+
from fairseq.data.audio.dataset_transforms.concataugment import ConcatAugment
|
26 |
+
from fairseq.data.audio.dataset_transforms.noisyoverlapaugment import (
|
27 |
+
NoisyOverlapAugment,
|
28 |
+
)
|
29 |
+
from fairseq.data.audio.feature_transforms import CompositeAudioFeatureTransform
|
30 |
+
from fairseq.data.audio.waveform_transforms import CompositeAudioWaveformTransform
|
31 |
+
|
32 |
+
logger = logging.getLogger(__name__)
|
33 |
+
|
34 |
+
|
35 |
+
def _collate_frames(
|
36 |
+
frames: List[torch.Tensor], is_audio_input: bool = False
|
37 |
+
) -> torch.Tensor:
|
38 |
+
"""
|
39 |
+
Convert a list of 2D frames into a padded 3D tensor
|
40 |
+
Args:
|
41 |
+
frames (list): list of 2D frames of size L[i]*f_dim. Where L[i] is
|
42 |
+
length of i-th frame and f_dim is static dimension of features
|
43 |
+
Returns:
|
44 |
+
3D tensor of size len(frames)*len_max*f_dim where len_max is max of L[i]
|
45 |
+
"""
|
46 |
+
max_len = max(frame.size(0) for frame in frames)
|
47 |
+
if is_audio_input:
|
48 |
+
out = frames[0].new_zeros((len(frames), max_len))
|
49 |
+
else:
|
50 |
+
out = frames[0].new_zeros((len(frames), max_len, frames[0].size(1)))
|
51 |
+
for i, v in enumerate(frames):
|
52 |
+
out[i, : v.size(0)] = v
|
53 |
+
return out
|
54 |
+
|
55 |
+
|
56 |
+
def _is_int_or_np_int(n):
|
57 |
+
return isinstance(n, int) or (
|
58 |
+
isinstance(n, np.generic) and isinstance(n.item(), int)
|
59 |
+
)
|
60 |
+
|
61 |
+
|
62 |
+
@dataclass
|
63 |
+
class SpeechToTextDatasetItem(object):
|
64 |
+
index: int
|
65 |
+
source: torch.Tensor
|
66 |
+
target: Optional[torch.Tensor] = None
|
67 |
+
speaker_id: Optional[int] = None
|
68 |
+
|
69 |
+
|
70 |
+
class SpeechToTextDataset(FairseqDataset):
|
71 |
+
LANG_TAG_TEMPLATE = "<lang:{}>"
|
72 |
+
|
73 |
+
def __init__(
|
74 |
+
self,
|
75 |
+
split: str,
|
76 |
+
is_train_split: bool,
|
77 |
+
cfg: S2TDataConfig,
|
78 |
+
audio_paths: List[str],
|
79 |
+
n_frames: List[int],
|
80 |
+
src_texts: Optional[List[str]] = None,
|
81 |
+
tgt_texts: Optional[List[str]] = None,
|
82 |
+
speakers: Optional[List[str]] = None,
|
83 |
+
src_langs: Optional[List[str]] = None,
|
84 |
+
tgt_langs: Optional[List[str]] = None,
|
85 |
+
ids: Optional[List[str]] = None,
|
86 |
+
tgt_dict: Optional[Dictionary] = None,
|
87 |
+
pre_tokenizer=None,
|
88 |
+
bpe_tokenizer=None,
|
89 |
+
n_frames_per_step=1,
|
90 |
+
speaker_to_id=None,
|
91 |
+
append_eos=True,
|
92 |
+
):
|
93 |
+
self.split, self.is_train_split = split, is_train_split
|
94 |
+
self.cfg = cfg
|
95 |
+
self.audio_paths, self.n_frames = audio_paths, n_frames
|
96 |
+
self.n_samples = len(audio_paths)
|
97 |
+
assert len(n_frames) == self.n_samples > 0
|
98 |
+
assert src_texts is None or len(src_texts) == self.n_samples
|
99 |
+
assert tgt_texts is None or len(tgt_texts) == self.n_samples
|
100 |
+
assert speakers is None or len(speakers) == self.n_samples
|
101 |
+
assert src_langs is None or len(src_langs) == self.n_samples
|
102 |
+
assert tgt_langs is None or len(tgt_langs) == self.n_samples
|
103 |
+
assert ids is None or len(ids) == self.n_samples
|
104 |
+
assert (tgt_dict is None and tgt_texts is None) or (
|
105 |
+
tgt_dict is not None and tgt_texts is not None
|
106 |
+
)
|
107 |
+
self.src_texts, self.tgt_texts = src_texts, tgt_texts
|
108 |
+
self.src_langs, self.tgt_langs = src_langs, tgt_langs
|
109 |
+
self.speakers = speakers
|
110 |
+
self.tgt_dict = tgt_dict
|
111 |
+
self.check_tgt_lang_tag()
|
112 |
+
self.ids = ids
|
113 |
+
self.shuffle = cfg.shuffle if is_train_split else False
|
114 |
+
|
115 |
+
self.feature_transforms = CompositeAudioFeatureTransform.from_config_dict(
|
116 |
+
self.cfg.get_feature_transforms(split, is_train_split)
|
117 |
+
)
|
118 |
+
self.waveform_transforms = CompositeAudioWaveformTransform.from_config_dict(
|
119 |
+
self.cfg.get_waveform_transforms(split, is_train_split)
|
120 |
+
)
|
121 |
+
# TODO: add these to data_cfg.py
|
122 |
+
self.dataset_transforms = CompositeAudioDatasetTransform.from_config_dict(
|
123 |
+
self.cfg.get_dataset_transforms(split, is_train_split)
|
124 |
+
)
|
125 |
+
|
126 |
+
# check proper usage of transforms
|
127 |
+
if self.feature_transforms and self.cfg.use_audio_input:
|
128 |
+
logger.warning(
|
129 |
+
"Feature transforms will not be applied. To use feature transforms, "
|
130 |
+
"set use_audio_input as False in config."
|
131 |
+
)
|
132 |
+
|
133 |
+
self.pre_tokenizer = pre_tokenizer
|
134 |
+
self.bpe_tokenizer = bpe_tokenizer
|
135 |
+
self.n_frames_per_step = n_frames_per_step
|
136 |
+
self.speaker_to_id = speaker_to_id
|
137 |
+
|
138 |
+
self.tgt_lens = self.get_tgt_lens_and_check_oov()
|
139 |
+
self.append_eos = append_eos
|
140 |
+
|
141 |
+
logger.info(self.__repr__())
|
142 |
+
|
143 |
+
def get_tgt_lens_and_check_oov(self):
|
144 |
+
if self.tgt_texts is None:
|
145 |
+
return [0 for _ in range(self.n_samples)]
|
146 |
+
tgt_lens = []
|
147 |
+
n_tokens, n_oov_tokens = 0, 0
|
148 |
+
for i in range(self.n_samples):
|
149 |
+
tokenized = self.get_tokenized_tgt_text(i).split(" ")
|
150 |
+
oov_tokens = [
|
151 |
+
t
|
152 |
+
for t in tokenized
|
153 |
+
if self.tgt_dict.index(t) == self.tgt_dict.unk_index
|
154 |
+
]
|
155 |
+
n_tokens += len(tokenized)
|
156 |
+
n_oov_tokens += len(oov_tokens)
|
157 |
+
tgt_lens.append(len(tokenized))
|
158 |
+
logger.info(f"'{self.split}' has {n_oov_tokens / n_tokens * 100:.2f}% OOV")
|
159 |
+
return tgt_lens
|
160 |
+
|
161 |
+
def __repr__(self):
|
162 |
+
return (
|
163 |
+
self.__class__.__name__
|
164 |
+
+ f'(split="{self.split}", n_samples={self.n_samples:_}, '
|
165 |
+
f"prepend_tgt_lang_tag={self.cfg.prepend_tgt_lang_tag}, "
|
166 |
+
f"n_frames_per_step={self.n_frames_per_step}, "
|
167 |
+
f"shuffle={self.shuffle}, "
|
168 |
+
f"feature_transforms={self.feature_transforms}, "
|
169 |
+
f"waveform_transforms={self.waveform_transforms}, "
|
170 |
+
f"dataset_transforms={self.dataset_transforms})"
|
171 |
+
)
|
172 |
+
|
173 |
+
@classmethod
|
174 |
+
def is_lang_tag(cls, token):
|
175 |
+
pattern = cls.LANG_TAG_TEMPLATE.replace("{}", "(.*)")
|
176 |
+
return re.match(pattern, token)
|
177 |
+
|
178 |
+
def check_tgt_lang_tag(self):
|
179 |
+
if self.cfg.prepend_tgt_lang_tag:
|
180 |
+
assert self.tgt_langs is not None and self.tgt_dict is not None
|
181 |
+
tgt_lang_tags = [
|
182 |
+
self.LANG_TAG_TEMPLATE.format(t) for t in set(self.tgt_langs)
|
183 |
+
]
|
184 |
+
assert all(t in self.tgt_dict for t in tgt_lang_tags)
|
185 |
+
|
186 |
+
@classmethod
|
187 |
+
def tokenize(cls, tokenizer, text: str):
|
188 |
+
return text if tokenizer is None else tokenizer.encode(text)
|
189 |
+
|
190 |
+
def get_tokenized_tgt_text(self, index: Union[int, List[int]]):
|
191 |
+
if _is_int_or_np_int(index):
|
192 |
+
text = self.tgt_texts[index]
|
193 |
+
else:
|
194 |
+
text = " ".join([self.tgt_texts[i] for i in index])
|
195 |
+
|
196 |
+
text = self.tokenize(self.pre_tokenizer, text)
|
197 |
+
text = self.tokenize(self.bpe_tokenizer, text)
|
198 |
+
return text
|
199 |
+
|
200 |
+
def pack_frames(self, feature: torch.Tensor):
|
201 |
+
if self.n_frames_per_step == 1:
|
202 |
+
return feature
|
203 |
+
n_packed_frames = feature.shape[0] // self.n_frames_per_step
|
204 |
+
feature = feature[: self.n_frames_per_step * n_packed_frames]
|
205 |
+
return feature.reshape(n_packed_frames, -1)
|
206 |
+
|
207 |
+
@classmethod
|
208 |
+
def get_lang_tag_idx(cls, lang: str, dictionary: Dictionary):
|
209 |
+
lang_tag_idx = dictionary.index(cls.LANG_TAG_TEMPLATE.format(lang))
|
210 |
+
assert lang_tag_idx != dictionary.unk()
|
211 |
+
return lang_tag_idx
|
212 |
+
|
213 |
+
def _get_source_audio(self, index: Union[int, List[int]]) -> torch.Tensor:
|
214 |
+
"""
|
215 |
+
Gives source audio for given index with any relevant transforms
|
216 |
+
applied. For ConcatAug, source audios for given indices are
|
217 |
+
concatenated in given order.
|
218 |
+
Args:
|
219 |
+
index (int or List[int]): index—or in the case of ConcatAug,
|
220 |
+
indices—to pull the source audio for
|
221 |
+
Returns:
|
222 |
+
source audios concatenated for given indices with
|
223 |
+
relevant transforms appplied
|
224 |
+
"""
|
225 |
+
if _is_int_or_np_int(index):
|
226 |
+
source = get_features_or_waveform(
|
227 |
+
self.audio_paths[index],
|
228 |
+
need_waveform=self.cfg.use_audio_input,
|
229 |
+
use_sample_rate=self.cfg.use_sample_rate,
|
230 |
+
waveform_transforms=self.waveform_transforms,
|
231 |
+
)
|
232 |
+
else:
|
233 |
+
source = np.concatenate(
|
234 |
+
[
|
235 |
+
get_features_or_waveform(
|
236 |
+
self.audio_paths[i],
|
237 |
+
need_waveform=self.cfg.use_audio_input,
|
238 |
+
use_sample_rate=self.cfg.use_sample_rate,
|
239 |
+
waveform_transforms=self.waveform_transforms,
|
240 |
+
)
|
241 |
+
for i in index
|
242 |
+
]
|
243 |
+
)
|
244 |
+
if self.cfg.use_audio_input:
|
245 |
+
source = torch.from_numpy(source).float()
|
246 |
+
if self.cfg.standardize_audio:
|
247 |
+
with torch.no_grad():
|
248 |
+
source = F.layer_norm(source, source.shape)
|
249 |
+
else:
|
250 |
+
if self.feature_transforms is not None:
|
251 |
+
source = self.feature_transforms(source)
|
252 |
+
source = torch.from_numpy(source).float()
|
253 |
+
return source
|
254 |
+
|
255 |
+
def __getitem__(self, index: int) -> SpeechToTextDatasetItem:
|
256 |
+
has_concat = self.dataset_transforms.has_transform(ConcatAugment)
|
257 |
+
if has_concat:
|
258 |
+
concat = self.dataset_transforms.get_transform(ConcatAugment)
|
259 |
+
indices = concat.find_indices(index, self.n_frames, self.n_samples)
|
260 |
+
|
261 |
+
source = self._get_source_audio(indices if has_concat else index)
|
262 |
+
source = self.pack_frames(source)
|
263 |
+
|
264 |
+
target = None
|
265 |
+
if self.tgt_texts is not None:
|
266 |
+
tokenized = self.get_tokenized_tgt_text(indices if has_concat else index)
|
267 |
+
target = self.tgt_dict.encode_line(
|
268 |
+
tokenized, add_if_not_exist=False, append_eos=self.append_eos
|
269 |
+
).long()
|
270 |
+
if self.cfg.prepend_tgt_lang_tag:
|
271 |
+
lang_tag_idx = self.get_lang_tag_idx(
|
272 |
+
self.tgt_langs[index], self.tgt_dict
|
273 |
+
)
|
274 |
+
target = torch.cat((torch.LongTensor([lang_tag_idx]), target), 0)
|
275 |
+
|
276 |
+
if self.cfg.prepend_bos_and_append_tgt_lang_tag:
|
277 |
+
bos = torch.LongTensor([self.tgt_dict.bos()])
|
278 |
+
lang_tag_idx = self.get_lang_tag_idx(self.tgt_langs[index], self.tgt_dict)
|
279 |
+
assert lang_tag_idx != self.tgt_dict.unk()
|
280 |
+
lang_tag_idx = torch.LongTensor([lang_tag_idx])
|
281 |
+
target = torch.cat((bos, target, lang_tag_idx), 0)
|
282 |
+
|
283 |
+
speaker_id = None
|
284 |
+
if self.speaker_to_id is not None:
|
285 |
+
speaker_id = self.speaker_to_id[self.speakers[index]]
|
286 |
+
return SpeechToTextDatasetItem(
|
287 |
+
index=index, source=source, target=target, speaker_id=speaker_id
|
288 |
+
)
|
289 |
+
|
290 |
+
def __len__(self):
|
291 |
+
return self.n_samples
|
292 |
+
|
293 |
+
def collater(
|
294 |
+
self, samples: List[SpeechToTextDatasetItem], return_order: bool = False
|
295 |
+
) -> Dict:
|
296 |
+
if len(samples) == 0:
|
297 |
+
return {}
|
298 |
+
indices = torch.tensor([x.index for x in samples], dtype=torch.long)
|
299 |
+
|
300 |
+
sources = [x.source for x in samples]
|
301 |
+
has_NOAug = self.dataset_transforms.has_transform(NoisyOverlapAugment)
|
302 |
+
if has_NOAug and self.cfg.use_audio_input:
|
303 |
+
NOAug = self.dataset_transforms.get_transform(NoisyOverlapAugment)
|
304 |
+
sources = NOAug(sources)
|
305 |
+
|
306 |
+
frames = _collate_frames(sources, self.cfg.use_audio_input)
|
307 |
+
# sort samples by descending number of frames
|
308 |
+
n_frames = torch.tensor([x.size(0) for x in sources], dtype=torch.long)
|
309 |
+
n_frames, order = n_frames.sort(descending=True)
|
310 |
+
indices = indices.index_select(0, order)
|
311 |
+
frames = frames.index_select(0, order)
|
312 |
+
|
313 |
+
target, target_lengths = None, None
|
314 |
+
prev_output_tokens = None
|
315 |
+
ntokens = None
|
316 |
+
if self.tgt_texts is not None:
|
317 |
+
target = fairseq_data_utils.collate_tokens(
|
318 |
+
[x.target for x in samples],
|
319 |
+
self.tgt_dict.pad(),
|
320 |
+
self.tgt_dict.eos(),
|
321 |
+
left_pad=False,
|
322 |
+
move_eos_to_beginning=False,
|
323 |
+
)
|
324 |
+
target = target.index_select(0, order)
|
325 |
+
target_lengths = torch.tensor(
|
326 |
+
[x.target.size(0) for x in samples], dtype=torch.long
|
327 |
+
).index_select(0, order)
|
328 |
+
prev_output_tokens = fairseq_data_utils.collate_tokens(
|
329 |
+
[x.target for x in samples],
|
330 |
+
self.tgt_dict.pad(),
|
331 |
+
eos_idx=None,
|
332 |
+
left_pad=False,
|
333 |
+
move_eos_to_beginning=True,
|
334 |
+
)
|
335 |
+
prev_output_tokens = prev_output_tokens.index_select(0, order)
|
336 |
+
ntokens = sum(x.target.size(0) for x in samples)
|
337 |
+
|
338 |
+
speaker = None
|
339 |
+
if self.speaker_to_id is not None:
|
340 |
+
speaker = (
|
341 |
+
torch.tensor([s.speaker_id for s in samples], dtype=torch.long)
|
342 |
+
.index_select(0, order)
|
343 |
+
.view(-1, 1)
|
344 |
+
)
|
345 |
+
|
346 |
+
net_input = {
|
347 |
+
"src_tokens": frames,
|
348 |
+
"src_lengths": n_frames,
|
349 |
+
"prev_output_tokens": prev_output_tokens,
|
350 |
+
}
|
351 |
+
out = {
|
352 |
+
"id": indices,
|
353 |
+
"net_input": net_input,
|
354 |
+
"speaker": speaker,
|
355 |
+
"target": target,
|
356 |
+
"target_lengths": target_lengths,
|
357 |
+
"ntokens": ntokens,
|
358 |
+
"nsentences": len(samples),
|
359 |
+
}
|
360 |
+
if return_order:
|
361 |
+
out["order"] = order
|
362 |
+
return out
|
363 |
+
|
364 |
+
def num_tokens(self, index):
|
365 |
+
return self.n_frames[index]
|
366 |
+
|
367 |
+
def size(self, index):
|
368 |
+
return self.n_frames[index], self.tgt_lens[index]
|
369 |
+
|
370 |
+
@property
|
371 |
+
def sizes(self):
|
372 |
+
return np.array(self.n_frames)
|
373 |
+
|
374 |
+
@property
|
375 |
+
def can_reuse_epoch_itr_across_epochs(self):
|
376 |
+
return True
|
377 |
+
|
378 |
+
def ordered_indices(self):
|
379 |
+
if self.shuffle:
|
380 |
+
order = [np.random.permutation(len(self))]
|
381 |
+
else:
|
382 |
+
order = [np.arange(len(self))]
|
383 |
+
# first by descending order of # of frames then by original/random order
|
384 |
+
order.append([-n for n in self.n_frames])
|
385 |
+
return np.lexsort(order)
|
386 |
+
|
387 |
+
def prefetch(self, indices):
|
388 |
+
raise False
|
389 |
+
|
390 |
+
|
391 |
+
class TextTargetMultitaskData(object):
|
392 |
+
# mandatory columns
|
393 |
+
KEY_ID, KEY_TEXT = "id", "tgt_text"
|
394 |
+
LANG_TAG_TEMPLATE = "<lang:{}>"
|
395 |
+
|
396 |
+
def __init__(self, args, split, tgt_dict):
|
397 |
+
samples = SpeechToTextDatasetCreator._load_samples_from_tsv(args.data, split)
|
398 |
+
self.data = {s[self.KEY_ID]: s[self.KEY_TEXT] for s in samples}
|
399 |
+
self.dict = tgt_dict
|
400 |
+
self.append_eos = args.decoder_type != "ctc"
|
401 |
+
self.pre_tokenizer = self.build_tokenizer(args)
|
402 |
+
self.bpe_tokenizer = self.build_bpe(args)
|
403 |
+
self.prepend_bos_and_append_tgt_lang_tag = (
|
404 |
+
args.prepend_bos_and_append_tgt_lang_tag
|
405 |
+
)
|
406 |
+
self.eos_token = args.eos_token
|
407 |
+
self.lang_tag_mapping = args.get_lang_tag_mapping
|
408 |
+
|
409 |
+
@classmethod
|
410 |
+
def is_lang_tag(cls, token):
|
411 |
+
pattern = cls.LANG_TAG_TEMPLATE.replace("{}", "(.*)")
|
412 |
+
return re.match(pattern, token)
|
413 |
+
|
414 |
+
@classmethod
|
415 |
+
def tokenize(cls, tokenizer, text: str):
|
416 |
+
return text if tokenizer is None else tokenizer.encode(text)
|
417 |
+
|
418 |
+
def get_tokenized_tgt_text(self, index: int):
|
419 |
+
text = self.tokenize(self.pre_tokenizer, self.data[index])
|
420 |
+
text = self.tokenize(self.bpe_tokenizer, text)
|
421 |
+
return text
|
422 |
+
|
423 |
+
def get_lang_tag_idx(self, lang: str, dictionary: Dictionary):
|
424 |
+
lang_tag = self.LANG_TAG_TEMPLATE.format(lang)
|
425 |
+
lang_tag = self.lang_tag_mapping.get(lang_tag, lang_tag)
|
426 |
+
lang_tag_idx = dictionary.index(lang_tag)
|
427 |
+
assert lang_tag_idx != dictionary.unk(), (lang, lang_tag)
|
428 |
+
return lang_tag_idx
|
429 |
+
|
430 |
+
def build_tokenizer(self, args):
|
431 |
+
pre_tokenizer = args.config.get("pre_tokenizer")
|
432 |
+
if pre_tokenizer is not None:
|
433 |
+
logger.info(f"pre-tokenizer: {pre_tokenizer}")
|
434 |
+
return encoders.build_tokenizer(Namespace(**pre_tokenizer))
|
435 |
+
else:
|
436 |
+
return None
|
437 |
+
|
438 |
+
def build_bpe(self, args):
|
439 |
+
bpe_tokenizer = args.config.get("bpe_tokenizer")
|
440 |
+
if bpe_tokenizer is not None:
|
441 |
+
logger.info(f"tokenizer: {bpe_tokenizer}")
|
442 |
+
return encoders.build_bpe(Namespace(**bpe_tokenizer))
|
443 |
+
else:
|
444 |
+
return None
|
445 |
+
|
446 |
+
def get(self, sample_id, tgt_lang=None):
|
447 |
+
if sample_id in self.data:
|
448 |
+
tokenized = self.get_tokenized_tgt_text(sample_id)
|
449 |
+
target = self.dict.encode_line(
|
450 |
+
tokenized,
|
451 |
+
add_if_not_exist=False,
|
452 |
+
append_eos=self.append_eos,
|
453 |
+
)
|
454 |
+
if self.prepend_bos_and_append_tgt_lang_tag:
|
455 |
+
bos = torch.LongTensor([self.dict.bos()])
|
456 |
+
lang_tag_idx = self.get_lang_tag_idx(tgt_lang, self.dict)
|
457 |
+
assert lang_tag_idx != self.dict.unk()
|
458 |
+
lang_tag_idx = torch.LongTensor([lang_tag_idx])
|
459 |
+
target = torch.cat((bos, target, lang_tag_idx), 0)
|
460 |
+
return target
|
461 |
+
else:
|
462 |
+
logger.warning(f"no target for {sample_id}")
|
463 |
+
return torch.IntTensor([])
|
464 |
+
|
465 |
+
def collater(self, samples: List[torch.Tensor]) -> torch.Tensor:
|
466 |
+
out = fairseq_data_utils.collate_tokens(
|
467 |
+
samples,
|
468 |
+
self.dict.pad(),
|
469 |
+
eos_idx=None,
|
470 |
+
left_pad=False,
|
471 |
+
move_eos_to_beginning=False,
|
472 |
+
).long()
|
473 |
+
|
474 |
+
prev_out = fairseq_data_utils.collate_tokens(
|
475 |
+
samples,
|
476 |
+
self.dict.pad(),
|
477 |
+
eos_idx=None,
|
478 |
+
left_pad=False,
|
479 |
+
move_eos_to_beginning=True,
|
480 |
+
).long()
|
481 |
+
|
482 |
+
target_lengths = torch.tensor([t.size(0) for t in samples], dtype=torch.long)
|
483 |
+
ntokens = sum(t.size(0) for t in samples)
|
484 |
+
|
485 |
+
output = {
|
486 |
+
"prev_output_tokens": prev_out,
|
487 |
+
"target": out,
|
488 |
+
"target_lengths": target_lengths,
|
489 |
+
"ntokens": ntokens,
|
490 |
+
}
|
491 |
+
|
492 |
+
return output
|
493 |
+
|
494 |
+
|
495 |
+
class SpeechToTextMultitaskDataset(SpeechToTextDataset):
|
496 |
+
def __init__(self, **kwargs):
|
497 |
+
super().__init__(**kwargs)
|
498 |
+
self.multitask_data = {}
|
499 |
+
|
500 |
+
def add_multitask_dataset(self, task_name, task_data):
|
501 |
+
self.multitask_data[task_name] = task_data
|
502 |
+
|
503 |
+
def __getitem__(
|
504 |
+
self, index: int
|
505 |
+
) -> Tuple[SpeechToTextDatasetItem, Dict[str, torch.Tensor]]:
|
506 |
+
s2t_data = super().__getitem__(index)
|
507 |
+
|
508 |
+
multitask_target = {}
|
509 |
+
sample_id = self.ids[index]
|
510 |
+
tgt_lang = self.tgt_langs[index]
|
511 |
+
for task_name, task_dataset in self.multitask_data.items():
|
512 |
+
multitask_target[task_name] = task_dataset.get(sample_id, tgt_lang)
|
513 |
+
|
514 |
+
return s2t_data, multitask_target
|
515 |
+
|
516 |
+
def collater(
|
517 |
+
self, samples: List[Tuple[SpeechToTextDatasetItem, Dict[str, torch.Tensor]]]
|
518 |
+
) -> Dict:
|
519 |
+
if len(samples) == 0:
|
520 |
+
return {}
|
521 |
+
|
522 |
+
out = super().collater([s for s, _ in samples], return_order=True)
|
523 |
+
order = out["order"]
|
524 |
+
del out["order"]
|
525 |
+
|
526 |
+
for task_name, task_dataset in self.multitask_data.items():
|
527 |
+
if "multitask" not in out:
|
528 |
+
out["multitask"] = {}
|
529 |
+
d = [s[task_name] for _, s in samples]
|
530 |
+
task_target = task_dataset.collater(d)
|
531 |
+
out["multitask"][task_name] = {
|
532 |
+
"target": task_target["target"].index_select(0, order),
|
533 |
+
"target_lengths": task_target["target_lengths"].index_select(0, order),
|
534 |
+
"ntokens": task_target["ntokens"],
|
535 |
+
}
|
536 |
+
out["multitask"][task_name]["net_input"] = {
|
537 |
+
"prev_output_tokens": task_target["prev_output_tokens"].index_select(
|
538 |
+
0, order
|
539 |
+
),
|
540 |
+
}
|
541 |
+
|
542 |
+
return out
|
543 |
+
|
544 |
+
|
545 |
+
class SpeechToTextDatasetCreator(object):
|
546 |
+
# mandatory columns
|
547 |
+
KEY_ID, KEY_AUDIO, KEY_N_FRAMES = "id", "audio", "n_frames"
|
548 |
+
KEY_TGT_TEXT = "tgt_text"
|
549 |
+
# optional columns
|
550 |
+
KEY_SPEAKER, KEY_SRC_TEXT = "speaker", "src_text"
|
551 |
+
KEY_SRC_LANG, KEY_TGT_LANG = "src_lang", "tgt_lang"
|
552 |
+
# default values
|
553 |
+
DEFAULT_SPEAKER = DEFAULT_SRC_TEXT = DEFAULT_LANG = ""
|
554 |
+
|
555 |
+
@classmethod
|
556 |
+
def _from_list(
|
557 |
+
cls,
|
558 |
+
split_name: str,
|
559 |
+
is_train_split,
|
560 |
+
samples: List[Dict],
|
561 |
+
cfg: S2TDataConfig,
|
562 |
+
tgt_dict,
|
563 |
+
pre_tokenizer,
|
564 |
+
bpe_tokenizer,
|
565 |
+
n_frames_per_step,
|
566 |
+
speaker_to_id,
|
567 |
+
multitask: Optional[Dict] = None,
|
568 |
+
) -> SpeechToTextDataset:
|
569 |
+
audio_root = Path(cfg.audio_root)
|
570 |
+
ids = [s[cls.KEY_ID] for s in samples]
|
571 |
+
audio_paths = [(audio_root / s[cls.KEY_AUDIO]).as_posix() for s in samples]
|
572 |
+
n_frames = [int(s[cls.KEY_N_FRAMES]) for s in samples]
|
573 |
+
tgt_texts = [s[cls.KEY_TGT_TEXT] for s in samples]
|
574 |
+
src_texts = [s.get(cls.KEY_SRC_TEXT, cls.DEFAULT_SRC_TEXT) for s in samples]
|
575 |
+
speakers = [s.get(cls.KEY_SPEAKER, cls.DEFAULT_SPEAKER) for s in samples]
|
576 |
+
src_langs = [s.get(cls.KEY_SRC_LANG, cls.DEFAULT_LANG) for s in samples]
|
577 |
+
tgt_langs = [s.get(cls.KEY_TGT_LANG, cls.DEFAULT_LANG) for s in samples]
|
578 |
+
|
579 |
+
has_multitask = multitask is not None and len(multitask.keys()) > 0
|
580 |
+
dataset_cls = (
|
581 |
+
SpeechToTextMultitaskDataset if has_multitask else SpeechToTextDataset
|
582 |
+
)
|
583 |
+
|
584 |
+
ds = dataset_cls(
|
585 |
+
split=split_name,
|
586 |
+
is_train_split=is_train_split,
|
587 |
+
cfg=cfg,
|
588 |
+
audio_paths=audio_paths,
|
589 |
+
n_frames=n_frames,
|
590 |
+
src_texts=src_texts,
|
591 |
+
tgt_texts=tgt_texts,
|
592 |
+
speakers=speakers,
|
593 |
+
src_langs=src_langs,
|
594 |
+
tgt_langs=tgt_langs,
|
595 |
+
ids=ids,
|
596 |
+
tgt_dict=tgt_dict,
|
597 |
+
pre_tokenizer=pre_tokenizer,
|
598 |
+
bpe_tokenizer=bpe_tokenizer,
|
599 |
+
n_frames_per_step=n_frames_per_step,
|
600 |
+
speaker_to_id=speaker_to_id,
|
601 |
+
)
|
602 |
+
|
603 |
+
if has_multitask:
|
604 |
+
for task_name, task_obj in multitask.items():
|
605 |
+
task_data = TextTargetMultitaskData(
|
606 |
+
task_obj.args, split_name, task_obj.target_dictionary
|
607 |
+
)
|
608 |
+
ds.add_multitask_dataset(task_name, task_data)
|
609 |
+
return ds
|
610 |
+
|
611 |
+
@classmethod
|
612 |
+
def get_size_ratios(
|
613 |
+
cls, datasets: List[SpeechToTextDataset], alpha: float = 1.0
|
614 |
+
) -> List[float]:
|
615 |
+
"""Size ratios for temperature-based sampling
|
616 |
+
(https://arxiv.org/abs/1907.05019)"""
|
617 |
+
|
618 |
+
id_to_lp, lp_to_sz = {}, defaultdict(int)
|
619 |
+
for ds in datasets:
|
620 |
+
lang_pairs = {f"{s}->{t}" for s, t in zip(ds.src_langs, ds.tgt_langs)}
|
621 |
+
assert len(lang_pairs) == 1
|
622 |
+
lang_pair = list(lang_pairs)[0]
|
623 |
+
id_to_lp[ds.split] = lang_pair
|
624 |
+
lp_to_sz[lang_pair] += sum(ds.n_frames)
|
625 |
+
|
626 |
+
sz_sum = sum(v for v in lp_to_sz.values())
|
627 |
+
lp_to_prob = {k: v / sz_sum for k, v in lp_to_sz.items()}
|
628 |
+
lp_to_tgt_prob = {k: v**alpha for k, v in lp_to_prob.items()}
|
629 |
+
prob_sum = sum(v for v in lp_to_tgt_prob.values())
|
630 |
+
lp_to_tgt_prob = {k: v / prob_sum for k, v in lp_to_tgt_prob.items()}
|
631 |
+
lp_to_sz_ratio = {
|
632 |
+
k: (lp_to_tgt_prob[k] * sz_sum) / v for k, v in lp_to_sz.items()
|
633 |
+
}
|
634 |
+
size_ratio = [lp_to_sz_ratio[id_to_lp[ds.split]] for ds in datasets]
|
635 |
+
|
636 |
+
p_formatted = {
|
637 |
+
k: f"{lp_to_prob[k]:.3f}->{lp_to_tgt_prob[k]:.3f}" for k in lp_to_sz
|
638 |
+
}
|
639 |
+
logger.info(f"sampling probability balancing: {p_formatted}")
|
640 |
+
sr_formatted = {ds.split: f"{r:.3f}" for ds, r in zip(datasets, size_ratio)}
|
641 |
+
logger.info(f"balanced sampling size ratio: {sr_formatted}")
|
642 |
+
return size_ratio
|
643 |
+
|
644 |
+
@classmethod
|
645 |
+
def _load_samples_from_tsv(cls, root: str, split: str):
|
646 |
+
tsv_path = Path(root) / f"{split}.tsv"
|
647 |
+
if not tsv_path.is_file():
|
648 |
+
raise FileNotFoundError(f"Dataset not found: {tsv_path}")
|
649 |
+
with open(tsv_path) as f:
|
650 |
+
reader = csv.DictReader(
|
651 |
+
f,
|
652 |
+
delimiter="\t",
|
653 |
+
quotechar=None,
|
654 |
+
doublequote=False,
|
655 |
+
lineterminator="\n",
|
656 |
+
quoting=csv.QUOTE_NONE,
|
657 |
+
)
|
658 |
+
samples = [dict(e) for e in reader]
|
659 |
+
if len(samples) == 0:
|
660 |
+
raise ValueError(f"Empty manifest: {tsv_path}")
|
661 |
+
return samples
|
662 |
+
|
663 |
+
@classmethod
|
664 |
+
def _from_tsv(
|
665 |
+
cls,
|
666 |
+
root: str,
|
667 |
+
cfg: S2TDataConfig,
|
668 |
+
split: str,
|
669 |
+
tgt_dict,
|
670 |
+
is_train_split: bool,
|
671 |
+
pre_tokenizer,
|
672 |
+
bpe_tokenizer,
|
673 |
+
n_frames_per_step,
|
674 |
+
speaker_to_id,
|
675 |
+
multitask: Optional[Dict] = None,
|
676 |
+
) -> SpeechToTextDataset:
|
677 |
+
samples = cls._load_samples_from_tsv(root, split)
|
678 |
+
return cls._from_list(
|
679 |
+
split,
|
680 |
+
is_train_split,
|
681 |
+
samples,
|
682 |
+
cfg,
|
683 |
+
tgt_dict,
|
684 |
+
pre_tokenizer,
|
685 |
+
bpe_tokenizer,
|
686 |
+
n_frames_per_step,
|
687 |
+
speaker_to_id,
|
688 |
+
multitask,
|
689 |
+
)
|
690 |
+
|
691 |
+
@classmethod
|
692 |
+
def from_tsv(
|
693 |
+
cls,
|
694 |
+
root: str,
|
695 |
+
cfg: S2TDataConfig,
|
696 |
+
splits: str,
|
697 |
+
tgt_dict,
|
698 |
+
pre_tokenizer,
|
699 |
+
bpe_tokenizer,
|
700 |
+
is_train_split: bool,
|
701 |
+
epoch: int,
|
702 |
+
seed: int,
|
703 |
+
n_frames_per_step: int = 1,
|
704 |
+
speaker_to_id=None,
|
705 |
+
multitask: Optional[Dict] = None,
|
706 |
+
) -> SpeechToTextDataset:
|
707 |
+
datasets = [
|
708 |
+
cls._from_tsv(
|
709 |
+
root=root,
|
710 |
+
cfg=cfg,
|
711 |
+
split=split,
|
712 |
+
tgt_dict=tgt_dict,
|
713 |
+
is_train_split=is_train_split,
|
714 |
+
pre_tokenizer=pre_tokenizer,
|
715 |
+
bpe_tokenizer=bpe_tokenizer,
|
716 |
+
n_frames_per_step=n_frames_per_step,
|
717 |
+
speaker_to_id=speaker_to_id,
|
718 |
+
multitask=multitask,
|
719 |
+
)
|
720 |
+
for split in splits.split(",")
|
721 |
+
]
|
722 |
+
|
723 |
+
if is_train_split and len(datasets) > 1 and cfg.sampling_alpha != 1.0:
|
724 |
+
# temperature-based sampling
|
725 |
+
size_ratios = cls.get_size_ratios(datasets, alpha=cfg.sampling_alpha)
|
726 |
+
datasets = [
|
727 |
+
ResamplingDataset(
|
728 |
+
d, size_ratio=r, seed=seed, epoch=epoch, replace=(r >= 1.0)
|
729 |
+
)
|
730 |
+
for r, d in zip(size_ratios, datasets)
|
731 |
+
]
|
732 |
+
|
733 |
+
return ConcatDataset(datasets) if len(datasets) > 1 else datasets[0]
|
fairseq/fairseq/data/backtranslation_dataset.py
ADDED
@@ -0,0 +1,165 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import torch
|
7 |
+
from fairseq import utils
|
8 |
+
|
9 |
+
from . import FairseqDataset
|
10 |
+
|
11 |
+
|
12 |
+
def backtranslate_samples(samples, collate_fn, generate_fn, cuda=True):
|
13 |
+
"""Backtranslate a list of samples.
|
14 |
+
|
15 |
+
Given an input (*samples*) of the form:
|
16 |
+
|
17 |
+
[{'id': 1, 'source': 'hallo welt'}]
|
18 |
+
|
19 |
+
this will return:
|
20 |
+
|
21 |
+
[{'id': 1, 'source': 'hello world', 'target': 'hallo welt'}]
|
22 |
+
|
23 |
+
Args:
|
24 |
+
samples (List[dict]): samples to backtranslate. Individual samples are
|
25 |
+
expected to have a 'source' key, which will become the 'target'
|
26 |
+
after backtranslation.
|
27 |
+
collate_fn (callable): function to collate samples into a mini-batch
|
28 |
+
generate_fn (callable): function to generate backtranslations
|
29 |
+
cuda (bool): use GPU for generation (default: ``True``)
|
30 |
+
|
31 |
+
Returns:
|
32 |
+
List[dict]: an updated list of samples with a backtranslated source
|
33 |
+
"""
|
34 |
+
collated_samples = collate_fn(samples)
|
35 |
+
s = utils.move_to_cuda(collated_samples) if cuda else collated_samples
|
36 |
+
generated_sources = generate_fn(s)
|
37 |
+
|
38 |
+
id_to_src = {sample["id"]: sample["source"] for sample in samples}
|
39 |
+
|
40 |
+
# Go through each tgt sentence in batch and its corresponding best
|
41 |
+
# generated hypothesis and create a backtranslation data pair
|
42 |
+
# {id: id, source: generated backtranslation, target: original tgt}
|
43 |
+
return [
|
44 |
+
{
|
45 |
+
"id": id.item(),
|
46 |
+
"target": id_to_src[id.item()],
|
47 |
+
"source": hypos[0]["tokens"].cpu(),
|
48 |
+
}
|
49 |
+
for id, hypos in zip(collated_samples["id"], generated_sources)
|
50 |
+
]
|
51 |
+
|
52 |
+
|
53 |
+
class BacktranslationDataset(FairseqDataset):
|
54 |
+
"""
|
55 |
+
Sets up a backtranslation dataset which takes a tgt batch, generates
|
56 |
+
a src using a tgt-src backtranslation function (*backtranslation_fn*),
|
57 |
+
and returns the corresponding `{generated src, input tgt}` batch.
|
58 |
+
|
59 |
+
Args:
|
60 |
+
tgt_dataset (~fairseq.data.FairseqDataset): the dataset to be
|
61 |
+
backtranslated. Only the source side of this dataset will be used.
|
62 |
+
After backtranslation, the source sentences in this dataset will be
|
63 |
+
returned as the targets.
|
64 |
+
src_dict (~fairseq.data.Dictionary): the dictionary of backtranslated
|
65 |
+
sentences.
|
66 |
+
tgt_dict (~fairseq.data.Dictionary, optional): the dictionary of
|
67 |
+
sentences to be backtranslated.
|
68 |
+
backtranslation_fn (callable, optional): function to call to generate
|
69 |
+
backtranslations. This is typically the `generate` method of a
|
70 |
+
:class:`~fairseq.sequence_generator.SequenceGenerator` object.
|
71 |
+
Pass in None when it is not available at initialization time, and
|
72 |
+
use set_backtranslation_fn function to set it when available.
|
73 |
+
output_collater (callable, optional): function to call on the
|
74 |
+
backtranslated samples to create the final batch
|
75 |
+
(default: ``tgt_dataset.collater``).
|
76 |
+
cuda: use GPU for generation
|
77 |
+
"""
|
78 |
+
|
79 |
+
def __init__(
|
80 |
+
self,
|
81 |
+
tgt_dataset,
|
82 |
+
src_dict,
|
83 |
+
tgt_dict=None,
|
84 |
+
backtranslation_fn=None,
|
85 |
+
output_collater=None,
|
86 |
+
cuda=True,
|
87 |
+
**kwargs
|
88 |
+
):
|
89 |
+
self.tgt_dataset = tgt_dataset
|
90 |
+
self.backtranslation_fn = backtranslation_fn
|
91 |
+
self.output_collater = (
|
92 |
+
output_collater if output_collater is not None else tgt_dataset.collater
|
93 |
+
)
|
94 |
+
self.cuda = cuda if torch.cuda.is_available() else False
|
95 |
+
self.src_dict = src_dict
|
96 |
+
self.tgt_dict = tgt_dict
|
97 |
+
|
98 |
+
def __getitem__(self, index):
|
99 |
+
"""
|
100 |
+
Returns a single sample from *tgt_dataset*. Note that backtranslation is
|
101 |
+
not applied in this step; use :func:`collater` instead to backtranslate
|
102 |
+
a batch of samples.
|
103 |
+
"""
|
104 |
+
return self.tgt_dataset[index]
|
105 |
+
|
106 |
+
def __len__(self):
|
107 |
+
return len(self.tgt_dataset)
|
108 |
+
|
109 |
+
def set_backtranslation_fn(self, backtranslation_fn):
|
110 |
+
self.backtranslation_fn = backtranslation_fn
|
111 |
+
|
112 |
+
def collater(self, samples):
|
113 |
+
"""Merge and backtranslate a list of samples to form a mini-batch.
|
114 |
+
|
115 |
+
Using the samples from *tgt_dataset*, load a collated target sample to
|
116 |
+
feed to the backtranslation model. Then take the backtranslation with
|
117 |
+
the best score as the source and the original input as the target.
|
118 |
+
|
119 |
+
Note: we expect *tgt_dataset* to provide a function `collater()` that
|
120 |
+
will collate samples into the format expected by *backtranslation_fn*.
|
121 |
+
After backtranslation, we will feed the new list of samples (i.e., the
|
122 |
+
`(backtranslated source, original source)` pairs) to *output_collater*
|
123 |
+
and return the result.
|
124 |
+
|
125 |
+
Args:
|
126 |
+
samples (List[dict]): samples to backtranslate and collate
|
127 |
+
|
128 |
+
Returns:
|
129 |
+
dict: a mini-batch with keys coming from *output_collater*
|
130 |
+
"""
|
131 |
+
if samples[0].get("is_dummy", False):
|
132 |
+
return samples
|
133 |
+
samples = backtranslate_samples(
|
134 |
+
samples=samples,
|
135 |
+
collate_fn=self.tgt_dataset.collater,
|
136 |
+
generate_fn=(lambda net_input: self.backtranslation_fn(net_input)),
|
137 |
+
cuda=self.cuda,
|
138 |
+
)
|
139 |
+
return self.output_collater(samples)
|
140 |
+
|
141 |
+
def num_tokens(self, index):
|
142 |
+
"""Just use the tgt dataset num_tokens"""
|
143 |
+
return self.tgt_dataset.num_tokens(index)
|
144 |
+
|
145 |
+
def ordered_indices(self):
|
146 |
+
"""Just use the tgt dataset ordered_indices"""
|
147 |
+
return self.tgt_dataset.ordered_indices()
|
148 |
+
|
149 |
+
def size(self, index):
|
150 |
+
"""Return an example's size as a float or tuple. This value is used
|
151 |
+
when filtering a dataset with ``--max-positions``.
|
152 |
+
|
153 |
+
Note: we use *tgt_dataset* to approximate the length of the source
|
154 |
+
sentence, since we do not know the actual length until after
|
155 |
+
backtranslation.
|
156 |
+
"""
|
157 |
+
tgt_size = self.tgt_dataset.size(index)[0]
|
158 |
+
return (tgt_size, tgt_size)
|
159 |
+
|
160 |
+
@property
|
161 |
+
def supports_prefetch(self):
|
162 |
+
return getattr(self.tgt_dataset, "supports_prefetch", False)
|
163 |
+
|
164 |
+
def prefetch(self, indices):
|
165 |
+
return self.tgt_dataset.prefetch(indices)
|
fairseq/fairseq/data/base_wrapper_dataset.py
ADDED
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
from torch.utils.data.dataloader import default_collate
|
7 |
+
|
8 |
+
from . import FairseqDataset
|
9 |
+
|
10 |
+
|
11 |
+
class BaseWrapperDataset(FairseqDataset):
|
12 |
+
def __init__(self, dataset):
|
13 |
+
super().__init__()
|
14 |
+
self.dataset = dataset
|
15 |
+
|
16 |
+
def __getitem__(self, index):
|
17 |
+
return self.dataset[index]
|
18 |
+
|
19 |
+
def __len__(self):
|
20 |
+
return len(self.dataset)
|
21 |
+
|
22 |
+
def collater(self, samples):
|
23 |
+
if hasattr(self.dataset, "collater"):
|
24 |
+
return self.dataset.collater(samples)
|
25 |
+
else:
|
26 |
+
return default_collate(samples)
|
27 |
+
|
28 |
+
@property
|
29 |
+
def sizes(self):
|
30 |
+
return self.dataset.sizes
|
31 |
+
|
32 |
+
def num_tokens(self, index):
|
33 |
+
return self.dataset.num_tokens(index)
|
34 |
+
|
35 |
+
def size(self, index):
|
36 |
+
return self.dataset.size(index)
|
37 |
+
|
38 |
+
def ordered_indices(self):
|
39 |
+
return self.dataset.ordered_indices()
|
40 |
+
|
41 |
+
@property
|
42 |
+
def supports_prefetch(self):
|
43 |
+
return getattr(self.dataset, "supports_prefetch", False)
|
44 |
+
|
45 |
+
def attr(self, attr: str, index: int):
|
46 |
+
return self.dataset.attr(attr, index)
|
47 |
+
|
48 |
+
def prefetch(self, indices):
|
49 |
+
self.dataset.prefetch(indices)
|
50 |
+
|
51 |
+
def get_batch_shapes(self):
|
52 |
+
return self.dataset.get_batch_shapes()
|
53 |
+
|
54 |
+
def batch_by_size(
|
55 |
+
self,
|
56 |
+
indices,
|
57 |
+
max_tokens=None,
|
58 |
+
max_sentences=None,
|
59 |
+
required_batch_size_multiple=1,
|
60 |
+
):
|
61 |
+
return self.dataset.batch_by_size(
|
62 |
+
indices,
|
63 |
+
max_tokens=max_tokens,
|
64 |
+
max_sentences=max_sentences,
|
65 |
+
required_batch_size_multiple=required_batch_size_multiple,
|
66 |
+
)
|
67 |
+
|
68 |
+
def filter_indices_by_size(self, indices, max_sizes):
|
69 |
+
return self.dataset.filter_indices_by_size(indices, max_sizes)
|
70 |
+
|
71 |
+
@property
|
72 |
+
def can_reuse_epoch_itr_across_epochs(self):
|
73 |
+
return self.dataset.can_reuse_epoch_itr_across_epochs
|
74 |
+
|
75 |
+
def set_epoch(self, epoch):
|
76 |
+
super().set_epoch(epoch)
|
77 |
+
if hasattr(self.dataset, "set_epoch"):
|
78 |
+
self.dataset.set_epoch(epoch)
|
fairseq/fairseq/data/bucket_pad_length_dataset.py
ADDED
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
import torch.nn.functional as F
|
8 |
+
from fairseq.data import BaseWrapperDataset
|
9 |
+
from fairseq.data.data_utils import get_buckets, get_bucketed_sizes
|
10 |
+
|
11 |
+
|
12 |
+
class BucketPadLengthDataset(BaseWrapperDataset):
|
13 |
+
"""
|
14 |
+
Bucket and pad item lengths to the nearest bucket size. This can be used to
|
15 |
+
reduce the number of unique batch shapes, which is important on TPUs since
|
16 |
+
each new batch shape requires a recompilation.
|
17 |
+
|
18 |
+
Args:
|
19 |
+
dataset (FairseqDatset): dataset to bucket
|
20 |
+
sizes (List[int]): all item sizes
|
21 |
+
num_buckets (int): number of buckets to create
|
22 |
+
pad_idx (int): padding symbol
|
23 |
+
left_pad (bool): if True, pad on the left; otherwise right pad
|
24 |
+
"""
|
25 |
+
|
26 |
+
def __init__(
|
27 |
+
self,
|
28 |
+
dataset,
|
29 |
+
sizes,
|
30 |
+
num_buckets,
|
31 |
+
pad_idx,
|
32 |
+
left_pad,
|
33 |
+
tensor_key=None,
|
34 |
+
):
|
35 |
+
super().__init__(dataset)
|
36 |
+
self.pad_idx = pad_idx
|
37 |
+
self.left_pad = left_pad
|
38 |
+
|
39 |
+
assert num_buckets > 0
|
40 |
+
self.buckets = get_buckets(sizes, num_buckets)
|
41 |
+
self._bucketed_sizes = get_bucketed_sizes(sizes, self.buckets)
|
42 |
+
self._tensor_key = tensor_key
|
43 |
+
|
44 |
+
def _set_tensor(self, item, val):
|
45 |
+
if self._tensor_key is None:
|
46 |
+
return val
|
47 |
+
item[self._tensor_key] = val
|
48 |
+
return item
|
49 |
+
|
50 |
+
def _get_tensor(self, item):
|
51 |
+
if self._tensor_key is None:
|
52 |
+
return item
|
53 |
+
return item[self._tensor_key]
|
54 |
+
|
55 |
+
def _pad(self, tensor, bucket_size, dim=-1):
|
56 |
+
num_pad = bucket_size - tensor.size(dim)
|
57 |
+
return F.pad(
|
58 |
+
tensor,
|
59 |
+
(num_pad if self.left_pad else 0, 0 if self.left_pad else num_pad),
|
60 |
+
value=self.pad_idx,
|
61 |
+
)
|
62 |
+
|
63 |
+
def __getitem__(self, index):
|
64 |
+
item = self.dataset[index]
|
65 |
+
bucket_size = self._bucketed_sizes[index]
|
66 |
+
tensor = self._get_tensor(item)
|
67 |
+
padded = self._pad(tensor, bucket_size)
|
68 |
+
return self._set_tensor(item, padded)
|
69 |
+
|
70 |
+
@property
|
71 |
+
def sizes(self):
|
72 |
+
return self._bucketed_sizes
|
73 |
+
|
74 |
+
def num_tokens(self, index):
|
75 |
+
return self._bucketed_sizes[index]
|
76 |
+
|
77 |
+
def size(self, index):
|
78 |
+
return self._bucketed_sizes[index]
|
fairseq/fairseq/data/codedataset.py
ADDED
@@ -0,0 +1,576 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
|
7 |
+
import json
|
8 |
+
import logging
|
9 |
+
import os
|
10 |
+
import random
|
11 |
+
from pathlib import Path
|
12 |
+
|
13 |
+
import numpy as np
|
14 |
+
import torch
|
15 |
+
import torch.utils.data
|
16 |
+
|
17 |
+
from . import data_utils
|
18 |
+
from fairseq.data.fairseq_dataset import FairseqDataset
|
19 |
+
|
20 |
+
F0_FRAME_SPACE = 0.005 # sec
|
21 |
+
|
22 |
+
|
23 |
+
logger = logging.getLogger(__name__)
|
24 |
+
|
25 |
+
|
26 |
+
class ExpressiveCodeDataConfig(object):
|
27 |
+
def __init__(self, json_path):
|
28 |
+
with open(json_path, "r") as f:
|
29 |
+
self.config = json.load(f)
|
30 |
+
self._manifests = self.config["manifests"]
|
31 |
+
|
32 |
+
@property
|
33 |
+
def manifests(self):
|
34 |
+
return self._manifests
|
35 |
+
|
36 |
+
@property
|
37 |
+
def n_units(self):
|
38 |
+
return self.config["n_units"]
|
39 |
+
|
40 |
+
@property
|
41 |
+
def sampling_rate(self):
|
42 |
+
return self.config["sampling_rate"]
|
43 |
+
|
44 |
+
@property
|
45 |
+
def code_hop_size(self):
|
46 |
+
return self.config["code_hop_size"]
|
47 |
+
|
48 |
+
@property
|
49 |
+
def f0_stats(self):
|
50 |
+
"""pre-computed f0 statistics path"""
|
51 |
+
return self.config.get("f0_stats", None)
|
52 |
+
|
53 |
+
@property
|
54 |
+
def f0_vq_type(self):
|
55 |
+
"""naive or precomp"""
|
56 |
+
return self.config["f0_vq_type"]
|
57 |
+
|
58 |
+
@property
|
59 |
+
def f0_vq_name(self):
|
60 |
+
return self.config["f0_vq_name"]
|
61 |
+
|
62 |
+
def get_f0_vq_naive_quantizer(self, log, norm_mean, norm_std):
|
63 |
+
key = "log" if log else "linear"
|
64 |
+
if norm_mean and norm_std:
|
65 |
+
key += "_mean_std_norm"
|
66 |
+
elif norm_mean:
|
67 |
+
key += "_mean_norm"
|
68 |
+
else:
|
69 |
+
key += "_none_norm"
|
70 |
+
return self.config["f0_vq_naive_quantizer"][key]
|
71 |
+
|
72 |
+
@property
|
73 |
+
def f0_vq_n_units(self):
|
74 |
+
return self.config["f0_vq_n_units"]
|
75 |
+
|
76 |
+
@property
|
77 |
+
def multispkr(self):
|
78 |
+
"""how to parse speaker label from audio path"""
|
79 |
+
return self.config.get("multispkr", None)
|
80 |
+
|
81 |
+
|
82 |
+
def get_f0(audio, rate=16000):
|
83 |
+
try:
|
84 |
+
import amfm_decompy.basic_tools as basic
|
85 |
+
import amfm_decompy.pYAAPT as pYAAPT
|
86 |
+
from librosa.util import normalize
|
87 |
+
except ImportError:
|
88 |
+
raise "Please install amfm_decompy (`pip install AMFM-decompy`) and librosa (`pip install librosa`)."
|
89 |
+
|
90 |
+
assert audio.ndim == 1
|
91 |
+
frame_length = 20.0 # ms
|
92 |
+
to_pad = int(frame_length / 1000 * rate) // 2
|
93 |
+
|
94 |
+
audio = normalize(audio) * 0.95
|
95 |
+
audio = np.pad(audio, (to_pad, to_pad), "constant", constant_values=0)
|
96 |
+
audio = basic.SignalObj(audio, rate)
|
97 |
+
pitch = pYAAPT.yaapt(
|
98 |
+
audio,
|
99 |
+
frame_length=frame_length,
|
100 |
+
frame_space=F0_FRAME_SPACE * 1000,
|
101 |
+
nccf_thresh1=0.25,
|
102 |
+
tda_frame_length=25.0,
|
103 |
+
)
|
104 |
+
f0 = pitch.samp_values
|
105 |
+
return f0
|
106 |
+
|
107 |
+
|
108 |
+
def interpolate_f0(f0):
|
109 |
+
try:
|
110 |
+
from scipy.interpolate import interp1d
|
111 |
+
except ImportError:
|
112 |
+
raise "Please install scipy (`pip install scipy`)"
|
113 |
+
|
114 |
+
orig_t = np.arange(f0.shape[0])
|
115 |
+
f0_interp = f0[:]
|
116 |
+
ii = f0_interp != 0
|
117 |
+
if ii.sum() > 1:
|
118 |
+
f0_interp = interp1d(
|
119 |
+
orig_t[ii], f0_interp[ii], bounds_error=False, kind="linear", fill_value=0
|
120 |
+
)(orig_t)
|
121 |
+
f0_interp = torch.Tensor(f0_interp).type_as(f0).to(f0.device)
|
122 |
+
return f0_interp
|
123 |
+
|
124 |
+
|
125 |
+
def naive_quantize(x, edges):
|
126 |
+
bin_idx = (x.view(-1, 1) > edges.view(1, -1)).long().sum(dim=1)
|
127 |
+
return bin_idx
|
128 |
+
|
129 |
+
|
130 |
+
def load_wav(full_path):
|
131 |
+
try:
|
132 |
+
import soundfile as sf
|
133 |
+
except ImportError:
|
134 |
+
raise "Please install soundfile (`pip install SoundFile`)"
|
135 |
+
data, sampling_rate = sf.read(full_path)
|
136 |
+
return data, sampling_rate
|
137 |
+
|
138 |
+
|
139 |
+
def parse_code(code_str, dictionary, append_eos):
|
140 |
+
code, duration = torch.unique_consecutive(
|
141 |
+
torch.ShortTensor(list(map(int, code_str.split()))), return_counts=True
|
142 |
+
)
|
143 |
+
code = " ".join(map(str, code.tolist()))
|
144 |
+
code = dictionary.encode_line(code, append_eos).short()
|
145 |
+
|
146 |
+
if append_eos:
|
147 |
+
duration = torch.cat((duration, duration.new_zeros((1,))), dim=0) # eos
|
148 |
+
duration = duration.short()
|
149 |
+
return code, duration
|
150 |
+
|
151 |
+
|
152 |
+
def parse_manifest(manifest, dictionary):
|
153 |
+
audio_files = []
|
154 |
+
codes = []
|
155 |
+
durations = []
|
156 |
+
speakers = []
|
157 |
+
|
158 |
+
with open(manifest) as info:
|
159 |
+
for line in info.readlines():
|
160 |
+
sample = eval(line.strip())
|
161 |
+
if "cpc_km100" in sample:
|
162 |
+
k = "cpc_km100"
|
163 |
+
elif "hubert_km100" in sample:
|
164 |
+
k = "hubert_km100"
|
165 |
+
elif "phone" in sample:
|
166 |
+
k = "phone"
|
167 |
+
else:
|
168 |
+
assert False, "unknown format"
|
169 |
+
code = sample[k]
|
170 |
+
code, duration = parse_code(code, dictionary, append_eos=True)
|
171 |
+
|
172 |
+
codes.append(code)
|
173 |
+
durations.append(duration)
|
174 |
+
audio_files.append(sample["audio"])
|
175 |
+
speakers.append(sample.get("speaker", None))
|
176 |
+
|
177 |
+
return audio_files, codes, durations, speakers
|
178 |
+
|
179 |
+
|
180 |
+
def parse_speaker(path, method):
|
181 |
+
if type(path) == str:
|
182 |
+
path = Path(path)
|
183 |
+
|
184 |
+
if method == "parent_name":
|
185 |
+
return path.parent.name
|
186 |
+
elif method == "parent_parent_name":
|
187 |
+
return path.parent.parent.name
|
188 |
+
elif method == "_":
|
189 |
+
return path.name.split("_")[0]
|
190 |
+
elif method == "single":
|
191 |
+
return "A"
|
192 |
+
elif callable(method):
|
193 |
+
return method(path)
|
194 |
+
else:
|
195 |
+
raise NotImplementedError()
|
196 |
+
|
197 |
+
|
198 |
+
def get_f0_by_filename(filename, tgt_sampling_rate):
|
199 |
+
audio, sampling_rate = load_wav(filename)
|
200 |
+
if sampling_rate != tgt_sampling_rate:
|
201 |
+
raise ValueError(
|
202 |
+
"{} SR doesn't match target {} SR".format(sampling_rate, tgt_sampling_rate)
|
203 |
+
)
|
204 |
+
|
205 |
+
# compute un-interpolated f0, and use Ann's interp in __getitem__ if set
|
206 |
+
f0 = get_f0(audio, rate=tgt_sampling_rate)
|
207 |
+
f0 = torch.from_numpy(f0.astype(np.float32))
|
208 |
+
return f0
|
209 |
+
|
210 |
+
|
211 |
+
def align_f0_to_durations(f0, durations, f0_code_ratio, tol=1):
|
212 |
+
code_len = durations.sum()
|
213 |
+
targ_len = int(f0_code_ratio * code_len)
|
214 |
+
diff = f0.size(0) - targ_len
|
215 |
+
assert abs(diff) <= tol, (
|
216 |
+
f"Cannot subsample F0: |{f0.size(0)} - {f0_code_ratio}*{code_len}|"
|
217 |
+
f" > {tol} (dur=\n{durations})"
|
218 |
+
)
|
219 |
+
if diff > 0:
|
220 |
+
f0 = f0[:targ_len]
|
221 |
+
elif diff < 0:
|
222 |
+
f0 = torch.cat((f0, f0.new_full((-diff,), f0[-1])), 0)
|
223 |
+
|
224 |
+
f0_offset = 0.0
|
225 |
+
seg_f0s = []
|
226 |
+
for dur in durations:
|
227 |
+
f0_dur = dur.item() * f0_code_ratio
|
228 |
+
seg_f0 = f0[int(f0_offset) : int(f0_offset + f0_dur)]
|
229 |
+
seg_f0 = seg_f0[seg_f0 != 0]
|
230 |
+
if len(seg_f0) == 0:
|
231 |
+
seg_f0 = torch.tensor(0).type(seg_f0.type())
|
232 |
+
else:
|
233 |
+
seg_f0 = seg_f0.mean()
|
234 |
+
seg_f0s.append(seg_f0)
|
235 |
+
f0_offset += f0_dur
|
236 |
+
|
237 |
+
assert int(f0_offset) == f0.size(0), f"{f0_offset} {f0.size()} {durations.sum()}"
|
238 |
+
return torch.tensor(seg_f0s)
|
239 |
+
|
240 |
+
|
241 |
+
class Paddings(object):
|
242 |
+
def __init__(self, code_val, dur_val=0, f0_val=-2.0):
|
243 |
+
self.code = code_val
|
244 |
+
self.dur = dur_val
|
245 |
+
self.f0 = f0_val
|
246 |
+
|
247 |
+
|
248 |
+
class Shifts(object):
|
249 |
+
def __init__(self, shifts_str, pads):
|
250 |
+
self._shifts = list(map(int, shifts_str.split(",")))
|
251 |
+
assert len(self._shifts) == 2, self._shifts
|
252 |
+
assert all(s >= 0 for s in self._shifts)
|
253 |
+
self.extra_length = max(s for s in self._shifts)
|
254 |
+
self.pads = pads
|
255 |
+
|
256 |
+
@property
|
257 |
+
def dur(self):
|
258 |
+
return self._shifts[0]
|
259 |
+
|
260 |
+
@property
|
261 |
+
def f0(self):
|
262 |
+
return self._shifts[1]
|
263 |
+
|
264 |
+
@staticmethod
|
265 |
+
def shift_one(seq, left_pad_num, right_pad_num, pad):
|
266 |
+
assert seq.ndim == 1
|
267 |
+
bos = seq.new_full((left_pad_num,), pad)
|
268 |
+
eos = seq.new_full((right_pad_num,), pad)
|
269 |
+
seq = torch.cat([bos, seq, eos])
|
270 |
+
mask = torch.ones_like(seq).bool()
|
271 |
+
mask[left_pad_num : len(seq) - right_pad_num] = 0
|
272 |
+
return seq, mask
|
273 |
+
|
274 |
+
def __call__(self, code, dur, f0):
|
275 |
+
if self.extra_length == 0:
|
276 |
+
code_mask = torch.zeros_like(code).bool()
|
277 |
+
dur_mask = torch.zeros_like(dur).bool()
|
278 |
+
f0_mask = torch.zeros_like(f0).bool()
|
279 |
+
return code, code_mask, dur, dur_mask, f0, f0_mask
|
280 |
+
|
281 |
+
code, code_mask = self.shift_one(code, 0, self.extra_length, self.pads.code)
|
282 |
+
dur, dur_mask = self.shift_one(
|
283 |
+
dur, self.dur, self.extra_length - self.dur, self.pads.dur
|
284 |
+
)
|
285 |
+
f0, f0_mask = self.shift_one(
|
286 |
+
f0, self.f0, self.extra_length - self.f0, self.pads.f0
|
287 |
+
)
|
288 |
+
return code, code_mask, dur, dur_mask, f0, f0_mask
|
289 |
+
|
290 |
+
|
291 |
+
class CodeDataset(FairseqDataset):
|
292 |
+
def __init__(
|
293 |
+
self,
|
294 |
+
manifest,
|
295 |
+
dictionary,
|
296 |
+
dur_dictionary,
|
297 |
+
f0_dictionary,
|
298 |
+
config,
|
299 |
+
discrete_dur,
|
300 |
+
discrete_f0,
|
301 |
+
log_f0,
|
302 |
+
normalize_f0_mean,
|
303 |
+
normalize_f0_std,
|
304 |
+
interpolate_f0,
|
305 |
+
return_filename=False,
|
306 |
+
strip_filename=True,
|
307 |
+
shifts="0,0",
|
308 |
+
return_continuous_f0=False,
|
309 |
+
):
|
310 |
+
random.seed(1234)
|
311 |
+
self.dictionary = dictionary
|
312 |
+
self.dur_dictionary = dur_dictionary
|
313 |
+
self.f0_dictionary = f0_dictionary
|
314 |
+
self.config = config
|
315 |
+
|
316 |
+
# duration config
|
317 |
+
self.discrete_dur = discrete_dur
|
318 |
+
|
319 |
+
# pitch config
|
320 |
+
self.discrete_f0 = discrete_f0
|
321 |
+
self.log_f0 = log_f0
|
322 |
+
self.normalize_f0_mean = normalize_f0_mean
|
323 |
+
self.normalize_f0_std = normalize_f0_std
|
324 |
+
self.interpolate_f0 = interpolate_f0
|
325 |
+
|
326 |
+
self.return_filename = return_filename
|
327 |
+
self.strip_filename = strip_filename
|
328 |
+
self.f0_code_ratio = config.code_hop_size / (
|
329 |
+
config.sampling_rate * F0_FRAME_SPACE
|
330 |
+
)
|
331 |
+
|
332 |
+
# use lazy loading to avoid sharing file handlers across workers
|
333 |
+
self.manifest = manifest
|
334 |
+
self._codes = None
|
335 |
+
self._durs = None
|
336 |
+
self._f0s = None
|
337 |
+
with open(f"{manifest}.leng.txt", "r") as f:
|
338 |
+
lengs = [int(line.rstrip()) for line in f]
|
339 |
+
edges = np.cumsum([0] + lengs)
|
340 |
+
self.starts, self.ends = edges[:-1], edges[1:]
|
341 |
+
with open(f"{manifest}.path.txt", "r") as f:
|
342 |
+
self.file_names = [line.rstrip() for line in f]
|
343 |
+
logger.info(f"num entries: {len(self.starts)}")
|
344 |
+
|
345 |
+
if os.path.exists(f"{manifest}.f0_stat.pt"):
|
346 |
+
self.f0_stats = torch.load(f"{manifest}.f0_stat.pt")
|
347 |
+
elif config.f0_stats:
|
348 |
+
self.f0_stats = torch.load(config.f0_stats)
|
349 |
+
|
350 |
+
self.multispkr = config.multispkr
|
351 |
+
if config.multispkr:
|
352 |
+
with open(f"{manifest}.speaker.txt", "r") as f:
|
353 |
+
self.spkrs = [line.rstrip() for line in f]
|
354 |
+
self.id_to_spkr = sorted(self.spkrs)
|
355 |
+
self.spkr_to_id = {k: v for v, k in enumerate(self.id_to_spkr)}
|
356 |
+
|
357 |
+
self.pads = Paddings(
|
358 |
+
dictionary.pad(),
|
359 |
+
0, # use 0 for duration padding
|
360 |
+
f0_dictionary.pad() if discrete_f0 else -5.0,
|
361 |
+
)
|
362 |
+
self.shifts = Shifts(shifts, pads=self.pads)
|
363 |
+
self.return_continuous_f0 = return_continuous_f0
|
364 |
+
|
365 |
+
def get_data_handlers(self):
|
366 |
+
logging.info(f"loading data for {self.manifest}")
|
367 |
+
self._codes = np.load(f"{self.manifest}.code.npy", mmap_mode="r")
|
368 |
+
self._durs = np.load(f"{self.manifest}.dur.npy", mmap_mode="r")
|
369 |
+
|
370 |
+
if self.discrete_f0:
|
371 |
+
if self.config.f0_vq_type == "precomp":
|
372 |
+
self._f0s = np.load(
|
373 |
+
f"{self.manifest}.{self.config.f0_vq_name}.npy", mmap_mode="r"
|
374 |
+
)
|
375 |
+
elif self.config.f0_vq_type == "naive":
|
376 |
+
self._f0s = np.load(f"{self.manifest}.f0.npy", mmap_mode="r")
|
377 |
+
quantizers_path = self.config.get_f0_vq_naive_quantizer(
|
378 |
+
self.log_f0, self.normalize_f0_mean, self.normalize_f0_std
|
379 |
+
)
|
380 |
+
quantizers = torch.load(quantizers_path)
|
381 |
+
n_units = self.config.f0_vq_n_units
|
382 |
+
self._f0_quantizer = torch.from_numpy(quantizers[n_units])
|
383 |
+
else:
|
384 |
+
raise ValueError(f"f0_vq_type {self.config.f0_vq_type} not supported")
|
385 |
+
else:
|
386 |
+
self._f0s = np.load(f"{self.manifest}.f0.npy", mmap_mode="r")
|
387 |
+
|
388 |
+
def preprocess_f0(self, f0, stats):
|
389 |
+
"""
|
390 |
+
1. interpolate
|
391 |
+
2. log transform (keep unvoiced frame 0)
|
392 |
+
"""
|
393 |
+
# TODO: change this to be dependent on config for naive quantizer
|
394 |
+
f0 = f0.clone()
|
395 |
+
if self.interpolate_f0:
|
396 |
+
f0 = interpolate_f0(f0)
|
397 |
+
|
398 |
+
mask = f0 != 0 # only process voiced frames
|
399 |
+
if self.log_f0:
|
400 |
+
f0[mask] = f0[mask].log()
|
401 |
+
if self.normalize_f0_mean:
|
402 |
+
mean = stats["logf0_mean"] if self.log_f0 else stats["f0_mean"]
|
403 |
+
f0[mask] = f0[mask] - mean
|
404 |
+
if self.normalize_f0_std:
|
405 |
+
std = stats["logf0_std"] if self.log_f0 else stats["f0_std"]
|
406 |
+
f0[mask] = f0[mask] / std
|
407 |
+
return f0
|
408 |
+
|
409 |
+
def _get_raw_item(self, index):
|
410 |
+
start, end = self.starts[index], self.ends[index]
|
411 |
+
if self._codes is None:
|
412 |
+
self.get_data_handlers()
|
413 |
+
code = torch.from_numpy(np.array(self._codes[start:end])).long()
|
414 |
+
dur = torch.from_numpy(np.array(self._durs[start:end]))
|
415 |
+
f0 = torch.from_numpy(np.array(self._f0s[start:end]))
|
416 |
+
return code, dur, f0
|
417 |
+
|
418 |
+
def __getitem__(self, index):
|
419 |
+
code, dur, f0 = self._get_raw_item(index)
|
420 |
+
code = torch.cat([code.new([self.dictionary.bos()]), code])
|
421 |
+
|
422 |
+
# use 0 for eos and bos
|
423 |
+
dur = torch.cat([dur.new([0]), dur])
|
424 |
+
if self.discrete_dur:
|
425 |
+
dur = self.dur_dictionary.encode_line(
|
426 |
+
" ".join(map(str, dur.tolist())), append_eos=False
|
427 |
+
).long()
|
428 |
+
else:
|
429 |
+
dur = dur.float()
|
430 |
+
|
431 |
+
# TODO: find a more elegant approach
|
432 |
+
raw_f0 = None
|
433 |
+
if self.discrete_f0:
|
434 |
+
if self.config.f0_vq_type == "precomp":
|
435 |
+
f0 = self.f0_dictionary.encode_line(
|
436 |
+
" ".join(map(str, f0.tolist())), append_eos=False
|
437 |
+
).long()
|
438 |
+
else:
|
439 |
+
f0 = f0.float()
|
440 |
+
f0 = self.preprocess_f0(f0, self.f0_stats[self.spkrs[index]])
|
441 |
+
if self.return_continuous_f0:
|
442 |
+
raw_f0 = f0
|
443 |
+
raw_f0 = torch.cat([raw_f0.new([self.f0_dictionary.bos()]), raw_f0])
|
444 |
+
f0 = naive_quantize(f0, self._f0_quantizer)
|
445 |
+
f0 = torch.cat([f0.new([self.f0_dictionary.bos()]), f0])
|
446 |
+
else:
|
447 |
+
f0 = f0.float()
|
448 |
+
if self.multispkr:
|
449 |
+
f0 = self.preprocess_f0(f0, self.f0_stats[self.spkrs[index]])
|
450 |
+
else:
|
451 |
+
f0 = self.preprocess_f0(f0, self.f0_stats)
|
452 |
+
f0 = torch.cat([f0.new([0]), f0])
|
453 |
+
|
454 |
+
if raw_f0 is not None:
|
455 |
+
*_, raw_f0, raw_f0_mask = self.shifts(code, dur, raw_f0)
|
456 |
+
else:
|
457 |
+
raw_f0_mask = None
|
458 |
+
|
459 |
+
code, code_mask, dur, dur_mask, f0, f0_mask = self.shifts(code, dur, f0)
|
460 |
+
if raw_f0_mask is not None:
|
461 |
+
assert (raw_f0_mask == f0_mask).all()
|
462 |
+
|
463 |
+
# is a padded frame if either input or output is padded
|
464 |
+
feats = {
|
465 |
+
"source": code[:-1],
|
466 |
+
"target": code[1:],
|
467 |
+
"mask": code_mask[1:].logical_or(code_mask[:-1]),
|
468 |
+
"dur_source": dur[:-1],
|
469 |
+
"dur_target": dur[1:],
|
470 |
+
"dur_mask": dur_mask[1:].logical_or(dur_mask[:-1]),
|
471 |
+
"f0_source": f0[:-1],
|
472 |
+
"f0_target": f0[1:],
|
473 |
+
"f0_mask": f0_mask[1:].logical_or(f0_mask[:-1]),
|
474 |
+
}
|
475 |
+
|
476 |
+
if raw_f0 is not None:
|
477 |
+
feats["raw_f0"] = raw_f0[1:]
|
478 |
+
|
479 |
+
if self.return_filename:
|
480 |
+
fname = self.file_names[index]
|
481 |
+
feats["filename"] = (
|
482 |
+
fname if not self.strip_filename else Path(fname).with_suffix("").name
|
483 |
+
)
|
484 |
+
return feats
|
485 |
+
|
486 |
+
def __len__(self):
|
487 |
+
return len(self.starts)
|
488 |
+
|
489 |
+
def size(self, index):
|
490 |
+
return self.ends[index] - self.starts[index] + self.shifts.extra_length
|
491 |
+
|
492 |
+
def num_tokens(self, index):
|
493 |
+
return self.size(index)
|
494 |
+
|
495 |
+
def collater(self, samples):
|
496 |
+
pad_idx, eos_idx = self.dictionary.pad(), self.dictionary.eos()
|
497 |
+
if len(samples) == 0:
|
498 |
+
return {}
|
499 |
+
|
500 |
+
src_tokens = data_utils.collate_tokens(
|
501 |
+
[s["source"] for s in samples], pad_idx, eos_idx, left_pad=False
|
502 |
+
)
|
503 |
+
|
504 |
+
tgt_tokens = data_utils.collate_tokens(
|
505 |
+
[s["target"] for s in samples],
|
506 |
+
pad_idx=pad_idx,
|
507 |
+
eos_idx=pad_idx, # appending padding, eos is there already
|
508 |
+
left_pad=False,
|
509 |
+
)
|
510 |
+
|
511 |
+
src_durs, tgt_durs = [
|
512 |
+
data_utils.collate_tokens(
|
513 |
+
[s[k] for s in samples],
|
514 |
+
pad_idx=self.pads.dur,
|
515 |
+
eos_idx=self.pads.dur,
|
516 |
+
left_pad=False,
|
517 |
+
)
|
518 |
+
for k in ["dur_source", "dur_target"]
|
519 |
+
]
|
520 |
+
|
521 |
+
src_f0s, tgt_f0s = [
|
522 |
+
data_utils.collate_tokens(
|
523 |
+
[s[k] for s in samples],
|
524 |
+
pad_idx=self.pads.f0,
|
525 |
+
eos_idx=self.pads.f0,
|
526 |
+
left_pad=False,
|
527 |
+
)
|
528 |
+
for k in ["f0_source", "f0_target"]
|
529 |
+
]
|
530 |
+
|
531 |
+
mask, dur_mask, f0_mask = [
|
532 |
+
data_utils.collate_tokens(
|
533 |
+
[s[k] for s in samples],
|
534 |
+
pad_idx=1,
|
535 |
+
eos_idx=1,
|
536 |
+
left_pad=False,
|
537 |
+
)
|
538 |
+
for k in ["mask", "dur_mask", "f0_mask"]
|
539 |
+
]
|
540 |
+
|
541 |
+
src_lengths = torch.LongTensor([s["source"].numel() for s in samples])
|
542 |
+
n_tokens = sum(len(s["source"]) for s in samples)
|
543 |
+
|
544 |
+
result = {
|
545 |
+
"nsentences": len(samples),
|
546 |
+
"ntokens": n_tokens,
|
547 |
+
"net_input": {
|
548 |
+
"src_tokens": src_tokens,
|
549 |
+
"src_lengths": src_lengths,
|
550 |
+
"dur_src": src_durs,
|
551 |
+
"f0_src": src_f0s,
|
552 |
+
},
|
553 |
+
"target": tgt_tokens,
|
554 |
+
"dur_target": tgt_durs,
|
555 |
+
"f0_target": tgt_f0s,
|
556 |
+
"mask": mask,
|
557 |
+
"dur_mask": dur_mask,
|
558 |
+
"f0_mask": f0_mask,
|
559 |
+
}
|
560 |
+
|
561 |
+
if "filename" in samples[0]:
|
562 |
+
result["filename"] = [s["filename"] for s in samples]
|
563 |
+
|
564 |
+
# TODO: remove this hack into the inference dataset
|
565 |
+
if "prefix" in samples[0]:
|
566 |
+
result["prefix"] = [s["prefix"] for s in samples]
|
567 |
+
|
568 |
+
if "raw_f0" in samples[0]:
|
569 |
+
raw_f0s = data_utils.collate_tokens(
|
570 |
+
[s["raw_f0"] for s in samples],
|
571 |
+
pad_idx=self.pads.f0,
|
572 |
+
eos_idx=self.pads.f0,
|
573 |
+
left_pad=False,
|
574 |
+
)
|
575 |
+
result["raw_f0"] = raw_f0s
|
576 |
+
return result
|
fairseq/fairseq/data/colorize_dataset.py
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import torch
|
7 |
+
|
8 |
+
from . import BaseWrapperDataset
|
9 |
+
|
10 |
+
|
11 |
+
class ColorizeDataset(BaseWrapperDataset):
|
12 |
+
"""Adds 'colors' property to net input that is obtained from the provided color getter for use by models"""
|
13 |
+
|
14 |
+
def __init__(self, dataset, color_getter):
|
15 |
+
super().__init__(dataset)
|
16 |
+
self.color_getter = color_getter
|
17 |
+
|
18 |
+
def collater(self, samples):
|
19 |
+
base_collate = super().collater(samples)
|
20 |
+
if len(base_collate) > 0:
|
21 |
+
base_collate["net_input"]["colors"] = torch.tensor(
|
22 |
+
list(self.color_getter(self.dataset, s["id"]) for s in samples),
|
23 |
+
dtype=torch.long,
|
24 |
+
)
|
25 |
+
return base_collate
|
fairseq/fairseq/data/concat_dataset.py
ADDED
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import bisect
|
7 |
+
|
8 |
+
import numpy as np
|
9 |
+
from torch.utils.data.dataloader import default_collate
|
10 |
+
|
11 |
+
from . import FairseqDataset
|
12 |
+
|
13 |
+
|
14 |
+
class ConcatDataset(FairseqDataset):
|
15 |
+
@staticmethod
|
16 |
+
def cumsum(sequence, sample_ratios):
|
17 |
+
r, s = [], 0
|
18 |
+
for e, ratio in zip(sequence, sample_ratios):
|
19 |
+
curr_len = int(ratio * len(e))
|
20 |
+
r.append(curr_len + s)
|
21 |
+
s += curr_len
|
22 |
+
return r
|
23 |
+
|
24 |
+
def __init__(self, datasets, sample_ratios=1):
|
25 |
+
super(ConcatDataset, self).__init__()
|
26 |
+
assert len(datasets) > 0, "datasets should not be an empty iterable"
|
27 |
+
self.datasets = list(datasets)
|
28 |
+
if isinstance(sample_ratios, int):
|
29 |
+
sample_ratios = [sample_ratios] * len(self.datasets)
|
30 |
+
self.sample_ratios = sample_ratios
|
31 |
+
self.cumulative_sizes = self.cumsum(self.datasets, sample_ratios)
|
32 |
+
self.real_sizes = [len(d) for d in self.datasets]
|
33 |
+
|
34 |
+
def __len__(self):
|
35 |
+
return self.cumulative_sizes[-1]
|
36 |
+
|
37 |
+
def __getitem__(self, idx):
|
38 |
+
dataset_idx, sample_idx = self._get_dataset_and_sample_index(idx)
|
39 |
+
return self.datasets[dataset_idx][sample_idx]
|
40 |
+
|
41 |
+
def _get_dataset_and_sample_index(self, idx: int):
|
42 |
+
dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
|
43 |
+
if dataset_idx == 0:
|
44 |
+
sample_idx = idx
|
45 |
+
else:
|
46 |
+
sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
|
47 |
+
sample_idx = sample_idx % self.real_sizes[dataset_idx]
|
48 |
+
return dataset_idx, sample_idx
|
49 |
+
|
50 |
+
def collater(self, samples, **extra_args):
|
51 |
+
# For now only supports datasets with same underlying collater implementations
|
52 |
+
if hasattr(self.datasets[0], "collater"):
|
53 |
+
return self.datasets[0].collater(samples, **extra_args)
|
54 |
+
else:
|
55 |
+
return default_collate(samples, **extra_args)
|
56 |
+
|
57 |
+
def size(self, idx: int):
|
58 |
+
"""
|
59 |
+
Return an example's size as a float or tuple.
|
60 |
+
"""
|
61 |
+
dataset_idx, sample_idx = self._get_dataset_and_sample_index(idx)
|
62 |
+
return self.datasets[dataset_idx].size(sample_idx)
|
63 |
+
|
64 |
+
def num_tokens(self, index: int):
|
65 |
+
return np.max(self.size(index))
|
66 |
+
|
67 |
+
def attr(self, attr: str, index: int):
|
68 |
+
dataset_idx = bisect.bisect_right(self.cumulative_sizes, index)
|
69 |
+
return getattr(self.datasets[dataset_idx], attr, None)
|
70 |
+
|
71 |
+
@property
|
72 |
+
def sizes(self):
|
73 |
+
_dataset_sizes = []
|
74 |
+
for ds, sr in zip(self.datasets, self.sample_ratios):
|
75 |
+
if isinstance(ds.sizes, np.ndarray):
|
76 |
+
_dataset_sizes.append(np.tile(ds.sizes, sr))
|
77 |
+
else:
|
78 |
+
# Only support underlying dataset with single size array.
|
79 |
+
assert isinstance(ds.sizes, list)
|
80 |
+
_dataset_sizes.append(np.tile(ds.sizes[0], sr))
|
81 |
+
return np.concatenate(_dataset_sizes)
|
82 |
+
|
83 |
+
@property
|
84 |
+
def supports_prefetch(self):
|
85 |
+
return all(d.supports_prefetch for d in self.datasets)
|
86 |
+
|
87 |
+
def ordered_indices(self):
|
88 |
+
"""
|
89 |
+
Returns indices sorted by length. So less padding is needed.
|
90 |
+
"""
|
91 |
+
if isinstance(self.sizes, np.ndarray) and len(self.sizes.shape) > 1:
|
92 |
+
# special handling for concatenating lang_pair_datasets
|
93 |
+
indices = np.arange(len(self))
|
94 |
+
sizes = self.sizes
|
95 |
+
tgt_sizes = (
|
96 |
+
sizes[:, 1] if len(sizes.shape) > 0 and sizes.shape[1] > 1 else None
|
97 |
+
)
|
98 |
+
src_sizes = (
|
99 |
+
sizes[:, 0] if len(sizes.shape) > 0 and sizes.shape[1] > 1 else sizes
|
100 |
+
)
|
101 |
+
# sort by target length, then source length
|
102 |
+
if tgt_sizes is not None:
|
103 |
+
indices = indices[np.argsort(tgt_sizes[indices], kind="mergesort")]
|
104 |
+
return indices[np.argsort(src_sizes[indices], kind="mergesort")]
|
105 |
+
else:
|
106 |
+
return np.argsort(self.sizes)
|
107 |
+
|
108 |
+
def prefetch(self, indices):
|
109 |
+
frm = 0
|
110 |
+
for to, ds in zip(self.cumulative_sizes, self.datasets):
|
111 |
+
real_size = len(ds)
|
112 |
+
if getattr(ds, "supports_prefetch", False):
|
113 |
+
ds.prefetch([(i - frm) % real_size for i in indices if frm <= i < to])
|
114 |
+
frm = to
|
115 |
+
|
116 |
+
@property
|
117 |
+
def can_reuse_epoch_itr_across_epochs(self):
|
118 |
+
return all(d.can_reuse_epoch_itr_across_epochs for d in self.datasets)
|
119 |
+
|
120 |
+
def set_epoch(self, epoch):
|
121 |
+
super().set_epoch(epoch)
|
122 |
+
for ds in self.datasets:
|
123 |
+
if hasattr(ds, "set_epoch"):
|
124 |
+
ds.set_epoch(epoch)
|
fairseq/fairseq/data/concat_sentences_dataset.py
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import torch
|
7 |
+
|
8 |
+
from . import FairseqDataset
|
9 |
+
|
10 |
+
|
11 |
+
class ConcatSentencesDataset(FairseqDataset):
|
12 |
+
def __init__(self, *datasets):
|
13 |
+
super().__init__()
|
14 |
+
self.datasets = datasets
|
15 |
+
assert all(
|
16 |
+
len(ds) == len(datasets[0]) for ds in datasets
|
17 |
+
), "datasets must have the same length"
|
18 |
+
|
19 |
+
def __getitem__(self, index):
|
20 |
+
return torch.cat([ds[index] for ds in self.datasets])
|
21 |
+
|
22 |
+
def __len__(self):
|
23 |
+
return len(self.datasets[0])
|
24 |
+
|
25 |
+
def collater(self, samples):
|
26 |
+
return self.datasets[0].collater(samples)
|
27 |
+
|
28 |
+
@property
|
29 |
+
def sizes(self):
|
30 |
+
return sum(ds.sizes for ds in self.datasets)
|
31 |
+
|
32 |
+
def num_tokens(self, index):
|
33 |
+
return sum(ds.num_tokens(index) for ds in self.datasets)
|
34 |
+
|
35 |
+
def size(self, index):
|
36 |
+
return sum(ds.size(index) for ds in self.datasets)
|
37 |
+
|
38 |
+
def ordered_indices(self):
|
39 |
+
return self.datasets[0].ordered_indices()
|
40 |
+
|
41 |
+
@property
|
42 |
+
def supports_prefetch(self):
|
43 |
+
return any(getattr(ds, "supports_prefetch", False) for ds in self.datasets)
|
44 |
+
|
45 |
+
def prefetch(self, indices):
|
46 |
+
for ds in self.datasets:
|
47 |
+
if getattr(ds, "supports_prefetch", False):
|
48 |
+
ds.prefetch(indices)
|
49 |
+
|
50 |
+
def set_epoch(self, epoch):
|
51 |
+
super().set_epoch(epoch)
|
52 |
+
for ds in self.datasets:
|
53 |
+
if hasattr(ds, "set_epoch"):
|
54 |
+
ds.set_epoch(epoch)
|
fairseq/fairseq/data/data_utils.py
ADDED
@@ -0,0 +1,1144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
try:
|
7 |
+
from collections.abc import Iterable
|
8 |
+
except ImportError:
|
9 |
+
from collections import Iterable
|
10 |
+
import contextlib
|
11 |
+
import itertools
|
12 |
+
import logging
|
13 |
+
import re
|
14 |
+
import warnings
|
15 |
+
from typing import Optional, Tuple
|
16 |
+
|
17 |
+
import math
|
18 |
+
import numpy as np
|
19 |
+
import torch
|
20 |
+
|
21 |
+
from fairseq.file_io import PathManager
|
22 |
+
from fairseq import utils
|
23 |
+
import os
|
24 |
+
|
25 |
+
logger = logging.getLogger(__name__)
|
26 |
+
|
27 |
+
|
28 |
+
def infer_language_pair(path):
|
29 |
+
"""Infer language pair from filename: <split>.<lang1>-<lang2>.(...).idx"""
|
30 |
+
src, dst = None, None
|
31 |
+
for filename in PathManager.ls(path):
|
32 |
+
parts = filename.split(".")
|
33 |
+
if len(parts) >= 3 and len(parts[1].split("-")) == 2:
|
34 |
+
return parts[1].split("-")
|
35 |
+
return src, dst
|
36 |
+
|
37 |
+
|
38 |
+
def collate_tokens(
|
39 |
+
values,
|
40 |
+
pad_idx,
|
41 |
+
eos_idx=None,
|
42 |
+
left_pad=False,
|
43 |
+
move_eos_to_beginning=False,
|
44 |
+
pad_to_length=None,
|
45 |
+
pad_to_multiple=1,
|
46 |
+
pad_to_bsz=None,
|
47 |
+
):
|
48 |
+
"""Convert a list of 1d tensors into a padded 2d tensor."""
|
49 |
+
size = max(v.size(0) for v in values)
|
50 |
+
size = size if pad_to_length is None else max(size, pad_to_length)
|
51 |
+
if pad_to_multiple != 1 and size % pad_to_multiple != 0:
|
52 |
+
size = int(((size - 0.1) // pad_to_multiple + 1) * pad_to_multiple)
|
53 |
+
|
54 |
+
batch_size = len(values) if pad_to_bsz is None else max(len(values), pad_to_bsz)
|
55 |
+
res = values[0].new(batch_size, size).fill_(pad_idx)
|
56 |
+
|
57 |
+
def copy_tensor(src, dst):
|
58 |
+
assert dst.numel() == src.numel()
|
59 |
+
if move_eos_to_beginning:
|
60 |
+
if eos_idx is None:
|
61 |
+
# if no eos_idx is specified, then use the last token in src
|
62 |
+
dst[0] = src[-1]
|
63 |
+
else:
|
64 |
+
dst[0] = eos_idx
|
65 |
+
dst[1:] = src[:-1]
|
66 |
+
else:
|
67 |
+
dst.copy_(src)
|
68 |
+
|
69 |
+
for i, v in enumerate(values):
|
70 |
+
copy_tensor(v, res[i][size - len(v) :] if left_pad else res[i][: len(v)])
|
71 |
+
return res
|
72 |
+
|
73 |
+
|
74 |
+
def load_indexed_dataset(
|
75 |
+
path, dictionary=None, dataset_impl=None, combine=False, default="cached"
|
76 |
+
):
|
77 |
+
"""A helper function for loading indexed datasets.
|
78 |
+
|
79 |
+
Args:
|
80 |
+
path (str): path to indexed dataset (e.g., 'data-bin/train')
|
81 |
+
dictionary (~fairseq.data.Dictionary): data dictionary
|
82 |
+
dataset_impl (str, optional): which dataset implementation to use. If
|
83 |
+
not provided, it will be inferred automatically. For legacy indexed
|
84 |
+
data we use the 'cached' implementation by default.
|
85 |
+
combine (bool, optional): automatically load and combine multiple
|
86 |
+
datasets. For example, if *path* is 'data-bin/train', then we will
|
87 |
+
combine 'data-bin/train', 'data-bin/train1', ... and return a
|
88 |
+
single ConcatDataset instance.
|
89 |
+
"""
|
90 |
+
import fairseq.data.indexed_dataset as indexed_dataset
|
91 |
+
from fairseq.data.concat_dataset import ConcatDataset
|
92 |
+
|
93 |
+
datasets = []
|
94 |
+
for k in itertools.count():
|
95 |
+
path_k = path + (str(k) if k > 0 else "")
|
96 |
+
try:
|
97 |
+
path_k = indexed_dataset.get_indexed_dataset_to_local(path_k)
|
98 |
+
except Exception as e:
|
99 |
+
if "StorageException: [404] Path not found" in str(e):
|
100 |
+
logger.warning(f"path_k: {e} not found")
|
101 |
+
else:
|
102 |
+
raise e
|
103 |
+
|
104 |
+
dataset_impl_k = dataset_impl
|
105 |
+
if dataset_impl_k is None:
|
106 |
+
dataset_impl_k = indexed_dataset.infer_dataset_impl(path_k)
|
107 |
+
dataset = indexed_dataset.make_dataset(
|
108 |
+
path_k,
|
109 |
+
impl=dataset_impl_k or default,
|
110 |
+
fix_lua_indexing=True,
|
111 |
+
dictionary=dictionary,
|
112 |
+
)
|
113 |
+
if dataset is None:
|
114 |
+
break
|
115 |
+
logger.info("loaded {:,} examples from: {}".format(len(dataset), path_k))
|
116 |
+
datasets.append(dataset)
|
117 |
+
if not combine:
|
118 |
+
break
|
119 |
+
if len(datasets) == 0:
|
120 |
+
return None
|
121 |
+
elif len(datasets) == 1:
|
122 |
+
return datasets[0]
|
123 |
+
else:
|
124 |
+
return ConcatDataset(datasets)
|
125 |
+
|
126 |
+
|
127 |
+
@contextlib.contextmanager
|
128 |
+
def numpy_seed(seed, *addl_seeds):
|
129 |
+
"""Context manager which seeds the NumPy PRNG with the specified seed and
|
130 |
+
restores the state afterward"""
|
131 |
+
if seed is None:
|
132 |
+
yield
|
133 |
+
return
|
134 |
+
if len(addl_seeds) > 0:
|
135 |
+
seed = int(hash((seed, *addl_seeds)) % 1e6)
|
136 |
+
state = np.random.get_state()
|
137 |
+
np.random.seed(seed)
|
138 |
+
try:
|
139 |
+
yield
|
140 |
+
finally:
|
141 |
+
np.random.set_state(state)
|
142 |
+
|
143 |
+
|
144 |
+
def collect_filtered(function, iterable, filtered):
|
145 |
+
"""
|
146 |
+
Similar to :func:`filter` but collects filtered elements in ``filtered``.
|
147 |
+
|
148 |
+
Args:
|
149 |
+
function (callable): function that returns ``False`` for elements that
|
150 |
+
should be filtered
|
151 |
+
iterable (iterable): iterable to filter
|
152 |
+
filtered (list): list to store filtered elements
|
153 |
+
"""
|
154 |
+
for el in iterable:
|
155 |
+
if function(el):
|
156 |
+
yield el
|
157 |
+
else:
|
158 |
+
filtered.append(el)
|
159 |
+
|
160 |
+
|
161 |
+
def _filter_by_size_dynamic(indices, size_fn, max_positions, raise_exception=False):
|
162 |
+
def compare_leq(a, b):
|
163 |
+
return a <= b if not isinstance(a, tuple) else max(a) <= b
|
164 |
+
|
165 |
+
def check_size(idx):
|
166 |
+
if isinstance(max_positions, float) or isinstance(max_positions, int):
|
167 |
+
return size_fn(idx) <= max_positions
|
168 |
+
elif isinstance(max_positions, dict):
|
169 |
+
idx_size = size_fn(idx)
|
170 |
+
assert isinstance(idx_size, dict)
|
171 |
+
intersect_keys = set(max_positions.keys()) & set(idx_size.keys())
|
172 |
+
return all(
|
173 |
+
all(
|
174 |
+
a is None or b is None or a <= b
|
175 |
+
for a, b in zip(idx_size[key], max_positions[key])
|
176 |
+
)
|
177 |
+
for key in intersect_keys
|
178 |
+
)
|
179 |
+
else:
|
180 |
+
# For MultiCorpusSampledDataset, will generalize it later
|
181 |
+
if not isinstance(size_fn(idx), Iterable):
|
182 |
+
return all(size_fn(idx) <= b for b in max_positions)
|
183 |
+
return all(
|
184 |
+
a is None or b is None or a <= b
|
185 |
+
for a, b in zip(size_fn(idx), max_positions)
|
186 |
+
)
|
187 |
+
|
188 |
+
ignored = []
|
189 |
+
itr = collect_filtered(check_size, indices, ignored)
|
190 |
+
indices = np.fromiter(itr, dtype=np.int64, count=-1)
|
191 |
+
return indices, ignored
|
192 |
+
|
193 |
+
|
194 |
+
def filter_by_size(indices, dataset, max_positions, raise_exception=False):
|
195 |
+
"""
|
196 |
+
[deprecated] Filter indices based on their size.
|
197 |
+
Use `FairseqDataset::filter_indices_by_size` instead.
|
198 |
+
|
199 |
+
Args:
|
200 |
+
indices (List[int]): ordered list of dataset indices
|
201 |
+
dataset (FairseqDataset): fairseq dataset instance
|
202 |
+
max_positions (tuple): filter elements larger than this size.
|
203 |
+
Comparisons are done component-wise.
|
204 |
+
raise_exception (bool, optional): if ``True``, raise an exception if
|
205 |
+
any elements are filtered (default: False).
|
206 |
+
"""
|
207 |
+
warnings.warn(
|
208 |
+
"data_utils.filter_by_size is deprecated. "
|
209 |
+
"Use `FairseqDataset::filter_indices_by_size` instead.",
|
210 |
+
stacklevel=2,
|
211 |
+
)
|
212 |
+
if isinstance(max_positions, float) or isinstance(max_positions, int):
|
213 |
+
if hasattr(dataset, "sizes") and isinstance(dataset.sizes, np.ndarray):
|
214 |
+
ignored = indices[dataset.sizes[indices] > max_positions].tolist()
|
215 |
+
indices = indices[dataset.sizes[indices] <= max_positions]
|
216 |
+
elif (
|
217 |
+
hasattr(dataset, "sizes")
|
218 |
+
and isinstance(dataset.sizes, list)
|
219 |
+
and len(dataset.sizes) == 1
|
220 |
+
):
|
221 |
+
ignored = indices[dataset.sizes[0][indices] > max_positions].tolist()
|
222 |
+
indices = indices[dataset.sizes[0][indices] <= max_positions]
|
223 |
+
else:
|
224 |
+
indices, ignored = _filter_by_size_dynamic(
|
225 |
+
indices, dataset.size, max_positions
|
226 |
+
)
|
227 |
+
else:
|
228 |
+
indices, ignored = _filter_by_size_dynamic(indices, dataset.size, max_positions)
|
229 |
+
|
230 |
+
if len(ignored) > 0 and raise_exception:
|
231 |
+
raise Exception(
|
232 |
+
(
|
233 |
+
"Size of sample #{} is invalid (={}) since max_positions={}, "
|
234 |
+
"skip this example with --skip-invalid-size-inputs-valid-test"
|
235 |
+
).format(ignored[0], dataset.size(ignored[0]), max_positions)
|
236 |
+
)
|
237 |
+
if len(ignored) > 0:
|
238 |
+
logger.warning(
|
239 |
+
(
|
240 |
+
"{} samples have invalid sizes and will be skipped, "
|
241 |
+
"max_positions={}, first few sample ids={}"
|
242 |
+
).format(len(ignored), max_positions, ignored[:10])
|
243 |
+
)
|
244 |
+
return indices
|
245 |
+
|
246 |
+
|
247 |
+
def filter_paired_dataset_indices_by_size(src_sizes, tgt_sizes, indices, max_sizes):
|
248 |
+
"""Filter a list of sample indices. Remove those that are longer
|
249 |
+
than specified in max_sizes.
|
250 |
+
|
251 |
+
Args:
|
252 |
+
indices (np.array): original array of sample indices
|
253 |
+
max_sizes (int or list[int] or tuple[int]): max sample size,
|
254 |
+
can be defined separately for src and tgt (then list or tuple)
|
255 |
+
|
256 |
+
Returns:
|
257 |
+
np.array: filtered sample array
|
258 |
+
list: list of removed indices
|
259 |
+
"""
|
260 |
+
if max_sizes is None:
|
261 |
+
return indices, []
|
262 |
+
if type(max_sizes) in (int, float):
|
263 |
+
max_src_size, max_tgt_size = max_sizes, max_sizes
|
264 |
+
else:
|
265 |
+
max_src_size, max_tgt_size = max_sizes
|
266 |
+
if tgt_sizes is None:
|
267 |
+
ignored = indices[src_sizes[indices] > max_src_size]
|
268 |
+
else:
|
269 |
+
ignored = indices[
|
270 |
+
(src_sizes[indices] > max_src_size) | (tgt_sizes[indices] > max_tgt_size)
|
271 |
+
]
|
272 |
+
if len(ignored) > 0:
|
273 |
+
if tgt_sizes is None:
|
274 |
+
indices = indices[src_sizes[indices] <= max_src_size]
|
275 |
+
else:
|
276 |
+
indices = indices[
|
277 |
+
(src_sizes[indices] <= max_src_size)
|
278 |
+
& (tgt_sizes[indices] <= max_tgt_size)
|
279 |
+
]
|
280 |
+
return indices, ignored.tolist()
|
281 |
+
|
282 |
+
|
283 |
+
def batch_by_size(
|
284 |
+
indices,
|
285 |
+
num_tokens_fn,
|
286 |
+
num_tokens_vec=None,
|
287 |
+
max_tokens=None,
|
288 |
+
max_sentences=None,
|
289 |
+
required_batch_size_multiple=1,
|
290 |
+
fixed_shapes=None,
|
291 |
+
):
|
292 |
+
"""
|
293 |
+
Yield mini-batches of indices bucketed by size. Batches may contain
|
294 |
+
sequences of different lengths.
|
295 |
+
|
296 |
+
Args:
|
297 |
+
indices (List[int]): ordered list of dataset indices
|
298 |
+
num_tokens_fn (callable): function that returns the number of tokens at
|
299 |
+
a given index
|
300 |
+
num_tokens_vec (List[int], optional): precomputed vector of the number
|
301 |
+
of tokens for each index in indices (to enable faster batch generation)
|
302 |
+
max_tokens (int, optional): max number of tokens in each batch
|
303 |
+
(default: None).
|
304 |
+
max_sentences (int, optional): max number of sentences in each
|
305 |
+
batch (default: None).
|
306 |
+
required_batch_size_multiple (int, optional): require batch size to
|
307 |
+
be less than N or a multiple of N (default: 1).
|
308 |
+
fixed_shapes (List[Tuple[int, int]], optional): if given, batches will
|
309 |
+
only be created with the given shapes. *max_sentences* and
|
310 |
+
*required_batch_size_multiple* will be ignored (default: None).
|
311 |
+
"""
|
312 |
+
try:
|
313 |
+
from fairseq.data.data_utils_fast import (
|
314 |
+
batch_by_size_fn,
|
315 |
+
batch_by_size_vec,
|
316 |
+
batch_fixed_shapes_fast,
|
317 |
+
)
|
318 |
+
except ImportError:
|
319 |
+
raise ImportError(
|
320 |
+
"Please build Cython components with: "
|
321 |
+
"`python setup.py build_ext --inplace`"
|
322 |
+
)
|
323 |
+
except ValueError:
|
324 |
+
raise ValueError(
|
325 |
+
"Please build (or rebuild) Cython components with `python setup.py build_ext --inplace`."
|
326 |
+
)
|
327 |
+
|
328 |
+
# added int() to avoid TypeError: an integer is required
|
329 |
+
max_tokens = int(max_tokens) if max_tokens is not None else -1
|
330 |
+
max_sentences = max_sentences if max_sentences is not None else -1
|
331 |
+
bsz_mult = required_batch_size_multiple
|
332 |
+
|
333 |
+
if not isinstance(indices, np.ndarray):
|
334 |
+
indices = np.fromiter(indices, dtype=np.int64, count=-1)
|
335 |
+
|
336 |
+
if num_tokens_vec is not None and not isinstance(num_tokens_vec, np.ndarray):
|
337 |
+
num_tokens_vec = np.fromiter(num_tokens_vec, dtype=np.int64, count=-1)
|
338 |
+
|
339 |
+
if fixed_shapes is None:
|
340 |
+
if num_tokens_vec is None:
|
341 |
+
b = batch_by_size_fn(
|
342 |
+
indices,
|
343 |
+
num_tokens_fn,
|
344 |
+
max_tokens,
|
345 |
+
max_sentences,
|
346 |
+
bsz_mult,
|
347 |
+
)
|
348 |
+
else:
|
349 |
+
b = batch_by_size_vec(
|
350 |
+
indices,
|
351 |
+
num_tokens_vec,
|
352 |
+
max_tokens,
|
353 |
+
max_sentences,
|
354 |
+
bsz_mult,
|
355 |
+
)
|
356 |
+
|
357 |
+
if bsz_mult > 1 and len(b[-1]) % bsz_mult != 0:
|
358 |
+
b = b[:-1]
|
359 |
+
|
360 |
+
return b
|
361 |
+
|
362 |
+
else:
|
363 |
+
fixed_shapes = np.array(fixed_shapes, dtype=np.int64)
|
364 |
+
sort_order = np.lexsort(
|
365 |
+
[
|
366 |
+
fixed_shapes[:, 1].argsort(), # length
|
367 |
+
fixed_shapes[:, 0].argsort(), # bsz
|
368 |
+
]
|
369 |
+
)
|
370 |
+
fixed_shapes_sorted = fixed_shapes[sort_order]
|
371 |
+
return batch_fixed_shapes_fast(indices, num_tokens_fn, fixed_shapes_sorted)
|
372 |
+
|
373 |
+
|
374 |
+
def post_process(sentence: str, symbol: str):
|
375 |
+
if symbol == "sentencepiece":
|
376 |
+
sentence = sentence.replace(" ", "").replace("\u2581", " ").strip()
|
377 |
+
elif symbol == "wordpiece":
|
378 |
+
sentence = sentence.replace(" ", "").replace("_", " ").strip()
|
379 |
+
elif symbol == "letter":
|
380 |
+
sentence = sentence.replace(" ", "").replace("|", " ").strip()
|
381 |
+
elif symbol == "silence":
|
382 |
+
import re
|
383 |
+
|
384 |
+
sentence = sentence.replace("<SIL>", "")
|
385 |
+
sentence = re.sub(" +", " ", sentence).strip()
|
386 |
+
elif symbol == "_EOW":
|
387 |
+
sentence = sentence.replace(" ", "").replace("_EOW", " ").strip()
|
388 |
+
elif symbol in {"subword_nmt", "@@ ", "@@"}:
|
389 |
+
if symbol == "subword_nmt":
|
390 |
+
symbol = "@@ "
|
391 |
+
sentence = (sentence + " ").replace(symbol, "").rstrip()
|
392 |
+
elif symbol == "none":
|
393 |
+
pass
|
394 |
+
elif symbol is not None:
|
395 |
+
raise NotImplementedError(f"Unknown post_process option: {symbol}")
|
396 |
+
return sentence
|
397 |
+
|
398 |
+
|
399 |
+
def compute_mask_indices(
|
400 |
+
shape: Tuple[int, int],
|
401 |
+
padding_mask: Optional[torch.Tensor],
|
402 |
+
mask_prob: float,
|
403 |
+
mask_length: int,
|
404 |
+
mask_type: str = "static",
|
405 |
+
mask_other: float = 0.0,
|
406 |
+
min_masks: int = 0,
|
407 |
+
no_overlap: bool = False,
|
408 |
+
min_space: int = 0,
|
409 |
+
require_same_masks: bool = True,
|
410 |
+
mask_dropout: float = 0.0,
|
411 |
+
add_masks: bool = False,
|
412 |
+
seed: Optional[int] = None,
|
413 |
+
epoch: Optional[int] = None,
|
414 |
+
indices: Optional[torch.Tensor] = None,
|
415 |
+
idc_select_ver: int = 1, # 2 to reproduce mask_tokens_dataset
|
416 |
+
num_mask_ver: int = 2, # 2 to reproduce mask_tokens_dataset
|
417 |
+
) -> np.ndarray:
|
418 |
+
"""
|
419 |
+
Computes random mask spans for a given shape
|
420 |
+
|
421 |
+
Args:
|
422 |
+
shape: the the shape for which to compute masks.
|
423 |
+
should be of size 2 where first element is batch size and 2nd is timesteps
|
424 |
+
padding_mask: optional padding mask of the same size as shape, which will prevent masking padded elements
|
425 |
+
mask_prob: probability for each token to be chosen as start of the span to be masked. this will be multiplied by
|
426 |
+
number of timesteps divided by length of mask span to mask approximately this percentage of all elements.
|
427 |
+
however due to overlaps, the actual number will be smaller (unless no_overlap is True)
|
428 |
+
mask_type: how to compute mask lengths
|
429 |
+
static = fixed size
|
430 |
+
uniform = sample from uniform distribution [mask_other, mask_length*2]
|
431 |
+
normal = sample from normal distribution with mean mask_length and stdev mask_other. mask is min 1 element
|
432 |
+
poisson = sample from possion distribution with lambda = mask length
|
433 |
+
min_masks: minimum number of masked spans
|
434 |
+
no_overlap: if false, will switch to an alternative recursive algorithm that prevents spans from overlapping
|
435 |
+
min_space: only used if no_overlap is True, this is how many elements to keep unmasked between spans
|
436 |
+
require_same_masks: if true, will randomly drop out masks until same amount of masks remains in each sample
|
437 |
+
mask_dropout: randomly dropout this percentage of masks in each example
|
438 |
+
"""
|
439 |
+
|
440 |
+
bsz, all_sz = shape
|
441 |
+
mask = np.full((bsz, all_sz), False)
|
442 |
+
|
443 |
+
if num_mask_ver == 1:
|
444 |
+
all_num_mask = int(
|
445 |
+
# add a random number for probabilistic rounding
|
446 |
+
mask_prob * all_sz / float(mask_length)
|
447 |
+
+ np.random.rand()
|
448 |
+
)
|
449 |
+
all_num_mask = max(min_masks, all_num_mask)
|
450 |
+
|
451 |
+
mask_idcs = []
|
452 |
+
for i in range(bsz):
|
453 |
+
if seed is not None and epoch is not None and indices is not None:
|
454 |
+
seed_i = int(hash((seed, epoch, indices[i].item())) % 1e6)
|
455 |
+
else:
|
456 |
+
seed_i = None
|
457 |
+
|
458 |
+
rng = np.random.default_rng(seed_i)
|
459 |
+
|
460 |
+
if padding_mask is not None:
|
461 |
+
sz = all_sz - padding_mask[i].long().sum().item()
|
462 |
+
assert sz >= 0, sz
|
463 |
+
else:
|
464 |
+
sz = all_sz
|
465 |
+
|
466 |
+
if num_mask_ver == 1:
|
467 |
+
if padding_mask is not None:
|
468 |
+
num_mask = int(
|
469 |
+
# add a random number for probabilistic rounding
|
470 |
+
mask_prob * sz / float(mask_length)
|
471 |
+
+ np.random.rand()
|
472 |
+
)
|
473 |
+
num_mask = max(min_masks, num_mask)
|
474 |
+
else:
|
475 |
+
num_mask = all_num_mask
|
476 |
+
elif num_mask_ver == 2:
|
477 |
+
num_mask = int(
|
478 |
+
# add a random number for probabilistic rounding
|
479 |
+
mask_prob * sz / float(mask_length)
|
480 |
+
+ rng.random()
|
481 |
+
)
|
482 |
+
num_mask = max(min_masks, num_mask)
|
483 |
+
else:
|
484 |
+
raise ValueError()
|
485 |
+
|
486 |
+
if mask_type == "static":
|
487 |
+
lengths = np.full(num_mask, mask_length)
|
488 |
+
elif mask_type == "uniform":
|
489 |
+
lengths = rng.randint(mask_other, mask_length * 2 + 1, size=num_mask)
|
490 |
+
elif mask_type == "normal":
|
491 |
+
lengths = rng.normal(mask_length, mask_other, size=num_mask)
|
492 |
+
lengths = [max(1, int(round(x))) for x in lengths]
|
493 |
+
elif mask_type == "poisson":
|
494 |
+
lengths = rng.poisson(mask_length, size=num_mask)
|
495 |
+
lengths = [int(round(x)) for x in lengths]
|
496 |
+
else:
|
497 |
+
raise Exception("unknown mask selection " + mask_type)
|
498 |
+
|
499 |
+
if sum(lengths) == 0:
|
500 |
+
if mask_type == "static":
|
501 |
+
raise ValueError(f"this should never happens")
|
502 |
+
else:
|
503 |
+
lengths = [min(mask_length, sz - 1)]
|
504 |
+
|
505 |
+
if no_overlap:
|
506 |
+
mask_idc = []
|
507 |
+
|
508 |
+
def arrange(s, e, length, keep_length):
|
509 |
+
span_start = rng.randint(s, e - length)
|
510 |
+
mask_idc.extend(span_start + i for i in range(length))
|
511 |
+
|
512 |
+
new_parts = []
|
513 |
+
if span_start - s - min_space >= keep_length:
|
514 |
+
new_parts.append((s, span_start - min_space + 1))
|
515 |
+
if e - span_start - length - min_space > keep_length:
|
516 |
+
new_parts.append((span_start + length + min_space, e))
|
517 |
+
return new_parts
|
518 |
+
|
519 |
+
parts = [(0, sz)]
|
520 |
+
min_length = min(lengths)
|
521 |
+
for length in sorted(lengths, reverse=True):
|
522 |
+
lens = np.fromiter(
|
523 |
+
(e - s if e - s >= length + min_space else 0 for s, e in parts),
|
524 |
+
np.int,
|
525 |
+
)
|
526 |
+
l_sum = np.sum(lens)
|
527 |
+
if l_sum == 0:
|
528 |
+
break
|
529 |
+
probs = lens / np.sum(lens)
|
530 |
+
c = rng.choice(len(parts), p=probs)
|
531 |
+
s, e = parts.pop(c)
|
532 |
+
parts.extend(arrange(s, e, length, min_length))
|
533 |
+
mask_idc = np.asarray(mask_idc)
|
534 |
+
else:
|
535 |
+
if idc_select_ver == 1:
|
536 |
+
min_len = min(lengths)
|
537 |
+
if sz - min_len <= num_mask:
|
538 |
+
min_len = sz - num_mask - 1
|
539 |
+
mask_idc = rng.choice(sz - min_len, num_mask, replace=False)
|
540 |
+
elif idc_select_ver == 2:
|
541 |
+
mask_idc = rng.choice(sz, num_mask, replace=False)
|
542 |
+
else:
|
543 |
+
raise ValueError()
|
544 |
+
|
545 |
+
mask_idc = np.asarray(
|
546 |
+
[
|
547 |
+
mask_idc[j] + offset
|
548 |
+
for j in range(len(mask_idc))
|
549 |
+
for offset in range(lengths[j])
|
550 |
+
]
|
551 |
+
)
|
552 |
+
|
553 |
+
mask_idc = np.unique(mask_idc[mask_idc < sz])
|
554 |
+
if len(mask_idc) >= sz:
|
555 |
+
raise ValueError(
|
556 |
+
(
|
557 |
+
f"the entire sequence is masked. "
|
558 |
+
f"sz={sz}; mask_idc[mask_idc]; "
|
559 |
+
f"index={indices[i] if indices is not None else None}"
|
560 |
+
)
|
561 |
+
)
|
562 |
+
mask_idcs.append(mask_idc)
|
563 |
+
|
564 |
+
target_len = None
|
565 |
+
if require_same_masks:
|
566 |
+
if add_masks:
|
567 |
+
target_len = max([len(m) for m in mask_idcs])
|
568 |
+
else:
|
569 |
+
target_len = min([len(m) for m in mask_idcs])
|
570 |
+
|
571 |
+
for i, mask_idc in enumerate(mask_idcs):
|
572 |
+
if target_len is not None and len(mask_idc) > target_len:
|
573 |
+
mask_idc = rng.choice(mask_idc, target_len, replace=False)
|
574 |
+
|
575 |
+
mask[i, mask_idc] = True
|
576 |
+
|
577 |
+
if target_len is not None and len(mask_idc) < target_len:
|
578 |
+
unmasked = np.flatnonzero(~mask[i])
|
579 |
+
to_mask = rng.choice(unmasked, target_len - len(mask_idc), replace=False)
|
580 |
+
mask[i, to_mask] = True
|
581 |
+
|
582 |
+
if mask_dropout > 0:
|
583 |
+
masked = np.flatnonzero(mask[i])
|
584 |
+
num_holes = np.rint(len(masked) * mask_dropout).astype(int)
|
585 |
+
to_drop = rng.choice(masked, num_holes, replace=False)
|
586 |
+
mask[i, to_drop] = False
|
587 |
+
|
588 |
+
return mask
|
589 |
+
|
590 |
+
|
591 |
+
def compute_block_mask_2d(
|
592 |
+
shape: Tuple[int, int],
|
593 |
+
mask_prob: float,
|
594 |
+
mask_length: int,
|
595 |
+
mask_prob_adjust: float = 0,
|
596 |
+
inverse_mask: bool = False,
|
597 |
+
require_same_masks: bool = True,
|
598 |
+
expand_adjcent: bool = False,
|
599 |
+
mask_dropout: float = 0,
|
600 |
+
non_overlapping: bool = False,
|
601 |
+
) -> torch.Tensor:
|
602 |
+
|
603 |
+
assert mask_length > 1
|
604 |
+
|
605 |
+
B, L = shape
|
606 |
+
|
607 |
+
d = int(L**0.5)
|
608 |
+
|
609 |
+
if inverse_mask:
|
610 |
+
mask_prob = 1 - mask_prob
|
611 |
+
|
612 |
+
if non_overlapping:
|
613 |
+
sz = math.ceil(d / mask_length)
|
614 |
+
inp_len = sz * sz
|
615 |
+
|
616 |
+
inp = torch.zeros((B, 1, sz, sz))
|
617 |
+
w = torch.ones((1, 1, mask_length, mask_length))
|
618 |
+
|
619 |
+
mask_inds = torch.multinomial(
|
620 |
+
1 - inp.view(B, -1),
|
621 |
+
int(inp_len * (mask_prob + mask_prob_adjust) * (1 + mask_dropout)),
|
622 |
+
replacement=False,
|
623 |
+
)
|
624 |
+
inp.view(B, -1).scatter_(1, mask_inds, 1)
|
625 |
+
|
626 |
+
mask = torch.nn.functional.conv_transpose2d(inp, w, stride=mask_length).squeeze(
|
627 |
+
1
|
628 |
+
)
|
629 |
+
if mask.size(-1) > d:
|
630 |
+
mask = mask[..., :d, :d]
|
631 |
+
else:
|
632 |
+
mask = torch.zeros((B, d, d))
|
633 |
+
mask_inds = torch.randint(
|
634 |
+
0,
|
635 |
+
L,
|
636 |
+
size=(
|
637 |
+
B,
|
638 |
+
int(
|
639 |
+
L
|
640 |
+
* ((mask_prob + mask_prob_adjust) / mask_length**2)
|
641 |
+
* (1 + mask_dropout)
|
642 |
+
),
|
643 |
+
),
|
644 |
+
)
|
645 |
+
mask.view(B, -1).scatter_(1, mask_inds, 1)
|
646 |
+
centers = mask.nonzero(as_tuple=True)
|
647 |
+
|
648 |
+
inds = ([], [], [])
|
649 |
+
|
650 |
+
offset = mask_length // 2
|
651 |
+
for i in range(mask_length):
|
652 |
+
for j in range(mask_length):
|
653 |
+
k1 = i - offset
|
654 |
+
k2 = j - offset
|
655 |
+
inds[0].append(centers[0])
|
656 |
+
inds[1].append(centers[1] + k1)
|
657 |
+
inds[2].append(centers[2] + k2)
|
658 |
+
|
659 |
+
i0 = torch.cat(inds[0])
|
660 |
+
i1 = torch.cat(inds[1]).clamp_(min=0, max=d - 1)
|
661 |
+
i2 = torch.cat(inds[2]).clamp_(min=0, max=d - 1)
|
662 |
+
|
663 |
+
mask[(i0, i1, i2)] = 1
|
664 |
+
|
665 |
+
def get_nbs(b, m, w):
|
666 |
+
all_nbs = torch.nn.functional.conv2d(m.unsqueeze(1), w, padding="same")
|
667 |
+
all_nbs = all_nbs.clamp_max_(1).view(b, -1)
|
668 |
+
return all_nbs
|
669 |
+
|
670 |
+
if require_same_masks and expand_adjcent:
|
671 |
+
w = torch.zeros((1, 1, 3, 3))
|
672 |
+
w[..., 0, 1] = 1
|
673 |
+
w[..., 2, 1] = 1
|
674 |
+
w[..., 1, 0] = 1
|
675 |
+
w[..., 1, 2] = 1
|
676 |
+
|
677 |
+
all_nbs = get_nbs(B, mask, w)
|
678 |
+
|
679 |
+
mask = mask.reshape(B, -1)
|
680 |
+
|
681 |
+
if require_same_masks:
|
682 |
+
n_masks = mask.sum(dim=-1)
|
683 |
+
final_target_len = int(L * (mask_prob))
|
684 |
+
target_len = int(final_target_len * (1 + mask_dropout))
|
685 |
+
|
686 |
+
for i in range(len(mask)):
|
687 |
+
n = n_masks[i]
|
688 |
+
m = mask[i]
|
689 |
+
r = 0
|
690 |
+
while expand_adjcent and n < target_len:
|
691 |
+
if r == 0:
|
692 |
+
nbs = all_nbs[i]
|
693 |
+
else:
|
694 |
+
nbs = get_nbs(1, m.view(1, d, d), w).flatten()
|
695 |
+
|
696 |
+
cands = (1 - m + nbs) > 1
|
697 |
+
cand_sz = int(cands.sum().item())
|
698 |
+
|
699 |
+
assert cand_sz > 0, f"{nbs} {cand_sz}"
|
700 |
+
|
701 |
+
to_mask = torch.multinomial(
|
702 |
+
cands.float(), min(cand_sz, int(target_len - n)), replacement=False
|
703 |
+
)
|
704 |
+
m[to_mask] = 1
|
705 |
+
assert to_mask.numel() > 0
|
706 |
+
n += to_mask.numel()
|
707 |
+
r += 1
|
708 |
+
|
709 |
+
if n > final_target_len:
|
710 |
+
to_unmask = torch.multinomial(
|
711 |
+
m, int(n - final_target_len), replacement=False
|
712 |
+
)
|
713 |
+
m[to_unmask] = 0
|
714 |
+
elif n < final_target_len:
|
715 |
+
to_mask = torch.multinomial(
|
716 |
+
(1 - m), int(final_target_len - n), replacement=False
|
717 |
+
)
|
718 |
+
m[to_mask] = 1
|
719 |
+
|
720 |
+
if inverse_mask:
|
721 |
+
mask = 1 - mask
|
722 |
+
|
723 |
+
return mask
|
724 |
+
|
725 |
+
|
726 |
+
def compute_block_mask_1d(
|
727 |
+
shape: Tuple[int, int],
|
728 |
+
mask_prob: float,
|
729 |
+
mask_length: int,
|
730 |
+
mask_prob_adjust: float = 0,
|
731 |
+
inverse_mask: bool = False,
|
732 |
+
require_same_masks: bool = True,
|
733 |
+
expand_adjcent: bool = False,
|
734 |
+
mask_dropout: float = 0,
|
735 |
+
non_overlapping: bool = False,
|
736 |
+
) -> torch.Tensor:
|
737 |
+
|
738 |
+
B, L = shape
|
739 |
+
|
740 |
+
if inverse_mask:
|
741 |
+
mask_prob = 1 - mask_prob
|
742 |
+
|
743 |
+
if non_overlapping:
|
744 |
+
sz = math.ceil(L / mask_length)
|
745 |
+
|
746 |
+
inp = torch.zeros((B, 1, sz))
|
747 |
+
w = torch.ones((1, 1, mask_length))
|
748 |
+
|
749 |
+
mask_inds = torch.multinomial(
|
750 |
+
1 - inp.view(B, -1),
|
751 |
+
int(sz * (mask_prob + mask_prob_adjust) * (1 + mask_dropout)),
|
752 |
+
replacement=False,
|
753 |
+
)
|
754 |
+
inp.view(B, -1).scatter_(1, mask_inds, 1)
|
755 |
+
|
756 |
+
mask = torch.nn.functional.conv_transpose1d(inp, w, stride=mask_length).squeeze(
|
757 |
+
1
|
758 |
+
)
|
759 |
+
if mask.size(-1) > L:
|
760 |
+
mask = mask[..., :L]
|
761 |
+
|
762 |
+
else:
|
763 |
+
mask = torch.zeros((B, L))
|
764 |
+
mask_inds = torch.randint(
|
765 |
+
0,
|
766 |
+
L,
|
767 |
+
size=(
|
768 |
+
B,
|
769 |
+
int(
|
770 |
+
L
|
771 |
+
* ((mask_prob + mask_prob_adjust) / mask_length)
|
772 |
+
* (1 + mask_dropout)
|
773 |
+
),
|
774 |
+
),
|
775 |
+
)
|
776 |
+
|
777 |
+
mask.view(B, -1).scatter_(1, mask_inds, 1)
|
778 |
+
centers = mask.nonzero(as_tuple=True)
|
779 |
+
|
780 |
+
inds = ([], [])
|
781 |
+
|
782 |
+
offset = mask_length // 2
|
783 |
+
for i in range(mask_length):
|
784 |
+
k1 = i - offset
|
785 |
+
inds[0].append(centers[0])
|
786 |
+
inds[1].append(centers[1] + k1)
|
787 |
+
|
788 |
+
i0 = torch.cat(inds[0])
|
789 |
+
i1 = torch.cat(inds[1]).clamp_(min=0, max=L - 1)
|
790 |
+
|
791 |
+
mask[(i0, i1)] = 1
|
792 |
+
|
793 |
+
def get_nbs(b, m, w):
|
794 |
+
all_nbs = torch.nn.functional.conv1d(m.unsqueeze(1), w, padding="same")
|
795 |
+
all_nbs = all_nbs.clamp_max_(1).view(b, -1)
|
796 |
+
return all_nbs
|
797 |
+
|
798 |
+
if require_same_masks and expand_adjcent:
|
799 |
+
w = torch.ones((1, 1, 3))
|
800 |
+
w[..., 1] = 0
|
801 |
+
all_nbs = get_nbs(B, mask, w)
|
802 |
+
|
803 |
+
mask = mask.view(B, -1)
|
804 |
+
|
805 |
+
if require_same_masks:
|
806 |
+
n_masks = mask.sum(dim=-1)
|
807 |
+
final_target_len = int(L * (mask_prob))
|
808 |
+
target_len = int(final_target_len * (1 + mask_dropout))
|
809 |
+
|
810 |
+
for i in range(len(mask)):
|
811 |
+
n = n_masks[i]
|
812 |
+
m = mask[i]
|
813 |
+
r = 0
|
814 |
+
while expand_adjcent and n < target_len:
|
815 |
+
if r == 0:
|
816 |
+
nbs = all_nbs[i]
|
817 |
+
else:
|
818 |
+
nbs = get_nbs(1, m.unsqueeze(0), w).squeeze(0)
|
819 |
+
|
820 |
+
cands = (1 - m + nbs) > 1
|
821 |
+
cand_sz = int(cands.sum().item())
|
822 |
+
|
823 |
+
assert cand_sz > 0, f"{nbs} {cand_sz}"
|
824 |
+
|
825 |
+
to_mask = torch.multinomial(
|
826 |
+
cands.float(), min(cand_sz, int(target_len - n)), replacement=False
|
827 |
+
)
|
828 |
+
m[to_mask] = 1
|
829 |
+
assert to_mask.numel() > 0
|
830 |
+
n += to_mask.numel()
|
831 |
+
r += 1
|
832 |
+
|
833 |
+
if n > final_target_len:
|
834 |
+
to_unmask = torch.multinomial(
|
835 |
+
m, int(n - final_target_len), replacement=False
|
836 |
+
)
|
837 |
+
m[to_unmask] = 0
|
838 |
+
elif n < final_target_len:
|
839 |
+
to_mask = torch.multinomial(
|
840 |
+
(1 - m), int(final_target_len - n), replacement=False
|
841 |
+
)
|
842 |
+
m[to_mask] = 1
|
843 |
+
|
844 |
+
if inverse_mask:
|
845 |
+
mask = 1 - mask
|
846 |
+
|
847 |
+
return mask
|
848 |
+
|
849 |
+
|
850 |
+
def get_mem_usage():
|
851 |
+
try:
|
852 |
+
import psutil
|
853 |
+
|
854 |
+
mb = 1024 * 1024
|
855 |
+
return f"used={psutil.virtual_memory().used / mb}Mb; avail={psutil.virtual_memory().available / mb}Mb"
|
856 |
+
except ImportError:
|
857 |
+
return "N/A"
|
858 |
+
|
859 |
+
|
860 |
+
# lens: torch.LongTensor
|
861 |
+
# returns: torch.BoolTensor
|
862 |
+
def lengths_to_padding_mask(lens):
|
863 |
+
bsz, max_lens = lens.size(0), torch.max(lens).item()
|
864 |
+
mask = torch.arange(max_lens).to(lens.device).view(1, max_lens)
|
865 |
+
mask = mask.expand(bsz, -1) >= lens.view(bsz, 1).expand(-1, max_lens)
|
866 |
+
return mask
|
867 |
+
|
868 |
+
|
869 |
+
# lens: torch.LongTensor
|
870 |
+
# returns: torch.BoolTensor
|
871 |
+
def lengths_to_mask(lens):
|
872 |
+
return ~lengths_to_padding_mask(lens)
|
873 |
+
|
874 |
+
|
875 |
+
def get_buckets(sizes, num_buckets):
|
876 |
+
buckets = np.unique(
|
877 |
+
np.percentile(
|
878 |
+
sizes,
|
879 |
+
np.linspace(0, 100, num_buckets + 1),
|
880 |
+
interpolation="lower",
|
881 |
+
)[1:]
|
882 |
+
)
|
883 |
+
return buckets
|
884 |
+
|
885 |
+
|
886 |
+
def get_bucketed_sizes(orig_sizes, buckets):
|
887 |
+
sizes = np.copy(orig_sizes)
|
888 |
+
assert np.min(sizes) >= 0
|
889 |
+
start_val = -1
|
890 |
+
for end_val in buckets:
|
891 |
+
mask = (sizes > start_val) & (sizes <= end_val)
|
892 |
+
sizes[mask] = end_val
|
893 |
+
start_val = end_val
|
894 |
+
return sizes
|
895 |
+
|
896 |
+
|
897 |
+
def _find_extra_valid_paths(dataset_path: str) -> set:
|
898 |
+
paths = utils.split_paths(dataset_path)
|
899 |
+
all_valid_paths = set()
|
900 |
+
for sub_dir in paths:
|
901 |
+
contents = PathManager.ls(sub_dir)
|
902 |
+
valid_paths = [c for c in contents if re.match("valid*[0-9].*", c) is not None]
|
903 |
+
all_valid_paths |= {os.path.basename(p) for p in valid_paths}
|
904 |
+
# Remove .bin, .idx etc
|
905 |
+
roots = {os.path.splitext(p)[0] for p in all_valid_paths}
|
906 |
+
return roots
|
907 |
+
|
908 |
+
|
909 |
+
def raise_if_valid_subsets_unintentionally_ignored(train_cfg) -> None:
|
910 |
+
"""Raises if there are paths matching 'valid*[0-9].*' which are not combined or ignored."""
|
911 |
+
if (
|
912 |
+
train_cfg.dataset.ignore_unused_valid_subsets
|
913 |
+
or train_cfg.dataset.combine_valid_subsets
|
914 |
+
or train_cfg.dataset.disable_validation
|
915 |
+
or not hasattr(train_cfg.task, "data")
|
916 |
+
):
|
917 |
+
return
|
918 |
+
other_paths = _find_extra_valid_paths(train_cfg.task.data)
|
919 |
+
specified_subsets = train_cfg.dataset.valid_subset.split(",")
|
920 |
+
ignored_paths = [p for p in other_paths if p not in specified_subsets]
|
921 |
+
if ignored_paths:
|
922 |
+
advice = "Set --combine-val to combine them or --ignore-unused-valid-subsets to ignore them."
|
923 |
+
msg = f"Valid paths {ignored_paths} will be ignored. {advice}"
|
924 |
+
raise ValueError(msg)
|
925 |
+
|
926 |
+
|
927 |
+
def compute_mask_indices_for_one(
|
928 |
+
sz,
|
929 |
+
mask_prob: float,
|
930 |
+
mask_length: int,
|
931 |
+
seed=None,
|
932 |
+
epoch=None,
|
933 |
+
index=None,
|
934 |
+
min_masks=0,
|
935 |
+
):
|
936 |
+
"""
|
937 |
+
set seed, epoch, index for deterministic masking
|
938 |
+
"""
|
939 |
+
seed = int(hash((seed, epoch, index)) % 1e6) if seed else None
|
940 |
+
rng = np.random.default_rng(seed)
|
941 |
+
|
942 |
+
# decide elements to mask
|
943 |
+
mask = np.full(sz, False)
|
944 |
+
num_mask = int(
|
945 |
+
# add a random number for probabilistic rounding
|
946 |
+
mask_prob * sz / float(mask_length)
|
947 |
+
+ rng.random()
|
948 |
+
)
|
949 |
+
num_mask = max(min_masks, num_mask)
|
950 |
+
|
951 |
+
# multiple masking as described in the vq-wav2vec paper (https://arxiv.org/abs/1910.05453)
|
952 |
+
mask_idc = rng.choice(sz, num_mask, replace=False)
|
953 |
+
mask_idc = np.concatenate([mask_idc + i for i in range(mask_length)])
|
954 |
+
mask_idc = mask_idc[mask_idc < len(mask)]
|
955 |
+
try:
|
956 |
+
mask[mask_idc] = True
|
957 |
+
except: # something wrong
|
958 |
+
print(f"Assigning mask indexes {mask_idc} to mask {mask} failed!")
|
959 |
+
raise
|
960 |
+
|
961 |
+
return mask
|
962 |
+
|
963 |
+
|
964 |
+
def compute_mask_indices_v2(
|
965 |
+
shape: Tuple[int, int],
|
966 |
+
padding_mask: Optional[torch.Tensor],
|
967 |
+
mask_prob: float,
|
968 |
+
mask_length: int,
|
969 |
+
min_masks: int = 0,
|
970 |
+
require_same_masks: bool = True,
|
971 |
+
seed: Optional[int] = None,
|
972 |
+
epoch: Optional[int] = None,
|
973 |
+
indices: Optional[torch.Tensor] = None,
|
974 |
+
) -> np.ndarray:
|
975 |
+
bsz, all_sz = shape
|
976 |
+
mask = np.full((bsz, all_sz), False)
|
977 |
+
for i in range(bsz):
|
978 |
+
if padding_mask is not None:
|
979 |
+
sz = all_sz - padding_mask[i].long().sum().item()
|
980 |
+
else:
|
981 |
+
sz = all_sz
|
982 |
+
index = indices[i].item() if indices is not None else None
|
983 |
+
mask_for_one = compute_mask_indices_for_one(
|
984 |
+
sz, mask_prob, mask_length, seed, epoch, index, min_masks
|
985 |
+
)
|
986 |
+
mask[i, :sz] = mask_for_one
|
987 |
+
|
988 |
+
if require_same_masks:
|
989 |
+
index_sum = indices.sum().item() if indices is not None else None
|
990 |
+
seed = int(hash((seed, epoch, index_sum)) % 1e6) if seed else None
|
991 |
+
rng = np.random.default_rng(seed)
|
992 |
+
|
993 |
+
num_mask = mask.sum(-1).min()
|
994 |
+
for i in range(bsz):
|
995 |
+
extra = mask[i].sum() - num_mask
|
996 |
+
if extra > 0:
|
997 |
+
to_unmask = rng.choice(np.nonzero(mask[i])[0], extra, replace=False)
|
998 |
+
mask[i, to_unmask] = False
|
999 |
+
|
1000 |
+
return mask
|
1001 |
+
|
1002 |
+
|
1003 |
+
# TODO: a copy of the original compute_mask_indices
|
1004 |
+
def compute_mask_indices_v3(
|
1005 |
+
shape: Tuple[int, int],
|
1006 |
+
padding_mask: Optional[torch.Tensor],
|
1007 |
+
mask_prob: float,
|
1008 |
+
mask_length: int,
|
1009 |
+
mask_type: str = "static",
|
1010 |
+
mask_other: float = 0.0,
|
1011 |
+
min_masks: int = 0,
|
1012 |
+
no_overlap: bool = False,
|
1013 |
+
min_space: int = 0,
|
1014 |
+
require_same_masks: bool = True,
|
1015 |
+
mask_dropout: float = 0.0,
|
1016 |
+
seed: Optional[int] = None,
|
1017 |
+
epoch: Optional[int] = None,
|
1018 |
+
indices: Optional[torch.Tensor] = None,
|
1019 |
+
) -> np.ndarray:
|
1020 |
+
"""
|
1021 |
+
Computes random mask spans for a given shape
|
1022 |
+
|
1023 |
+
Args:
|
1024 |
+
shape: the the shape for which to compute masks.
|
1025 |
+
should be of size 2 where first element is batch size and 2nd is timesteps
|
1026 |
+
padding_mask: optional padding mask of the same size as shape, which will prevent masking padded elements
|
1027 |
+
mask_prob: probability for each token to be chosen as start of the span to be masked. this will be multiplied by
|
1028 |
+
number of timesteps divided by length of mask span to mask approximately this percentage of all elements.
|
1029 |
+
however due to overlaps, the actual number will be smaller (unless no_overlap is True)
|
1030 |
+
mask_type: how to compute mask lengths
|
1031 |
+
static = fixed size
|
1032 |
+
uniform = sample from uniform distribution [mask_other, mask_length*2]
|
1033 |
+
normal = sample from normal distribution with mean mask_length and stdev mask_other. mask is min 1 element
|
1034 |
+
poisson = sample from possion distribution with lambda = mask length
|
1035 |
+
min_masks: minimum number of masked spans
|
1036 |
+
no_overlap: if false, will switch to an alternative recursive algorithm that prevents spans from overlapping
|
1037 |
+
min_space: only used if no_overlap is True, this is how many elements to keep unmasked between spans
|
1038 |
+
require_same_masks: if true, will randomly drop out masks until same amount of masks remains in each sample
|
1039 |
+
mask_dropout: randomly dropout this percentage of masks in each example
|
1040 |
+
"""
|
1041 |
+
bsz, all_sz = shape
|
1042 |
+
mask = np.full((bsz, all_sz), False)
|
1043 |
+
|
1044 |
+
all_num_mask = int(
|
1045 |
+
# add a random number for probabilistic rounding
|
1046 |
+
mask_prob * all_sz / float(mask_length)
|
1047 |
+
+ np.random.rand()
|
1048 |
+
)
|
1049 |
+
|
1050 |
+
all_num_mask = max(min_masks, all_num_mask)
|
1051 |
+
|
1052 |
+
mask_idcs = []
|
1053 |
+
for i in range(bsz):
|
1054 |
+
if seed is not None and epoch is not None and indices is not None:
|
1055 |
+
seed_i = int(hash((seed, epoch, indices[i].item())) % 1e6)
|
1056 |
+
else:
|
1057 |
+
seed_i = None
|
1058 |
+
rng = np.random.default_rng(seed_i)
|
1059 |
+
|
1060 |
+
if padding_mask is not None:
|
1061 |
+
sz = all_sz - padding_mask[i].long().sum().item()
|
1062 |
+
num_mask = int(
|
1063 |
+
# add a random number for probabilistic rounding
|
1064 |
+
mask_prob * sz / float(mask_length)
|
1065 |
+
+ rng.random()
|
1066 |
+
)
|
1067 |
+
num_mask = max(min_masks, num_mask)
|
1068 |
+
else:
|
1069 |
+
sz = all_sz
|
1070 |
+
num_mask = all_num_mask
|
1071 |
+
|
1072 |
+
if mask_type == "static":
|
1073 |
+
lengths = np.full(num_mask, mask_length)
|
1074 |
+
elif mask_type == "uniform":
|
1075 |
+
lengths = rng.randint(mask_other, mask_length * 2 + 1, size=num_mask)
|
1076 |
+
elif mask_type == "normal":
|
1077 |
+
lengths = rng.normal(mask_length, mask_other, size=num_mask)
|
1078 |
+
lengths = [max(1, int(round(x))) for x in lengths]
|
1079 |
+
elif mask_type == "poisson":
|
1080 |
+
lengths = rng.poisson(mask_length, size=num_mask)
|
1081 |
+
lengths = [int(round(x)) for x in lengths]
|
1082 |
+
else:
|
1083 |
+
raise Exception("unknown mask selection " + mask_type)
|
1084 |
+
|
1085 |
+
if sum(lengths) == 0:
|
1086 |
+
lengths[0] = min(mask_length, sz - 1)
|
1087 |
+
|
1088 |
+
if no_overlap:
|
1089 |
+
mask_idc = []
|
1090 |
+
|
1091 |
+
def arrange(s, e, length, keep_length):
|
1092 |
+
span_start = rng.randint(s, e - length)
|
1093 |
+
mask_idc.extend(span_start + i for i in range(length))
|
1094 |
+
|
1095 |
+
new_parts = []
|
1096 |
+
if span_start - s - min_space >= keep_length:
|
1097 |
+
new_parts.append((s, span_start - min_space + 1))
|
1098 |
+
if e - span_start - length - min_space > keep_length:
|
1099 |
+
new_parts.append((span_start + length + min_space, e))
|
1100 |
+
return new_parts
|
1101 |
+
|
1102 |
+
parts = [(0, sz)]
|
1103 |
+
min_length = min(lengths)
|
1104 |
+
for length in sorted(lengths, reverse=True):
|
1105 |
+
lens = np.fromiter(
|
1106 |
+
(e - s if e - s >= length + min_space else 0 for s, e in parts),
|
1107 |
+
np.int,
|
1108 |
+
)
|
1109 |
+
l_sum = np.sum(lens)
|
1110 |
+
if l_sum == 0:
|
1111 |
+
break
|
1112 |
+
probs = lens / np.sum(lens)
|
1113 |
+
c = rng.choice(len(parts), p=probs)
|
1114 |
+
s, e = parts.pop(c)
|
1115 |
+
parts.extend(arrange(s, e, length, min_length))
|
1116 |
+
mask_idc = np.asarray(mask_idc)
|
1117 |
+
else:
|
1118 |
+
min_len = min(lengths)
|
1119 |
+
if sz - min_len <= num_mask:
|
1120 |
+
min_len = sz - num_mask - 1
|
1121 |
+
|
1122 |
+
mask_idc = rng.choice(sz - min_len, num_mask, replace=False)
|
1123 |
+
|
1124 |
+
mask_idc = np.asarray(
|
1125 |
+
[
|
1126 |
+
mask_idc[j] + offset
|
1127 |
+
for j in range(len(mask_idc))
|
1128 |
+
for offset in range(lengths[j])
|
1129 |
+
]
|
1130 |
+
)
|
1131 |
+
|
1132 |
+
mask_idcs.append(np.unique(mask_idc[mask_idc < sz]))
|
1133 |
+
|
1134 |
+
min_len = min([len(m) for m in mask_idcs])
|
1135 |
+
for i, mask_idc in enumerate(mask_idcs):
|
1136 |
+
if len(mask_idc) > min_len and require_same_masks:
|
1137 |
+
mask_idc = rng.choice(mask_idc, min_len, replace=False)
|
1138 |
+
if mask_dropout > 0:
|
1139 |
+
num_holes = np.rint(len(mask_idc) * mask_dropout).astype(int)
|
1140 |
+
mask_idc = rng.choice(mask_idc, len(mask_idc) - num_holes, replace=False)
|
1141 |
+
|
1142 |
+
mask[i, mask_idc] = True
|
1143 |
+
|
1144 |
+
return mask
|
fairseq/fairseq/data/data_utils_fast.pyx
ADDED
@@ -0,0 +1,178 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# cython: language_level=3
|
2 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the MIT license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import numpy as np
|
8 |
+
|
9 |
+
cimport cython
|
10 |
+
cimport numpy as np
|
11 |
+
|
12 |
+
from libc.stdint cimport int32_t, int64_t
|
13 |
+
from libcpp cimport bool as bool_t
|
14 |
+
|
15 |
+
ctypedef int64_t DTYPE_t
|
16 |
+
|
17 |
+
@cython.cdivision(True)
|
18 |
+
@cython.boundscheck(False)
|
19 |
+
@cython.wraparound(False)
|
20 |
+
cpdef list batch_by_size_vec(
|
21 |
+
np.ndarray[int64_t, ndim=1] indices,
|
22 |
+
np.ndarray[int64_t, ndim=1] num_tokens_vec,
|
23 |
+
int64_t max_tokens,
|
24 |
+
int64_t max_sentences,
|
25 |
+
int32_t bsz_mult,
|
26 |
+
):
|
27 |
+
if indices.shape[0] == 0:
|
28 |
+
return []
|
29 |
+
|
30 |
+
assert max_tokens <= 0 or np.max(num_tokens_vec) <= max_tokens, (
|
31 |
+
f"Sentences lengths should not exceed max_tokens={max_tokens}"
|
32 |
+
)
|
33 |
+
|
34 |
+
cdef int32_t indices_len = indices.shape[0]
|
35 |
+
cdef np.ndarray[int32_t, ndim=1] batches_ends = \
|
36 |
+
np.zeros(indices_len, dtype=np.int32)
|
37 |
+
cdef int32_t[:] batches_ends_view = batches_ends
|
38 |
+
cdef int64_t[:] num_tokens_view = num_tokens_vec
|
39 |
+
|
40 |
+
cdef int32_t pos = 0
|
41 |
+
cdef int32_t new_batch_end = 0
|
42 |
+
|
43 |
+
cdef int64_t new_batch_max_tokens = 0
|
44 |
+
cdef int32_t new_batch_sentences = 0
|
45 |
+
cdef int64_t new_batch_num_tokens = 0
|
46 |
+
|
47 |
+
cdef bool_t overflow = False
|
48 |
+
cdef bool_t size_matches_with_bsz_mult = False
|
49 |
+
|
50 |
+
cdef int32_t batches_count = 0
|
51 |
+
cdef int32_t batch_start = 0
|
52 |
+
cdef int64_t tail_max_tokens = 0
|
53 |
+
cdef int64_t batch_max_tokens = 0
|
54 |
+
|
55 |
+
for pos in range(indices_len):
|
56 |
+
# At every pos we keep stats about the last complete batch [batch_start:batch_end),
|
57 |
+
# and tail [batch_end:pos].
|
58 |
+
# 1) Every time when (batch + tail) forms a valid batch
|
59 |
+
# (according to max_tokens, max_sentences and bsz_mult) we append tail to batch.
|
60 |
+
# 2) When (batch+tail) violates max_tokens or max_sentences constraints
|
61 |
+
# we finalize running batch, and tail becomes a new batch.
|
62 |
+
# 3) There is a corner case when tail also violates constraints.
|
63 |
+
# In that situation [batch_end:pos-1] (tail without the current pos)
|
64 |
+
# gets added to the finalized batches, while [pos:pos] becomes a new tail.
|
65 |
+
#
|
66 |
+
# Important: For the sake of performance try to avoid using function calls within this loop.
|
67 |
+
|
68 |
+
tail_max_tokens = tail_max_tokens \
|
69 |
+
if tail_max_tokens > num_tokens_view[pos] \
|
70 |
+
else num_tokens_view[pos]
|
71 |
+
new_batch_end = pos + 1
|
72 |
+
new_batch_max_tokens = batch_max_tokens \
|
73 |
+
if batch_max_tokens > tail_max_tokens \
|
74 |
+
else tail_max_tokens
|
75 |
+
new_batch_sentences = new_batch_end - batch_start
|
76 |
+
new_batch_num_tokens = new_batch_sentences * new_batch_max_tokens
|
77 |
+
|
78 |
+
overflow = (new_batch_sentences > max_sentences > 0 or
|
79 |
+
new_batch_num_tokens > max_tokens > 0)
|
80 |
+
size_matches_with_bsz_mult = (new_batch_sentences < bsz_mult or
|
81 |
+
new_batch_sentences % bsz_mult == 0)
|
82 |
+
|
83 |
+
if overflow:
|
84 |
+
tail_num_tokens = tail_max_tokens * \
|
85 |
+
(new_batch_end - batches_ends_view[batches_count])
|
86 |
+
tail_overflow = tail_num_tokens > max_tokens > 0
|
87 |
+
# In case of a tail overflow finalize two batches
|
88 |
+
if tail_overflow:
|
89 |
+
batches_count += 1
|
90 |
+
batches_ends_view[batches_count] = pos
|
91 |
+
tail_max_tokens = num_tokens_view[pos]
|
92 |
+
batch_start = batches_ends_view[batches_count]
|
93 |
+
batches_count += 1
|
94 |
+
new_batch_max_tokens = tail_max_tokens
|
95 |
+
|
96 |
+
if overflow or size_matches_with_bsz_mult:
|
97 |
+
batches_ends_view[batches_count] = new_batch_end
|
98 |
+
batch_max_tokens = new_batch_max_tokens
|
99 |
+
tail_max_tokens = 0
|
100 |
+
if batches_ends_view[batches_count] != indices_len:
|
101 |
+
batches_count += 1
|
102 |
+
# Memory and time-efficient split
|
103 |
+
return np.split(indices, batches_ends[:batches_count])
|
104 |
+
|
105 |
+
|
106 |
+
@cython.boundscheck(False)
|
107 |
+
@cython.wraparound(False)
|
108 |
+
cpdef list batch_by_size_fn(
|
109 |
+
np.ndarray[DTYPE_t, ndim=1] indices,
|
110 |
+
num_tokens_fn,
|
111 |
+
int64_t max_tokens,
|
112 |
+
int64_t max_sentences,
|
113 |
+
int32_t bsz_mult,
|
114 |
+
):
|
115 |
+
cdef int32_t indices_len = indices.shape[0]
|
116 |
+
cdef np.ndarray[int64_t, ndim=1] num_tokens_vec = np.zeros(indices_len,
|
117 |
+
dtype=np.int64)
|
118 |
+
cdef DTYPE_t[:] indices_view = indices
|
119 |
+
cdef DTYPE_t[:] num_tokens_vec_view = num_tokens_vec
|
120 |
+
cdef int64_t pos
|
121 |
+
for pos in range(indices_len):
|
122 |
+
num_tokens_vec[pos] = num_tokens_fn(indices_view[pos])
|
123 |
+
return batch_by_size_vec(indices, num_tokens_vec, max_tokens,
|
124 |
+
max_sentences, bsz_mult,)
|
125 |
+
|
126 |
+
|
127 |
+
cdef _find_valid_shape(
|
128 |
+
DTYPE_t[:, :] shapes_view,
|
129 |
+
int64_t num_sentences,
|
130 |
+
int64_t num_tokens,
|
131 |
+
):
|
132 |
+
"""Return index of first valid shape of -1 if none is found."""
|
133 |
+
for i in range(shapes_view.shape[0]):
|
134 |
+
if num_sentences <= shapes_view[i][0] and num_tokens <= shapes_view[i][1]:
|
135 |
+
return i
|
136 |
+
return -1
|
137 |
+
|
138 |
+
|
139 |
+
@cython.cdivision(True)
|
140 |
+
cpdef list batch_fixed_shapes_fast(
|
141 |
+
np.ndarray[DTYPE_t, ndim=1] indices,
|
142 |
+
num_tokens_fn,
|
143 |
+
np.ndarray[DTYPE_t, ndim=2] fixed_shapes_sorted,
|
144 |
+
):
|
145 |
+
cdef int64_t sample_len = 0
|
146 |
+
cdef list sample_lens = []
|
147 |
+
cdef list batch = []
|
148 |
+
cdef list batches = []
|
149 |
+
cdef int64_t mod_len
|
150 |
+
cdef int64_t i
|
151 |
+
cdef int64_t idx
|
152 |
+
cdef int64_t num_tokens
|
153 |
+
cdef DTYPE_t[:] indices_view = indices
|
154 |
+
cdef DTYPE_t[:, :] shapes_view = fixed_shapes_sorted
|
155 |
+
|
156 |
+
for i in range(len(indices_view)):
|
157 |
+
idx = indices_view[i]
|
158 |
+
num_tokens = num_tokens_fn(idx)
|
159 |
+
sample_lens.append(num_tokens)
|
160 |
+
sample_len = max(sample_len, num_tokens)
|
161 |
+
|
162 |
+
shape_idx = _find_valid_shape(shapes_view, len(batch) + 1, sample_len)
|
163 |
+
if shape_idx == -1:
|
164 |
+
batches.append(batch)
|
165 |
+
batch = []
|
166 |
+
sample_lens = []
|
167 |
+
sample_len = 0
|
168 |
+
shapes_view = fixed_shapes_sorted
|
169 |
+
elif shape_idx > 0:
|
170 |
+
# small optimization for the next call to _find_valid_shape
|
171 |
+
shapes_view = shapes_view[shape_idx:]
|
172 |
+
|
173 |
+
batch.append(idx)
|
174 |
+
|
175 |
+
if len(batch) > 0:
|
176 |
+
batches.append(batch)
|
177 |
+
|
178 |
+
return batches
|
fairseq/fairseq/data/denoising_dataset.py
ADDED
@@ -0,0 +1,443 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import math
|
7 |
+
|
8 |
+
import numpy as np
|
9 |
+
import torch
|
10 |
+
|
11 |
+
from . import FairseqDataset, data_utils
|
12 |
+
|
13 |
+
|
14 |
+
def collate(
|
15 |
+
samples,
|
16 |
+
pad_idx,
|
17 |
+
eos_idx,
|
18 |
+
vocab,
|
19 |
+
left_pad_source=False,
|
20 |
+
left_pad_target=False,
|
21 |
+
input_feeding=True,
|
22 |
+
pad_to_length=None,
|
23 |
+
):
|
24 |
+
assert input_feeding
|
25 |
+
if len(samples) == 0:
|
26 |
+
return {}
|
27 |
+
|
28 |
+
def merge(key, left_pad, move_eos_to_beginning=False, pad_to_length=None):
|
29 |
+
return data_utils.collate_tokens(
|
30 |
+
[s[key] for s in samples],
|
31 |
+
pad_idx,
|
32 |
+
eos_idx=None, # use eos_idx of each sample instead of vocab.eos()
|
33 |
+
left_pad=left_pad,
|
34 |
+
move_eos_to_beginning=move_eos_to_beginning,
|
35 |
+
pad_to_length=pad_to_length,
|
36 |
+
)
|
37 |
+
|
38 |
+
id = torch.LongTensor([s["id"] for s in samples])
|
39 |
+
src_tokens = merge(
|
40 |
+
"source",
|
41 |
+
left_pad=left_pad_source,
|
42 |
+
pad_to_length=pad_to_length["source"] if pad_to_length is not None else None,
|
43 |
+
)
|
44 |
+
# sort by descending source length
|
45 |
+
src_lengths = torch.LongTensor([s["source"].numel() for s in samples])
|
46 |
+
src_lengths, sort_order = src_lengths.sort(descending=True)
|
47 |
+
id = id.index_select(0, sort_order)
|
48 |
+
src_tokens = src_tokens.index_select(0, sort_order)
|
49 |
+
|
50 |
+
prev_output_tokens = None
|
51 |
+
target = None
|
52 |
+
if samples[0].get("target", None) is not None:
|
53 |
+
target = merge(
|
54 |
+
"target",
|
55 |
+
left_pad=left_pad_target,
|
56 |
+
pad_to_length=pad_to_length["target"]
|
57 |
+
if pad_to_length is not None
|
58 |
+
else None,
|
59 |
+
)
|
60 |
+
target = target.index_select(0, sort_order)
|
61 |
+
ntokens = sum(len(s["target"]) for s in samples)
|
62 |
+
|
63 |
+
if input_feeding:
|
64 |
+
# we create a shifted version of targets for feeding the
|
65 |
+
# previous output token(s) into the next decoder step
|
66 |
+
prev_output_tokens = merge(
|
67 |
+
"target",
|
68 |
+
left_pad=left_pad_target,
|
69 |
+
move_eos_to_beginning=True,
|
70 |
+
pad_to_length=pad_to_length["target"]
|
71 |
+
if pad_to_length is not None
|
72 |
+
else None,
|
73 |
+
)
|
74 |
+
prev_output_tokens = prev_output_tokens.index_select(0, sort_order)
|
75 |
+
else:
|
76 |
+
ntokens = sum(len(s["source"]) for s in samples)
|
77 |
+
|
78 |
+
batch = {
|
79 |
+
"id": id,
|
80 |
+
"ntokens": ntokens,
|
81 |
+
"net_input": {
|
82 |
+
"src_tokens": src_tokens,
|
83 |
+
"src_lengths": src_lengths,
|
84 |
+
},
|
85 |
+
"target": target,
|
86 |
+
"nsentences": samples[0]["source"].size(0),
|
87 |
+
"sort_order": sort_order,
|
88 |
+
}
|
89 |
+
if prev_output_tokens is not None:
|
90 |
+
batch["net_input"]["prev_output_tokens"] = prev_output_tokens
|
91 |
+
|
92 |
+
return batch
|
93 |
+
|
94 |
+
|
95 |
+
class DenoisingDataset(FairseqDataset):
|
96 |
+
"""
|
97 |
+
A wrapper around TokenBlockDataset for BART dataset.
|
98 |
+
|
99 |
+
Args:
|
100 |
+
dataset (TokenBlockDataset): dataset to wrap
|
101 |
+
sizes (List[int]): sentence lengths
|
102 |
+
vocab (~fairseq.data.Dictionary): vocabulary
|
103 |
+
mask_idx (int): dictionary index used for masked token
|
104 |
+
mask_whole_words: only mask whole words. This should be a byte mask
|
105 |
+
over vocab indices, indicating whether it is the beginning of a
|
106 |
+
word. We will extend any mask to encompass the whole word.
|
107 |
+
shuffle (bool, optional): shuffle the elements before batching.
|
108 |
+
Default: ``True``
|
109 |
+
seed: Seed for random number generator for reproducibility.
|
110 |
+
"""
|
111 |
+
|
112 |
+
def __init__(
|
113 |
+
self,
|
114 |
+
dataset,
|
115 |
+
sizes,
|
116 |
+
vocab,
|
117 |
+
mask_idx,
|
118 |
+
mask_whole_words,
|
119 |
+
shuffle,
|
120 |
+
seed,
|
121 |
+
mask,
|
122 |
+
mask_random,
|
123 |
+
insert,
|
124 |
+
rotate,
|
125 |
+
permute_sentences,
|
126 |
+
bpe,
|
127 |
+
replace_length,
|
128 |
+
mask_length,
|
129 |
+
poisson_lambda,
|
130 |
+
eos=None,
|
131 |
+
item_transform_func=None,
|
132 |
+
):
|
133 |
+
self.dataset = dataset
|
134 |
+
|
135 |
+
self.sizes = sizes
|
136 |
+
|
137 |
+
self.vocab = vocab
|
138 |
+
self.shuffle = shuffle
|
139 |
+
self.seed = seed
|
140 |
+
self.mask_idx = mask_idx
|
141 |
+
self.mask_whole_word = mask_whole_words
|
142 |
+
self.mask_ratio = mask
|
143 |
+
self.random_ratio = mask_random
|
144 |
+
self.insert_ratio = insert
|
145 |
+
self.rotate_ratio = rotate
|
146 |
+
self.permute_sentence_ratio = permute_sentences
|
147 |
+
self.eos = eos if eos is not None else vocab.eos()
|
148 |
+
self.item_transform_func = item_transform_func
|
149 |
+
|
150 |
+
if bpe != "gpt2":
|
151 |
+
self.full_stop_index = self.vocab.eos()
|
152 |
+
else:
|
153 |
+
assert bpe == "gpt2"
|
154 |
+
self.full_stop_index = self.vocab.index("13")
|
155 |
+
|
156 |
+
self.replace_length = replace_length
|
157 |
+
if self.replace_length not in [-1, 0, 1]:
|
158 |
+
raise ValueError(f"invalid arg: replace_length={self.replace_length}")
|
159 |
+
if mask_length not in ["subword", "word", "span-poisson"]:
|
160 |
+
raise ValueError(f"invalid arg: mask-length={mask_length}")
|
161 |
+
if mask_length == "subword" and replace_length not in [0, 1]:
|
162 |
+
raise ValueError(f"if using subwords, use replace-length=1 or 0")
|
163 |
+
|
164 |
+
self.mask_span_distribution = None
|
165 |
+
if mask_length == "span-poisson":
|
166 |
+
_lambda = poisson_lambda
|
167 |
+
|
168 |
+
lambda_to_the_k = 1
|
169 |
+
e_to_the_minus_lambda = math.exp(-_lambda)
|
170 |
+
k_factorial = 1
|
171 |
+
ps = []
|
172 |
+
for k in range(0, 128):
|
173 |
+
ps.append(e_to_the_minus_lambda * lambda_to_the_k / k_factorial)
|
174 |
+
lambda_to_the_k *= _lambda
|
175 |
+
k_factorial *= k + 1
|
176 |
+
if ps[-1] < 0.0000001:
|
177 |
+
break
|
178 |
+
ps = torch.FloatTensor(ps)
|
179 |
+
self.mask_span_distribution = torch.distributions.Categorical(ps)
|
180 |
+
|
181 |
+
self.epoch = 0
|
182 |
+
|
183 |
+
@property
|
184 |
+
def can_reuse_epoch_itr_across_epochs(self):
|
185 |
+
return True # only the noise changes, not item sizes
|
186 |
+
|
187 |
+
def set_epoch(self, epoch, **unused):
|
188 |
+
self.epoch = epoch
|
189 |
+
|
190 |
+
def __getitem__(self, index):
|
191 |
+
with data_utils.numpy_seed(self.seed, self.epoch, index):
|
192 |
+
tokens = self.dataset[index]
|
193 |
+
assert tokens[-1] == self.eos
|
194 |
+
source, target = tokens, tokens.clone()
|
195 |
+
|
196 |
+
if self.permute_sentence_ratio > 0.0:
|
197 |
+
source = self.permute_sentences(source, self.permute_sentence_ratio)
|
198 |
+
|
199 |
+
if self.mask_ratio > 0:
|
200 |
+
source = self.add_whole_word_mask(source, self.mask_ratio)
|
201 |
+
|
202 |
+
if self.insert_ratio > 0:
|
203 |
+
source = self.add_insertion_noise(source, self.insert_ratio)
|
204 |
+
|
205 |
+
if self.rotate_ratio > 0.0 and np.random.random() < self.rotate_ratio:
|
206 |
+
source = self.add_rolling_noise(source)
|
207 |
+
# there can additional changes to make:
|
208 |
+
if self.item_transform_func is not None:
|
209 |
+
source, target = self.item_transform_func(source, target)
|
210 |
+
|
211 |
+
assert (source >= 0).all()
|
212 |
+
assert (source[1:-1] >= 1).all()
|
213 |
+
assert (source <= len(self.vocab)).all()
|
214 |
+
assert source[0] == self.vocab.bos()
|
215 |
+
assert source[-1] == self.eos
|
216 |
+
return {
|
217 |
+
"id": index,
|
218 |
+
"source": source,
|
219 |
+
"target": target,
|
220 |
+
}
|
221 |
+
|
222 |
+
def __len__(self):
|
223 |
+
return len(self.dataset)
|
224 |
+
|
225 |
+
def permute_sentences(self, source, p=1.0):
|
226 |
+
full_stops = source == self.full_stop_index
|
227 |
+
# Pretend it ends with a full stop so last span is a sentence
|
228 |
+
full_stops[-2] = 1
|
229 |
+
|
230 |
+
# Tokens that are full stops, where the previous token is not
|
231 |
+
sentence_ends = (full_stops[1:] * ~full_stops[:-1]).nonzero(as_tuple=False) + 2
|
232 |
+
result = source.clone()
|
233 |
+
|
234 |
+
num_sentences = sentence_ends.size(0)
|
235 |
+
num_to_permute = math.ceil((num_sentences * 2 * p) / 2.0)
|
236 |
+
substitutions = torch.randperm(num_sentences)[:num_to_permute]
|
237 |
+
ordering = torch.arange(0, num_sentences)
|
238 |
+
ordering[substitutions] = substitutions[torch.randperm(num_to_permute)]
|
239 |
+
|
240 |
+
# Ignore <bos> at start
|
241 |
+
index = 1
|
242 |
+
for i in ordering:
|
243 |
+
sentence = source[(sentence_ends[i - 1] if i > 0 else 1) : sentence_ends[i]]
|
244 |
+
result[index : index + sentence.size(0)] = sentence
|
245 |
+
index += sentence.size(0)
|
246 |
+
return result
|
247 |
+
|
248 |
+
def word_starts(self, source):
|
249 |
+
if self.mask_whole_word is not None:
|
250 |
+
is_word_start = self.mask_whole_word.gather(0, source)
|
251 |
+
else:
|
252 |
+
is_word_start = torch.ones(source.size())
|
253 |
+
is_word_start[0] = 0
|
254 |
+
is_word_start[-1] = 0
|
255 |
+
return is_word_start
|
256 |
+
|
257 |
+
def add_whole_word_mask(self, source, p):
|
258 |
+
is_word_start = self.word_starts(source)
|
259 |
+
num_to_mask = int(math.ceil(is_word_start.float().sum() * p))
|
260 |
+
num_inserts = 0
|
261 |
+
if num_to_mask == 0:
|
262 |
+
return source
|
263 |
+
|
264 |
+
if self.mask_span_distribution is not None:
|
265 |
+
lengths = self.mask_span_distribution.sample(sample_shape=(num_to_mask,))
|
266 |
+
|
267 |
+
# Make sure we have enough to mask
|
268 |
+
cum_length = torch.cumsum(lengths, 0)
|
269 |
+
while cum_length[-1] < num_to_mask:
|
270 |
+
lengths = torch.cat(
|
271 |
+
[
|
272 |
+
lengths,
|
273 |
+
self.mask_span_distribution.sample(sample_shape=(num_to_mask,)),
|
274 |
+
],
|
275 |
+
dim=0,
|
276 |
+
)
|
277 |
+
cum_length = torch.cumsum(lengths, 0)
|
278 |
+
|
279 |
+
# Trim to masking budget
|
280 |
+
i = 0
|
281 |
+
while cum_length[i] < num_to_mask:
|
282 |
+
i += 1
|
283 |
+
lengths[i] = num_to_mask - (0 if i == 0 else cum_length[i - 1])
|
284 |
+
num_to_mask = i + 1
|
285 |
+
lengths = lengths[:num_to_mask]
|
286 |
+
|
287 |
+
# Handle 0-length mask (inserts) separately
|
288 |
+
lengths = lengths[lengths > 0]
|
289 |
+
num_inserts = num_to_mask - lengths.size(0)
|
290 |
+
num_to_mask -= num_inserts
|
291 |
+
if num_to_mask == 0:
|
292 |
+
return self.add_insertion_noise(source, num_inserts / source.size(0))
|
293 |
+
|
294 |
+
assert (lengths > 0).all()
|
295 |
+
else:
|
296 |
+
lengths = torch.ones((num_to_mask,)).long()
|
297 |
+
assert is_word_start[-1] == 0
|
298 |
+
word_starts = is_word_start.nonzero(as_tuple=False)
|
299 |
+
indices = word_starts[
|
300 |
+
torch.randperm(word_starts.size(0))[:num_to_mask]
|
301 |
+
].squeeze(1)
|
302 |
+
mask_random = torch.FloatTensor(num_to_mask).uniform_() < self.random_ratio
|
303 |
+
|
304 |
+
source_length = source.size(0)
|
305 |
+
assert source_length - 1 not in indices
|
306 |
+
to_keep = torch.ones(source_length, dtype=torch.bool)
|
307 |
+
is_word_start[
|
308 |
+
-1
|
309 |
+
] = 255 # acts as a long length, so spans don't go over the end of doc
|
310 |
+
if self.replace_length == 0:
|
311 |
+
to_keep[indices] = 0
|
312 |
+
else:
|
313 |
+
# keep index, but replace it with [MASK]
|
314 |
+
source[indices] = self.mask_idx
|
315 |
+
source[indices[mask_random]] = torch.randint(
|
316 |
+
1, len(self.vocab), size=(mask_random.sum(),)
|
317 |
+
)
|
318 |
+
|
319 |
+
if self.mask_span_distribution is not None:
|
320 |
+
assert len(lengths.size()) == 1
|
321 |
+
assert lengths.size() == indices.size()
|
322 |
+
lengths -= 1
|
323 |
+
while indices.size(0) > 0:
|
324 |
+
assert lengths.size() == indices.size()
|
325 |
+
lengths -= is_word_start[indices + 1].long()
|
326 |
+
uncompleted = lengths >= 0
|
327 |
+
indices = indices[uncompleted] + 1
|
328 |
+
mask_random = mask_random[uncompleted]
|
329 |
+
lengths = lengths[uncompleted]
|
330 |
+
if self.replace_length != -1:
|
331 |
+
# delete token
|
332 |
+
to_keep[indices] = 0
|
333 |
+
else:
|
334 |
+
# keep index, but replace it with [MASK]
|
335 |
+
source[indices] = self.mask_idx
|
336 |
+
source[indices[mask_random]] = torch.randint(
|
337 |
+
1, len(self.vocab), size=(mask_random.sum(),)
|
338 |
+
)
|
339 |
+
else:
|
340 |
+
# A bit faster when all lengths are 1
|
341 |
+
while indices.size(0) > 0:
|
342 |
+
uncompleted = is_word_start[indices + 1] == 0
|
343 |
+
indices = indices[uncompleted] + 1
|
344 |
+
mask_random = mask_random[uncompleted]
|
345 |
+
if self.replace_length != -1:
|
346 |
+
# delete token
|
347 |
+
to_keep[indices] = 0
|
348 |
+
else:
|
349 |
+
# keep index, but replace it with [MASK]
|
350 |
+
source[indices] = self.mask_idx
|
351 |
+
source[indices[mask_random]] = torch.randint(
|
352 |
+
1, len(self.vocab), size=(mask_random.sum(),)
|
353 |
+
)
|
354 |
+
|
355 |
+
assert source_length - 1 not in indices
|
356 |
+
|
357 |
+
source = source[to_keep]
|
358 |
+
|
359 |
+
if num_inserts > 0:
|
360 |
+
source = self.add_insertion_noise(source, num_inserts / source.size(0))
|
361 |
+
|
362 |
+
return source
|
363 |
+
|
364 |
+
def add_permuted_noise(self, tokens, p):
|
365 |
+
num_words = len(tokens)
|
366 |
+
num_to_permute = math.ceil(((num_words * 2) * p) / 2.0)
|
367 |
+
substitutions = torch.randperm(num_words - 2)[:num_to_permute] + 1
|
368 |
+
tokens[substitutions] = tokens[substitutions[torch.randperm(num_to_permute)]]
|
369 |
+
return tokens
|
370 |
+
|
371 |
+
def add_rolling_noise(self, tokens):
|
372 |
+
offset = np.random.randint(1, max(1, tokens.size(-1) - 1) + 1)
|
373 |
+
tokens = torch.cat(
|
374 |
+
(tokens[0:1], tokens[offset:-1], tokens[1:offset], tokens[-1:]),
|
375 |
+
dim=0,
|
376 |
+
)
|
377 |
+
return tokens
|
378 |
+
|
379 |
+
def add_insertion_noise(self, tokens, p):
|
380 |
+
if p == 0.0:
|
381 |
+
return tokens
|
382 |
+
|
383 |
+
num_tokens = len(tokens)
|
384 |
+
n = int(math.ceil(num_tokens * p))
|
385 |
+
|
386 |
+
noise_indices = torch.randperm(num_tokens + n - 2)[:n] + 1
|
387 |
+
noise_mask = torch.zeros(size=(num_tokens + n,), dtype=torch.bool)
|
388 |
+
noise_mask[noise_indices] = 1
|
389 |
+
result = torch.LongTensor(n + len(tokens)).fill_(-1)
|
390 |
+
|
391 |
+
num_random = int(math.ceil(n * self.random_ratio))
|
392 |
+
result[noise_indices[num_random:]] = self.mask_idx
|
393 |
+
result[noise_indices[:num_random]] = torch.randint(
|
394 |
+
low=1, high=len(self.vocab), size=(num_random,)
|
395 |
+
)
|
396 |
+
|
397 |
+
result[~noise_mask] = tokens
|
398 |
+
|
399 |
+
assert (result >= 0).all()
|
400 |
+
return result
|
401 |
+
|
402 |
+
def collater(self, samples, pad_to_length=None):
|
403 |
+
"""Merge a list of samples to form a mini-batch.
|
404 |
+
Args:
|
405 |
+
samples (List[dict]): samples to collate
|
406 |
+
Returns:
|
407 |
+
dict: a mini-batch of data
|
408 |
+
"""
|
409 |
+
return collate(
|
410 |
+
samples, self.vocab.pad(), self.eos, self.vocab, pad_to_length=pad_to_length
|
411 |
+
)
|
412 |
+
|
413 |
+
def num_tokens(self, index):
|
414 |
+
"""Return the number of tokens in a sample. This value is used to
|
415 |
+
enforce ``--max-tokens`` during batching."""
|
416 |
+
return self.sizes[index]
|
417 |
+
|
418 |
+
def size(self, index):
|
419 |
+
"""Return an example's size as a float or tuple. This value is used when
|
420 |
+
filtering a dataset with ``--max-positions``."""
|
421 |
+
return self.sizes[index]
|
422 |
+
|
423 |
+
def ordered_indices(self):
|
424 |
+
"""Return an ordered list of indices. Batches will be constructed based
|
425 |
+
on this order."""
|
426 |
+
if self.shuffle:
|
427 |
+
indices = np.random.permutation(len(self))
|
428 |
+
else:
|
429 |
+
indices = np.arange(len(self))
|
430 |
+
return indices[np.argsort(self.sizes[indices], kind="mergesort")]
|
431 |
+
|
432 |
+
def prefetch(self, indices):
|
433 |
+
self.src.prefetch(indices)
|
434 |
+
self.tgt.prefetch(indices)
|
435 |
+
|
436 |
+
@property
|
437 |
+
def supports_prefetch(self):
|
438 |
+
return (
|
439 |
+
hasattr(self.src, "supports_prefetch")
|
440 |
+
and self.src.supports_prefetch
|
441 |
+
and hasattr(self.tgt, "supports_prefetch")
|
442 |
+
and self.tgt.supports_prefetch
|
443 |
+
)
|
fairseq/fairseq/data/dictionary.py
ADDED
@@ -0,0 +1,403 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import os
|
7 |
+
from collections import Counter
|
8 |
+
from multiprocessing import Pool
|
9 |
+
|
10 |
+
import torch
|
11 |
+
from fairseq import utils
|
12 |
+
from fairseq.data import data_utils
|
13 |
+
from fairseq.file_chunker_utils import Chunker, find_offsets
|
14 |
+
from fairseq.file_io import PathManager
|
15 |
+
from fairseq.tokenizer import tokenize_line
|
16 |
+
|
17 |
+
|
18 |
+
class Dictionary:
|
19 |
+
"""A mapping from symbols to consecutive integers"""
|
20 |
+
|
21 |
+
def __init__(
|
22 |
+
self,
|
23 |
+
*, # begin keyword-only arguments
|
24 |
+
bos="<s>",
|
25 |
+
pad="<pad>",
|
26 |
+
eos="</s>",
|
27 |
+
unk="<unk>",
|
28 |
+
extra_special_symbols=None,
|
29 |
+
add_special_symbols=True,
|
30 |
+
):
|
31 |
+
self.bos_word, self.unk_word, self.pad_word, self.eos_word = bos, unk, pad, eos
|
32 |
+
self.symbols = []
|
33 |
+
self.count = []
|
34 |
+
self.indices = {}
|
35 |
+
if add_special_symbols:
|
36 |
+
self.bos_index = self.add_symbol(bos)
|
37 |
+
self.pad_index = self.add_symbol(pad)
|
38 |
+
self.eos_index = self.add_symbol(eos)
|
39 |
+
self.unk_index = self.add_symbol(unk)
|
40 |
+
if extra_special_symbols:
|
41 |
+
for s in extra_special_symbols:
|
42 |
+
self.add_symbol(s)
|
43 |
+
self.nspecial = len(self.symbols)
|
44 |
+
|
45 |
+
def __eq__(self, other):
|
46 |
+
return self.indices == other.indices
|
47 |
+
|
48 |
+
def __getitem__(self, idx):
|
49 |
+
if idx < len(self.symbols):
|
50 |
+
return self.symbols[idx]
|
51 |
+
return self.unk_word
|
52 |
+
|
53 |
+
def get_count(self, idx):
|
54 |
+
return self.count[idx]
|
55 |
+
|
56 |
+
def __len__(self):
|
57 |
+
"""Returns the number of symbols in the dictionary"""
|
58 |
+
return len(self.symbols)
|
59 |
+
|
60 |
+
def __contains__(self, sym):
|
61 |
+
return sym in self.indices
|
62 |
+
|
63 |
+
def index(self, sym):
|
64 |
+
"""Returns the index of the specified symbol"""
|
65 |
+
assert isinstance(sym, str)
|
66 |
+
if sym in self.indices:
|
67 |
+
return self.indices[sym]
|
68 |
+
return self.unk_index
|
69 |
+
|
70 |
+
def string(
|
71 |
+
self,
|
72 |
+
tensor,
|
73 |
+
bpe_symbol=None,
|
74 |
+
escape_unk=False,
|
75 |
+
extra_symbols_to_ignore=None,
|
76 |
+
unk_string=None,
|
77 |
+
include_eos=False,
|
78 |
+
separator=" ",
|
79 |
+
):
|
80 |
+
"""Helper for converting a tensor of token indices to a string.
|
81 |
+
|
82 |
+
Can optionally remove BPE symbols or escape <unk> words.
|
83 |
+
"""
|
84 |
+
if torch.is_tensor(tensor) and tensor.dim() == 2:
|
85 |
+
return "\n".join(
|
86 |
+
self.string(
|
87 |
+
t,
|
88 |
+
bpe_symbol,
|
89 |
+
escape_unk,
|
90 |
+
extra_symbols_to_ignore,
|
91 |
+
include_eos=include_eos,
|
92 |
+
)
|
93 |
+
for t in tensor
|
94 |
+
)
|
95 |
+
|
96 |
+
extra_symbols_to_ignore = set(extra_symbols_to_ignore or [])
|
97 |
+
if not include_eos:
|
98 |
+
extra_symbols_to_ignore.add(self.eos())
|
99 |
+
|
100 |
+
def token_string(i):
|
101 |
+
if i == self.unk():
|
102 |
+
if unk_string is not None:
|
103 |
+
return unk_string
|
104 |
+
else:
|
105 |
+
return self.unk_string(escape_unk)
|
106 |
+
else:
|
107 |
+
return self[i]
|
108 |
+
|
109 |
+
if hasattr(self, "bos_index"):
|
110 |
+
extra_symbols_to_ignore.add(self.bos())
|
111 |
+
|
112 |
+
sent = separator.join(
|
113 |
+
token_string(i)
|
114 |
+
for i in tensor
|
115 |
+
if utils.item(i) not in extra_symbols_to_ignore
|
116 |
+
)
|
117 |
+
|
118 |
+
return data_utils.post_process(sent, bpe_symbol)
|
119 |
+
|
120 |
+
def unk_string(self, escape=False):
|
121 |
+
"""Return unknown string, optionally escaped as: <<unk>>"""
|
122 |
+
if escape:
|
123 |
+
return "<{}>".format(self.unk_word)
|
124 |
+
else:
|
125 |
+
return self.unk_word
|
126 |
+
|
127 |
+
def add_symbol(self, word, n=1, overwrite=False):
|
128 |
+
"""Adds a word to the dictionary"""
|
129 |
+
if word in self.indices and not overwrite:
|
130 |
+
idx = self.indices[word]
|
131 |
+
self.count[idx] = self.count[idx] + n
|
132 |
+
return idx
|
133 |
+
else:
|
134 |
+
idx = len(self.symbols)
|
135 |
+
self.indices[word] = idx
|
136 |
+
self.symbols.append(word)
|
137 |
+
self.count.append(n)
|
138 |
+
return idx
|
139 |
+
|
140 |
+
def update(self, new_dict):
|
141 |
+
"""Updates counts from new dictionary."""
|
142 |
+
for word in new_dict.symbols:
|
143 |
+
idx2 = new_dict.indices[word]
|
144 |
+
if word in self.indices:
|
145 |
+
idx = self.indices[word]
|
146 |
+
self.count[idx] = self.count[idx] + new_dict.count[idx2]
|
147 |
+
else:
|
148 |
+
idx = len(self.symbols)
|
149 |
+
self.indices[word] = idx
|
150 |
+
self.symbols.append(word)
|
151 |
+
self.count.append(new_dict.count[idx2])
|
152 |
+
|
153 |
+
def finalize(self, threshold=-1, nwords=-1, padding_factor=8):
|
154 |
+
"""Sort symbols by frequency in descending order, ignoring special ones.
|
155 |
+
|
156 |
+
Args:
|
157 |
+
- threshold defines the minimum word count
|
158 |
+
- nwords defines the total number of words in the final dictionary,
|
159 |
+
including special symbols
|
160 |
+
- padding_factor can be used to pad the dictionary size to be a
|
161 |
+
multiple of 8, which is important on some hardware (e.g., Nvidia
|
162 |
+
Tensor Cores).
|
163 |
+
"""
|
164 |
+
if nwords <= 0:
|
165 |
+
nwords = len(self)
|
166 |
+
|
167 |
+
new_indices = dict(zip(self.symbols[: self.nspecial], range(self.nspecial)))
|
168 |
+
new_symbols = self.symbols[: self.nspecial]
|
169 |
+
new_count = self.count[: self.nspecial]
|
170 |
+
|
171 |
+
c = Counter(
|
172 |
+
dict(
|
173 |
+
sorted(zip(self.symbols[self.nspecial :], self.count[self.nspecial :]))
|
174 |
+
)
|
175 |
+
)
|
176 |
+
for symbol, count in c.most_common(nwords - self.nspecial):
|
177 |
+
if count >= threshold:
|
178 |
+
new_indices[symbol] = len(new_symbols)
|
179 |
+
new_symbols.append(symbol)
|
180 |
+
new_count.append(count)
|
181 |
+
else:
|
182 |
+
break
|
183 |
+
|
184 |
+
assert len(new_symbols) == len(new_indices)
|
185 |
+
|
186 |
+
self.count = list(new_count)
|
187 |
+
self.symbols = list(new_symbols)
|
188 |
+
self.indices = new_indices
|
189 |
+
|
190 |
+
self.pad_to_multiple_(padding_factor)
|
191 |
+
|
192 |
+
def pad_to_multiple_(self, padding_factor):
|
193 |
+
"""Pad Dictionary size to be a multiple of *padding_factor*."""
|
194 |
+
if padding_factor > 1:
|
195 |
+
i = 0
|
196 |
+
while len(self) % padding_factor != 0:
|
197 |
+
symbol = "madeupword{:04d}".format(i)
|
198 |
+
self.add_symbol(symbol, n=0)
|
199 |
+
i += 1
|
200 |
+
|
201 |
+
def bos(self):
|
202 |
+
"""Helper to get index of beginning-of-sentence symbol"""
|
203 |
+
return self.bos_index
|
204 |
+
|
205 |
+
def pad(self):
|
206 |
+
"""Helper to get index of pad symbol"""
|
207 |
+
return self.pad_index
|
208 |
+
|
209 |
+
def eos(self):
|
210 |
+
"""Helper to get index of end-of-sentence symbol"""
|
211 |
+
return self.eos_index
|
212 |
+
|
213 |
+
def unk(self):
|
214 |
+
"""Helper to get index of unk symbol"""
|
215 |
+
return self.unk_index
|
216 |
+
|
217 |
+
@classmethod
|
218 |
+
def load(cls, f, add_special_symbols=True):
|
219 |
+
"""Loads the dictionary from a text file with the format:
|
220 |
+
|
221 |
+
```
|
222 |
+
<symbol0> <count0>
|
223 |
+
<symbol1> <count1>
|
224 |
+
...
|
225 |
+
```
|
226 |
+
"""
|
227 |
+
d = cls(add_special_symbols=add_special_symbols)
|
228 |
+
d.add_from_file(f)
|
229 |
+
return d
|
230 |
+
|
231 |
+
def add_from_file(self, f):
|
232 |
+
"""
|
233 |
+
Loads a pre-existing dictionary from a text file and adds its symbols
|
234 |
+
to this instance.
|
235 |
+
"""
|
236 |
+
if isinstance(f, str):
|
237 |
+
try:
|
238 |
+
with open(PathManager.get_local_path(f), "r", encoding="utf-8") as fd:
|
239 |
+
self.add_from_file(fd)
|
240 |
+
except FileNotFoundError as fnfe:
|
241 |
+
raise fnfe
|
242 |
+
except UnicodeError:
|
243 |
+
raise Exception(
|
244 |
+
"Incorrect encoding detected in {}, please "
|
245 |
+
"rebuild the dataset".format(f)
|
246 |
+
)
|
247 |
+
return
|
248 |
+
|
249 |
+
lines = f.readlines()
|
250 |
+
indices_start_line = self._load_meta(lines)
|
251 |
+
|
252 |
+
for line in lines[indices_start_line:]:
|
253 |
+
try:
|
254 |
+
line, field = line.rstrip().rsplit(" ", 1)
|
255 |
+
if field == "#fairseq:overwrite":
|
256 |
+
overwrite = True
|
257 |
+
line, field = line.rsplit(" ", 1)
|
258 |
+
else:
|
259 |
+
overwrite = False
|
260 |
+
count = int(field)
|
261 |
+
word = line
|
262 |
+
if word in self and not overwrite:
|
263 |
+
raise RuntimeError(
|
264 |
+
"Duplicate word found when loading Dictionary: '{}'. "
|
265 |
+
"Duplicate words can overwrite earlier ones by adding the "
|
266 |
+
"#fairseq:overwrite flag at the end of the corresponding row "
|
267 |
+
"in the dictionary file. If using the Camembert model, please "
|
268 |
+
"download an updated copy of the model file.".format(word)
|
269 |
+
)
|
270 |
+
self.add_symbol(word, n=count, overwrite=overwrite)
|
271 |
+
except ValueError:
|
272 |
+
raise ValueError(
|
273 |
+
f"Incorrect dictionary format, expected '<token> <cnt> [flags]': \"{line}\""
|
274 |
+
)
|
275 |
+
|
276 |
+
def _save(self, f, kv_iterator):
|
277 |
+
if isinstance(f, str):
|
278 |
+
PathManager.mkdirs(os.path.dirname(f))
|
279 |
+
with PathManager.open(f, "w", encoding="utf-8") as fd:
|
280 |
+
return self.save(fd)
|
281 |
+
for k, v in kv_iterator:
|
282 |
+
print("{} {}".format(k, v), file=f)
|
283 |
+
|
284 |
+
def _get_meta(self):
|
285 |
+
return [], []
|
286 |
+
|
287 |
+
def _load_meta(self, lines):
|
288 |
+
return 0
|
289 |
+
|
290 |
+
def save(self, f):
|
291 |
+
"""Stores dictionary into a text file"""
|
292 |
+
ex_keys, ex_vals = self._get_meta()
|
293 |
+
self._save(
|
294 |
+
f,
|
295 |
+
zip(
|
296 |
+
ex_keys + self.symbols[self.nspecial :],
|
297 |
+
ex_vals + self.count[self.nspecial :],
|
298 |
+
),
|
299 |
+
)
|
300 |
+
|
301 |
+
def dummy_sentence(self, length):
|
302 |
+
t = torch.Tensor(length).uniform_(self.nspecial + 1, len(self)).long()
|
303 |
+
t[-1] = self.eos()
|
304 |
+
return t
|
305 |
+
|
306 |
+
def encode_line(
|
307 |
+
self,
|
308 |
+
line,
|
309 |
+
line_tokenizer=tokenize_line,
|
310 |
+
add_if_not_exist=True,
|
311 |
+
consumer=None,
|
312 |
+
append_eos=True,
|
313 |
+
reverse_order=False,
|
314 |
+
) -> torch.IntTensor:
|
315 |
+
words = line_tokenizer(line)
|
316 |
+
if reverse_order:
|
317 |
+
words = list(reversed(words))
|
318 |
+
nwords = len(words)
|
319 |
+
ids = torch.IntTensor(nwords + 1 if append_eos else nwords)
|
320 |
+
|
321 |
+
for i, word in enumerate(words):
|
322 |
+
if add_if_not_exist:
|
323 |
+
idx = self.add_symbol(word)
|
324 |
+
else:
|
325 |
+
idx = self.index(word)
|
326 |
+
if consumer is not None:
|
327 |
+
consumer(word, idx)
|
328 |
+
ids[i] = idx
|
329 |
+
if append_eos:
|
330 |
+
ids[nwords] = self.eos_index
|
331 |
+
return ids
|
332 |
+
|
333 |
+
@staticmethod
|
334 |
+
def _add_file_to_dictionary_single_worker(
|
335 |
+
filename,
|
336 |
+
tokenize,
|
337 |
+
eos_word,
|
338 |
+
start_offset,
|
339 |
+
end_offset,
|
340 |
+
):
|
341 |
+
counter = Counter()
|
342 |
+
with Chunker(filename, start_offset, end_offset) as line_iterator:
|
343 |
+
for line in line_iterator:
|
344 |
+
for word in tokenize(line):
|
345 |
+
counter.update([word])
|
346 |
+
counter.update([eos_word])
|
347 |
+
return counter
|
348 |
+
|
349 |
+
@staticmethod
|
350 |
+
def add_file_to_dictionary(filename, dict, tokenize, num_workers):
|
351 |
+
def merge_result(counter):
|
352 |
+
for w, c in sorted(counter.items()):
|
353 |
+
dict.add_symbol(w, c)
|
354 |
+
|
355 |
+
local_file = PathManager.get_local_path(filename)
|
356 |
+
offsets = find_offsets(local_file, num_workers)
|
357 |
+
if num_workers > 1:
|
358 |
+
chunks = zip(offsets, offsets[1:])
|
359 |
+
pool = Pool(processes=num_workers)
|
360 |
+
results = []
|
361 |
+
for (start_offset, end_offset) in chunks:
|
362 |
+
results.append(
|
363 |
+
pool.apply_async(
|
364 |
+
Dictionary._add_file_to_dictionary_single_worker,
|
365 |
+
(
|
366 |
+
local_file,
|
367 |
+
tokenize,
|
368 |
+
dict.eos_word,
|
369 |
+
start_offset,
|
370 |
+
end_offset,
|
371 |
+
),
|
372 |
+
)
|
373 |
+
)
|
374 |
+
pool.close()
|
375 |
+
pool.join()
|
376 |
+
for r in results:
|
377 |
+
merge_result(r.get())
|
378 |
+
else:
|
379 |
+
merge_result(
|
380 |
+
Dictionary._add_file_to_dictionary_single_worker(
|
381 |
+
local_file, tokenize, dict.eos_word, offsets[0], offsets[1]
|
382 |
+
)
|
383 |
+
)
|
384 |
+
|
385 |
+
|
386 |
+
class TruncatedDictionary(object):
|
387 |
+
def __init__(self, wrapped_dict, length):
|
388 |
+
self.__class__ = type(
|
389 |
+
wrapped_dict.__class__.__name__,
|
390 |
+
(self.__class__, wrapped_dict.__class__),
|
391 |
+
{},
|
392 |
+
)
|
393 |
+
self.__dict__ = wrapped_dict.__dict__
|
394 |
+
self.wrapped_dict = wrapped_dict
|
395 |
+
self.length = min(len(self.wrapped_dict), length)
|
396 |
+
|
397 |
+
def __len__(self):
|
398 |
+
return self.length
|
399 |
+
|
400 |
+
def __getitem__(self, i):
|
401 |
+
if i < self.length:
|
402 |
+
return self.wrapped_dict[i]
|
403 |
+
return self.wrapped_dict.unk()
|
fairseq/fairseq/data/fairseq_dataset.py
ADDED
@@ -0,0 +1,205 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import logging
|
7 |
+
import numpy as np
|
8 |
+
import torch.utils.data
|
9 |
+
from fairseq.data import data_utils
|
10 |
+
|
11 |
+
logger = logging.getLogger(__name__)
|
12 |
+
|
13 |
+
|
14 |
+
class EpochListening:
|
15 |
+
"""Mixin for receiving updates whenever the epoch increments."""
|
16 |
+
|
17 |
+
@property
|
18 |
+
def can_reuse_epoch_itr_across_epochs(self):
|
19 |
+
"""
|
20 |
+
Whether we can reuse the :class:`fairseq.data.EpochBatchIterator` for
|
21 |
+
this dataset across epochs.
|
22 |
+
|
23 |
+
This needs to return ``False`` if the sample sizes can change across
|
24 |
+
epochs, in which case we may need to regenerate batches at each epoch.
|
25 |
+
If your dataset relies in ``set_epoch`` then you should consider setting
|
26 |
+
this to ``False``.
|
27 |
+
"""
|
28 |
+
return True
|
29 |
+
|
30 |
+
def set_epoch(self, epoch):
|
31 |
+
"""Will receive the updated epoch number at the beginning of the epoch."""
|
32 |
+
pass
|
33 |
+
|
34 |
+
|
35 |
+
class FairseqDataset(torch.utils.data.Dataset, EpochListening):
|
36 |
+
"""A dataset that provides helpers for batching."""
|
37 |
+
|
38 |
+
def __getitem__(self, index):
|
39 |
+
raise NotImplementedError
|
40 |
+
|
41 |
+
def __len__(self):
|
42 |
+
raise NotImplementedError
|
43 |
+
|
44 |
+
def collater(self, samples):
|
45 |
+
"""Merge a list of samples to form a mini-batch.
|
46 |
+
|
47 |
+
Args:
|
48 |
+
samples (List[dict]): samples to collate
|
49 |
+
|
50 |
+
Returns:
|
51 |
+
dict: a mini-batch suitable for forwarding with a Model
|
52 |
+
"""
|
53 |
+
raise NotImplementedError
|
54 |
+
|
55 |
+
def num_tokens(self, index):
|
56 |
+
"""Return the number of tokens in a sample. This value is used to
|
57 |
+
enforce ``--max-tokens`` during batching."""
|
58 |
+
raise NotImplementedError
|
59 |
+
|
60 |
+
def num_tokens_vec(self, indices):
|
61 |
+
"""Return the number of tokens for a set of positions defined by indices.
|
62 |
+
This value is used to enforce ``--max-tokens`` during batching."""
|
63 |
+
raise NotImplementedError
|
64 |
+
|
65 |
+
def size(self, index):
|
66 |
+
"""Return an example's size as a float or tuple. This value is used when
|
67 |
+
filtering a dataset with ``--max-positions``."""
|
68 |
+
raise NotImplementedError
|
69 |
+
|
70 |
+
def ordered_indices(self):
|
71 |
+
"""Return an ordered list of indices. Batches will be constructed based
|
72 |
+
on this order."""
|
73 |
+
return np.arange(len(self), dtype=np.int64)
|
74 |
+
|
75 |
+
@property
|
76 |
+
def supports_prefetch(self):
|
77 |
+
"""Whether this dataset supports prefetching."""
|
78 |
+
return False
|
79 |
+
|
80 |
+
def attr(self, attr: str, index: int):
|
81 |
+
return getattr(self, attr, None)
|
82 |
+
|
83 |
+
def prefetch(self, indices):
|
84 |
+
"""Prefetch the data required for this epoch."""
|
85 |
+
raise NotImplementedError
|
86 |
+
|
87 |
+
def get_batch_shapes(self):
|
88 |
+
"""
|
89 |
+
Return a list of valid batch shapes, for example::
|
90 |
+
|
91 |
+
[(8, 512), (16, 256), (32, 128)]
|
92 |
+
|
93 |
+
The first dimension of each tuple is the batch size and can be ``None``
|
94 |
+
to automatically infer the max batch size based on ``--max-tokens``.
|
95 |
+
The second dimension of each tuple is the max supported length as given
|
96 |
+
by :func:`fairseq.data.FairseqDataset.num_tokens`.
|
97 |
+
|
98 |
+
This will be used by :func:`fairseq.data.FairseqDataset.batch_by_size`
|
99 |
+
to restrict batch shapes. This is useful on TPUs to avoid too many
|
100 |
+
dynamic shapes (and recompilations).
|
101 |
+
"""
|
102 |
+
return None
|
103 |
+
|
104 |
+
def batch_by_size(
|
105 |
+
self,
|
106 |
+
indices,
|
107 |
+
max_tokens=None,
|
108 |
+
max_sentences=None,
|
109 |
+
required_batch_size_multiple=1,
|
110 |
+
):
|
111 |
+
"""
|
112 |
+
Given an ordered set of indices, return batches according to
|
113 |
+
*max_tokens*, *max_sentences* and *required_batch_size_multiple*.
|
114 |
+
"""
|
115 |
+
from fairseq.data import data_utils
|
116 |
+
|
117 |
+
fixed_shapes = self.get_batch_shapes()
|
118 |
+
if fixed_shapes is not None:
|
119 |
+
|
120 |
+
def adjust_bsz(bsz, num_tokens):
|
121 |
+
if bsz is None:
|
122 |
+
assert max_tokens is not None, "Must specify --max-tokens"
|
123 |
+
bsz = max_tokens // num_tokens
|
124 |
+
if max_sentences is not None:
|
125 |
+
bsz = min(bsz, max_sentences)
|
126 |
+
elif (
|
127 |
+
bsz >= required_batch_size_multiple
|
128 |
+
and bsz % required_batch_size_multiple != 0
|
129 |
+
):
|
130 |
+
bsz -= bsz % required_batch_size_multiple
|
131 |
+
return bsz
|
132 |
+
|
133 |
+
fixed_shapes = np.array(
|
134 |
+
[
|
135 |
+
[adjust_bsz(bsz, num_tokens), num_tokens]
|
136 |
+
for (bsz, num_tokens) in fixed_shapes
|
137 |
+
]
|
138 |
+
)
|
139 |
+
|
140 |
+
try:
|
141 |
+
num_tokens_vec = self.num_tokens_vec(indices).astype("int64")
|
142 |
+
except NotImplementedError:
|
143 |
+
num_tokens_vec = None
|
144 |
+
|
145 |
+
return data_utils.batch_by_size(
|
146 |
+
indices,
|
147 |
+
num_tokens_fn=self.num_tokens,
|
148 |
+
num_tokens_vec=num_tokens_vec,
|
149 |
+
max_tokens=max_tokens,
|
150 |
+
max_sentences=max_sentences,
|
151 |
+
required_batch_size_multiple=required_batch_size_multiple,
|
152 |
+
fixed_shapes=fixed_shapes,
|
153 |
+
)
|
154 |
+
|
155 |
+
def filter_indices_by_size(self, indices, max_sizes):
|
156 |
+
"""
|
157 |
+
Filter a list of sample indices. Remove those that are longer than
|
158 |
+
specified in *max_sizes*.
|
159 |
+
|
160 |
+
WARNING: don't update, override method in child classes
|
161 |
+
|
162 |
+
Args:
|
163 |
+
indices (np.array): original array of sample indices
|
164 |
+
max_sizes (int or list[int] or tuple[int]): max sample size,
|
165 |
+
can be defined separately for src and tgt (then list or tuple)
|
166 |
+
|
167 |
+
Returns:
|
168 |
+
np.array: filtered sample array
|
169 |
+
list: list of removed indices
|
170 |
+
"""
|
171 |
+
if isinstance(max_sizes, float) or isinstance(max_sizes, int):
|
172 |
+
if hasattr(self, "sizes") and isinstance(self.sizes, np.ndarray):
|
173 |
+
ignored = indices[self.sizes[indices] > max_sizes].tolist()
|
174 |
+
indices = indices[self.sizes[indices] <= max_sizes]
|
175 |
+
elif (
|
176 |
+
hasattr(self, "sizes")
|
177 |
+
and isinstance(self.sizes, list)
|
178 |
+
and len(self.sizes) == 1
|
179 |
+
):
|
180 |
+
ignored = indices[self.sizes[0][indices] > max_sizes].tolist()
|
181 |
+
indices = indices[self.sizes[0][indices] <= max_sizes]
|
182 |
+
else:
|
183 |
+
indices, ignored = data_utils._filter_by_size_dynamic(
|
184 |
+
indices, self.size, max_sizes
|
185 |
+
)
|
186 |
+
else:
|
187 |
+
indices, ignored = data_utils._filter_by_size_dynamic(
|
188 |
+
indices, self.size, max_sizes
|
189 |
+
)
|
190 |
+
return indices, ignored
|
191 |
+
|
192 |
+
@property
|
193 |
+
def supports_fetch_outside_dataloader(self):
|
194 |
+
"""Whether this dataset supports fetching outside the workers of the dataloader."""
|
195 |
+
return True
|
196 |
+
|
197 |
+
|
198 |
+
class FairseqIterableDataset(torch.utils.data.IterableDataset, EpochListening):
|
199 |
+
"""
|
200 |
+
For datasets that need to be read sequentially, usually because the data is
|
201 |
+
being streamed or otherwise can't be manipulated on a single machine.
|
202 |
+
"""
|
203 |
+
|
204 |
+
def __iter__(self):
|
205 |
+
raise NotImplementedError
|
fairseq/fairseq/data/fasta_dataset.py
ADDED
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import os
|
7 |
+
import subprocess
|
8 |
+
import threading
|
9 |
+
from pathlib import Path
|
10 |
+
|
11 |
+
import numpy as np
|
12 |
+
import torch
|
13 |
+
|
14 |
+
|
15 |
+
def fasta_file_path(prefix_path):
|
16 |
+
return prefix_path + ".fasta"
|
17 |
+
|
18 |
+
|
19 |
+
class FastaDataset(torch.utils.data.Dataset):
|
20 |
+
"""
|
21 |
+
For loading protein sequence datasets in the common FASTA data format
|
22 |
+
"""
|
23 |
+
|
24 |
+
def __init__(self, path: str, cache_indices=False):
|
25 |
+
self.fn = fasta_file_path(path)
|
26 |
+
self.threadlocal = threading.local()
|
27 |
+
self.cache = Path(f"{path}.fasta.idx.npy")
|
28 |
+
if cache_indices:
|
29 |
+
if self.cache.exists():
|
30 |
+
self.offsets, self.sizes = np.load(self.cache)
|
31 |
+
else:
|
32 |
+
self.offsets, self.sizes = self._build_index(path)
|
33 |
+
np.save(self.cache, np.stack([self.offsets, self.sizes]))
|
34 |
+
else:
|
35 |
+
self.offsets, self.sizes = self._build_index(path)
|
36 |
+
|
37 |
+
def _get_file(self):
|
38 |
+
if not hasattr(self.threadlocal, "f"):
|
39 |
+
self.threadlocal.f = open(self.fn, "r")
|
40 |
+
return self.threadlocal.f
|
41 |
+
|
42 |
+
def __getitem__(self, idx):
|
43 |
+
f = self._get_file()
|
44 |
+
f.seek(self.offsets[idx])
|
45 |
+
desc = f.readline().strip()
|
46 |
+
line = f.readline()
|
47 |
+
seq = ""
|
48 |
+
while line != "" and line[0] != ">":
|
49 |
+
seq += line.strip()
|
50 |
+
line = f.readline()
|
51 |
+
return desc, seq
|
52 |
+
|
53 |
+
def __len__(self):
|
54 |
+
return self.offsets.size
|
55 |
+
|
56 |
+
def _build_index(self, path: str):
|
57 |
+
# Use grep and awk to get 100M/s on local SSD.
|
58 |
+
# Should process your enormous 100G fasta in ~10 min single core...
|
59 |
+
path = fasta_file_path(path)
|
60 |
+
bytes_offsets = subprocess.check_output(
|
61 |
+
f"cat {path} | tqdm --bytes --total $(wc -c < {path})"
|
62 |
+
"| grep --byte-offset '^>' -o | cut -d: -f1",
|
63 |
+
shell=True,
|
64 |
+
)
|
65 |
+
fasta_lengths = subprocess.check_output(
|
66 |
+
f"cat {path} | tqdm --bytes --total $(wc -c < {path})"
|
67 |
+
"| awk '/^>/ {print \"\";next;} { printf(\"%s\",$0);}' | tail -n+2 | awk '{print length($1)}'",
|
68 |
+
shell=True,
|
69 |
+
)
|
70 |
+
bytes_np = np.fromstring(bytes_offsets, dtype=np.int64, sep=" ")
|
71 |
+
sizes_np = np.fromstring(fasta_lengths, dtype=np.int64, sep=" ")
|
72 |
+
return bytes_np, sizes_np
|
73 |
+
|
74 |
+
def __setstate__(self, state):
|
75 |
+
self.__dict__ = state
|
76 |
+
self.threadlocal = threading.local()
|
77 |
+
|
78 |
+
def __getstate__(self):
|
79 |
+
d = {}
|
80 |
+
for i, v in self.__dict__.items():
|
81 |
+
if i != "threadlocal":
|
82 |
+
d[i] = v
|
83 |
+
return d
|
84 |
+
|
85 |
+
def __del__(self):
|
86 |
+
if hasattr(self.threadlocal, "f"):
|
87 |
+
self.threadlocal.f.close()
|
88 |
+
del self.threadlocal.f
|
89 |
+
|
90 |
+
@staticmethod
|
91 |
+
def exists(path):
|
92 |
+
return os.path.exists(fasta_file_path(path))
|
93 |
+
|
94 |
+
|
95 |
+
class EncodedFastaDataset(FastaDataset):
|
96 |
+
"""
|
97 |
+
The FastaDataset returns raw sequences - this allows us to return
|
98 |
+
indices with a dictionary instead.
|
99 |
+
"""
|
100 |
+
|
101 |
+
def __init__(self, path, dictionary):
|
102 |
+
super().__init__(path, cache_indices=True)
|
103 |
+
self.dictionary = dictionary
|
104 |
+
|
105 |
+
def __getitem__(self, idx):
|
106 |
+
desc, seq = super().__getitem__(idx)
|
107 |
+
return self.dictionary.encode_line(seq, line_tokenizer=list).long()
|
fairseq/fairseq/data/id_dataset.py
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import torch
|
7 |
+
|
8 |
+
from . import FairseqDataset
|
9 |
+
|
10 |
+
|
11 |
+
class IdDataset(FairseqDataset):
|
12 |
+
def __getitem__(self, index):
|
13 |
+
return index
|
14 |
+
|
15 |
+
def __len__(self):
|
16 |
+
return 0
|
17 |
+
|
18 |
+
def collater(self, samples):
|
19 |
+
return torch.tensor(samples)
|