Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- fairseq/examples/textless_nlp/pgslm/sample/sample.py +612 -0
- fairseq/examples/textless_nlp/pgslm/scripts/join_units_manifest.py +48 -0
- fairseq/examples/translation/prepare-iwslt14.sh +115 -0
- fairseq/examples/translation/prepare-wmt14en2fr.sh +136 -0
- fairseq/examples/translation_moe/README.md +89 -0
- fairseq/examples/translation_moe/score.py +197 -0
- fairseq/examples/translation_moe/translation_moe_src/__init__.py +6 -0
- fairseq/examples/translation_moe/translation_moe_src/logsumexp_moe.py +26 -0
- fairseq/examples/translation_moe/translation_moe_src/mean_pool_gating_network.py +50 -0
- fairseq/examples/translation_moe/translation_moe_src/translation_moe.py +259 -0
- fairseq/examples/truncated_bptt/README.md +70 -0
- fairseq/examples/truncated_bptt/__init__.py +6 -0
- fairseq/examples/truncated_bptt/transformer_xl_model.py +143 -0
- fairseq/examples/truncated_bptt/truncated_bptt_lm_task.py +285 -0
- fairseq/examples/unsupervised_quality_estimation/aggregate_scores.py +41 -0
- fairseq/examples/unsupervised_quality_estimation/meteor.py +109 -0
- fairseq/examples/unsupervised_quality_estimation/repeat_lines.py +28 -0
- fairseq/examples/wav2vec/__init__.py +0 -0
- fairseq/examples/wav2vec/config/finetuning/base_10m.yaml +63 -0
- fairseq/examples/wav2vec/config/finetuning/base_1h.yaml +63 -0
- fairseq/examples/wav2vec/config/finetuning/run_config/slurm_1.yaml +26 -0
- fairseq/examples/wav2vec/config/finetuning/run_config/slurm_16.yaml +27 -0
- fairseq/examples/wav2vec/config/finetuning/run_config/slurm_1_aws.yaml +37 -0
- fairseq/examples/wav2vec/config/finetuning/run_config/slurm_1_old.yaml +27 -0
- fairseq/examples/wav2vec/config/finetuning/run_config/slurm_2.yaml +27 -0
- fairseq/examples/wav2vec/config/finetuning/run_config/slurm_2_aws.yaml +37 -0
- fairseq/examples/wav2vec/config/finetuning/run_config/slurm_2g.yaml +26 -0
- fairseq/examples/wav2vec/config/finetuning/run_config/slurm_3.yaml +27 -0
- fairseq/examples/wav2vec/config/finetuning/run_config/slurm_4g.yaml +26 -0
- fairseq/examples/wav2vec/config/finetuning/run_config/slurm_4g_aws.yaml +37 -0
- fairseq/examples/wav2vec/config/finetuning/run_config/slurm_8.yaml +26 -0
- fairseq/examples/wav2vec/config/finetuning/vox_10h_2_aws.yaml +81 -0
- fairseq/examples/wav2vec/config/finetuning/vox_10h_aws.yaml +104 -0
- fairseq/examples/wav2vec/config/finetuning/vox_10m_2.yaml +114 -0
- fairseq/examples/wav2vec/config/finetuning/vox_10m_2_aws.yaml +114 -0
- fairseq/examples/wav2vec/config/finetuning/vox_10m_3.yaml +105 -0
- fairseq/examples/wav2vec/config/finetuning/vox_1h.yaml +63 -0
- fairseq/examples/wav2vec/config/finetuning/vox_1h_2.yaml +104 -0
- fairseq/examples/wav2vec/config/finetuning/vox_1h_2_aws.yaml +114 -0
- fairseq/examples/wav2vec/config/finetuning/vox_1h_aws.yaml +80 -0
- fairseq/examples/wav2vec/config/finetuning/vox_960h.yaml +57 -0
- fairseq/examples/wav2vec/config/finetuning/vox_960h_2.yaml +105 -0
- fairseq/examples/wav2vec/config/finetuning/vox_960h_2_aws.yaml +82 -0
- fairseq/examples/wav2vec/config/finetuning/vox_960h_3.yaml +101 -0
- fairseq/examples/wav2vec/config/pretraining/wav2vec2_base_librispeech.yaml +57 -0
- fairseq/examples/wav2vec/config/pretraining/wav2vec2_conformer_base_librispeech.yaml +60 -0
- fairseq/examples/wav2vec/config/pretraining/wav2vec2_conformer_large_librivox.yaml +72 -0
- fairseq/examples/wav2vec/config/pretraining/wav2vec2_large_librivox.yaml +70 -0
- fairseq/examples/wav2vec/config/pretraining/wav2vec2_large_librivox_tpu-pod.yaml +72 -0
- fairseq/examples/wav2vec/config/pretraining/wav2vec2_large_librivox_tpu.yaml +77 -0
fairseq/examples/textless_nlp/pgslm/sample/sample.py
ADDED
@@ -0,0 +1,612 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import os
|
7 |
+
import torch.multiprocessing as mp
|
8 |
+
import numpy as np
|
9 |
+
import json
|
10 |
+
|
11 |
+
import torch
|
12 |
+
from torch.distributions.categorical import Categorical
|
13 |
+
|
14 |
+
from fairseq import checkpoint_utils, options, utils
|
15 |
+
from fairseq.data.codedataset import CodeDataset, ExpressiveCodeDataConfig
|
16 |
+
from fairseq.dataclass.utils import convert_namespace_to_omegaconf
|
17 |
+
from torch.utils.data import DataLoader, DistributedSampler
|
18 |
+
from fairseq.utils import move_to_cuda
|
19 |
+
|
20 |
+
import tqdm
|
21 |
+
import random
|
22 |
+
import pathlib
|
23 |
+
|
24 |
+
import sys, pathlib
|
25 |
+
|
26 |
+
sys.path.append(str(pathlib.Path(__file__).parent.parent))
|
27 |
+
from inference_dataset import InferenceDataset, explode_batch
|
28 |
+
from naive_decoder import Naive_F0_Decoder
|
29 |
+
from truncated_laplace import truncated_laplace
|
30 |
+
|
31 |
+
CODETYPE_TO_FRAMETIME = {"cpc_km100": 0.01, "hubert": 0.02} # 10ms # 20ms
|
32 |
+
|
33 |
+
|
34 |
+
class TemperatureDecoder:
|
35 |
+
def __init__(self, Ts, discrete_dur=False, discrete_f0=False):
|
36 |
+
self.T_token, self.T_dur, self.T_f0 = Ts
|
37 |
+
self.discrete_dur = discrete_dur
|
38 |
+
self.discrete_f0 = discrete_f0
|
39 |
+
|
40 |
+
def __call__(self, output):
|
41 |
+
def sample_multinomial(key, T):
|
42 |
+
logits = output[key][:, -1, :].float()
|
43 |
+
return Categorical(logits=logits / T).sample().unsqueeze(-1)
|
44 |
+
|
45 |
+
def sample_laplace(key, T, truncate_at_zero):
|
46 |
+
mean = output[key][:, -1, :].float()
|
47 |
+
return truncated_laplace(mean=mean, T=T, truncate_by_zero=truncate_at_zero)
|
48 |
+
|
49 |
+
if self.T_token > 0:
|
50 |
+
new_tokens = sample_multinomial("token", self.T_token)
|
51 |
+
else:
|
52 |
+
new_tokens = output["token"][:, -1, :].argmax(dim=-1, keepdim=True)
|
53 |
+
|
54 |
+
if not self.discrete_dur and self.T_dur == 0:
|
55 |
+
new_durations = output["duration"][:, -1].round().int()
|
56 |
+
elif not self.discrete_dur and self.T_dur > 0:
|
57 |
+
new_durations = (
|
58 |
+
sample_laplace("duration", self.T_dur, truncate_at_zero=True)
|
59 |
+
.round()
|
60 |
+
.int()
|
61 |
+
)
|
62 |
+
elif self.discrete_dur and self.T_dur > 0:
|
63 |
+
new_durations = sample_multinomial("duration", self.T_dur)
|
64 |
+
elif self.discrete_dur and self.T_dur == 0:
|
65 |
+
new_durations = output["duration"][:, -1, :].argmax(dim=-1, keepdim=True)
|
66 |
+
else:
|
67 |
+
assert False
|
68 |
+
|
69 |
+
if not self.discrete_f0 and self.T_f0 == 0:
|
70 |
+
new_f0 = output["f0"][:, -1]
|
71 |
+
elif not self.discrete_f0 and self.T_f0 > 0:
|
72 |
+
new_f0 = sample_laplace("f0", self.T_f0, truncate_at_zero=False)
|
73 |
+
elif self.discrete_f0 and self.T_f0 > 0:
|
74 |
+
new_f0 = sample_multinomial("f0", self.T_f0)
|
75 |
+
elif self.discrete_f0 and self.T_f0 == 0:
|
76 |
+
new_f0 = output["f0"][:, -1, :].argmax(dim=-1, keepdim=True)
|
77 |
+
else:
|
78 |
+
assert False
|
79 |
+
|
80 |
+
return new_tokens, new_durations, new_f0
|
81 |
+
|
82 |
+
|
83 |
+
class FilterNamesDataset:
|
84 |
+
def __init__(self, dataset, fnames_path):
|
85 |
+
self.dataset = dataset
|
86 |
+
|
87 |
+
with open(fnames_path, "r") as fin:
|
88 |
+
fnames = set((eval(line)["audio"] for line in fin))
|
89 |
+
print(f"# will retrict the dataset for {len(fnames)} files")
|
90 |
+
|
91 |
+
self.indexes = []
|
92 |
+
|
93 |
+
for i, datapoint in enumerate(dataset):
|
94 |
+
if datapoint["filename"] in fnames:
|
95 |
+
self.indexes.append(i)
|
96 |
+
assert len(self.indexes) == len(fnames), f"{len(self.indexes)} {len(fnames)}"
|
97 |
+
|
98 |
+
self.collater = self.dataset.collater
|
99 |
+
self.discrete_dur = self.dataset.discrete_dur
|
100 |
+
self.discrete_f0 = self.dataset.discrete_f0
|
101 |
+
|
102 |
+
def __len__(self):
|
103 |
+
return len(self.indexes)
|
104 |
+
|
105 |
+
def __getitem__(self, k):
|
106 |
+
k = self.indexes[k]
|
107 |
+
return self.dataset[k]
|
108 |
+
|
109 |
+
def size(self, k):
|
110 |
+
k = self.indexes[k]
|
111 |
+
return self.dataset.size(k)
|
112 |
+
|
113 |
+
|
114 |
+
@torch.no_grad()
|
115 |
+
def do_sampling(
|
116 |
+
model,
|
117 |
+
batch,
|
118 |
+
eos_token,
|
119 |
+
decoder,
|
120 |
+
autoregressive_steps=100,
|
121 |
+
teacher_force_tokens=False,
|
122 |
+
teacher_force_duration=False,
|
123 |
+
teacher_force_f0=False,
|
124 |
+
match_duration=False,
|
125 |
+
):
|
126 |
+
def autoregressive_step_(output, autoregressive_steps):
|
127 |
+
new_tokens, new_durations, new_f0 = decoder(output)
|
128 |
+
|
129 |
+
n = output["token"].size(1) if output["token"].ndim == 3 else 1
|
130 |
+
|
131 |
+
if teacher_force_tokens:
|
132 |
+
new_tokens = batch["target"][:, n - 1].unsqueeze(-1)
|
133 |
+
if teacher_force_duration:
|
134 |
+
new_durations = batch["dur_target"][:, n - 1].unsqueeze(-1)
|
135 |
+
if teacher_force_f0:
|
136 |
+
new_f0 = batch["f0_target"][:, n - 1].unsqueeze(-1)
|
137 |
+
|
138 |
+
batch["net_input"]["src_tokens"] = torch.cat(
|
139 |
+
[batch["net_input"]["src_tokens"], new_tokens], dim=1
|
140 |
+
)
|
141 |
+
batch["net_input"]["dur_src"] = torch.cat(
|
142 |
+
[batch["net_input"]["dur_src"], new_durations], dim=1
|
143 |
+
)
|
144 |
+
batch["net_input"]["f0_src"] = torch.cat(
|
145 |
+
[batch["net_input"]["f0_src"], new_f0], dim=1
|
146 |
+
)
|
147 |
+
|
148 |
+
outputs = []
|
149 |
+
|
150 |
+
if teacher_force_tokens or teacher_force_duration or teacher_force_f0:
|
151 |
+
max_time = batch["target"].size(1)
|
152 |
+
prefix_time = batch["net_input"]["src_tokens"].size(1)
|
153 |
+
|
154 |
+
autoregressive_steps = max_time - prefix_time + 1 # should be 0
|
155 |
+
|
156 |
+
for _ in range(autoregressive_steps):
|
157 |
+
output = model(**batch["net_input"])
|
158 |
+
|
159 |
+
last_steps = (
|
160 |
+
output["token"][:, -1, ...],
|
161 |
+
output["duration"][:, -1, ...],
|
162 |
+
output["f0"][:, -1, ...],
|
163 |
+
)
|
164 |
+
outputs.append(last_steps)
|
165 |
+
|
166 |
+
autoregressive_step_(output, autoregressive_steps)
|
167 |
+
tokens, duration, f0 = (
|
168 |
+
batch["net_input"]["src_tokens"],
|
169 |
+
batch["net_input"]["dur_src"],
|
170 |
+
batch["net_input"]["f0_src"],
|
171 |
+
)
|
172 |
+
|
173 |
+
if (
|
174 |
+
match_duration
|
175 |
+
and (batch["dur_target"].sum(dim=-1) < duration.sum(dim=-1)).all()
|
176 |
+
):
|
177 |
+
break
|
178 |
+
|
179 |
+
return tokens, duration, f0, outputs
|
180 |
+
|
181 |
+
|
182 |
+
def unroll_duration(token_stream, duration_stream):
|
183 |
+
assert len(token_stream) == len(
|
184 |
+
duration_stream
|
185 |
+
), f"{len(token_stream)} != {len(duration_stream)}"
|
186 |
+
non_positive_durations = sum(d <= 0 for d in duration_stream)
|
187 |
+
if non_positive_durations > 0:
|
188 |
+
print(
|
189 |
+
f"# {non_positive_durations} durations are non-positive, they will be capped to 1"
|
190 |
+
)
|
191 |
+
|
192 |
+
result = []
|
193 |
+
|
194 |
+
duration_stream_rounded_capped = [max(1, int(round(x))) for x in duration_stream]
|
195 |
+
for t, d in zip(token_stream, duration_stream_rounded_capped):
|
196 |
+
result.extend([t] * d)
|
197 |
+
|
198 |
+
return result
|
199 |
+
|
200 |
+
|
201 |
+
def realign_shifted_streams(tokens, durations, F0s, shifts):
|
202 |
+
"""
|
203 |
+
Durations are shifted by 1, F0 by 2
|
204 |
+
>>> tokens = ["<s>", "t1", "t2", "t3", "</s>", "x", "x"]
|
205 |
+
>>> durations = ["<0>", "<0>", "d1", "d2", "d3", "<0>", "x"]
|
206 |
+
>>> F0s = ["<0>", "<0>", "<0>", "f1", "f2", "f3", "<0>"]
|
207 |
+
>>> shifts = [1,2]
|
208 |
+
>>> realign_shifted_streams(tokens, durations, F0s, shifts)
|
209 |
+
(['<s>', 't1', 't2', 't3', '</s>'], ['<0>', 'd1', 'd2', 'd3', '<0>'], ['<0>', 'f1', 'f2', 'f3', '<0>'])
|
210 |
+
"""
|
211 |
+
max_shift = max(shifts)
|
212 |
+
if max_shift > 0:
|
213 |
+
shift_durations, shift_F0s = shifts
|
214 |
+
|
215 |
+
tokens = tokens[:-max_shift]
|
216 |
+
durations = durations[shift_durations:]
|
217 |
+
if shift_durations < max_shift:
|
218 |
+
durations = durations[: -(max_shift - shift_durations)]
|
219 |
+
|
220 |
+
if F0s is not None:
|
221 |
+
F0s = F0s[shift_F0s:]
|
222 |
+
if shift_F0s < max_shift:
|
223 |
+
F0s = F0s[: -(max_shift - shift_F0s)]
|
224 |
+
|
225 |
+
assert len(tokens) == len(durations), f"{len(tokens)} =! {len(durations)}"
|
226 |
+
if F0s is not None:
|
227 |
+
assert len(tokens) == len(F0s), f"{len(tokens)} =! {len(F0s)}"
|
228 |
+
|
229 |
+
return tokens, durations, F0s
|
230 |
+
|
231 |
+
|
232 |
+
def maybe_cut_eos(produced_tokens, produced_duration, produced_f0, eos_idx):
|
233 |
+
if eos_idx in produced_tokens:
|
234 |
+
eos_index = produced_tokens.index(eos_idx)
|
235 |
+
produced_tokens = produced_tokens[:eos_index]
|
236 |
+
produced_duration = produced_duration[:eos_index]
|
237 |
+
produced_f0 = produced_f0[:eos_index]
|
238 |
+
return produced_tokens, produced_duration, produced_f0
|
239 |
+
|
240 |
+
|
241 |
+
def maybe_filter_pad(produced_tokens, produced_duration, produced_f0, pad_idx):
|
242 |
+
if pad_idx not in produced_tokens:
|
243 |
+
return produced_tokens, produced_duration, produced_f0
|
244 |
+
|
245 |
+
assert len(produced_tokens) == len(produced_duration) == len(produced_f0)
|
246 |
+
|
247 |
+
print("<pad> is detected in the output!")
|
248 |
+
filtered_tokens, filtered_duration, filtered_f0 = [], [], []
|
249 |
+
|
250 |
+
for t, d, f in zip(produced_tokens, produced_duration, produced_f0):
|
251 |
+
if t != pad_idx:
|
252 |
+
filtered_tokens.append(t)
|
253 |
+
filtered_duration.append(d)
|
254 |
+
filtered_f0.append(f)
|
255 |
+
return filtered_tokens, filtered_duration, filtered_f0
|
256 |
+
|
257 |
+
|
258 |
+
def match_duration(produced_tokens, produced_duration, produced_f0, target_duration):
|
259 |
+
"""
|
260 |
+
>>> tokens = ['t'] * 4
|
261 |
+
>>> F0s = ['f0'] * 4
|
262 |
+
>>> produced_duration = [1, 10, 10, 10]
|
263 |
+
>>> match_duration(tokens, produced_duration, F0s, target_duration=100)
|
264 |
+
(['t', 't', 't', 't'], [1, 10, 10, 10], ['f0', 'f0', 'f0', 'f0'])
|
265 |
+
>>> match_duration(tokens, produced_duration, F0s, target_duration=5)
|
266 |
+
(['t', 't'], [1, 4], ['f0', 'f0'])
|
267 |
+
"""
|
268 |
+
if sum(produced_duration) <= target_duration:
|
269 |
+
return produced_tokens, produced_duration, produced_f0
|
270 |
+
|
271 |
+
running_duration = 0
|
272 |
+
filtered_duration = []
|
273 |
+
|
274 |
+
for next_tok_duration in produced_duration:
|
275 |
+
if running_duration + next_tok_duration < target_duration:
|
276 |
+
filtered_duration.append(next_tok_duration)
|
277 |
+
running_duration += next_tok_duration
|
278 |
+
else:
|
279 |
+
to_add = target_duration - running_duration
|
280 |
+
assert to_add <= next_tok_duration
|
281 |
+
filtered_duration.append(to_add)
|
282 |
+
break
|
283 |
+
|
284 |
+
produced_duration = filtered_duration
|
285 |
+
assert sum(produced_duration) == target_duration
|
286 |
+
|
287 |
+
n_tok = len(filtered_duration)
|
288 |
+
|
289 |
+
return produced_tokens[:n_tok], produced_duration, produced_f0[:n_tok]
|
290 |
+
|
291 |
+
|
292 |
+
def main(rank, world_size, args):
|
293 |
+
if world_size > 1:
|
294 |
+
torch.distributed.init_process_group(
|
295 |
+
backend="gloo", init_method="env://", world_size=world_size, rank=rank
|
296 |
+
)
|
297 |
+
torch.cuda.set_device(rank)
|
298 |
+
|
299 |
+
raw_args = args
|
300 |
+
args = convert_namespace_to_omegaconf(args)
|
301 |
+
if args.common.seed is not None:
|
302 |
+
random.seed(args.common.seed)
|
303 |
+
np.random.seed(args.common.seed)
|
304 |
+
utils.set_torch_seed(args.common.seed)
|
305 |
+
|
306 |
+
models, model_args, task = checkpoint_utils.load_model_ensemble_and_task(
|
307 |
+
[raw_args.path], arg_overrides={"data": args.task.data}
|
308 |
+
)
|
309 |
+
tgt_dict = task.target_dictionary
|
310 |
+
|
311 |
+
for model in models:
|
312 |
+
model.prepare_for_inference_(args)
|
313 |
+
model.cuda().eval()
|
314 |
+
if raw_args.fp16:
|
315 |
+
model = model.half()
|
316 |
+
model = models[0]
|
317 |
+
|
318 |
+
config = ExpressiveCodeDataConfig(args.task.data)
|
319 |
+
|
320 |
+
dataset = CodeDataset(
|
321 |
+
manifest=config.manifests[raw_args.subset],
|
322 |
+
dictionary=task.source_dictionary,
|
323 |
+
dur_dictionary=task.source_duration_dictionary,
|
324 |
+
f0_dictionary=task.source_f0_dictionary,
|
325 |
+
config=config,
|
326 |
+
discrete_dur=task.cfg.discrete_duration,
|
327 |
+
discrete_f0=task.cfg.discrete_f0,
|
328 |
+
log_f0=task.cfg.log_f0,
|
329 |
+
normalize_f0_mean=task.cfg.normalize_f0_mean,
|
330 |
+
normalize_f0_std=task.cfg.normalize_f0_std,
|
331 |
+
interpolate_f0=task.cfg.interpolate_f0,
|
332 |
+
shifts=task.cfg.stream_shifts,
|
333 |
+
return_filename=True,
|
334 |
+
strip_filename=False,
|
335 |
+
)
|
336 |
+
tgt_dict = task.target_dictionary
|
337 |
+
shifts = dataset.shifts.dur, dataset.shifts.f0
|
338 |
+
max_shift = max(shifts)
|
339 |
+
|
340 |
+
fname = raw_args.output
|
341 |
+
if world_size > 1:
|
342 |
+
fname += f"_{rank}"
|
343 |
+
output_file = open(fname, "w")
|
344 |
+
|
345 |
+
if raw_args.filter_names:
|
346 |
+
dataset = FilterNamesDataset(dataset, raw_args.filter_names)
|
347 |
+
|
348 |
+
dataset = InferenceDataset(dataset, raw_args.prefix_length, filter_short=True)
|
349 |
+
print(f"Dataset size {len(dataset)}")
|
350 |
+
sampler = (
|
351 |
+
None
|
352 |
+
if world_size == 1
|
353 |
+
else DistributedSampler(
|
354 |
+
dataset, num_replicas=world_size, rank=rank, shuffle=False
|
355 |
+
)
|
356 |
+
)
|
357 |
+
dataloader = DataLoader(
|
358 |
+
dataset,
|
359 |
+
batch_size=1,
|
360 |
+
shuffle=False,
|
361 |
+
collate_fn=dataset.collater,
|
362 |
+
sampler=sampler,
|
363 |
+
)
|
364 |
+
|
365 |
+
Ts = raw_args.T_token, raw_args.T_duration, raw_args.T_f0
|
366 |
+
decoder = TemperatureDecoder(
|
367 |
+
Ts, discrete_dur=task.cfg.discrete_duration, discrete_f0=task.cfg.discrete_f0
|
368 |
+
)
|
369 |
+
|
370 |
+
dataset_size = len(dataset)
|
371 |
+
|
372 |
+
f0_decoder = None
|
373 |
+
if raw_args.f0_discretization_bounds:
|
374 |
+
assert task.cfg.discrete_f0
|
375 |
+
f0_decoder = Naive_F0_Decoder(raw_args.f0_discretization_bounds).cuda()
|
376 |
+
|
377 |
+
pbar = (
|
378 |
+
tqdm.tqdm(
|
379 |
+
total=dataset_size
|
380 |
+
if raw_args.max_samples is None
|
381 |
+
else min(raw_args.max_samples, dataset_size)
|
382 |
+
)
|
383 |
+
if world_size == 1
|
384 |
+
else None
|
385 |
+
)
|
386 |
+
|
387 |
+
samples_produced = 0
|
388 |
+
|
389 |
+
for batch in dataloader:
|
390 |
+
if (
|
391 |
+
raw_args.max_samples is not None
|
392 |
+
and samples_produced >= raw_args.max_samples
|
393 |
+
):
|
394 |
+
break
|
395 |
+
|
396 |
+
prefix = batch["prefix"][0]
|
397 |
+
|
398 |
+
batch = explode_batch(batch, raw_args.batch_explosion_rate)
|
399 |
+
batch = move_to_cuda(batch)
|
400 |
+
|
401 |
+
if not raw_args.short_curcuit:
|
402 |
+
produced_tokens, produced_durations, produced_f0, _ = do_sampling(
|
403 |
+
models[0],
|
404 |
+
batch,
|
405 |
+
tgt_dict.eos(),
|
406 |
+
decoder,
|
407 |
+
autoregressive_steps=raw_args.max_length - prefix + max_shift,
|
408 |
+
teacher_force_tokens=raw_args.teacher_force_tokens,
|
409 |
+
match_duration=raw_args.match_duration,
|
410 |
+
teacher_force_duration=raw_args.teacher_force_duration,
|
411 |
+
teacher_force_f0=raw_args.teacher_force_f0,
|
412 |
+
)
|
413 |
+
|
414 |
+
# stip entries corresponding to <s>
|
415 |
+
produced_tokens = produced_tokens[:, 1:]
|
416 |
+
produced_durations = produced_durations[:, 1:]
|
417 |
+
produced_f0 = produced_f0[:, 1:]
|
418 |
+
|
419 |
+
else:
|
420 |
+
max_length = raw_args.max_length + max_shift
|
421 |
+
produced_tokens, produced_durations, produced_f0 = (
|
422 |
+
batch["target"][:, :max_length],
|
423 |
+
batch["dur_target"][:, :max_length],
|
424 |
+
batch["f0_target"][:, :max_length],
|
425 |
+
)
|
426 |
+
|
427 |
+
if f0_decoder is not None:
|
428 |
+
produced_f0 = f0_decoder(produced_f0)
|
429 |
+
|
430 |
+
produced_tokens, produced_durations, produced_f0 = (
|
431 |
+
produced_tokens.cpu().tolist(),
|
432 |
+
produced_durations.cpu().tolist(),
|
433 |
+
produced_f0.cpu().tolist(),
|
434 |
+
)
|
435 |
+
|
436 |
+
bsz = batch["target"].size(0)
|
437 |
+
assert bsz == raw_args.batch_explosion_rate
|
438 |
+
|
439 |
+
for i in range(bsz):
|
440 |
+
if (
|
441 |
+
raw_args.max_samples is not None
|
442 |
+
and samples_produced >= raw_args.max_samples
|
443 |
+
):
|
444 |
+
break
|
445 |
+
|
446 |
+
produced_tokens_i = produced_tokens[i]
|
447 |
+
produced_durations_i = produced_durations[i]
|
448 |
+
produced_f0_i = produced_f0[i]
|
449 |
+
|
450 |
+
(
|
451 |
+
produced_tokens_i,
|
452 |
+
produced_durations_i,
|
453 |
+
produced_f0_i,
|
454 |
+
) = realign_shifted_streams(
|
455 |
+
produced_tokens_i, produced_durations_i, produced_f0_i, shifts
|
456 |
+
)
|
457 |
+
|
458 |
+
produced_tokens_i, produced_durations_i, produced_f0_i = maybe_cut_eos(
|
459 |
+
produced_tokens_i, produced_durations_i, produced_f0_i, tgt_dict.eos()
|
460 |
+
)
|
461 |
+
|
462 |
+
produced_tokens_i, produced_durations_i, produced_f0_i = maybe_filter_pad(
|
463 |
+
produced_tokens_i, produced_durations_i, produced_f0_i, tgt_dict.pad()
|
464 |
+
)
|
465 |
+
|
466 |
+
if raw_args.match_duration:
|
467 |
+
# NB: here we cheat a bit and use that padding has duration 0
|
468 |
+
# so no need to re-align and remove padding
|
469 |
+
dur_target_i = batch["dur_target"][i, :].sum().item()
|
470 |
+
produced_tokens_i, produced_durations_i, produced_f0_i = match_duration(
|
471 |
+
produced_tokens_i, produced_durations_i, produced_f0_i, dur_target_i
|
472 |
+
)
|
473 |
+
|
474 |
+
if raw_args.cut_prompt:
|
475 |
+
produced_tokens_i, produced_durations_i, produced_f0_i = (
|
476 |
+
produced_tokens_i[prefix:],
|
477 |
+
produced_durations_i[prefix:],
|
478 |
+
produced_f0_i[prefix:],
|
479 |
+
)
|
480 |
+
|
481 |
+
prompt_fname = batch["filename"][0]
|
482 |
+
fname = str(pathlib.Path(prompt_fname).with_suffix("")) + f"__{i}.wav"
|
483 |
+
|
484 |
+
token_stream = unroll_duration(produced_tokens_i, produced_durations_i)
|
485 |
+
f0_stream = unroll_duration(produced_f0_i, produced_durations_i)
|
486 |
+
output_line = json.dumps(
|
487 |
+
{
|
488 |
+
"audio": fname,
|
489 |
+
"prompt": prompt_fname,
|
490 |
+
raw_args.code_type: " ".join(map(str, token_stream)),
|
491 |
+
"duration": round(
|
492 |
+
sum(produced_durations_i)
|
493 |
+
* CODETYPE_TO_FRAMETIME[raw_args.code_type],
|
494 |
+
3,
|
495 |
+
),
|
496 |
+
"raw_duration": produced_durations_i,
|
497 |
+
"raw_f0": produced_f0_i,
|
498 |
+
"f0": [round(f0, 3) for f0 in f0_stream],
|
499 |
+
}
|
500 |
+
)
|
501 |
+
print(output_line, file=output_file)
|
502 |
+
|
503 |
+
if pbar:
|
504 |
+
pbar.update(1)
|
505 |
+
samples_produced += 1
|
506 |
+
|
507 |
+
if raw_args.debug:
|
508 |
+
break
|
509 |
+
|
510 |
+
output_file.close()
|
511 |
+
|
512 |
+
if world_size > 1:
|
513 |
+
# important that everything is flushed before aggregating
|
514 |
+
torch.distributed.barrier()
|
515 |
+
|
516 |
+
if world_size > 1 and rank == 0:
|
517 |
+
with open(raw_args.output, "w") as fout:
|
518 |
+
for i in range(world_size):
|
519 |
+
f = raw_args.output + f"_{i}"
|
520 |
+
with open(f, "r") as fin:
|
521 |
+
fout.write(fin.read())
|
522 |
+
os.remove(f)
|
523 |
+
|
524 |
+
|
525 |
+
def cli_main():
|
526 |
+
parser = options.get_interactive_generation_parser()
|
527 |
+
parser.add_argument(
|
528 |
+
"--prefix-length",
|
529 |
+
type=int,
|
530 |
+
default=1,
|
531 |
+
help="Prompt prefix length (including <s>)",
|
532 |
+
)
|
533 |
+
parser.add_argument("--output", type=str, default=None, required=True)
|
534 |
+
parser.add_argument(
|
535 |
+
"--debug", action="store_true", help="Process only the first batch"
|
536 |
+
)
|
537 |
+
parser.add_argument(
|
538 |
+
"--ignore-durations",
|
539 |
+
action="store_true",
|
540 |
+
help="If set, the duration stream is ignored",
|
541 |
+
)
|
542 |
+
parser.add_argument(
|
543 |
+
"--max-length", type=int, default=200, help="Maximal produced length"
|
544 |
+
)
|
545 |
+
parser.add_argument(
|
546 |
+
"--code-type", choices=["cpc_km100", "hubert"], default="cpc_km100"
|
547 |
+
)
|
548 |
+
parser.add_argument("--max-samples", type=int, default=None)
|
549 |
+
parser.add_argument("--prompt-duration-scaler", type=float, default=1.0)
|
550 |
+
parser.add_argument("--teacher-force-tokens", action="store_true", default=False)
|
551 |
+
parser.add_argument("--teacher-force-duration", action="store_true", default=False)
|
552 |
+
parser.add_argument("--teacher-force-f0", action="store_true", default=False)
|
553 |
+
parser.add_argument("--filter-names", type=str, default=None)
|
554 |
+
parser.add_argument(
|
555 |
+
"--match-duration",
|
556 |
+
action="store_true",
|
557 |
+
help="Do not produce sequences longer that ground-truth",
|
558 |
+
)
|
559 |
+
parser.add_argument(
|
560 |
+
"--cut-prompt",
|
561 |
+
action="store_true",
|
562 |
+
help="Remove prompt from the produced audio",
|
563 |
+
)
|
564 |
+
parser.add_argument(
|
565 |
+
"--short-curcuit", action="store_true", help="Use 'target' as a sample"
|
566 |
+
)
|
567 |
+
parser.add_argument("--f0-discretization-bounds", type=str, default=None)
|
568 |
+
|
569 |
+
parser.add_argument("--batch-explosion-rate", type=int, default=1)
|
570 |
+
|
571 |
+
parser.add_argument("--T-token", type=float, default=1.0)
|
572 |
+
parser.add_argument("--T-duration", type=float, default=1.0)
|
573 |
+
parser.add_argument("--T-f0", type=float, default=1.0)
|
574 |
+
|
575 |
+
parser.add_argument(
|
576 |
+
"--subset", type=str, default="valid", choices=["test", "valid"]
|
577 |
+
)
|
578 |
+
|
579 |
+
args = options.parse_args_and_arch(parser)
|
580 |
+
|
581 |
+
assert (
|
582 |
+
args.prefix_length >= 1
|
583 |
+
), "Prefix length includes bos token <s>, hence the minimum is 1."
|
584 |
+
assert all(
|
585 |
+
t >= 0 for t in [args.T_token, args.T_f0, args.T_duration]
|
586 |
+
), "T must be non-negative!"
|
587 |
+
|
588 |
+
world_size = torch.cuda.device_count()
|
589 |
+
if world_size > 1:
|
590 |
+
import random
|
591 |
+
|
592 |
+
mp.set_start_method("spawn", force=True)
|
593 |
+
os.environ["MASTER_ADDR"] = "localhost"
|
594 |
+
os.environ["MASTER_PORT"] = str(random.randint(10_000, 50_000))
|
595 |
+
|
596 |
+
print(f"Using {world_size} devices, master port {os.environ['MASTER_PORT']}")
|
597 |
+
|
598 |
+
mp.spawn(
|
599 |
+
main,
|
600 |
+
nprocs=world_size,
|
601 |
+
args=(
|
602 |
+
world_size,
|
603 |
+
args,
|
604 |
+
),
|
605 |
+
join=True,
|
606 |
+
)
|
607 |
+
else:
|
608 |
+
main(rank=0, world_size=world_size, args=args)
|
609 |
+
|
610 |
+
|
611 |
+
if __name__ == "__main__":
|
612 |
+
cli_main()
|
fairseq/examples/textless_nlp/pgslm/scripts/join_units_manifest.py
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import json
|
7 |
+
import argparse
|
8 |
+
import pathlib
|
9 |
+
|
10 |
+
|
11 |
+
def main():
|
12 |
+
parser = argparse.ArgumentParser()
|
13 |
+
parser.add_argument("--manifest", required=True)
|
14 |
+
parser.add_argument("--units", required=True)
|
15 |
+
parser.add_argument("--output", required=True)
|
16 |
+
parser.add_argument("--sample_rate", type=int, default=16_000)
|
17 |
+
|
18 |
+
args = parser.parse_args()
|
19 |
+
|
20 |
+
with open(args.manifest, "r") as manifest, open(args.units, "r") as units, open(
|
21 |
+
args.output, "w"
|
22 |
+
) as outp:
|
23 |
+
root = manifest.readline().strip()
|
24 |
+
root = pathlib.Path(root)
|
25 |
+
|
26 |
+
for manifest_line, unit_line in zip(manifest.readlines(), units.readlines()):
|
27 |
+
path, frames = manifest_line.split()
|
28 |
+
duration = int(frames) / float(args.sample_rate)
|
29 |
+
fname = root / path
|
30 |
+
speaker = fname.parent.parent.name
|
31 |
+
|
32 |
+
units = unit_line.split("|")[1]
|
33 |
+
|
34 |
+
print(
|
35 |
+
json.dumps(
|
36 |
+
dict(
|
37 |
+
audio=str(root / path),
|
38 |
+
duration=duration,
|
39 |
+
hubert_km100=units.strip(),
|
40 |
+
speaker=speaker,
|
41 |
+
)
|
42 |
+
),
|
43 |
+
file=outp,
|
44 |
+
)
|
45 |
+
|
46 |
+
|
47 |
+
if __name__ == "__main__":
|
48 |
+
main()
|
fairseq/examples/translation/prepare-iwslt14.sh
ADDED
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env bash
|
2 |
+
#
|
3 |
+
# Adapted from https://github.com/facebookresearch/MIXER/blob/master/prepareData.sh
|
4 |
+
|
5 |
+
echo 'Cloning Moses github repository (for tokenization scripts)...'
|
6 |
+
git clone https://github.com/moses-smt/mosesdecoder.git
|
7 |
+
|
8 |
+
echo 'Cloning Subword NMT repository (for BPE pre-processing)...'
|
9 |
+
git clone https://github.com/rsennrich/subword-nmt.git
|
10 |
+
|
11 |
+
SCRIPTS=mosesdecoder/scripts
|
12 |
+
TOKENIZER=$SCRIPTS/tokenizer/tokenizer.perl
|
13 |
+
LC=$SCRIPTS/tokenizer/lowercase.perl
|
14 |
+
CLEAN=$SCRIPTS/training/clean-corpus-n.perl
|
15 |
+
BPEROOT=subword-nmt/subword_nmt
|
16 |
+
BPE_TOKENS=10000
|
17 |
+
|
18 |
+
URL="http://dl.fbaipublicfiles.com/fairseq/data/iwslt14/de-en.tgz"
|
19 |
+
GZ=de-en.tgz
|
20 |
+
|
21 |
+
if [ ! -d "$SCRIPTS" ]; then
|
22 |
+
echo "Please set SCRIPTS variable correctly to point to Moses scripts."
|
23 |
+
exit
|
24 |
+
fi
|
25 |
+
|
26 |
+
src=de
|
27 |
+
tgt=en
|
28 |
+
lang=de-en
|
29 |
+
prep=iwslt14.tokenized.de-en
|
30 |
+
tmp=$prep/tmp
|
31 |
+
orig=orig
|
32 |
+
|
33 |
+
mkdir -p $orig $tmp $prep
|
34 |
+
|
35 |
+
echo "Downloading data from ${URL}..."
|
36 |
+
cd $orig
|
37 |
+
wget "$URL"
|
38 |
+
|
39 |
+
if [ -f $GZ ]; then
|
40 |
+
echo "Data successfully downloaded."
|
41 |
+
else
|
42 |
+
echo "Data not successfully downloaded."
|
43 |
+
exit
|
44 |
+
fi
|
45 |
+
|
46 |
+
tar zxvf $GZ
|
47 |
+
cd ..
|
48 |
+
|
49 |
+
echo "pre-processing train data..."
|
50 |
+
for l in $src $tgt; do
|
51 |
+
f=train.tags.$lang.$l
|
52 |
+
tok=train.tags.$lang.tok.$l
|
53 |
+
|
54 |
+
cat $orig/$lang/$f | \
|
55 |
+
grep -v '<url>' | \
|
56 |
+
grep -v '<talkid>' | \
|
57 |
+
grep -v '<keywords>' | \
|
58 |
+
sed -e 's/<title>//g' | \
|
59 |
+
sed -e 's/<\/title>//g' | \
|
60 |
+
sed -e 's/<description>//g' | \
|
61 |
+
sed -e 's/<\/description>//g' | \
|
62 |
+
perl $TOKENIZER -threads 8 -l $l > $tmp/$tok
|
63 |
+
echo ""
|
64 |
+
done
|
65 |
+
perl $CLEAN -ratio 1.5 $tmp/train.tags.$lang.tok $src $tgt $tmp/train.tags.$lang.clean 1 175
|
66 |
+
for l in $src $tgt; do
|
67 |
+
perl $LC < $tmp/train.tags.$lang.clean.$l > $tmp/train.tags.$lang.$l
|
68 |
+
done
|
69 |
+
|
70 |
+
echo "pre-processing valid/test data..."
|
71 |
+
for l in $src $tgt; do
|
72 |
+
for o in `ls $orig/$lang/IWSLT14.TED*.$l.xml`; do
|
73 |
+
fname=${o##*/}
|
74 |
+
f=$tmp/${fname%.*}
|
75 |
+
echo $o $f
|
76 |
+
grep '<seg id' $o | \
|
77 |
+
sed -e 's/<seg id="[0-9]*">\s*//g' | \
|
78 |
+
sed -e 's/\s*<\/seg>\s*//g' | \
|
79 |
+
sed -e "s/\’/\'/g" | \
|
80 |
+
perl $TOKENIZER -threads 8 -l $l | \
|
81 |
+
perl $LC > $f
|
82 |
+
echo ""
|
83 |
+
done
|
84 |
+
done
|
85 |
+
|
86 |
+
|
87 |
+
echo "creating train, valid, test..."
|
88 |
+
for l in $src $tgt; do
|
89 |
+
awk '{if (NR%23 == 0) print $0; }' $tmp/train.tags.de-en.$l > $tmp/valid.$l
|
90 |
+
awk '{if (NR%23 != 0) print $0; }' $tmp/train.tags.de-en.$l > $tmp/train.$l
|
91 |
+
|
92 |
+
cat $tmp/IWSLT14.TED.dev2010.de-en.$l \
|
93 |
+
$tmp/IWSLT14.TEDX.dev2012.de-en.$l \
|
94 |
+
$tmp/IWSLT14.TED.tst2010.de-en.$l \
|
95 |
+
$tmp/IWSLT14.TED.tst2011.de-en.$l \
|
96 |
+
$tmp/IWSLT14.TED.tst2012.de-en.$l \
|
97 |
+
> $tmp/test.$l
|
98 |
+
done
|
99 |
+
|
100 |
+
TRAIN=$tmp/train.en-de
|
101 |
+
BPE_CODE=$prep/code
|
102 |
+
rm -f $TRAIN
|
103 |
+
for l in $src $tgt; do
|
104 |
+
cat $tmp/train.$l >> $TRAIN
|
105 |
+
done
|
106 |
+
|
107 |
+
echo "learn_bpe.py on ${TRAIN}..."
|
108 |
+
python $BPEROOT/learn_bpe.py -s $BPE_TOKENS < $TRAIN > $BPE_CODE
|
109 |
+
|
110 |
+
for L in $src $tgt; do
|
111 |
+
for f in train.$L valid.$L test.$L; do
|
112 |
+
echo "apply_bpe.py to ${f}..."
|
113 |
+
python $BPEROOT/apply_bpe.py -c $BPE_CODE < $tmp/$f > $prep/$f
|
114 |
+
done
|
115 |
+
done
|
fairseq/examples/translation/prepare-wmt14en2fr.sh
ADDED
@@ -0,0 +1,136 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
# Adapted from https://github.com/facebookresearch/MIXER/blob/master/prepareData.sh
|
3 |
+
|
4 |
+
echo 'Cloning Moses github repository (for tokenization scripts)...'
|
5 |
+
git clone https://github.com/moses-smt/mosesdecoder.git
|
6 |
+
|
7 |
+
echo 'Cloning Subword NMT repository (for BPE pre-processing)...'
|
8 |
+
git clone https://github.com/rsennrich/subword-nmt.git
|
9 |
+
|
10 |
+
SCRIPTS=mosesdecoder/scripts
|
11 |
+
TOKENIZER=$SCRIPTS/tokenizer/tokenizer.perl
|
12 |
+
CLEAN=$SCRIPTS/training/clean-corpus-n.perl
|
13 |
+
NORM_PUNC=$SCRIPTS/tokenizer/normalize-punctuation.perl
|
14 |
+
REM_NON_PRINT_CHAR=$SCRIPTS/tokenizer/remove-non-printing-char.perl
|
15 |
+
BPEROOT=subword-nmt/subword_nmt
|
16 |
+
BPE_TOKENS=40000
|
17 |
+
|
18 |
+
URLS=(
|
19 |
+
"http://statmt.org/wmt13/training-parallel-europarl-v7.tgz"
|
20 |
+
"http://statmt.org/wmt13/training-parallel-commoncrawl.tgz"
|
21 |
+
"http://statmt.org/wmt13/training-parallel-un.tgz"
|
22 |
+
"http://statmt.org/wmt14/training-parallel-nc-v9.tgz"
|
23 |
+
"http://statmt.org/wmt10/training-giga-fren.tar"
|
24 |
+
"http://statmt.org/wmt14/test-full.tgz"
|
25 |
+
)
|
26 |
+
FILES=(
|
27 |
+
"training-parallel-europarl-v7.tgz"
|
28 |
+
"training-parallel-commoncrawl.tgz"
|
29 |
+
"training-parallel-un.tgz"
|
30 |
+
"training-parallel-nc-v9.tgz"
|
31 |
+
"training-giga-fren.tar"
|
32 |
+
"test-full.tgz"
|
33 |
+
)
|
34 |
+
CORPORA=(
|
35 |
+
"training/europarl-v7.fr-en"
|
36 |
+
"commoncrawl.fr-en"
|
37 |
+
"un/undoc.2000.fr-en"
|
38 |
+
"training/news-commentary-v9.fr-en"
|
39 |
+
"giga-fren.release2.fixed"
|
40 |
+
)
|
41 |
+
|
42 |
+
if [ ! -d "$SCRIPTS" ]; then
|
43 |
+
echo "Please set SCRIPTS variable correctly to point to Moses scripts."
|
44 |
+
exit
|
45 |
+
fi
|
46 |
+
|
47 |
+
src=en
|
48 |
+
tgt=fr
|
49 |
+
lang=en-fr
|
50 |
+
prep=wmt14_en_fr
|
51 |
+
tmp=$prep/tmp
|
52 |
+
orig=orig
|
53 |
+
|
54 |
+
mkdir -p $orig $tmp $prep
|
55 |
+
|
56 |
+
cd $orig
|
57 |
+
|
58 |
+
for ((i=0;i<${#URLS[@]};++i)); do
|
59 |
+
file=${FILES[i]}
|
60 |
+
if [ -f $file ]; then
|
61 |
+
echo "$file already exists, skipping download"
|
62 |
+
else
|
63 |
+
url=${URLS[i]}
|
64 |
+
wget "$url"
|
65 |
+
if [ -f $file ]; then
|
66 |
+
echo "$url successfully downloaded."
|
67 |
+
else
|
68 |
+
echo "$url not successfully downloaded."
|
69 |
+
exit -1
|
70 |
+
fi
|
71 |
+
if [ ${file: -4} == ".tgz" ]; then
|
72 |
+
tar zxvf $file
|
73 |
+
elif [ ${file: -4} == ".tar" ]; then
|
74 |
+
tar xvf $file
|
75 |
+
fi
|
76 |
+
fi
|
77 |
+
done
|
78 |
+
|
79 |
+
gunzip giga-fren.release2.fixed.*.gz
|
80 |
+
cd ..
|
81 |
+
|
82 |
+
echo "pre-processing train data..."
|
83 |
+
for l in $src $tgt; do
|
84 |
+
rm $tmp/train.tags.$lang.tok.$l
|
85 |
+
for f in "${CORPORA[@]}"; do
|
86 |
+
cat $orig/$f.$l | \
|
87 |
+
perl $NORM_PUNC $l | \
|
88 |
+
perl $REM_NON_PRINT_CHAR | \
|
89 |
+
perl $TOKENIZER -threads 8 -a -l $l >> $tmp/train.tags.$lang.tok.$l
|
90 |
+
done
|
91 |
+
done
|
92 |
+
|
93 |
+
echo "pre-processing test data..."
|
94 |
+
for l in $src $tgt; do
|
95 |
+
if [ "$l" == "$src" ]; then
|
96 |
+
t="src"
|
97 |
+
else
|
98 |
+
t="ref"
|
99 |
+
fi
|
100 |
+
grep '<seg id' $orig/test-full/newstest2014-fren-$t.$l.sgm | \
|
101 |
+
sed -e 's/<seg id="[0-9]*">\s*//g' | \
|
102 |
+
sed -e 's/\s*<\/seg>\s*//g' | \
|
103 |
+
sed -e "s/\’/\'/g" | \
|
104 |
+
perl $TOKENIZER -threads 8 -a -l $l > $tmp/test.$l
|
105 |
+
echo ""
|
106 |
+
done
|
107 |
+
|
108 |
+
echo "splitting train and valid..."
|
109 |
+
for l in $src $tgt; do
|
110 |
+
awk '{if (NR%1333 == 0) print $0; }' $tmp/train.tags.$lang.tok.$l > $tmp/valid.$l
|
111 |
+
awk '{if (NR%1333 != 0) print $0; }' $tmp/train.tags.$lang.tok.$l > $tmp/train.$l
|
112 |
+
done
|
113 |
+
|
114 |
+
TRAIN=$tmp/train.fr-en
|
115 |
+
BPE_CODE=$prep/code
|
116 |
+
rm -f $TRAIN
|
117 |
+
for l in $src $tgt; do
|
118 |
+
cat $tmp/train.$l >> $TRAIN
|
119 |
+
done
|
120 |
+
|
121 |
+
echo "learn_bpe.py on ${TRAIN}..."
|
122 |
+
python $BPEROOT/learn_bpe.py -s $BPE_TOKENS < $TRAIN > $BPE_CODE
|
123 |
+
|
124 |
+
for L in $src $tgt; do
|
125 |
+
for f in train.$L valid.$L test.$L; do
|
126 |
+
echo "apply_bpe.py to ${f}..."
|
127 |
+
python $BPEROOT/apply_bpe.py -c $BPE_CODE < $tmp/$f > $tmp/bpe.$f
|
128 |
+
done
|
129 |
+
done
|
130 |
+
|
131 |
+
perl $CLEAN -ratio 1.5 $tmp/bpe.train $src $tgt $prep/train 1 250
|
132 |
+
perl $CLEAN -ratio 1.5 $tmp/bpe.valid $src $tgt $prep/valid 1 250
|
133 |
+
|
134 |
+
for L in $src $tgt; do
|
135 |
+
cp $tmp/bpe.test.$L $prep/test.$L
|
136 |
+
done
|
fairseq/examples/translation_moe/README.md
ADDED
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Mixture Models for Diverse Machine Translation: Tricks of the Trade (Shen et al., 2019)
|
2 |
+
|
3 |
+
This page includes instructions for reproducing results from the paper [Mixture Models for Diverse Machine Translation: Tricks of the Trade (Shen et al., 2019)](https://arxiv.org/abs/1902.07816).
|
4 |
+
|
5 |
+
## Download data
|
6 |
+
|
7 |
+
First, follow the [instructions to download and preprocess the WMT'17 En-De dataset](../translation#prepare-wmt14en2desh).
|
8 |
+
Make sure to learn a joint vocabulary by passing the `--joined-dictionary` option to `fairseq-preprocess`.
|
9 |
+
|
10 |
+
## Train a model
|
11 |
+
|
12 |
+
Then we can train a mixture of experts model using the `translation_moe` task.
|
13 |
+
Use the `--method` flag to choose the MoE variant; we support hard mixtures with a learned or uniform prior (`--method hMoElp` and `hMoEup`, respectively) and soft mixures (`--method sMoElp` and `sMoEup`).
|
14 |
+
The model is trained with online responsibility assignment and shared parameterization.
|
15 |
+
|
16 |
+
The following command will train a `hMoElp` model with `3` experts:
|
17 |
+
```bash
|
18 |
+
fairseq-train --ddp-backend='legacy_ddp' \
|
19 |
+
data-bin/wmt17_en_de \
|
20 |
+
--max-update 100000 \
|
21 |
+
--task translation_moe --user-dir examples/translation_moe/translation_moe_src \
|
22 |
+
--method hMoElp --mean-pool-gating-network \
|
23 |
+
--num-experts 3 \
|
24 |
+
--arch transformer_wmt_en_de --share-all-embeddings \
|
25 |
+
--optimizer adam --adam-betas '(0.9, 0.98)' --clip-norm 0.0 \
|
26 |
+
--lr-scheduler inverse_sqrt --warmup-init-lr 1e-07 --warmup-updates 4000 \
|
27 |
+
--lr 0.0007 \
|
28 |
+
--dropout 0.1 --weight-decay 0.0 --criterion cross_entropy \
|
29 |
+
--max-tokens 3584
|
30 |
+
```
|
31 |
+
|
32 |
+
## Translate
|
33 |
+
|
34 |
+
Once a model is trained, we can generate translations from different experts using the `--gen-expert` option.
|
35 |
+
For example, to generate from expert 0:
|
36 |
+
```bash
|
37 |
+
fairseq-generate data-bin/wmt17_en_de \
|
38 |
+
--path checkpoints/checkpoint_best.pt \
|
39 |
+
--beam 1 --remove-bpe \
|
40 |
+
--task translation_moe --user-dir examples/translation_moe/translation_moe_src \
|
41 |
+
--method hMoElp --mean-pool-gating-network \
|
42 |
+
--num-experts 3 \
|
43 |
+
--gen-expert 0
|
44 |
+
```
|
45 |
+
|
46 |
+
## Evaluate
|
47 |
+
|
48 |
+
First download a tokenized version of the WMT'14 En-De test set with multiple references:
|
49 |
+
```bash
|
50 |
+
wget dl.fbaipublicfiles.com/fairseq/data/wmt14-en-de.extra_refs.tok
|
51 |
+
```
|
52 |
+
|
53 |
+
Next apply BPE on the fly and run generation for each expert:
|
54 |
+
```bash
|
55 |
+
BPE_CODE=examples/translation/wmt17_en_de/code
|
56 |
+
for EXPERT in $(seq 0 2); do \
|
57 |
+
cat wmt14-en-de.extra_refs.tok \
|
58 |
+
| grep ^S | cut -f 2 \
|
59 |
+
| fairseq-interactive data-bin/wmt17_en_de \
|
60 |
+
--path checkpoints/checkpoint_best.pt \
|
61 |
+
--beam 1 \
|
62 |
+
--bpe subword_nmt --bpe-codes $BPE_CODE \
|
63 |
+
--buffer-size 500 --max-tokens 6000 \
|
64 |
+
--task translation_moe --user-dir examples/translation_moe/translation_moe_src \
|
65 |
+
--method hMoElp --mean-pool-gating-network \
|
66 |
+
--num-experts 3 \
|
67 |
+
--gen-expert $EXPERT ; \
|
68 |
+
done > wmt14-en-de.extra_refs.tok.gen.3experts
|
69 |
+
```
|
70 |
+
|
71 |
+
Finally use `score_moe.py` to compute pairwise BLUE and average oracle BLEU:
|
72 |
+
```bash
|
73 |
+
python examples/translation_moe/score.py --sys wmt14-en-de.extra_refs.tok.gen.3experts --ref wmt14-en-de.extra_refs.tok
|
74 |
+
# pairwise BLEU: 48.26
|
75 |
+
# #refs covered: 2.11
|
76 |
+
# multi-reference BLEU (leave-one-out): 59.46
|
77 |
+
```
|
78 |
+
This matches row 3 from Table 7 in the paper.
|
79 |
+
|
80 |
+
## Citation
|
81 |
+
|
82 |
+
```bibtex
|
83 |
+
@article{shen2019mixture,
|
84 |
+
title = {Mixture Models for Diverse Machine Translation: Tricks of the Trade},
|
85 |
+
author = {Tianxiao Shen and Myle Ott and Michael Auli and Marc'Aurelio Ranzato},
|
86 |
+
journal = {International Conference on Machine Learning},
|
87 |
+
year = 2019,
|
88 |
+
}
|
89 |
+
```
|
fairseq/examples/translation_moe/score.py
ADDED
@@ -0,0 +1,197 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the MIT license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
"""
|
7 |
+
Scoring script for computing pairwise BLEU and multi-ref BLEU over a set of
|
8 |
+
candidate hypotheses.
|
9 |
+
|
10 |
+
See `"Mixture Models for Diverse Machine Translation: Tricks of the Trade"
|
11 |
+
(Shen et al., 2019) <https://arxiv.org/abs/1902.07816>`_.
|
12 |
+
"""
|
13 |
+
|
14 |
+
import argparse
|
15 |
+
import random
|
16 |
+
import sys
|
17 |
+
from itertools import chain
|
18 |
+
|
19 |
+
import numpy as np
|
20 |
+
import sacrebleu
|
21 |
+
from sacrebleu import corpus_bleu as _corpus_bleu
|
22 |
+
|
23 |
+
def main():
|
24 |
+
parser = argparse.ArgumentParser(sys.argv[0])
|
25 |
+
parser.add_argument(
|
26 |
+
"--sys", nargs="*", default="", metavar="FILE", help="path to system output"
|
27 |
+
)
|
28 |
+
parser.add_argument("--ref", default="", metavar="FILE", help="path to references")
|
29 |
+
parser.add_argument(
|
30 |
+
"--output",
|
31 |
+
default="",
|
32 |
+
metavar="FILE",
|
33 |
+
help="print outputs into a pretty format",
|
34 |
+
)
|
35 |
+
args = parser.parse_args()
|
36 |
+
|
37 |
+
if args.sys:
|
38 |
+
src, tgt, hypos, log_probs = load_sys(args.sys)
|
39 |
+
print("pairwise BLEU: %.2f" % pairwise(hypos))
|
40 |
+
if args.output:
|
41 |
+
merge(src, tgt, hypos, log_probs, args.output)
|
42 |
+
|
43 |
+
if args.ref:
|
44 |
+
_, _, refs = load_ref(args.ref)
|
45 |
+
if args.sys:
|
46 |
+
multi_ref(refs, hypos)
|
47 |
+
else:
|
48 |
+
intra_ref(refs)
|
49 |
+
|
50 |
+
|
51 |
+
def dictolist(d):
|
52 |
+
a = sorted(d.items(), key=lambda i: i[0])
|
53 |
+
return [i[1] for i in a]
|
54 |
+
|
55 |
+
|
56 |
+
def load_sys(paths):
|
57 |
+
src, tgt, hypos, log_probs = {}, {}, {}, {}
|
58 |
+
for path in paths:
|
59 |
+
with open(path) as f:
|
60 |
+
for line in f:
|
61 |
+
line = line.rstrip()
|
62 |
+
# S: source
|
63 |
+
# T: target
|
64 |
+
# D: detokenized system output
|
65 |
+
if line.startswith(("S-", "T-", "D-")):
|
66 |
+
i = int(line[line.find("-") + 1 : line.find("\t")])
|
67 |
+
if line.startswith("S-"):
|
68 |
+
src[i] = line.split("\t")[1]
|
69 |
+
if line.startswith("T-"):
|
70 |
+
tgt[i] = line.split("\t")[1]
|
71 |
+
if line.startswith("D-"):
|
72 |
+
if i not in hypos:
|
73 |
+
hypos[i] = []
|
74 |
+
log_probs[i] = []
|
75 |
+
hypos[i].append(line.split("\t")[2])
|
76 |
+
log_probs[i].append(float(line.split("\t")[1]))
|
77 |
+
return dictolist(src), dictolist(tgt), dictolist(hypos), dictolist(log_probs)
|
78 |
+
|
79 |
+
|
80 |
+
def load_ref(path):
|
81 |
+
with open(path) as f:
|
82 |
+
lines = f.readlines()
|
83 |
+
src, tgt, refs = [], [], []
|
84 |
+
i = 0
|
85 |
+
while i < len(lines):
|
86 |
+
if lines[i].startswith("S-"):
|
87 |
+
src.append(lines[i].split("\t")[1].rstrip())
|
88 |
+
i += 1
|
89 |
+
elif lines[i].startswith("T-"):
|
90 |
+
tgt.append(lines[i].split("\t")[1].rstrip())
|
91 |
+
i += 1
|
92 |
+
else:
|
93 |
+
a = []
|
94 |
+
while i < len(lines) and lines[i].startswith("R"):
|
95 |
+
a.append(lines[i].split("\t")[1].rstrip())
|
96 |
+
i += 1
|
97 |
+
refs.append(a)
|
98 |
+
return src, tgt, refs
|
99 |
+
|
100 |
+
|
101 |
+
def merge(src, tgt, hypos, log_probs, path):
|
102 |
+
with open(path, "w") as f:
|
103 |
+
for s, t, hs, lps in zip(src, tgt, hypos, log_probs):
|
104 |
+
f.write(s + "\n")
|
105 |
+
f.write(t + "\n")
|
106 |
+
f.write("\n")
|
107 |
+
for h, lp in zip(hs, lps):
|
108 |
+
f.write("\t%f\t%s\n" % (lp, h.strip()))
|
109 |
+
f.write("------------------------------------------------------\n")
|
110 |
+
|
111 |
+
|
112 |
+
def corpus_bleu(sys_stream, ref_streams):
|
113 |
+
bleu = _corpus_bleu(sys_stream, ref_streams, tokenize="none")
|
114 |
+
return bleu.score
|
115 |
+
|
116 |
+
|
117 |
+
def sentence_bleu(hypothesis, reference):
|
118 |
+
bleu = _corpus_bleu(hypothesis, reference)
|
119 |
+
for i in range(1, 4):
|
120 |
+
bleu.counts[i] += 1
|
121 |
+
bleu.totals[i] += 1
|
122 |
+
bleu = sacrebleu.BLEU.compute_bleu(
|
123 |
+
bleu.counts,
|
124 |
+
bleu.totals,
|
125 |
+
bleu.sys_len,
|
126 |
+
bleu.ref_len,
|
127 |
+
smooth_method="exp",
|
128 |
+
)
|
129 |
+
return bleu.score
|
130 |
+
|
131 |
+
|
132 |
+
def pairwise(sents):
|
133 |
+
_ref, _hypo = [], []
|
134 |
+
for s in sents:
|
135 |
+
for i in range(len(s)):
|
136 |
+
for j in range(len(s)):
|
137 |
+
if i != j:
|
138 |
+
_ref.append(s[i])
|
139 |
+
_hypo.append(s[j])
|
140 |
+
return corpus_bleu(_hypo, [_ref])
|
141 |
+
|
142 |
+
|
143 |
+
def multi_ref(refs, hypos):
|
144 |
+
_ref, _hypo = [], []
|
145 |
+
ref_cnt = 0
|
146 |
+
assert len(refs) == len(hypos)
|
147 |
+
|
148 |
+
# count number of refs covered
|
149 |
+
for rs, hs in zip(refs, hypos):
|
150 |
+
a = set()
|
151 |
+
for h in hs:
|
152 |
+
s = [sentence_bleu(h, r) for r in rs]
|
153 |
+
j = np.argmax(s)
|
154 |
+
_ref.append(rs[j])
|
155 |
+
_hypo.append(h)
|
156 |
+
best = [k for k in range(len(rs)) if s[k] == s[j]]
|
157 |
+
a.add(random.choice(best))
|
158 |
+
ref_cnt += len(a)
|
159 |
+
print("#refs covered: %.2f" % (ref_cnt / len(refs)))
|
160 |
+
|
161 |
+
# transpose refs and hypos
|
162 |
+
refs = list(zip(*refs))
|
163 |
+
hypos = list(zip(*hypos))
|
164 |
+
|
165 |
+
# compute multi-ref corpus BLEU (leave-one-out to be comparable to intra_ref)
|
166 |
+
k = len(hypos)
|
167 |
+
m = len(refs)
|
168 |
+
flat_hypos = [hypos[j][i] for i in range(len(hypos[0])) for j in range(k)]
|
169 |
+
duplicated_refs = [[ref for ref in refs_i for _ in range(k)] for refs_i in refs]
|
170 |
+
loo_bleus = []
|
171 |
+
for held_out_ref in range(m):
|
172 |
+
remaining_refs = (
|
173 |
+
duplicated_refs[:held_out_ref] + duplicated_refs[held_out_ref + 1 :]
|
174 |
+
)
|
175 |
+
assert len(remaining_refs) == m - 1
|
176 |
+
loo_bleus.append(corpus_bleu(flat_hypos, remaining_refs))
|
177 |
+
print("average multi-reference BLEU (leave-one-out): %.2f" % np.mean(loo_bleus))
|
178 |
+
|
179 |
+
|
180 |
+
def intra_ref(refs):
|
181 |
+
print("ref pairwise BLEU: %.2f" % pairwise(refs))
|
182 |
+
refs = list(zip(*refs))
|
183 |
+
m = len(refs)
|
184 |
+
concat_h = []
|
185 |
+
concat_rest = [[] for j in range(m - 1)]
|
186 |
+
for i, h in enumerate(refs):
|
187 |
+
rest = refs[:i] + refs[i + 1 :]
|
188 |
+
concat_h.append(h)
|
189 |
+
for j in range(m - 1):
|
190 |
+
concat_rest[j].extend(rest[j])
|
191 |
+
concat_h = list(chain.from_iterable(concat_h))
|
192 |
+
bleu = corpus_bleu(concat_h, concat_rest)
|
193 |
+
print("multi-reference BLEU (leave-one-out): %.2f" % bleu)
|
194 |
+
|
195 |
+
|
196 |
+
if __name__ == "__main__":
|
197 |
+
main()
|
fairseq/examples/translation_moe/translation_moe_src/__init__.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
from . import translation_moe # noqa
|
fairseq/examples/translation_moe/translation_moe_src/logsumexp_moe.py
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import torch
|
7 |
+
|
8 |
+
|
9 |
+
class LogSumExpMoE(torch.autograd.Function):
|
10 |
+
"""Standard LogSumExp forward pass, but use *posterior* for the backward.
|
11 |
+
|
12 |
+
See `"Mixture Models for Diverse Machine Translation: Tricks of the Trade"
|
13 |
+
(Shen et al., 2019) <https://arxiv.org/abs/1902.07816>`_.
|
14 |
+
"""
|
15 |
+
|
16 |
+
@staticmethod
|
17 |
+
def forward(ctx, logp, posterior, dim=-1):
|
18 |
+
ctx.save_for_backward(posterior)
|
19 |
+
ctx.dim = dim
|
20 |
+
return torch.logsumexp(logp, dim=dim)
|
21 |
+
|
22 |
+
@staticmethod
|
23 |
+
def backward(ctx, grad_output):
|
24 |
+
(posterior,) = ctx.saved_tensors
|
25 |
+
grad_logp = grad_output.unsqueeze(ctx.dim) * posterior
|
26 |
+
return grad_logp, None, None
|
fairseq/examples/translation_moe/translation_moe_src/mean_pool_gating_network.py
ADDED
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch.nn.functional as F
|
8 |
+
|
9 |
+
|
10 |
+
class MeanPoolGatingNetwork(torch.nn.Module):
|
11 |
+
"""A simple mean-pooling gating network for selecting experts.
|
12 |
+
|
13 |
+
This module applies mean pooling over an encoder's output and returns
|
14 |
+
reponsibilities for each expert. The encoder format is expected to match
|
15 |
+
:class:`fairseq.models.transformer.TransformerEncoder`.
|
16 |
+
"""
|
17 |
+
|
18 |
+
def __init__(self, embed_dim, num_experts, dropout=None):
|
19 |
+
super().__init__()
|
20 |
+
self.embed_dim = embed_dim
|
21 |
+
self.num_experts = num_experts
|
22 |
+
|
23 |
+
self.fc1 = torch.nn.Linear(embed_dim, embed_dim)
|
24 |
+
self.dropout = torch.nn.Dropout(dropout) if dropout is not None else None
|
25 |
+
self.fc2 = torch.nn.Linear(embed_dim, num_experts)
|
26 |
+
|
27 |
+
def forward(self, encoder_out):
|
28 |
+
if not (
|
29 |
+
"encoder_out" in encoder_out
|
30 |
+
and "encoder_padding_mask" in encoder_out
|
31 |
+
and encoder_out["encoder_out"][0].size(2) == self.embed_dim
|
32 |
+
):
|
33 |
+
raise ValueError("Unexpected format for encoder_out")
|
34 |
+
|
35 |
+
# mean pooling over time
|
36 |
+
encoder_padding_mask = encoder_out["encoder_padding_mask"][0] # B x T
|
37 |
+
encoder_out = encoder_out["encoder_out"][0].transpose(0, 1) # B x T x C
|
38 |
+
if encoder_padding_mask is not None:
|
39 |
+
encoder_out = encoder_out.clone() # required because of transpose above
|
40 |
+
encoder_out[encoder_padding_mask] = 0
|
41 |
+
ntokens = torch.sum(~encoder_padding_mask, dim=1, keepdim=True)
|
42 |
+
x = torch.sum(encoder_out, dim=1) / ntokens.type_as(encoder_out)
|
43 |
+
else:
|
44 |
+
x = torch.mean(encoder_out, dim=1)
|
45 |
+
|
46 |
+
x = torch.tanh(self.fc1(x))
|
47 |
+
if self.dropout is not None:
|
48 |
+
x = self.dropout(x)
|
49 |
+
x = self.fc2(x)
|
50 |
+
return F.log_softmax(x, dim=-1, dtype=torch.float32).type_as(x)
|
fairseq/examples/translation_moe/translation_moe_src/translation_moe.py
ADDED
@@ -0,0 +1,259 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
from dataclasses import dataclass, field
|
7 |
+
import torch
|
8 |
+
from omegaconf import II
|
9 |
+
|
10 |
+
from fairseq import utils
|
11 |
+
from fairseq.logging import metrics
|
12 |
+
from fairseq.dataclass import ChoiceEnum
|
13 |
+
from fairseq.tasks import register_task
|
14 |
+
from fairseq.tasks.translation import TranslationConfig, TranslationTask
|
15 |
+
|
16 |
+
from .logsumexp_moe import LogSumExpMoE
|
17 |
+
from .mean_pool_gating_network import MeanPoolGatingNetwork
|
18 |
+
|
19 |
+
|
20 |
+
METHOD_CHOICES = ChoiceEnum(["sMoElp", "sMoEup", "hMoElp", "hMoEup"])
|
21 |
+
|
22 |
+
|
23 |
+
@dataclass
|
24 |
+
class TranslationMoEConfig(TranslationConfig):
|
25 |
+
method: METHOD_CHOICES = field(
|
26 |
+
default="hMoEup",
|
27 |
+
metadata={"help": "MoE method"},
|
28 |
+
)
|
29 |
+
num_experts: int = field(
|
30 |
+
default=3,
|
31 |
+
metadata={"help": "number of experts"},
|
32 |
+
)
|
33 |
+
mean_pool_gating_network: bool = field(
|
34 |
+
default=False,
|
35 |
+
metadata={"help": "use a simple mean-pooling gating network"},
|
36 |
+
)
|
37 |
+
mean_pool_gating_network_dropout: float = field(
|
38 |
+
default=0,
|
39 |
+
metadata={"help": "dropout for mean-pooling gating network"},
|
40 |
+
)
|
41 |
+
mean_pool_gating_network_encoder_dim: int = field(
|
42 |
+
default=0,
|
43 |
+
metadata={"help": "encoder output dim for mean-pooling gating network"},
|
44 |
+
)
|
45 |
+
gen_expert: int = field(
|
46 |
+
default=0,
|
47 |
+
metadata={"help": "which expert to use for generation"},
|
48 |
+
)
|
49 |
+
sentence_avg: bool = II("optimization.sentence_avg")
|
50 |
+
|
51 |
+
|
52 |
+
@register_task("translation_moe", dataclass=TranslationMoEConfig)
|
53 |
+
class TranslationMoETask(TranslationTask):
|
54 |
+
"""
|
55 |
+
Translation task for Mixture of Experts (MoE) models.
|
56 |
+
|
57 |
+
See `"Mixture Models for Diverse Machine Translation: Tricks of the Trade"
|
58 |
+
(Shen et al., 2019) <https://arxiv.org/abs/1902.07816>`_.
|
59 |
+
|
60 |
+
Args:
|
61 |
+
src_dict (~fairseq.data.Dictionary): dictionary for the source language
|
62 |
+
tgt_dict (~fairseq.data.Dictionary): dictionary for the target language
|
63 |
+
|
64 |
+
.. note::
|
65 |
+
|
66 |
+
The translation task is compatible with :mod:`fairseq-train`,
|
67 |
+
:mod:`fairseq-generate` and :mod:`fairseq-interactive`.
|
68 |
+
|
69 |
+
The translation task provides the following additional command-line
|
70 |
+
arguments:
|
71 |
+
|
72 |
+
.. argparse::
|
73 |
+
:ref: fairseq.tasks.translation_parser
|
74 |
+
:prog:
|
75 |
+
"""
|
76 |
+
|
77 |
+
cfg: TranslationMoEConfig
|
78 |
+
|
79 |
+
def __init__(self, cfg: TranslationMoEConfig, src_dict, tgt_dict):
|
80 |
+
if cfg.method == "sMoElp":
|
81 |
+
# soft MoE with learned prior
|
82 |
+
self.uniform_prior = False
|
83 |
+
self.hard_selection = False
|
84 |
+
elif cfg.method == "sMoEup":
|
85 |
+
# soft MoE with uniform prior
|
86 |
+
self.uniform_prior = True
|
87 |
+
self.hard_selection = False
|
88 |
+
elif cfg.method == "hMoElp":
|
89 |
+
# hard MoE with learned prior
|
90 |
+
self.uniform_prior = False
|
91 |
+
self.hard_selection = True
|
92 |
+
elif cfg.method == "hMoEup":
|
93 |
+
# hard MoE with uniform prior
|
94 |
+
self.uniform_prior = True
|
95 |
+
self.hard_selection = True
|
96 |
+
|
97 |
+
# add indicator tokens for each expert
|
98 |
+
for i in range(cfg.num_experts):
|
99 |
+
# add to both dictionaries in case we're sharing embeddings
|
100 |
+
src_dict.add_symbol("<expert_{}>".format(i))
|
101 |
+
tgt_dict.add_symbol("<expert_{}>".format(i))
|
102 |
+
|
103 |
+
super().__init__(cfg, src_dict, tgt_dict)
|
104 |
+
|
105 |
+
def build_model(self, cfg, from_checkpoint=False):
|
106 |
+
from fairseq import models
|
107 |
+
|
108 |
+
model = models.build_model(cfg, self)
|
109 |
+
if not self.uniform_prior and not hasattr(model, "gating_network"):
|
110 |
+
if self.cfg.mean_pool_gating_network:
|
111 |
+
if self.cfg.mean_pool_gating_network_encoder_dim > 0:
|
112 |
+
encoder_dim = self.cfg.mean_pool_gating_network_encoder_dim
|
113 |
+
elif getattr(cfg, "encoder_embed_dim", None):
|
114 |
+
# assume that encoder_embed_dim is the encoder's output dimension
|
115 |
+
encoder_dim = cfg.encoder_embed_dim
|
116 |
+
else:
|
117 |
+
raise ValueError(
|
118 |
+
"Must specify --mean-pool-gating-network-encoder-dim"
|
119 |
+
)
|
120 |
+
|
121 |
+
if self.cfg.mean_pool_gating_network_dropout > 0:
|
122 |
+
dropout = self.cfg.mean_pool_gating_network_dropout
|
123 |
+
elif getattr(cfg, "dropout", None):
|
124 |
+
dropout = cfg.dropout
|
125 |
+
else:
|
126 |
+
raise ValueError("Must specify task.mean_pool_gating_network_dropout")
|
127 |
+
|
128 |
+
model.gating_network = MeanPoolGatingNetwork(
|
129 |
+
encoder_dim,
|
130 |
+
self.cfg.num_experts,
|
131 |
+
dropout,
|
132 |
+
)
|
133 |
+
else:
|
134 |
+
raise ValueError(
|
135 |
+
"translation_moe task with learned prior requires the model to "
|
136 |
+
"have a gating network; try using --mean-pool-gating-network"
|
137 |
+
)
|
138 |
+
return model
|
139 |
+
|
140 |
+
def expert_index(self, i):
|
141 |
+
return i + self.tgt_dict.index("<expert_0>")
|
142 |
+
|
143 |
+
def _get_loss(self, sample, model, criterion):
|
144 |
+
assert hasattr(
|
145 |
+
criterion, "compute_loss"
|
146 |
+
), "translation_moe task requires the criterion to implement the compute_loss() method"
|
147 |
+
|
148 |
+
k = self.cfg.num_experts
|
149 |
+
bsz = sample["target"].size(0)
|
150 |
+
|
151 |
+
def get_lprob_y(encoder_out, prev_output_tokens_k):
|
152 |
+
net_output = model.decoder(
|
153 |
+
prev_output_tokens=prev_output_tokens_k,
|
154 |
+
encoder_out=encoder_out,
|
155 |
+
)
|
156 |
+
loss, _ = criterion.compute_loss(model, net_output, sample, reduce=False)
|
157 |
+
loss = loss.view(bsz, -1)
|
158 |
+
return -loss.sum(dim=1, keepdim=True) # -> B x 1
|
159 |
+
|
160 |
+
def get_lprob_yz(winners=None):
|
161 |
+
encoder_out = model.encoder(
|
162 |
+
src_tokens=sample["net_input"]["src_tokens"],
|
163 |
+
src_lengths=sample["net_input"]["src_lengths"],
|
164 |
+
)
|
165 |
+
|
166 |
+
if winners is None:
|
167 |
+
lprob_y = []
|
168 |
+
for i in range(k):
|
169 |
+
prev_output_tokens_k = sample["net_input"][
|
170 |
+
"prev_output_tokens"
|
171 |
+
].clone()
|
172 |
+
assert not prev_output_tokens_k.requires_grad
|
173 |
+
prev_output_tokens_k[:, 0] = self.expert_index(i)
|
174 |
+
lprob_y.append(get_lprob_y(encoder_out, prev_output_tokens_k))
|
175 |
+
lprob_y = torch.cat(lprob_y, dim=1) # -> B x K
|
176 |
+
else:
|
177 |
+
prev_output_tokens_k = sample["net_input"]["prev_output_tokens"].clone()
|
178 |
+
prev_output_tokens_k[:, 0] = self.expert_index(winners)
|
179 |
+
lprob_y = get_lprob_y(encoder_out, prev_output_tokens_k) # -> B
|
180 |
+
|
181 |
+
if self.uniform_prior:
|
182 |
+
lprob_yz = lprob_y
|
183 |
+
else:
|
184 |
+
lprob_z = model.gating_network(encoder_out) # B x K
|
185 |
+
if winners is not None:
|
186 |
+
lprob_z = lprob_z.gather(dim=1, index=winners.unsqueeze(-1))
|
187 |
+
lprob_yz = lprob_y + lprob_z.type_as(lprob_y) # B x K
|
188 |
+
|
189 |
+
return lprob_yz
|
190 |
+
|
191 |
+
# compute responsibilities without dropout
|
192 |
+
with utils.model_eval(model): # disable dropout
|
193 |
+
with torch.no_grad(): # disable autograd
|
194 |
+
lprob_yz = get_lprob_yz() # B x K
|
195 |
+
prob_z_xy = torch.nn.functional.softmax(lprob_yz, dim=1)
|
196 |
+
assert not prob_z_xy.requires_grad
|
197 |
+
|
198 |
+
# compute loss with dropout
|
199 |
+
if self.hard_selection:
|
200 |
+
winners = prob_z_xy.max(dim=1)[1]
|
201 |
+
loss = -get_lprob_yz(winners)
|
202 |
+
else:
|
203 |
+
lprob_yz = get_lprob_yz() # B x K
|
204 |
+
loss = -LogSumExpMoE.apply(lprob_yz, prob_z_xy, 1)
|
205 |
+
|
206 |
+
loss = loss.sum()
|
207 |
+
sample_size = (
|
208 |
+
sample["target"].size(0) if self.cfg.sentence_avg else sample["ntokens"]
|
209 |
+
)
|
210 |
+
logging_output = {
|
211 |
+
"loss": utils.item(loss.data),
|
212 |
+
"ntokens": sample["ntokens"],
|
213 |
+
"nsentences": bsz,
|
214 |
+
"sample_size": sample_size,
|
215 |
+
"posterior": prob_z_xy.float().sum(dim=0).cpu(),
|
216 |
+
}
|
217 |
+
return loss, sample_size, logging_output
|
218 |
+
|
219 |
+
def train_step(
|
220 |
+
self, sample, model, criterion, optimizer, update_num, ignore_grad=False
|
221 |
+
):
|
222 |
+
model.train()
|
223 |
+
loss, sample_size, logging_output = self._get_loss(sample, model, criterion)
|
224 |
+
if ignore_grad:
|
225 |
+
loss *= 0
|
226 |
+
optimizer.backward(loss)
|
227 |
+
return loss, sample_size, logging_output
|
228 |
+
|
229 |
+
def valid_step(self, sample, model, criterion):
|
230 |
+
model.eval()
|
231 |
+
with torch.no_grad():
|
232 |
+
loss, sample_size, logging_output = self._get_loss(sample, model, criterion)
|
233 |
+
return loss, sample_size, logging_output
|
234 |
+
|
235 |
+
def inference_step(
|
236 |
+
self,
|
237 |
+
generator,
|
238 |
+
models,
|
239 |
+
sample,
|
240 |
+
prefix_tokens=None,
|
241 |
+
expert=None,
|
242 |
+
constraints=None,
|
243 |
+
):
|
244 |
+
expert = expert or self.cfg.gen_expert
|
245 |
+
with torch.no_grad():
|
246 |
+
return generator.generate(
|
247 |
+
models,
|
248 |
+
sample,
|
249 |
+
prefix_tokens=prefix_tokens,
|
250 |
+
constraints=constraints,
|
251 |
+
bos_token=self.expert_index(expert),
|
252 |
+
)
|
253 |
+
|
254 |
+
def reduce_metrics(self, logging_outputs, criterion):
|
255 |
+
super().reduce_metrics(logging_outputs, criterion)
|
256 |
+
metrics.log_scalar(
|
257 |
+
"posterior",
|
258 |
+
sum(log["posterior"] for log in logging_outputs if "posterior" in log),
|
259 |
+
)
|
fairseq/examples/truncated_bptt/README.md
ADDED
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Truncated Backpropagation Through Time (BPTT)
|
2 |
+
|
3 |
+
Truncated BPTT is a useful technique for training language models on very long
|
4 |
+
sequences. Typically a long sequences is split into chunks and a language model
|
5 |
+
is trained over the chunks sequentially. The LM may condition on previous
|
6 |
+
chunks, but gradients only flow through the current chunk. This technique was
|
7 |
+
the basis for the paper: [Transformer-XL: Attentive Language Models Beyond a
|
8 |
+
Fixed-Length Context](https://arxiv.org/abs/1901.02860), which achieved
|
9 |
+
state-of-the-art language modeling results at the time of publication.
|
10 |
+
|
11 |
+
It is slightly tricky to implement Truncated BPTT efficiently in fairseq, since
|
12 |
+
we need to iterate over the data sequentially and disable any batch shuffling
|
13 |
+
logic. The code provided in this example illustrates how to implement Truncated
|
14 |
+
BPTT in fairseq by overriding ``FairseqTask::get_batch_iterator`` to iterate
|
15 |
+
over the data sequentially. Crucially, this example supports batching and
|
16 |
+
multi-GPU (data parallel) training.
|
17 |
+
|
18 |
+
##### 0. Setup
|
19 |
+
|
20 |
+
First, see the general [language modeling README](README.md) for instructions on
|
21 |
+
preprocessing the WikiText-103 data.
|
22 |
+
|
23 |
+
##### 1. Train a Transformer-XL model on WikiText-103
|
24 |
+
|
25 |
+
We will train a 16-layer Transformer-XL model following the [hyperparameters
|
26 |
+
used in the original
|
27 |
+
paper](https://github.com/kimiyoung/transformer-xl/blob/master/pytorch/run_wt103_base.sh).
|
28 |
+
|
29 |
+
The following command assumes 4 GPUs, so that the total batch size is 60
|
30 |
+
sequences (15 x 4). Training should take ~24 hours on 4 V100 GPUs:
|
31 |
+
```bash
|
32 |
+
CUDA_VISIBLE_DEVICES=0,1,2,3 fairseq-train \
|
33 |
+
--user-dir examples/truncated_bptt \
|
34 |
+
data-bin/wikitext-103/ \
|
35 |
+
--task truncated_bptt_lm --tokens-per-sample 150 \
|
36 |
+
--batch-size 15 --max-update 200000 \
|
37 |
+
--arch transformer_xl --n-layer 16 --d-model 410 --n-head 10 \
|
38 |
+
--d-head 41 --d-inner 2100 --dropout 0.1 --dropatt 0.0 --mem-len 150 \
|
39 |
+
--optimizer adam --clip-norm 0.25 \
|
40 |
+
--lr-scheduler cosine --warmup-updates 0 --min-lr 0.0 --lr 0.00025 \
|
41 |
+
--log-format json --log-interval 25 \
|
42 |
+
--fp16
|
43 |
+
```
|
44 |
+
|
45 |
+
If training on a single GPU, set `--update-freq=4` to accumulate 4x gradients
|
46 |
+
and simulate training on 4 GPUs.
|
47 |
+
|
48 |
+
##### 2. Evaluate
|
49 |
+
|
50 |
+
```bash
|
51 |
+
fairseq-eval-lm data-bin/wikitext-103/ \
|
52 |
+
--path checkpoints/checkpoint_best.pt \
|
53 |
+
--user-dir examples/truncated_bptt/ \
|
54 |
+
--task truncated_bptt_lm \
|
55 |
+
--batch-size 1 --required-batch-size-multiple 1 \
|
56 |
+
--model-overrides '{"mem_len":640,"clamp_len":400,"same_length":True}' \
|
57 |
+
--tokens-per-sample 64
|
58 |
+
# ... | INFO | fairseq_cli.eval_lm | num. model params: 151123537
|
59 |
+
# ... | INFO | fairseq_cli.eval_lm | Evaluated 245569 tokens in 83.1s (2956.82 tokens/s)
|
60 |
+
# ... | INFO | fairseq_cli.eval_lm | Loss (base 2): 4.5668, Perplexity: 23.70
|
61 |
+
# Compare to 24.0 test perplexity from the paper
|
62 |
+
```
|
63 |
+
|
64 |
+
*Note:* During training the model saw 150 tokens of context
|
65 |
+
(``--tokens-per-sample=150``) and 150 extra memory tokens (``--mem-len=150``).
|
66 |
+
During evaluation we measure perplexity on sequences of 64 tokens
|
67 |
+
(``--tokens-per-sample=64``) and increase the memory length
|
68 |
+
(``--model-overrides='{"mem_len":640}'``). These settings match the evaluation
|
69 |
+
settings from [the original
|
70 |
+
paper](https://github.com/kimiyoung/transformer-xl/blob/master/pytorch/run_wt103_base.sh).
|
fairseq/examples/truncated_bptt/__init__.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
from . import transformer_xl_model, truncated_bptt_lm_task # noqa
|
fairseq/examples/truncated_bptt/transformer_xl_model.py
ADDED
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import logging
|
7 |
+
from dataclasses import dataclass, field
|
8 |
+
from typing import Dict, List, Optional
|
9 |
+
|
10 |
+
import torch
|
11 |
+
from fairseq.dataclass import FairseqDataclass
|
12 |
+
from fairseq.models import (
|
13 |
+
FairseqIncrementalDecoder,
|
14 |
+
FairseqLanguageModel,
|
15 |
+
register_model,
|
16 |
+
)
|
17 |
+
from fairseq.modules.checkpoint_activations import checkpoint_wrapper
|
18 |
+
from omegaconf import II
|
19 |
+
|
20 |
+
|
21 |
+
logger = logging.getLogger(__name__)
|
22 |
+
|
23 |
+
|
24 |
+
@dataclass
|
25 |
+
class TransformerXLConfig(FairseqDataclass):
|
26 |
+
# defaults come from the original Transformer-XL code
|
27 |
+
cutoffs: List[int] = field(default_factory=lambda: [20000, 40000, 200000])
|
28 |
+
d_model: int = 500
|
29 |
+
n_head: int = 10
|
30 |
+
d_head: int = 50
|
31 |
+
d_inner: int = 1000
|
32 |
+
div_val: int = 1
|
33 |
+
n_layer: int = 12
|
34 |
+
mem_len: int = 0
|
35 |
+
clamp_len: int = -1
|
36 |
+
same_length: bool = False
|
37 |
+
dropout: float = 0.0
|
38 |
+
dropatt: float = 0.0
|
39 |
+
checkpoint_activations: bool = False
|
40 |
+
offload_activations: bool = False
|
41 |
+
max_target_positions: int = II("task.max_target_positions")
|
42 |
+
|
43 |
+
|
44 |
+
@register_model("transformer_xl", dataclass=TransformerXLConfig)
|
45 |
+
class TransformerXLLanguageModel(FairseqLanguageModel):
|
46 |
+
@classmethod
|
47 |
+
def build_model(cls, cfg: TransformerXLConfig, task):
|
48 |
+
return cls(TransformerXLDecoder(cfg, task))
|
49 |
+
|
50 |
+
|
51 |
+
class TransformerXLDecoder(FairseqIncrementalDecoder):
|
52 |
+
def __init__(self, cfg, task):
|
53 |
+
try:
|
54 |
+
from transformers.models.transfo_xl import (
|
55 |
+
TransfoXLConfig,
|
56 |
+
TransfoXLLMHeadModel,
|
57 |
+
)
|
58 |
+
except ImportError:
|
59 |
+
from transformers.configuration_transfo_xl import TransfoXLConfig
|
60 |
+
from transformers.modeling_transfo_xl import TransfoXLLMHeadModel
|
61 |
+
|
62 |
+
super().__init__(task.target_dictionary)
|
63 |
+
self.cfg = cfg
|
64 |
+
|
65 |
+
# remove any cutoffs larger than the vocab size
|
66 |
+
cutoffs = [
|
67 |
+
cutoff for cutoff in cfg.cutoffs if cutoff < len(task.target_dictionary)
|
68 |
+
]
|
69 |
+
|
70 |
+
config = TransfoXLConfig(
|
71 |
+
vocab_size=len(task.target_dictionary),
|
72 |
+
cutoffs=cutoffs,
|
73 |
+
d_model=cfg.d_model,
|
74 |
+
d_embed=cfg.d_model,
|
75 |
+
n_head=cfg.n_head,
|
76 |
+
d_head=cfg.d_head,
|
77 |
+
d_inner=cfg.d_inner,
|
78 |
+
div_val=cfg.div_val,
|
79 |
+
n_layer=cfg.n_layer,
|
80 |
+
mem_len=cfg.mem_len,
|
81 |
+
clamp_len=cfg.clamp_len,
|
82 |
+
same_length=cfg.same_length,
|
83 |
+
dropout=cfg.dropout,
|
84 |
+
dropatt=cfg.dropatt,
|
85 |
+
)
|
86 |
+
logger.info(config)
|
87 |
+
self.model = TransfoXLLMHeadModel(config)
|
88 |
+
|
89 |
+
if cfg.checkpoint_activations or cfg.offload_activations:
|
90 |
+
for i in range(len(self.model.transformer.layers)):
|
91 |
+
self.model.transformer.layers[i] = checkpoint_wrapper(
|
92 |
+
self.model.transformer.layers[i],
|
93 |
+
offload_to_cpu=cfg.offload_activations,
|
94 |
+
)
|
95 |
+
# TODO: may save mem to wrap(layer.pos_ff.CoreNet[3])
|
96 |
+
|
97 |
+
self._mems = None
|
98 |
+
|
99 |
+
def forward(
|
100 |
+
self,
|
101 |
+
src_tokens,
|
102 |
+
src_lengths=None, # unused
|
103 |
+
incremental_state: Optional[Dict[str, List[torch.Tensor]]] = None,
|
104 |
+
encoder_out=None,
|
105 |
+
):
|
106 |
+
if incremental_state is not None: # used during inference
|
107 |
+
mems = self.get_incremental_state(incremental_state, "mems")
|
108 |
+
src_tokens = src_tokens[:, -1:] # only keep the most recent token
|
109 |
+
else:
|
110 |
+
mems = self._mems
|
111 |
+
|
112 |
+
output = self.model(
|
113 |
+
input_ids=src_tokens,
|
114 |
+
mems=mems,
|
115 |
+
return_dict=False,
|
116 |
+
)
|
117 |
+
|
118 |
+
if len(output) >= 2:
|
119 |
+
if incremental_state is not None:
|
120 |
+
self.set_incremental_state(incremental_state, "mems", output[1])
|
121 |
+
else:
|
122 |
+
self._mems = output[1]
|
123 |
+
|
124 |
+
return (output[0],)
|
125 |
+
|
126 |
+
def max_positions(self):
|
127 |
+
return self.cfg.max_target_positions
|
128 |
+
|
129 |
+
def reorder_incremental_state(
|
130 |
+
self,
|
131 |
+
incremental_state: Dict[str, Dict[str, Optional[torch.Tensor]]],
|
132 |
+
new_order: torch.Tensor,
|
133 |
+
):
|
134 |
+
"""Reorder incremental state.
|
135 |
+
|
136 |
+
This will be called when the order of the input has changed from the
|
137 |
+
previous time step. A typical use case is beam search, where the input
|
138 |
+
order changes between time steps based on the selection of beams.
|
139 |
+
"""
|
140 |
+
mems = self.get_incremental_state(incremental_state, "mems")
|
141 |
+
if mems is not None:
|
142 |
+
new_mems = [mems_i.index_select(1, new_order) for mems_i in mems]
|
143 |
+
self.set_incremental_state(incremental_state, "mems", new_mems)
|
fairseq/examples/truncated_bptt/truncated_bptt_lm_task.py
ADDED
@@ -0,0 +1,285 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import logging
|
7 |
+
import os
|
8 |
+
from dataclasses import dataclass, field
|
9 |
+
from typing import List, Optional, Tuple
|
10 |
+
|
11 |
+
import torch
|
12 |
+
from fairseq import utils
|
13 |
+
from fairseq.data import (
|
14 |
+
Dictionary,
|
15 |
+
TokenBlockDataset,
|
16 |
+
data_utils,
|
17 |
+
iterators,
|
18 |
+
)
|
19 |
+
from fairseq.dataclass import FairseqDataclass
|
20 |
+
from fairseq.distributed import utils as dist_utils
|
21 |
+
from fairseq.tasks import FairseqTask, register_task
|
22 |
+
from omegaconf import II
|
23 |
+
|
24 |
+
|
25 |
+
logger = logging.getLogger(__name__)
|
26 |
+
|
27 |
+
|
28 |
+
@dataclass
|
29 |
+
class TruncatedBPTTLMConfig(FairseqDataclass):
|
30 |
+
data: str = field(default="???", metadata={"help": "path to data directory"})
|
31 |
+
tokens_per_sample: int = field(
|
32 |
+
default=1024, metadata={"help": "max number of tokens per sequence"},
|
33 |
+
)
|
34 |
+
batch_size: int = II("dataset.batch_size")
|
35 |
+
# Some models use *max_target_positions* to know how many positional
|
36 |
+
# embeddings to learn. We use II(...) to make it default to
|
37 |
+
# *tokens_per_sample*, but in principle there could be more positional
|
38 |
+
# embeddings than tokens in a single batch. This may also be irrelevant for
|
39 |
+
# custom model implementations.
|
40 |
+
max_target_positions: int = II("task.tokens_per_sample")
|
41 |
+
# these will be populated automatically if not provided
|
42 |
+
data_parallel_rank: Optional[int] = None
|
43 |
+
data_parallel_size: Optional[int] = None
|
44 |
+
|
45 |
+
|
46 |
+
@register_task("truncated_bptt_lm", dataclass=TruncatedBPTTLMConfig)
|
47 |
+
class TruncatedBPTTLMTask(FairseqTask):
|
48 |
+
def __init__(self, cfg: TruncatedBPTTLMConfig):
|
49 |
+
super().__init__(cfg)
|
50 |
+
|
51 |
+
if cfg.data_parallel_rank is None or cfg.data_parallel_size is None:
|
52 |
+
if torch.distributed.is_initialized():
|
53 |
+
cfg.data_parallel_rank = dist_utils.get_data_parallel_rank()
|
54 |
+
cfg.data_parallel_size = dist_utils.get_data_parallel_world_size()
|
55 |
+
else:
|
56 |
+
cfg.data_parallel_rank = 0
|
57 |
+
cfg.data_parallel_size = 1
|
58 |
+
|
59 |
+
# load the dictionary
|
60 |
+
paths = utils.split_paths(cfg.data)
|
61 |
+
assert len(paths) > 0
|
62 |
+
self.dictionary = Dictionary.load(os.path.join(paths[0], "dict.txt"))
|
63 |
+
logger.info("dictionary: {} types".format(len(self.dictionary)))
|
64 |
+
|
65 |
+
def load_dataset(self, split, epoch=1, combine=False, **kwargs):
|
66 |
+
"""Load a given dataset split (e.g., train, valid, test)"""
|
67 |
+
|
68 |
+
# support sharded datasets
|
69 |
+
paths = utils.split_paths(self.cfg.data)
|
70 |
+
assert len(paths) > 0
|
71 |
+
data_path = paths[(epoch - 1) % len(paths)]
|
72 |
+
split_path = os.path.join(data_path, split)
|
73 |
+
|
74 |
+
# each element of *data* will be a tensorized line from the original
|
75 |
+
# text dataset, similar to ``open(split_path).readlines()``
|
76 |
+
data = data_utils.load_indexed_dataset(
|
77 |
+
split_path, self.dictionary, combine=combine
|
78 |
+
)
|
79 |
+
if data is None:
|
80 |
+
raise FileNotFoundError(
|
81 |
+
"Dataset not found: {} ({})".format(split, split_path)
|
82 |
+
)
|
83 |
+
|
84 |
+
# this is similar to ``data.view(-1).split(tokens_per_sample)``
|
85 |
+
data = TokenBlockDataset(
|
86 |
+
data,
|
87 |
+
data.sizes,
|
88 |
+
block_size=self.cfg.tokens_per_sample,
|
89 |
+
pad=None, # unused
|
90 |
+
eos=None, # unused
|
91 |
+
break_mode="none",
|
92 |
+
)
|
93 |
+
|
94 |
+
self.datasets[split] = TruncatedBPTTDataset(
|
95 |
+
data=data,
|
96 |
+
bsz_per_shard=self.cfg.batch_size,
|
97 |
+
shard_id=self.cfg.data_parallel_rank,
|
98 |
+
num_shards=self.cfg.data_parallel_size,
|
99 |
+
)
|
100 |
+
|
101 |
+
def dataset(self, split):
|
102 |
+
return self.datasets[split]
|
103 |
+
|
104 |
+
def get_batch_iterator(
|
105 |
+
self,
|
106 |
+
dataset,
|
107 |
+
num_workers=0,
|
108 |
+
epoch=1,
|
109 |
+
data_buffer_size=0,
|
110 |
+
skip_remainder_batch=False,
|
111 |
+
**kwargs
|
112 |
+
):
|
113 |
+
return iterators.EpochBatchIterator(
|
114 |
+
dataset=dataset,
|
115 |
+
collate_fn=self._collate_fn,
|
116 |
+
num_workers=num_workers,
|
117 |
+
epoch=epoch,
|
118 |
+
buffer_size=data_buffer_size,
|
119 |
+
# we don't use the batching functionality from EpochBatchIterator;
|
120 |
+
# instead every item in *dataset* is a whole batch
|
121 |
+
batch_sampler=[[i] for i in range(len(dataset))],
|
122 |
+
disable_shuffling=True,
|
123 |
+
skip_remainder_batch=skip_remainder_batch,
|
124 |
+
)
|
125 |
+
|
126 |
+
def _collate_fn(self, items: List[List[torch.Tensor]]):
|
127 |
+
# we don't use fairseq's batching functionality, so we expect a single
|
128 |
+
# Tensor of type List[torch.Tensor]
|
129 |
+
assert len(items) == 1
|
130 |
+
|
131 |
+
# item will have shape B x T (the last batch may have length < T)
|
132 |
+
id, item = items[0]
|
133 |
+
item = data_utils.collate_tokens(item, pad_idx=self.source_dictionary.pad())
|
134 |
+
B, T = item.size()
|
135 |
+
|
136 |
+
# shift item one position over and append a padding token for the target
|
137 |
+
target = torch.nn.functional.pad(
|
138 |
+
item[:, 1:], (0, 1, 0, 0), value=self.target_dictionary.pad()
|
139 |
+
)
|
140 |
+
|
141 |
+
# fairseq expects batches to have the following structure
|
142 |
+
return {
|
143 |
+
"id": torch.tensor([id] * item.size(0)),
|
144 |
+
"net_input": {"src_tokens": item,},
|
145 |
+
"target": target,
|
146 |
+
"nsentences": item.size(0),
|
147 |
+
"ntokens": item.numel(),
|
148 |
+
}
|
149 |
+
|
150 |
+
def build_dataset_for_inference(
|
151 |
+
self, src_tokens: List[torch.Tensor], src_lengths: List[int], **kwargs
|
152 |
+
) -> torch.utils.data.Dataset:
|
153 |
+
eos = self.source_dictionary.eos()
|
154 |
+
dataset = TokenBlockDataset(
|
155 |
+
src_tokens,
|
156 |
+
src_lengths,
|
157 |
+
block_size=None, # ignored for "eos" break mode
|
158 |
+
pad=self.source_dictionary.pad(),
|
159 |
+
eos=eos,
|
160 |
+
break_mode="eos",
|
161 |
+
)
|
162 |
+
|
163 |
+
class Dataset(torch.utils.data.Dataset):
|
164 |
+
def __getitem__(self, i):
|
165 |
+
item = dataset[i]
|
166 |
+
if item[-1] == eos:
|
167 |
+
# remove eos to support generating with a prefix
|
168 |
+
item = item[:-1]
|
169 |
+
return (i, [item])
|
170 |
+
|
171 |
+
def __len__(self):
|
172 |
+
return len(dataset)
|
173 |
+
|
174 |
+
return Dataset()
|
175 |
+
|
176 |
+
def inference_step(
|
177 |
+
self, generator, models, sample, prefix_tokens=None, constraints=None
|
178 |
+
):
|
179 |
+
with torch.no_grad():
|
180 |
+
if constraints is not None:
|
181 |
+
raise NotImplementedError
|
182 |
+
|
183 |
+
# SequenceGenerator doesn't use *src_tokens* directly, we need to
|
184 |
+
# pass the *prefix_tokens* argument instead.
|
185 |
+
if prefix_tokens is None and sample["net_input"]["src_tokens"].nelement():
|
186 |
+
prefix_tokens = sample["net_input"]["src_tokens"]
|
187 |
+
|
188 |
+
# begin generation with the end-of-sentence token
|
189 |
+
bos_token = self.source_dictionary.eos()
|
190 |
+
|
191 |
+
return generator.generate(
|
192 |
+
models, sample, prefix_tokens=prefix_tokens, bos_token=bos_token
|
193 |
+
)
|
194 |
+
|
195 |
+
def eval_lm_dataloader(
|
196 |
+
self,
|
197 |
+
dataset,
|
198 |
+
max_tokens: Optional[int] = 36000,
|
199 |
+
batch_size: Optional[int] = None,
|
200 |
+
max_positions: Optional[int] = None,
|
201 |
+
num_shards: int = 1,
|
202 |
+
shard_id: int = 0,
|
203 |
+
num_workers: int = 1,
|
204 |
+
data_buffer_size: int = 10,
|
205 |
+
context_window: int = 0,
|
206 |
+
):
|
207 |
+
if context_window > 0:
|
208 |
+
raise NotImplementedError(
|
209 |
+
"Transformer-XL doesn't need --context-window, try "
|
210 |
+
"--model-overrides '{\"mem_len\":42}' instead "
|
211 |
+
)
|
212 |
+
return self.get_batch_iterator(
|
213 |
+
dataset=dataset,
|
214 |
+
max_tokens=max_tokens,
|
215 |
+
max_sentences=batch_size,
|
216 |
+
max_positions=max_positions,
|
217 |
+
ignore_invalid_inputs=True,
|
218 |
+
num_shards=num_shards,
|
219 |
+
shard_id=shard_id,
|
220 |
+
num_workers=num_workers,
|
221 |
+
data_buffer_size=data_buffer_size,
|
222 |
+
).next_epoch_itr(shuffle=False)
|
223 |
+
|
224 |
+
@property
|
225 |
+
def source_dictionary(self):
|
226 |
+
return self.dictionary
|
227 |
+
|
228 |
+
@property
|
229 |
+
def target_dictionary(self):
|
230 |
+
return self.dictionary
|
231 |
+
|
232 |
+
|
233 |
+
class TruncatedBPTTDataset(torch.utils.data.Dataset):
|
234 |
+
def __init__(
|
235 |
+
self,
|
236 |
+
data: List[torch.Tensor], # ordered list of items
|
237 |
+
bsz_per_shard, # number of items processed per GPUs per forward
|
238 |
+
shard_id, # current GPU ID
|
239 |
+
num_shards, # number of GPUs
|
240 |
+
):
|
241 |
+
super().__init__()
|
242 |
+
self.data = data
|
243 |
+
|
244 |
+
def batchify(data, bsz):
|
245 |
+
# Work out how cleanly we can divide the dataset into bsz parts.
|
246 |
+
nbatch = data.size(0) // bsz
|
247 |
+
# Trim off any extra elements that wouldn't cleanly fit (remainders).
|
248 |
+
data = data.narrow(0, 0, nbatch * bsz)
|
249 |
+
# Evenly divide the data across the bsz batches.
|
250 |
+
data = data.view(bsz, -1).contiguous()
|
251 |
+
return data
|
252 |
+
|
253 |
+
# total number of sequences processed by all GPUs in each forward pass
|
254 |
+
global_batch_size = bsz_per_shard * num_shards
|
255 |
+
|
256 |
+
"""
|
257 |
+
With a 16 item dataset, bsz_per_shard=2 and num_shards=3,
|
258 |
+
*indices* might look like:
|
259 |
+
|
260 |
+
indices = [[0, 1],
|
261 |
+
[2, 3],
|
262 |
+
[4, 5],
|
263 |
+
[6, 7],
|
264 |
+
[8, 9],
|
265 |
+
[10, 11]]
|
266 |
+
|
267 |
+
The size of the TruncatedBPTTDataset instance will be 2,
|
268 |
+
and shard 1 will see items:
|
269 |
+
|
270 |
+
[(0, [data[4], data[6]]),
|
271 |
+
(1, [data[5], data[7]])]
|
272 |
+
"""
|
273 |
+
indices = batchify(torch.arange(len(data)), global_batch_size)
|
274 |
+
assert indices.size(0) == global_batch_size
|
275 |
+
|
276 |
+
self.my_indices = indices[
|
277 |
+
shard_id * bsz_per_shard : (shard_id + 1) * bsz_per_shard
|
278 |
+
]
|
279 |
+
assert self.my_indices.size(0) == bsz_per_shard
|
280 |
+
|
281 |
+
def __len__(self):
|
282 |
+
return self.my_indices.size(1)
|
283 |
+
|
284 |
+
def __getitem__(self, i) -> Tuple[int, List[torch.Tensor]]:
|
285 |
+
return (i, [self.data[idx] for idx in self.my_indices[:, i]])
|
fairseq/examples/unsupervised_quality_estimation/aggregate_scores.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import argparse
|
7 |
+
import sys
|
8 |
+
|
9 |
+
import numpy as np
|
10 |
+
|
11 |
+
|
12 |
+
aggregate_funcs = {
|
13 |
+
"std": np.std,
|
14 |
+
"var": np.var,
|
15 |
+
"median": np.median,
|
16 |
+
"mean": np.mean,
|
17 |
+
"min": np.min,
|
18 |
+
"max": np.max,
|
19 |
+
}
|
20 |
+
|
21 |
+
|
22 |
+
def main():
|
23 |
+
parser = argparse.ArgumentParser()
|
24 |
+
parser.add_argument("-i", "--input_file", required=True, type=str)
|
25 |
+
parser.add_argument("-n", "--repeat_times", required=True, type=int)
|
26 |
+
parser.add_argument("-o", "--output_file", required=False)
|
27 |
+
parser.add_argument("-f", "--func", required=False, default="mean")
|
28 |
+
args = parser.parse_args()
|
29 |
+
|
30 |
+
stream = open(args.output_file, "w") if args.output_file else sys.stdout
|
31 |
+
|
32 |
+
segment_scores = []
|
33 |
+
for line in open(args.input_file):
|
34 |
+
segment_scores.append(float(line.strip()))
|
35 |
+
if len(segment_scores) == args.repeat_times:
|
36 |
+
stream.write("{}\n".format(aggregate_funcs[args.func](segment_scores)))
|
37 |
+
segment_scores = []
|
38 |
+
|
39 |
+
|
40 |
+
if __name__ == "__main__":
|
41 |
+
main()
|
fairseq/examples/unsupervised_quality_estimation/meteor.py
ADDED
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import argparse
|
7 |
+
import math
|
8 |
+
import os
|
9 |
+
import subprocess
|
10 |
+
import sys
|
11 |
+
import tempfile
|
12 |
+
from collections import defaultdict
|
13 |
+
from itertools import combinations
|
14 |
+
|
15 |
+
|
16 |
+
def read_translations(path, n_repeats):
|
17 |
+
segment_counter = 0
|
18 |
+
segment_translations = []
|
19 |
+
translations = defaultdict(list)
|
20 |
+
for line in open(path):
|
21 |
+
segment_translations.append(" ".join(line.split()))
|
22 |
+
if len(segment_translations) == n_repeats:
|
23 |
+
translations[segment_counter] = segment_translations
|
24 |
+
segment_translations = []
|
25 |
+
segment_counter += 1
|
26 |
+
return translations
|
27 |
+
|
28 |
+
|
29 |
+
def generate_input(translations, n_repeats):
|
30 |
+
_, ref_path = tempfile.mkstemp()
|
31 |
+
_, mt_path = tempfile.mkstemp()
|
32 |
+
ref_fh = open(ref_path, "w")
|
33 |
+
mt_fh = open(mt_path, "w")
|
34 |
+
for segid in sorted(translations.keys()):
|
35 |
+
assert len(translations[segid]) == n_repeats
|
36 |
+
indexes = combinations(range(n_repeats), 2)
|
37 |
+
for idx1, idx2 in indexes:
|
38 |
+
mt_fh.write(translations[segid][idx1].strip() + "\n")
|
39 |
+
ref_fh.write(translations[segid][idx2].strip() + "\n")
|
40 |
+
sys.stderr.write("\nSaved translations to %s and %s" % (ref_path, mt_path))
|
41 |
+
return ref_path, mt_path
|
42 |
+
|
43 |
+
|
44 |
+
def run_meteor(ref_path, mt_path, metric_path, lang="en"):
|
45 |
+
_, out_path = tempfile.mkstemp()
|
46 |
+
subprocess.call(
|
47 |
+
[
|
48 |
+
"java",
|
49 |
+
"-Xmx2G",
|
50 |
+
"-jar",
|
51 |
+
metric_path,
|
52 |
+
mt_path,
|
53 |
+
ref_path,
|
54 |
+
"-p",
|
55 |
+
"0.5 0.2 0.6 0.75", # default parameters, only changed alpha to give equal weight to P and R
|
56 |
+
"-norm",
|
57 |
+
"-l",
|
58 |
+
lang,
|
59 |
+
],
|
60 |
+
stdout=open(out_path, "w"),
|
61 |
+
)
|
62 |
+
os.remove(ref_path)
|
63 |
+
os.remove(mt_path)
|
64 |
+
sys.stderr.write("\nSaved Meteor output to %s" % out_path)
|
65 |
+
return out_path
|
66 |
+
|
67 |
+
|
68 |
+
def read_output(meteor_output_path, n_repeats):
|
69 |
+
n_combinations = math.factorial(n_repeats) / (
|
70 |
+
math.factorial(2) * math.factorial(n_repeats - 2)
|
71 |
+
)
|
72 |
+
raw_scores = []
|
73 |
+
average_scores = []
|
74 |
+
for line in open(meteor_output_path):
|
75 |
+
if not line.startswith("Segment "):
|
76 |
+
continue
|
77 |
+
score = float(line.strip().split("\t")[1])
|
78 |
+
raw_scores.append(score)
|
79 |
+
if len(raw_scores) == n_combinations:
|
80 |
+
average_scores.append(sum(raw_scores) / n_combinations)
|
81 |
+
raw_scores = []
|
82 |
+
os.remove(meteor_output_path)
|
83 |
+
return average_scores
|
84 |
+
|
85 |
+
|
86 |
+
def main():
|
87 |
+
parser = argparse.ArgumentParser()
|
88 |
+
parser.add_argument("-i", "--infile")
|
89 |
+
parser.add_argument("-n", "--repeat_times", type=int)
|
90 |
+
parser.add_argument("-m", "--meteor")
|
91 |
+
parser.add_argument("-o", "--output")
|
92 |
+
args = parser.parse_args()
|
93 |
+
|
94 |
+
translations = read_translations(args.infile, args.repeat_times)
|
95 |
+
sys.stderr.write("\nGenerating input for Meteor...")
|
96 |
+
ref_path, mt_path = generate_input(translations, args.repeat_times)
|
97 |
+
sys.stderr.write("\nRunning Meteor...")
|
98 |
+
out_path = run_meteor(ref_path, mt_path, args.meteor)
|
99 |
+
sys.stderr.write("\nReading output...")
|
100 |
+
scores = read_output(out_path, args.repeat_times)
|
101 |
+
sys.stderr.write("\nWriting results...")
|
102 |
+
with open(args.output, "w") as o:
|
103 |
+
for scr in scores:
|
104 |
+
o.write("{}\n".format(scr))
|
105 |
+
o.close()
|
106 |
+
|
107 |
+
|
108 |
+
if __name__ == "__main__":
|
109 |
+
main()
|
fairseq/examples/unsupervised_quality_estimation/repeat_lines.py
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import argparse
|
7 |
+
import sys
|
8 |
+
|
9 |
+
|
10 |
+
def _normalize_spaces(line):
|
11 |
+
return " ".join(line.split())
|
12 |
+
|
13 |
+
|
14 |
+
def main():
|
15 |
+
parser = argparse.ArgumentParser()
|
16 |
+
parser.add_argument("-i", "--input_file", required=True, type=str)
|
17 |
+
parser.add_argument("-n", "--repeat_times", required=True, type=int)
|
18 |
+
parser.add_argument("-o", "--output_file", required=False, type=str)
|
19 |
+
args = parser.parse_args()
|
20 |
+
stream = open(args.output_file, "w") if args.output_file else sys.stdout
|
21 |
+
|
22 |
+
for line in open(args.input_file):
|
23 |
+
for _ in range(args.repeat_times):
|
24 |
+
stream.write(_normalize_spaces(line) + "\n")
|
25 |
+
|
26 |
+
|
27 |
+
if __name__ == "__main__":
|
28 |
+
main()
|
fairseq/examples/wav2vec/__init__.py
ADDED
File without changes
|
fairseq/examples/wav2vec/config/finetuning/base_10m.yaml
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _group_
|
2 |
+
|
3 |
+
common:
|
4 |
+
fp16: true
|
5 |
+
log_format: json
|
6 |
+
log_interval: 200
|
7 |
+
|
8 |
+
checkpoint:
|
9 |
+
save_interval: 1000
|
10 |
+
save_interval_updates: 50
|
11 |
+
keep_interval_updates: 1
|
12 |
+
no_epoch_checkpoints: true
|
13 |
+
best_checkpoint_metric: wer
|
14 |
+
|
15 |
+
task:
|
16 |
+
_name: audio_finetuning
|
17 |
+
data: ???
|
18 |
+
normalize: false
|
19 |
+
labels: ltr
|
20 |
+
|
21 |
+
dataset:
|
22 |
+
num_workers: 6
|
23 |
+
max_tokens: 3200000
|
24 |
+
skip_invalid_size_inputs_valid_test: true
|
25 |
+
validate_after_updates: 10000
|
26 |
+
validate_interval: 1000
|
27 |
+
valid_subset: dev_other
|
28 |
+
|
29 |
+
distributed_training:
|
30 |
+
ddp_backend: legacy_ddp
|
31 |
+
distributed_world_size: 2
|
32 |
+
|
33 |
+
criterion:
|
34 |
+
_name: ctc
|
35 |
+
zero_infinity: true
|
36 |
+
|
37 |
+
optimization:
|
38 |
+
max_update: 13000
|
39 |
+
lr: [0.00005]
|
40 |
+
sentence_avg: true
|
41 |
+
update_freq: [4]
|
42 |
+
|
43 |
+
optimizer:
|
44 |
+
_name: adam
|
45 |
+
adam_betas: (0.9,0.98)
|
46 |
+
adam_eps: 1e-08
|
47 |
+
|
48 |
+
lr_scheduler:
|
49 |
+
_name: tri_stage
|
50 |
+
phase_ratio: [0.1, 0.4, 0.5]
|
51 |
+
final_lr_scale: 0.05
|
52 |
+
|
53 |
+
model:
|
54 |
+
_name: wav2vec_ctc
|
55 |
+
w2v_path: ???
|
56 |
+
apply_mask: true
|
57 |
+
mask_prob: 0.65
|
58 |
+
mask_channel_prob: 0.25
|
59 |
+
mask_channel_length: 64
|
60 |
+
layerdrop: 0.1
|
61 |
+
activation_dropout: 0.1
|
62 |
+
feature_grad_mult: 0.0
|
63 |
+
freeze_finetune_updates: 10000
|
fairseq/examples/wav2vec/config/finetuning/base_1h.yaml
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _group_
|
2 |
+
|
3 |
+
common:
|
4 |
+
fp16: true
|
5 |
+
log_format: json
|
6 |
+
log_interval: 200
|
7 |
+
|
8 |
+
checkpoint:
|
9 |
+
save_interval: 50
|
10 |
+
save_interval_updates: 1000
|
11 |
+
keep_interval_updates: 1
|
12 |
+
no_epoch_checkpoints: true
|
13 |
+
best_checkpoint_metric: wer
|
14 |
+
|
15 |
+
task:
|
16 |
+
_name: audio_finetuning
|
17 |
+
data: ???
|
18 |
+
normalize: false
|
19 |
+
labels: ltr
|
20 |
+
|
21 |
+
dataset:
|
22 |
+
num_workers: 6
|
23 |
+
max_tokens: 3200000
|
24 |
+
skip_invalid_size_inputs_valid_test: true
|
25 |
+
validate_after_updates: 10000
|
26 |
+
validate_interval: 1000
|
27 |
+
valid_subset: dev_other
|
28 |
+
|
29 |
+
distributed_training:
|
30 |
+
ddp_backend: legacy_ddp
|
31 |
+
distributed_world_size: 2
|
32 |
+
|
33 |
+
criterion:
|
34 |
+
_name: ctc
|
35 |
+
zero_infinity: true
|
36 |
+
|
37 |
+
optimization:
|
38 |
+
max_update: 13000
|
39 |
+
lr: [0.00005]
|
40 |
+
sentence_avg: true
|
41 |
+
update_freq: [4]
|
42 |
+
|
43 |
+
optimizer:
|
44 |
+
_name: adam
|
45 |
+
adam_betas: (0.9,0.98)
|
46 |
+
adam_eps: 1e-08
|
47 |
+
|
48 |
+
lr_scheduler:
|
49 |
+
_name: tri_stage
|
50 |
+
phase_ratio: [0.1, 0.4, 0.5]
|
51 |
+
final_lr_scale: 0.05
|
52 |
+
|
53 |
+
model:
|
54 |
+
_name: wav2vec_ctc
|
55 |
+
w2v_path: ???
|
56 |
+
apply_mask: true
|
57 |
+
mask_prob: 0.65
|
58 |
+
mask_channel_prob: 0.25
|
59 |
+
mask_channel_length: 64
|
60 |
+
layerdrop: 0.1
|
61 |
+
activation_dropout: 0.1
|
62 |
+
feature_grad_mult: 0.0
|
63 |
+
freeze_finetune_updates: 10000
|
fairseq/examples/wav2vec/config/finetuning/run_config/slurm_1.yaml
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _global_
|
2 |
+
|
3 |
+
hydra:
|
4 |
+
job:
|
5 |
+
config:
|
6 |
+
override_dirname:
|
7 |
+
kv_sep: ':'
|
8 |
+
item_sep: '__'
|
9 |
+
exclude_keys:
|
10 |
+
- run_config
|
11 |
+
- distributed_training.distributed_port
|
12 |
+
sweep:
|
13 |
+
dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
|
14 |
+
subdir: ${hydra.job.num}
|
15 |
+
launcher:
|
16 |
+
submitit_folder: ${hydra.sweep.dir}
|
17 |
+
timeout_min: 4320
|
18 |
+
cpus_per_task: 10
|
19 |
+
gpus_per_node: 8
|
20 |
+
tasks_per_node: 8
|
21 |
+
mem_gb: 450
|
22 |
+
nodes: 1
|
23 |
+
name: ${env:PREFIX}_${hydra.job.config_name}
|
24 |
+
partition: devlab,learnlab,learnfair,scavenge
|
25 |
+
constraint: volta32gb
|
26 |
+
max_num_timeout: 30
|
fairseq/examples/wav2vec/config/finetuning/run_config/slurm_16.yaml
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _global_
|
2 |
+
|
3 |
+
hydra:
|
4 |
+
job:
|
5 |
+
config:
|
6 |
+
override_dirname:
|
7 |
+
kv_sep: ':'
|
8 |
+
item_sep: '__'
|
9 |
+
exclude_keys:
|
10 |
+
- run_config
|
11 |
+
- distributed_training.distributed_port
|
12 |
+
sweep:
|
13 |
+
dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
|
14 |
+
subdir: ${hydra.job.num}
|
15 |
+
launcher:
|
16 |
+
submitit_folder: ${hydra.sweep.dir}
|
17 |
+
timeout_min: 4320
|
18 |
+
cpus_per_task: 80
|
19 |
+
gpus_per_node: 8
|
20 |
+
tasks_per_node: 1
|
21 |
+
mem_gb: 450
|
22 |
+
nodes: 16
|
23 |
+
name: ${env:PREFIX}_${hydra.job.config_name}
|
24 |
+
partition: learnlab,learnfair,scavenge
|
25 |
+
constraint: volta32gb
|
26 |
+
max_num_timeout: 30
|
27 |
+
exclude: learnfair1381,learnfair5192,learnfair2304
|
fairseq/examples/wav2vec/config/finetuning/run_config/slurm_1_aws.yaml
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _global_
|
2 |
+
|
3 |
+
hydra:
|
4 |
+
job:
|
5 |
+
config:
|
6 |
+
override_dirname:
|
7 |
+
kv_sep: ':'
|
8 |
+
item_sep: '/'
|
9 |
+
exclude_keys:
|
10 |
+
- run_config
|
11 |
+
- distributed_training.distributed_port
|
12 |
+
- distributed_training.distributed_world_size
|
13 |
+
- model.pretrained_model_path
|
14 |
+
- model.target_network_path
|
15 |
+
- next_script
|
16 |
+
- task.cache_in_scratch
|
17 |
+
- task.local_cache_path
|
18 |
+
- task.data
|
19 |
+
- checkpoint.save_interval_updates
|
20 |
+
- checkpoint.keep_interval_updates
|
21 |
+
- checkpoint.save_on_overflow
|
22 |
+
- common.log_interval
|
23 |
+
- common.user_dir
|
24 |
+
sweep:
|
25 |
+
dir: /fsx-wav2vec/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
|
26 |
+
subdir: ''
|
27 |
+
launcher:
|
28 |
+
submitit_folder: ${hydra.sweep.dir}
|
29 |
+
timeout_min: 4320
|
30 |
+
cpus_per_task: 80
|
31 |
+
gpus_per_node: 8
|
32 |
+
tasks_per_node: 1
|
33 |
+
mem_gb: 0
|
34 |
+
nodes: 1
|
35 |
+
name: ${env:PREFIX}_${hydra.job.config_name}
|
36 |
+
partition: wav2vec,learnlab,learnfair
|
37 |
+
max_num_timeout: 30
|
fairseq/examples/wav2vec/config/finetuning/run_config/slurm_1_old.yaml
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _global_
|
2 |
+
|
3 |
+
hydra:
|
4 |
+
job:
|
5 |
+
config:
|
6 |
+
override_dirname:
|
7 |
+
kv_sep: ':'
|
8 |
+
item_sep: '__'
|
9 |
+
exclude_keys:
|
10 |
+
- run_config
|
11 |
+
- distributed_training.distributed_port
|
12 |
+
sweep:
|
13 |
+
dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
|
14 |
+
subdir: ${hydra.job.num}
|
15 |
+
launcher:
|
16 |
+
submitit_folder: ${hydra.sweep.dir}
|
17 |
+
timeout_min: 4320
|
18 |
+
cpus_per_task: 80
|
19 |
+
gpus_per_node: 8
|
20 |
+
tasks_per_node: 1
|
21 |
+
mem_gb: 450
|
22 |
+
nodes: 1
|
23 |
+
name: ${env:PREFIX}_wav2vec3_small_librispeech
|
24 |
+
partition: devlab,learnlab,learnfair,scavenge
|
25 |
+
constraint: volta32gb
|
26 |
+
max_num_timeout: 30
|
27 |
+
exclude: learnfair1381
|
fairseq/examples/wav2vec/config/finetuning/run_config/slurm_2.yaml
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _global_
|
2 |
+
|
3 |
+
hydra:
|
4 |
+
job:
|
5 |
+
config:
|
6 |
+
override_dirname:
|
7 |
+
kv_sep: ':'
|
8 |
+
item_sep: '__'
|
9 |
+
exclude_keys:
|
10 |
+
- run_config
|
11 |
+
- distributed_training.distributed_port
|
12 |
+
sweep:
|
13 |
+
dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
|
14 |
+
subdir: ${hydra.job.num}
|
15 |
+
launcher:
|
16 |
+
submitit_folder: ${hydra.sweep.dir}
|
17 |
+
timeout_min: 4320
|
18 |
+
cpus_per_task: 10
|
19 |
+
gpus_per_node: 8
|
20 |
+
tasks_per_node: 8
|
21 |
+
mem_gb: 450
|
22 |
+
nodes: 2
|
23 |
+
name: ${env:PREFIX}_${hydra.job.config_name}
|
24 |
+
partition: devlab,learnlab,learnfair,scavenge
|
25 |
+
constraint: volta32gb
|
26 |
+
max_num_timeout: 30
|
27 |
+
exclude: learnfair7491,learnfair7477,learnfair7487
|
fairseq/examples/wav2vec/config/finetuning/run_config/slurm_2_aws.yaml
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _global_
|
2 |
+
|
3 |
+
hydra:
|
4 |
+
job:
|
5 |
+
config:
|
6 |
+
override_dirname:
|
7 |
+
kv_sep: ':'
|
8 |
+
item_sep: '/'
|
9 |
+
exclude_keys:
|
10 |
+
- run_config
|
11 |
+
- distributed_training.distributed_port
|
12 |
+
- distributed_training.distributed_world_size
|
13 |
+
- model.pretrained_model_path
|
14 |
+
- model.target_network_path
|
15 |
+
- next_script
|
16 |
+
- task.cache_in_scratch
|
17 |
+
- task.local_cache_path
|
18 |
+
- task.data
|
19 |
+
- checkpoint.save_interval_updates
|
20 |
+
- checkpoint.keep_interval_updates
|
21 |
+
- checkpoint.save_on_overflow
|
22 |
+
- common.log_interval
|
23 |
+
- common.user_dir
|
24 |
+
sweep:
|
25 |
+
dir: /fsx-wav2vec/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
|
26 |
+
subdir: ''
|
27 |
+
launcher:
|
28 |
+
submitit_folder: ${hydra.sweep.dir}
|
29 |
+
timeout_min: 4320
|
30 |
+
cpus_per_task: 80
|
31 |
+
gpus_per_node: 8
|
32 |
+
tasks_per_node: 1
|
33 |
+
mem_gb: 0
|
34 |
+
nodes: 2
|
35 |
+
name: ${env:PREFIX}_${hydra.job.config_name}
|
36 |
+
partition: wav2vec,learnlab,learnfair
|
37 |
+
max_num_timeout: 30
|
fairseq/examples/wav2vec/config/finetuning/run_config/slurm_2g.yaml
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _global_
|
2 |
+
|
3 |
+
hydra:
|
4 |
+
job:
|
5 |
+
config:
|
6 |
+
override_dirname:
|
7 |
+
kv_sep: ':'
|
8 |
+
item_sep: '__'
|
9 |
+
exclude_keys:
|
10 |
+
- run_config
|
11 |
+
- distributed_training.distributed_port
|
12 |
+
sweep:
|
13 |
+
dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
|
14 |
+
subdir: ${hydra.job.num}
|
15 |
+
launcher:
|
16 |
+
submitit_folder: ${hydra.sweep.dir}
|
17 |
+
timeout_min: 4320
|
18 |
+
cpus_per_task: 10
|
19 |
+
gpus_per_node: 2
|
20 |
+
tasks_per_node: 2
|
21 |
+
mem_gb: 200
|
22 |
+
nodes: 1
|
23 |
+
name: ${env:PREFIX}_${hydra.job.config_name}
|
24 |
+
partition: devlab,learnlab,learnfair,scavenge
|
25 |
+
constraint: volta32gb
|
26 |
+
max_num_timeout: 30
|
fairseq/examples/wav2vec/config/finetuning/run_config/slurm_3.yaml
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _global_
|
2 |
+
|
3 |
+
hydra:
|
4 |
+
job:
|
5 |
+
config:
|
6 |
+
override_dirname:
|
7 |
+
kv_sep: ':'
|
8 |
+
item_sep: '__'
|
9 |
+
exclude_keys:
|
10 |
+
- run_config
|
11 |
+
- distributed_training.distributed_port
|
12 |
+
sweep:
|
13 |
+
dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
|
14 |
+
subdir: ${hydra.job.num}
|
15 |
+
launcher:
|
16 |
+
submitit_folder: ${hydra.sweep.dir}
|
17 |
+
timeout_min: 4320
|
18 |
+
cpus_per_task: 10
|
19 |
+
gpus_per_node: 8
|
20 |
+
tasks_per_node: 8
|
21 |
+
mem_gb: 450
|
22 |
+
nodes: 3
|
23 |
+
name: ${env:PREFIX}_${hydra.job.config_name}
|
24 |
+
partition: devlab,learnlab,learnfair,scavenge
|
25 |
+
constraint: volta32gb
|
26 |
+
max_num_timeout: 30
|
27 |
+
exclude: learnfair7491,learnfair7477,learnfair7487
|
fairseq/examples/wav2vec/config/finetuning/run_config/slurm_4g.yaml
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _global_
|
2 |
+
|
3 |
+
hydra:
|
4 |
+
job:
|
5 |
+
config:
|
6 |
+
override_dirname:
|
7 |
+
kv_sep: ':'
|
8 |
+
item_sep: '__'
|
9 |
+
exclude_keys:
|
10 |
+
- run_config
|
11 |
+
- distributed_training.distributed_port
|
12 |
+
sweep:
|
13 |
+
dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
|
14 |
+
subdir: ${hydra.job.num}
|
15 |
+
launcher:
|
16 |
+
submitit_folder: ${hydra.sweep.dir}
|
17 |
+
timeout_min: 4320
|
18 |
+
cpus_per_task: 10
|
19 |
+
gpus_per_node: 4
|
20 |
+
tasks_per_node: 4
|
21 |
+
mem_gb: 200
|
22 |
+
nodes: 1
|
23 |
+
name: ${env:PREFIX}_${hydra.job.config_name}
|
24 |
+
partition: devlab,learnlab,learnfair,scavenge
|
25 |
+
constraint: volta32gb
|
26 |
+
max_num_timeout: 30
|
fairseq/examples/wav2vec/config/finetuning/run_config/slurm_4g_aws.yaml
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _global_
|
2 |
+
|
3 |
+
hydra:
|
4 |
+
job:
|
5 |
+
config:
|
6 |
+
override_dirname:
|
7 |
+
kv_sep: ':'
|
8 |
+
item_sep: '/'
|
9 |
+
exclude_keys:
|
10 |
+
- run_config
|
11 |
+
- distributed_training.distributed_port
|
12 |
+
- distributed_training.distributed_world_size
|
13 |
+
- model.pretrained_model_path
|
14 |
+
- model.target_network_path
|
15 |
+
- next_script
|
16 |
+
- task.cache_in_scratch
|
17 |
+
- task.local_cache_path
|
18 |
+
- task.data
|
19 |
+
- checkpoint.save_interval_updates
|
20 |
+
- checkpoint.keep_interval_updates
|
21 |
+
- checkpoint.save_on_overflow
|
22 |
+
- common.log_interval
|
23 |
+
- common.user_dir
|
24 |
+
sweep:
|
25 |
+
dir: /fsx-wav2vec/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
|
26 |
+
subdir: ''
|
27 |
+
launcher:
|
28 |
+
submitit_folder: ${hydra.sweep.dir}
|
29 |
+
timeout_min: 4320
|
30 |
+
cpus_per_task: 80
|
31 |
+
gpus_per_node: 4
|
32 |
+
tasks_per_node: 1
|
33 |
+
mem_gb: 0
|
34 |
+
nodes: 1
|
35 |
+
name: ${env:PREFIX}_${hydra.job.config_name}
|
36 |
+
partition: wav2vec,learnlab,learnfair
|
37 |
+
max_num_timeout: 30
|
fairseq/examples/wav2vec/config/finetuning/run_config/slurm_8.yaml
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _global_
|
2 |
+
|
3 |
+
hydra:
|
4 |
+
job:
|
5 |
+
config:
|
6 |
+
override_dirname:
|
7 |
+
kv_sep: ':'
|
8 |
+
item_sep: '__'
|
9 |
+
exclude_keys:
|
10 |
+
- run_config
|
11 |
+
- distributed_training.distributed_port
|
12 |
+
sweep:
|
13 |
+
dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
|
14 |
+
subdir: ${hydra.job.num}
|
15 |
+
launcher:
|
16 |
+
submitit_folder: ${hydra.sweep.dir}
|
17 |
+
timeout_min: 4320
|
18 |
+
cpus_per_task: 10
|
19 |
+
gpus_per_node: 8
|
20 |
+
tasks_per_node: 8
|
21 |
+
mem_gb: 400
|
22 |
+
nodes: 8
|
23 |
+
name: ${env:PREFIX}_${hydra.job.config_name}
|
24 |
+
partition: devlab,learnlab,learnfair,scavenge
|
25 |
+
constraint: volta32gb
|
26 |
+
max_num_timeout: 30
|
fairseq/examples/wav2vec/config/finetuning/vox_10h_2_aws.yaml
ADDED
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _group_
|
2 |
+
|
3 |
+
common:
|
4 |
+
fp16: true
|
5 |
+
log_format: json
|
6 |
+
log_interval: 200
|
7 |
+
user_dir: /data/home/abaevski/fairseq-py/examples/data2vec
|
8 |
+
# tensorboard_logdir: tb
|
9 |
+
|
10 |
+
checkpoint:
|
11 |
+
save_interval: 10
|
12 |
+
no_epoch_checkpoints: true
|
13 |
+
best_checkpoint_metric: wer
|
14 |
+
|
15 |
+
task:
|
16 |
+
_name: audio_finetuning
|
17 |
+
data: /fsx-wav2vec/abaevski/data/libri/10h/wav2vec/raw
|
18 |
+
labels: ltr
|
19 |
+
normalize: true
|
20 |
+
|
21 |
+
dataset:
|
22 |
+
num_workers: 6
|
23 |
+
max_tokens: 1280000
|
24 |
+
skip_invalid_size_inputs_valid_test: true
|
25 |
+
validate_after_updates: 100
|
26 |
+
validate_interval: 10
|
27 |
+
valid_subset: dev_other
|
28 |
+
required_batch_size_multiple: 1
|
29 |
+
|
30 |
+
distributed_training:
|
31 |
+
ddp_backend: legacy_ddp
|
32 |
+
distributed_world_size: 4
|
33 |
+
|
34 |
+
criterion:
|
35 |
+
_name: ctc
|
36 |
+
zero_infinity: true
|
37 |
+
post_process: letter
|
38 |
+
wer_kenlm_model: /fsx-wav2vec/abaevski/data/libri/4-gram.bin
|
39 |
+
wer_lexicon: /fsx-wav2vec/abaevski/data/libri/10h/wav2vec/raw/lexicon_ltr2.lst
|
40 |
+
wer_lm_weight: 2.0
|
41 |
+
wer_word_score: 4
|
42 |
+
wer_sil_weight: -5
|
43 |
+
|
44 |
+
optimization:
|
45 |
+
max_update: 60000
|
46 |
+
lr: [1e-5]
|
47 |
+
# lr: [1e-5] # base 10h wer
|
48 |
+
sentence_avg: true
|
49 |
+
update_freq: [1] # base 10h we -> 2/4
|
50 |
+
|
51 |
+
optimizer:
|
52 |
+
_name: adam
|
53 |
+
adam_betas: (0.9,0.98)
|
54 |
+
adam_eps: 1e-08
|
55 |
+
|
56 |
+
lr_scheduler:
|
57 |
+
_name: tri_stage
|
58 |
+
phase_ratio: null
|
59 |
+
warmup_steps: 8000
|
60 |
+
hold_steps: 0
|
61 |
+
decay_steps: 72000
|
62 |
+
final_lr_scale: 0.05
|
63 |
+
|
64 |
+
model:
|
65 |
+
_name: wav2vec_ctc
|
66 |
+
w2v_path: ???
|
67 |
+
apply_mask: true
|
68 |
+
mask_prob: 0.75
|
69 |
+
mask_length: 5
|
70 |
+
# mask_prob: 0.65 # base 10h wer
|
71 |
+
mask_channel_prob: 0.1
|
72 |
+
# mask_channel_prob: 0.6 # base 10h wer
|
73 |
+
mask_channel_length: 64
|
74 |
+
layerdrop: 0
|
75 |
+
# layerdrop: 0.05 # base 10h wer
|
76 |
+
activation_dropout: 0.1
|
77 |
+
feature_grad_mult: 0.0
|
78 |
+
freeze_finetune_updates: 100
|
79 |
+
dropout: 0
|
80 |
+
final_dropout: 0
|
81 |
+
attention_dropout: 0
|
fairseq/examples/wav2vec/config/finetuning/vox_10h_aws.yaml
ADDED
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _group_
|
2 |
+
|
3 |
+
common:
|
4 |
+
fp16: true
|
5 |
+
log_format: json
|
6 |
+
log_interval: 200
|
7 |
+
user_dir: /data/home/abaevski/fairseq-py/examples/data2vec
|
8 |
+
# tensorboard_logdir: tb
|
9 |
+
|
10 |
+
checkpoint:
|
11 |
+
save_interval: 10
|
12 |
+
no_epoch_checkpoints: true
|
13 |
+
best_checkpoint_metric: wer
|
14 |
+
|
15 |
+
task:
|
16 |
+
_name: audio_finetuning
|
17 |
+
data: /fsx-wav2vec/abaevski/data/libri/10h/wav2vec/raw
|
18 |
+
labels: ltr
|
19 |
+
normalize: true
|
20 |
+
|
21 |
+
dataset:
|
22 |
+
num_workers: 6
|
23 |
+
max_tokens: 1280000
|
24 |
+
skip_invalid_size_inputs_valid_test: true
|
25 |
+
validate_after_updates: 100
|
26 |
+
validate_interval: 10
|
27 |
+
valid_subset: dev_other
|
28 |
+
required_batch_size_multiple: 1
|
29 |
+
|
30 |
+
distributed_training:
|
31 |
+
ddp_backend: legacy_ddp
|
32 |
+
distributed_world_size: 4
|
33 |
+
|
34 |
+
criterion:
|
35 |
+
_name: ctc
|
36 |
+
zero_infinity: true
|
37 |
+
post_process: letter
|
38 |
+
# wer_kenlm_model: /fsx-wav2vec/abaevski/data/libri/4-gram.bin
|
39 |
+
# wer_lexicon: /fsx-wav2vec/abaevski/data/libri/10h/wav2vec/raw/lexicon_ltr2.lst
|
40 |
+
# wer_lm_weight: 2.0
|
41 |
+
# wer_word_score: -1.0
|
42 |
+
|
43 |
+
optimization:
|
44 |
+
max_update: 60000
|
45 |
+
lr: [2e-5]
|
46 |
+
# lr: [1e-5] # base 10h wer
|
47 |
+
sentence_avg: true
|
48 |
+
update_freq: [1] # base 10h we -> 2/4
|
49 |
+
|
50 |
+
optimizer:
|
51 |
+
_name: adam
|
52 |
+
adam_betas: (0.9,0.98)
|
53 |
+
adam_eps: 1e-08
|
54 |
+
|
55 |
+
lr_scheduler:
|
56 |
+
_name: tri_stage
|
57 |
+
phase_ratio: null
|
58 |
+
warmup_steps: 8000
|
59 |
+
hold_steps: 0
|
60 |
+
decay_steps: 72000
|
61 |
+
final_lr_scale: 0.05
|
62 |
+
|
63 |
+
model:
|
64 |
+
_name: wav2vec_ctc
|
65 |
+
w2v_path: ???
|
66 |
+
apply_mask: true
|
67 |
+
mask_prob: 0.4
|
68 |
+
mask_length: 5
|
69 |
+
# mask_prob: 0.65 # base 10h wer
|
70 |
+
mask_channel_prob: 0.1
|
71 |
+
# mask_channel_prob: 0.6 # base 10h wer
|
72 |
+
mask_channel_length: 64
|
73 |
+
layerdrop: 0.1
|
74 |
+
# layerdrop: 0.05 # base 10h wer
|
75 |
+
activation_dropout: 0.1
|
76 |
+
feature_grad_mult: 0.0
|
77 |
+
freeze_finetune_updates: 100
|
78 |
+
dropout: 0
|
79 |
+
final_dropout: 0
|
80 |
+
attention_dropout: 0
|
81 |
+
|
82 |
+
hydra:
|
83 |
+
job:
|
84 |
+
config:
|
85 |
+
override_dirname:
|
86 |
+
kv_sep: ':'
|
87 |
+
item_sep: '__'
|
88 |
+
exclude_keys:
|
89 |
+
- run_config
|
90 |
+
- distributed_training.distributed_port
|
91 |
+
sweep:
|
92 |
+
dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}/${hydra.job.override_dirname}
|
93 |
+
subdir: ${hydra.job.num}
|
94 |
+
launcher:
|
95 |
+
submitit_folder: ${hydra.sweep.dir}
|
96 |
+
timeout_min: 3000
|
97 |
+
cpus_per_task: 10
|
98 |
+
gpus_per_node: 4
|
99 |
+
tasks_per_node: 4
|
100 |
+
mem_gb: 0
|
101 |
+
nodes: 1
|
102 |
+
name: ${env:PREFIX}_${hydra.job.config_name}
|
103 |
+
partition: wav2vec,learnlab
|
104 |
+
max_num_timeout: 30
|
fairseq/examples/wav2vec/config/finetuning/vox_10m_2.yaml
ADDED
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _group_
|
2 |
+
|
3 |
+
common:
|
4 |
+
fp16: true
|
5 |
+
fp16_no_flatten_grads: true
|
6 |
+
log_format: json
|
7 |
+
log_interval: 200
|
8 |
+
user_dir: /private/home/abaevski/fairseq-py/examples/data2vec
|
9 |
+
# tensorboard_logdir: tb
|
10 |
+
|
11 |
+
checkpoint:
|
12 |
+
save_interval: 500
|
13 |
+
save_interval_updates: 500
|
14 |
+
keep_interval_updates: 1
|
15 |
+
no_epoch_checkpoints: true
|
16 |
+
best_checkpoint_metric: wer
|
17 |
+
|
18 |
+
task:
|
19 |
+
_name: audio_finetuning
|
20 |
+
data: /checkpoint/abaevski/data/speech/libri/10m/wav2vec/raw
|
21 |
+
labels: ltr
|
22 |
+
normalize: true
|
23 |
+
|
24 |
+
dataset:
|
25 |
+
num_workers: 6
|
26 |
+
max_tokens: 1000000
|
27 |
+
skip_invalid_size_inputs_valid_test: true
|
28 |
+
validate_after_updates: 100
|
29 |
+
validate_interval: 500
|
30 |
+
valid_subset: dev_other
|
31 |
+
required_batch_size_multiple: 1
|
32 |
+
|
33 |
+
distributed_training:
|
34 |
+
ddp_backend: legacy_ddp
|
35 |
+
distributed_world_size: 4
|
36 |
+
|
37 |
+
criterion:
|
38 |
+
_name: ctc
|
39 |
+
zero_infinity: true
|
40 |
+
post_process: letter
|
41 |
+
wer_kenlm_model: /checkpoint/abaevski/data/speech/libri/4-gram.bin
|
42 |
+
wer_lexicon: /checkpoint/abaevski/data/speech/libri/10h/wav2vec/raw/lexicon_ltr2.lst
|
43 |
+
wer_lm_weight: 5
|
44 |
+
wer_word_score: 2
|
45 |
+
wer_sil_weight: -2
|
46 |
+
|
47 |
+
optimization:
|
48 |
+
max_update: 10000
|
49 |
+
lr: [2e-6]
|
50 |
+
# lr: [1e-5] # base 10h wer
|
51 |
+
sentence_avg: true
|
52 |
+
update_freq: [4] # base 10h we -> 2/4
|
53 |
+
|
54 |
+
optimizer:
|
55 |
+
_name: composite
|
56 |
+
dynamic_groups: true
|
57 |
+
groups:
|
58 |
+
default:
|
59 |
+
lr_float: 2e-6
|
60 |
+
optimizer:
|
61 |
+
_name: adam
|
62 |
+
adam_betas: [0.9,0.95]
|
63 |
+
lr_scheduler:
|
64 |
+
_name: cosine
|
65 |
+
warmup_updates: 1000
|
66 |
+
|
67 |
+
lr_scheduler: pass_through
|
68 |
+
|
69 |
+
model:
|
70 |
+
_name: wav2vec_ctc
|
71 |
+
w2v_path: ???
|
72 |
+
apply_mask: true
|
73 |
+
mask_prob: 0.4
|
74 |
+
mask_length: 3
|
75 |
+
# mask_prob: 0.65 # base 10h wer
|
76 |
+
mask_channel_prob: 0.25
|
77 |
+
# mask_channel_prob: 0.6 # base 10h wer
|
78 |
+
mask_channel_length: 64
|
79 |
+
layerdrop: 0.1
|
80 |
+
# layerdrop: 0.05 # base 10h wer
|
81 |
+
freeze_finetune_updates: 100
|
82 |
+
|
83 |
+
zero_mask: true
|
84 |
+
feature_grad_mult: 0.0
|
85 |
+
activation_dropout: 0.1
|
86 |
+
dropout: 0
|
87 |
+
final_dropout: 0
|
88 |
+
attention_dropout: 0
|
89 |
+
update_alibi: false
|
90 |
+
|
91 |
+
#hydra:
|
92 |
+
# job:
|
93 |
+
# config:
|
94 |
+
# override_dirname:
|
95 |
+
# kv_sep: ':'
|
96 |
+
# item_sep: '__'
|
97 |
+
# exclude_keys:
|
98 |
+
# - run_config
|
99 |
+
# - distributed_training.distributed_port
|
100 |
+
# sweep:
|
101 |
+
# dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}/${hydra.job.override_dirname}
|
102 |
+
# subdir: ${hydra.job.num}
|
103 |
+
# launcher:
|
104 |
+
# submitit_folder: ${hydra.sweep.dir}
|
105 |
+
# timeout_min: 3000
|
106 |
+
# cpus_per_task: 10
|
107 |
+
# gpus_per_node: 4
|
108 |
+
# tasks_per_node: 4
|
109 |
+
# mem_gb: 250
|
110 |
+
# nodes: 1
|
111 |
+
# name: ${env:PREFIX}_${hydra.job.config_name}
|
112 |
+
# partition: devlab,learnlab,learnfair,scavenge
|
113 |
+
# constraint: volta32gb
|
114 |
+
# max_num_timeout: 30
|
fairseq/examples/wav2vec/config/finetuning/vox_10m_2_aws.yaml
ADDED
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _group_
|
2 |
+
|
3 |
+
common:
|
4 |
+
fp16: true
|
5 |
+
fp16_no_flatten_grads: true
|
6 |
+
log_format: json
|
7 |
+
log_interval: 200
|
8 |
+
user_dir: /data/home/abaevski/fairseq-py/examples/data2vec
|
9 |
+
# tensorboard_logdir: tb
|
10 |
+
|
11 |
+
checkpoint:
|
12 |
+
save_interval: 500
|
13 |
+
save_interval_updates: 500
|
14 |
+
keep_interval_updates: 1
|
15 |
+
no_epoch_checkpoints: true
|
16 |
+
best_checkpoint_metric: wer
|
17 |
+
|
18 |
+
task:
|
19 |
+
_name: audio_finetuning
|
20 |
+
data: /fsx-wav2vec/abaevski/data/libri/10m/wav2vec/raw
|
21 |
+
labels: ltr
|
22 |
+
normalize: true
|
23 |
+
|
24 |
+
dataset:
|
25 |
+
num_workers: 6
|
26 |
+
max_tokens: 1000000
|
27 |
+
skip_invalid_size_inputs_valid_test: true
|
28 |
+
validate_after_updates: 100
|
29 |
+
validate_interval: 500
|
30 |
+
valid_subset: dev_other
|
31 |
+
required_batch_size_multiple: 1
|
32 |
+
|
33 |
+
distributed_training:
|
34 |
+
ddp_backend: legacy_ddp
|
35 |
+
distributed_world_size: 4
|
36 |
+
|
37 |
+
criterion:
|
38 |
+
_name: ctc
|
39 |
+
zero_infinity: true
|
40 |
+
post_process: letter
|
41 |
+
wer_kenlm_model: /fsx-wav2vec/abaevski/data/libri/4-gram.bin
|
42 |
+
wer_lexicon: /fsx-wav2vec/abaevski/data/libri/10h/wav2vec/raw/lexicon_ltr2.lst
|
43 |
+
wer_lm_weight: 5
|
44 |
+
wer_word_score: 2
|
45 |
+
wer_sil_weight: -2
|
46 |
+
|
47 |
+
optimization:
|
48 |
+
max_update: 10000
|
49 |
+
lr: [2e-6]
|
50 |
+
# lr: [1e-5] # base 10h wer
|
51 |
+
sentence_avg: true
|
52 |
+
update_freq: [4] # base 10h we -> 2/4
|
53 |
+
|
54 |
+
optimizer:
|
55 |
+
_name: composite
|
56 |
+
dynamic_groups: true
|
57 |
+
groups:
|
58 |
+
default:
|
59 |
+
lr_float: 2e-6
|
60 |
+
optimizer:
|
61 |
+
_name: adam
|
62 |
+
adam_betas: [0.9,0.95]
|
63 |
+
lr_scheduler:
|
64 |
+
_name: cosine
|
65 |
+
warmup_updates: 1000
|
66 |
+
|
67 |
+
lr_scheduler: pass_through
|
68 |
+
|
69 |
+
model:
|
70 |
+
_name: wav2vec_ctc
|
71 |
+
w2v_path: ???
|
72 |
+
apply_mask: true
|
73 |
+
mask_prob: 0.4
|
74 |
+
mask_length: 3
|
75 |
+
# mask_prob: 0.65 # base 10h wer
|
76 |
+
mask_channel_prob: 0.25
|
77 |
+
# mask_channel_prob: 0.6 # base 10h wer
|
78 |
+
mask_channel_length: 64
|
79 |
+
layerdrop: 0.1
|
80 |
+
# layerdrop: 0.05 # base 10h wer
|
81 |
+
freeze_finetune_updates: 100
|
82 |
+
|
83 |
+
zero_mask: true
|
84 |
+
feature_grad_mult: 0.0
|
85 |
+
activation_dropout: 0.1
|
86 |
+
dropout: 0
|
87 |
+
final_dropout: 0
|
88 |
+
attention_dropout: 0
|
89 |
+
update_alibi: false
|
90 |
+
|
91 |
+
#hydra:
|
92 |
+
# job:
|
93 |
+
# config:
|
94 |
+
# override_dirname:
|
95 |
+
# kv_sep: ':'
|
96 |
+
# item_sep: '__'
|
97 |
+
# exclude_keys:
|
98 |
+
# - run_config
|
99 |
+
# - distributed_training.distributed_port
|
100 |
+
# sweep:
|
101 |
+
# dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}/${hydra.job.override_dirname}
|
102 |
+
# subdir: ${hydra.job.num}
|
103 |
+
# launcher:
|
104 |
+
# submitit_folder: ${hydra.sweep.dir}
|
105 |
+
# timeout_min: 3000
|
106 |
+
# cpus_per_task: 10
|
107 |
+
# gpus_per_node: 4
|
108 |
+
# tasks_per_node: 4
|
109 |
+
# mem_gb: 250
|
110 |
+
# nodes: 1
|
111 |
+
# name: ${env:PREFIX}_${hydra.job.config_name}
|
112 |
+
# partition: devlab,learnlab,learnfair,scavenge
|
113 |
+
# constraint: volta32gb
|
114 |
+
# max_num_timeout: 30
|
fairseq/examples/wav2vec/config/finetuning/vox_10m_3.yaml
ADDED
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _group_
|
2 |
+
|
3 |
+
common:
|
4 |
+
fp16: true
|
5 |
+
log_format: json
|
6 |
+
log_interval: 200
|
7 |
+
user_dir: /private/home/abaevski/fairseq-py/examples/data2vec
|
8 |
+
# tensorboard_logdir: tb
|
9 |
+
|
10 |
+
checkpoint:
|
11 |
+
save_interval: 1000
|
12 |
+
save_interval_updates: 100
|
13 |
+
keep_interval_updates: 1
|
14 |
+
no_epoch_checkpoints: true
|
15 |
+
best_checkpoint_metric: wer
|
16 |
+
|
17 |
+
task:
|
18 |
+
_name: audio_finetuning
|
19 |
+
data: /checkpoint/abaevski/data/speech/libri/10m/wav2vec/raw
|
20 |
+
labels: ltr
|
21 |
+
normalize: true
|
22 |
+
|
23 |
+
dataset:
|
24 |
+
num_workers: 6
|
25 |
+
max_tokens: 1280000
|
26 |
+
skip_invalid_size_inputs_valid_test: true
|
27 |
+
validate_after_updates: 10000
|
28 |
+
validate_interval: 500
|
29 |
+
valid_subset: dev_other
|
30 |
+
required_batch_size_multiple: 8
|
31 |
+
|
32 |
+
distributed_training:
|
33 |
+
ddp_backend: legacy_ddp
|
34 |
+
distributed_world_size: 4
|
35 |
+
|
36 |
+
criterion:
|
37 |
+
_name: ctc
|
38 |
+
zero_infinity: true
|
39 |
+
post_process: letter
|
40 |
+
wer_kenlm_model: /checkpoint/abaevski/data/speech/libri/4-gram.bin
|
41 |
+
wer_lexicon: /checkpoint/abaevski/data/speech/libri/10h/wav2vec/raw/lexicon_ltr2.lst
|
42 |
+
wer_lm_weight: 8
|
43 |
+
wer_word_score: 5.8
|
44 |
+
wer_sil_weight: -8
|
45 |
+
|
46 |
+
optimization:
|
47 |
+
max_update: 13000
|
48 |
+
lr: [2e-5]
|
49 |
+
# lr: [1e-5] # base 10h wer
|
50 |
+
sentence_avg: true
|
51 |
+
update_freq: [5] # base 10h we -> 2/4
|
52 |
+
|
53 |
+
optimizer:
|
54 |
+
_name: adam
|
55 |
+
adam_betas: (0.9,0.98)
|
56 |
+
adam_eps: 1e-08
|
57 |
+
|
58 |
+
lr_scheduler:
|
59 |
+
_name: tri_stage
|
60 |
+
phase_ratio: [0.1, 0.4, 0.5]
|
61 |
+
final_lr_scale: 0.05
|
62 |
+
|
63 |
+
model:
|
64 |
+
_name: wav2vec_ctc
|
65 |
+
w2v_path: ???
|
66 |
+
apply_mask: true
|
67 |
+
mask_prob: 0.65
|
68 |
+
mask_length: 10
|
69 |
+
# mask_prob: 0.65 # base 10h wer
|
70 |
+
mask_channel_prob: 0.25
|
71 |
+
# mask_channel_prob: 0.6 # base 10h wer
|
72 |
+
mask_channel_length: 64
|
73 |
+
layerdrop: 0.1
|
74 |
+
# layerdrop: 0.05 # base 10h wer
|
75 |
+
activation_dropout: 0.1
|
76 |
+
feature_grad_mult: 0.0
|
77 |
+
freeze_finetune_updates: 10000
|
78 |
+
dropout: 0
|
79 |
+
final_dropout: 0
|
80 |
+
attention_dropout: 0
|
81 |
+
|
82 |
+
hydra:
|
83 |
+
job:
|
84 |
+
config:
|
85 |
+
override_dirname:
|
86 |
+
kv_sep: ':'
|
87 |
+
item_sep: '__'
|
88 |
+
exclude_keys:
|
89 |
+
- run_config
|
90 |
+
- distributed_training.distributed_port
|
91 |
+
sweep:
|
92 |
+
dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}/${hydra.job.override_dirname}
|
93 |
+
subdir: ${hydra.job.num}
|
94 |
+
launcher:
|
95 |
+
submitit_folder: ${hydra.sweep.dir}
|
96 |
+
timeout_min: 3000
|
97 |
+
cpus_per_task: 10
|
98 |
+
gpus_per_node: 4
|
99 |
+
tasks_per_node: 4
|
100 |
+
mem_gb: 250
|
101 |
+
nodes: 1
|
102 |
+
name: ${env:PREFIX}_${hydra.job.config_name}
|
103 |
+
partition: devlab,learnlab,learnfair,scavenge
|
104 |
+
constraint: volta32gb
|
105 |
+
max_num_timeout: 30
|
fairseq/examples/wav2vec/config/finetuning/vox_1h.yaml
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _group_
|
2 |
+
|
3 |
+
common:
|
4 |
+
fp16: true
|
5 |
+
log_format: json
|
6 |
+
log_interval: 200
|
7 |
+
|
8 |
+
checkpoint:
|
9 |
+
save_interval: 1000
|
10 |
+
save_interval_updates: 50
|
11 |
+
keep_interval_updates: 1
|
12 |
+
no_epoch_checkpoints: true
|
13 |
+
best_checkpoint_metric: wer
|
14 |
+
|
15 |
+
task:
|
16 |
+
_name: audio_finetuning
|
17 |
+
data: ???
|
18 |
+
normalize: true
|
19 |
+
labels: ltr
|
20 |
+
|
21 |
+
dataset:
|
22 |
+
num_workers: 6
|
23 |
+
max_tokens: 1280000
|
24 |
+
skip_invalid_size_inputs_valid_test: true
|
25 |
+
validate_after_updates: 10000
|
26 |
+
validate_interval: 1000
|
27 |
+
valid_subset: dev_other
|
28 |
+
|
29 |
+
distributed_training:
|
30 |
+
ddp_backend: legacy_ddp
|
31 |
+
distributed_world_size: 4
|
32 |
+
|
33 |
+
criterion:
|
34 |
+
_name: ctc
|
35 |
+
zero_infinity: true
|
36 |
+
|
37 |
+
optimization:
|
38 |
+
max_update: 13000
|
39 |
+
lr: [0.0003]
|
40 |
+
sentence_avg: true
|
41 |
+
update_freq: [5]
|
42 |
+
|
43 |
+
optimizer:
|
44 |
+
_name: adam
|
45 |
+
adam_betas: (0.9,0.98)
|
46 |
+
adam_eps: 1e-08
|
47 |
+
|
48 |
+
lr_scheduler:
|
49 |
+
_name: tri_stage
|
50 |
+
phase_ratio: [0.1, 0.4, 0.5]
|
51 |
+
final_lr_scale: 0.05
|
52 |
+
|
53 |
+
model:
|
54 |
+
_name: wav2vec_ctc
|
55 |
+
w2v_path: ???
|
56 |
+
apply_mask: true
|
57 |
+
mask_prob: 0.75
|
58 |
+
mask_channel_prob: 0.25
|
59 |
+
mask_channel_length: 64
|
60 |
+
layerdrop: 0.1
|
61 |
+
activation_dropout: 0.1
|
62 |
+
feature_grad_mult: 0.0
|
63 |
+
freeze_finetune_updates: 10000
|
fairseq/examples/wav2vec/config/finetuning/vox_1h_2.yaml
ADDED
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _group_
|
2 |
+
|
3 |
+
common:
|
4 |
+
fp16: true
|
5 |
+
log_format: json
|
6 |
+
log_interval: 200
|
7 |
+
user_dir: /private/home/abaevski/fairseq-py/examples/data2vec
|
8 |
+
# tensorboard_logdir: tb
|
9 |
+
|
10 |
+
checkpoint:
|
11 |
+
save_interval: 100
|
12 |
+
save_interval_updates: 500
|
13 |
+
keep_interval_updates: 1
|
14 |
+
no_epoch_checkpoints: true
|
15 |
+
best_checkpoint_metric: wer
|
16 |
+
|
17 |
+
task:
|
18 |
+
_name: audio_finetuning
|
19 |
+
data: /checkpoint/abaevski/data/speech/libri/1h/wav2vec/raw
|
20 |
+
labels: ltr
|
21 |
+
normalize: true
|
22 |
+
|
23 |
+
dataset:
|
24 |
+
num_workers: 6
|
25 |
+
max_tokens: 1000000
|
26 |
+
skip_invalid_size_inputs_valid_test: true
|
27 |
+
validate_after_updates: 100
|
28 |
+
validate_interval: 100
|
29 |
+
valid_subset: dev_other
|
30 |
+
required_batch_size_multiple: 1
|
31 |
+
|
32 |
+
distributed_training:
|
33 |
+
ddp_backend: legacy_ddp
|
34 |
+
distributed_world_size: 8
|
35 |
+
|
36 |
+
criterion:
|
37 |
+
_name: ctc
|
38 |
+
zero_infinity: true
|
39 |
+
post_process: letter
|
40 |
+
wer_kenlm_model: /checkpoint/abaevski/data/speech/libri/4-gram.bin
|
41 |
+
wer_lexicon: /checkpoint/abaevski/data/speech/libri/10h/wav2vec/raw/lexicon_ltr2.lst
|
42 |
+
wer_lm_weight: 6
|
43 |
+
wer_word_score: -0.1
|
44 |
+
wer_sil_weight: -4.7
|
45 |
+
|
46 |
+
optimization:
|
47 |
+
max_update: 60000
|
48 |
+
lr: [1e-5]
|
49 |
+
# lr: [1e-5] # base 10h wer
|
50 |
+
sentence_avg: true
|
51 |
+
update_freq: [1] # base 10h we -> 2/4
|
52 |
+
|
53 |
+
optimizer:
|
54 |
+
_name: adam
|
55 |
+
adam_betas: (0.9,0.98)
|
56 |
+
adam_eps: 1e-08
|
57 |
+
|
58 |
+
lr_scheduler:
|
59 |
+
_name: cosine
|
60 |
+
warmup_updates: 4000
|
61 |
+
|
62 |
+
model:
|
63 |
+
_name: wav2vec_ctc
|
64 |
+
w2v_path: ???
|
65 |
+
apply_mask: true
|
66 |
+
mask_prob: 0.65
|
67 |
+
mask_length: 5
|
68 |
+
# mask_prob: 0.65 # base 10h wer
|
69 |
+
mask_channel_prob: 0.25
|
70 |
+
# mask_channel_prob: 0.6 # base 10h wer
|
71 |
+
mask_channel_length: 64
|
72 |
+
layerdrop: 0.1
|
73 |
+
# layerdrop: 0.05 # base 10h wer
|
74 |
+
activation_dropout: 0.1
|
75 |
+
feature_grad_mult: 0.0
|
76 |
+
freeze_finetune_updates: 100
|
77 |
+
dropout: 0
|
78 |
+
final_dropout: 0
|
79 |
+
attention_dropout: 0
|
80 |
+
|
81 |
+
hydra:
|
82 |
+
job:
|
83 |
+
config:
|
84 |
+
override_dirname:
|
85 |
+
kv_sep: ':'
|
86 |
+
item_sep: '__'
|
87 |
+
exclude_keys:
|
88 |
+
- run_config
|
89 |
+
- distributed_training.distributed_port
|
90 |
+
sweep:
|
91 |
+
dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}/${hydra.job.override_dirname}
|
92 |
+
subdir: ${hydra.job.num}
|
93 |
+
launcher:
|
94 |
+
submitit_folder: ${hydra.sweep.dir}
|
95 |
+
timeout_min: 3000
|
96 |
+
cpus_per_task: 10
|
97 |
+
gpus_per_node: 4
|
98 |
+
tasks_per_node: 4
|
99 |
+
mem_gb: 250
|
100 |
+
nodes: 1
|
101 |
+
name: ${env:PREFIX}_${hydra.job.config_name}
|
102 |
+
partition: devlab,learnlab,learnfair,scavenge
|
103 |
+
constraint: volta32gb
|
104 |
+
max_num_timeout: 30
|
fairseq/examples/wav2vec/config/finetuning/vox_1h_2_aws.yaml
ADDED
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _group_
|
2 |
+
|
3 |
+
common:
|
4 |
+
fp16: true
|
5 |
+
fp16_no_flatten_grads: true
|
6 |
+
log_format: json
|
7 |
+
log_interval: 200
|
8 |
+
user_dir: /data/home/abaevski/fairseq-py/examples/data2vec
|
9 |
+
# tensorboard_logdir: tb
|
10 |
+
|
11 |
+
checkpoint:
|
12 |
+
save_interval: 100
|
13 |
+
save_interval_updates: 500
|
14 |
+
keep_interval_updates: 1
|
15 |
+
no_epoch_checkpoints: true
|
16 |
+
best_checkpoint_metric: wer
|
17 |
+
|
18 |
+
task:
|
19 |
+
_name: audio_finetuning
|
20 |
+
data: /fsx-wav2vec/abaevski/data/libri/1h/wav2vec/raw
|
21 |
+
labels: ltr
|
22 |
+
normalize: true
|
23 |
+
|
24 |
+
dataset:
|
25 |
+
num_workers: 6
|
26 |
+
max_tokens: 1000000
|
27 |
+
skip_invalid_size_inputs_valid_test: true
|
28 |
+
validate_after_updates: 100
|
29 |
+
validate_interval: 500
|
30 |
+
valid_subset: dev_other
|
31 |
+
required_batch_size_multiple: 1
|
32 |
+
|
33 |
+
distributed_training:
|
34 |
+
ddp_backend: legacy_ddp
|
35 |
+
distributed_world_size: 4
|
36 |
+
|
37 |
+
criterion:
|
38 |
+
_name: ctc
|
39 |
+
zero_infinity: true
|
40 |
+
post_process: letter
|
41 |
+
wer_kenlm_model: /fsx-wav2vec/abaevski/data/libri/4-gram.bin
|
42 |
+
wer_lexicon: /fsx-wav2vec/abaevski/data/libri/10h/wav2vec/raw/lexicon_ltr2.lst
|
43 |
+
wer_lm_weight: 5
|
44 |
+
wer_word_score: 0
|
45 |
+
wer_sil_weight: -4
|
46 |
+
|
47 |
+
optimization:
|
48 |
+
max_update: 10000
|
49 |
+
lr: [2e-6]
|
50 |
+
# lr: [1e-5] # base 10h wer
|
51 |
+
sentence_avg: true
|
52 |
+
update_freq: [4] # base 10h we -> 2/4
|
53 |
+
|
54 |
+
optimizer:
|
55 |
+
_name: composite
|
56 |
+
dynamic_groups: true
|
57 |
+
groups:
|
58 |
+
default:
|
59 |
+
lr_float: 2e-6
|
60 |
+
optimizer:
|
61 |
+
_name: adam
|
62 |
+
adam_betas: [0.9,0.95]
|
63 |
+
lr_scheduler:
|
64 |
+
_name: cosine
|
65 |
+
warmup_updates: 1000
|
66 |
+
|
67 |
+
lr_scheduler: pass_through
|
68 |
+
|
69 |
+
model:
|
70 |
+
_name: wav2vec_ctc
|
71 |
+
w2v_path: ???
|
72 |
+
apply_mask: true
|
73 |
+
mask_prob: 0.4
|
74 |
+
mask_length: 3
|
75 |
+
# mask_prob: 0.65 # base 10h wer
|
76 |
+
mask_channel_prob: 0.25
|
77 |
+
# mask_channel_prob: 0.6 # base 10h wer
|
78 |
+
mask_channel_length: 64
|
79 |
+
layerdrop: 0.1
|
80 |
+
# layerdrop: 0.05 # base 10h wer
|
81 |
+
freeze_finetune_updates: 100
|
82 |
+
|
83 |
+
zero_mask: true
|
84 |
+
feature_grad_mult: 0.0
|
85 |
+
activation_dropout: 0.1
|
86 |
+
dropout: 0
|
87 |
+
final_dropout: 0
|
88 |
+
attention_dropout: 0
|
89 |
+
update_alibi: false
|
90 |
+
|
91 |
+
#hydra:
|
92 |
+
# job:
|
93 |
+
# config:
|
94 |
+
# override_dirname:
|
95 |
+
# kv_sep: ':'
|
96 |
+
# item_sep: '__'
|
97 |
+
# exclude_keys:
|
98 |
+
# - run_config
|
99 |
+
# - distributed_training.distributed_port
|
100 |
+
# sweep:
|
101 |
+
# dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}/${hydra.job.override_dirname}
|
102 |
+
# subdir: ${hydra.job.num}
|
103 |
+
# launcher:
|
104 |
+
# submitit_folder: ${hydra.sweep.dir}
|
105 |
+
# timeout_min: 3000
|
106 |
+
# cpus_per_task: 10
|
107 |
+
# gpus_per_node: 4
|
108 |
+
# tasks_per_node: 4
|
109 |
+
# mem_gb: 250
|
110 |
+
# nodes: 1
|
111 |
+
# name: ${env:PREFIX}_${hydra.job.config_name}
|
112 |
+
# partition: devlab,learnlab,learnfair,scavenge
|
113 |
+
# constraint: volta32gb
|
114 |
+
# max_num_timeout: 30
|
fairseq/examples/wav2vec/config/finetuning/vox_1h_aws.yaml
ADDED
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _group_
|
2 |
+
|
3 |
+
common:
|
4 |
+
fp16: true
|
5 |
+
log_format: json
|
6 |
+
log_interval: 200
|
7 |
+
user_dir: /data/home/abaevski/fairseq-py/examples/data2vec
|
8 |
+
# tensorboard_logdir: tb
|
9 |
+
|
10 |
+
checkpoint:
|
11 |
+
save_interval: 100
|
12 |
+
save_interval_updates: 500
|
13 |
+
keep_interval_updates: 1
|
14 |
+
no_epoch_checkpoints: true
|
15 |
+
best_checkpoint_metric: wer
|
16 |
+
|
17 |
+
task:
|
18 |
+
_name: audio_finetuning
|
19 |
+
data: /fsx-wav2vec/abaevski/data/libri/10m/wav2vec/raw
|
20 |
+
labels: ltr
|
21 |
+
normalize: true
|
22 |
+
|
23 |
+
dataset:
|
24 |
+
num_workers: 6
|
25 |
+
max_tokens: 1000000
|
26 |
+
skip_invalid_size_inputs_valid_test: true
|
27 |
+
validate_after_updates: 10000
|
28 |
+
validate_interval: 100
|
29 |
+
valid_subset: dev_other
|
30 |
+
required_batch_size_multiple: 8
|
31 |
+
|
32 |
+
distributed_training:
|
33 |
+
ddp_backend: legacy_ddp
|
34 |
+
distributed_world_size: 8
|
35 |
+
|
36 |
+
criterion:
|
37 |
+
_name: ctc
|
38 |
+
zero_infinity: true
|
39 |
+
post_process: letter
|
40 |
+
wer_kenlm_model: /fsx-wav2vec/abaevski/data/libri/4-gram.bin
|
41 |
+
wer_lexicon: /fsx-wav2vec/abaevski/data/libri/10h/wav2vec/raw/lexicon_ltr2.lst
|
42 |
+
wer_lm_weight: 5
|
43 |
+
wer_word_score: -0.1
|
44 |
+
wer_sil_weight: -4.7
|
45 |
+
|
46 |
+
optimization:
|
47 |
+
max_update: 13000
|
48 |
+
lr: [6e-5]
|
49 |
+
# lr: [1e-5] # base 10h wer
|
50 |
+
sentence_avg: true
|
51 |
+
update_freq: [5] # base 10h we -> 2/4
|
52 |
+
|
53 |
+
optimizer:
|
54 |
+
_name: adam
|
55 |
+
adam_betas: (0.9,0.98)
|
56 |
+
adam_eps: 1e-08
|
57 |
+
|
58 |
+
lr_scheduler:
|
59 |
+
_name: cosine
|
60 |
+
warmup_updates: 4000
|
61 |
+
|
62 |
+
model:
|
63 |
+
_name: wav2vec_ctc
|
64 |
+
w2v_path: ???
|
65 |
+
apply_mask: true
|
66 |
+
mask_prob: 0.3
|
67 |
+
mask_length: 3
|
68 |
+
# mask_prob: 0.65 # base 10h wer
|
69 |
+
mask_channel_prob: 0.25
|
70 |
+
# mask_channel_prob: 0.6 # base 10h wer
|
71 |
+
mask_channel_length: 64
|
72 |
+
layerdrop: 0.1
|
73 |
+
# layerdrop: 0.05 # base 10h wer
|
74 |
+
activation_dropout: 0.1
|
75 |
+
feature_grad_mult: 0.0
|
76 |
+
freeze_finetune_updates: 10000
|
77 |
+
dropout: 0
|
78 |
+
final_dropout: 0
|
79 |
+
attention_dropout: 0
|
80 |
+
update_alibi: false
|
fairseq/examples/wav2vec/config/finetuning/vox_960h.yaml
ADDED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _group_
|
2 |
+
|
3 |
+
common:
|
4 |
+
fp16: true
|
5 |
+
log_format: json
|
6 |
+
log_interval: 200
|
7 |
+
|
8 |
+
checkpoint:
|
9 |
+
no_epoch_checkpoints: true
|
10 |
+
best_checkpoint_metric: wer
|
11 |
+
|
12 |
+
task:
|
13 |
+
_name: audio_finetuning
|
14 |
+
data: ???
|
15 |
+
normalize: true
|
16 |
+
labels: ltr
|
17 |
+
|
18 |
+
dataset:
|
19 |
+
num_workers: 6
|
20 |
+
max_tokens: 1280000
|
21 |
+
skip_invalid_size_inputs_valid_test: true
|
22 |
+
valid_subset: dev_other
|
23 |
+
|
24 |
+
distributed_training:
|
25 |
+
ddp_backend: legacy_ddp
|
26 |
+
distributed_world_size: 24
|
27 |
+
|
28 |
+
criterion:
|
29 |
+
_name: ctc
|
30 |
+
zero_infinity: true
|
31 |
+
|
32 |
+
optimization:
|
33 |
+
max_update: 320000
|
34 |
+
lr: [0.00003]
|
35 |
+
sentence_avg: true
|
36 |
+
|
37 |
+
optimizer:
|
38 |
+
_name: adam
|
39 |
+
adam_betas: (0.9,0.98)
|
40 |
+
adam_eps: 1e-08
|
41 |
+
|
42 |
+
lr_scheduler:
|
43 |
+
_name: tri_stage
|
44 |
+
phase_ratio: [0.1, 0.4, 0.5]
|
45 |
+
final_lr_scale: 0.05
|
46 |
+
|
47 |
+
model:
|
48 |
+
_name: wav2vec_ctc
|
49 |
+
w2v_path: ???
|
50 |
+
apply_mask: true
|
51 |
+
mask_prob: 0.5
|
52 |
+
mask_channel_prob: 0.25
|
53 |
+
mask_channel_length: 64
|
54 |
+
layerdrop: 0.1
|
55 |
+
activation_dropout: 0.1
|
56 |
+
feature_grad_mult: 0.0
|
57 |
+
freeze_finetune_updates: 10000
|
fairseq/examples/wav2vec/config/finetuning/vox_960h_2.yaml
ADDED
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _group_
|
2 |
+
|
3 |
+
common:
|
4 |
+
fp16: true
|
5 |
+
log_format: json
|
6 |
+
log_interval: 200
|
7 |
+
user_dir: /private/home/abaevski/fairseq-py/examples/data2vec
|
8 |
+
# tensorboard_logdir: tb
|
9 |
+
|
10 |
+
checkpoint:
|
11 |
+
save_interval: 1
|
12 |
+
no_epoch_checkpoints: true
|
13 |
+
best_checkpoint_metric: wer
|
14 |
+
|
15 |
+
task:
|
16 |
+
_name: audio_finetuning
|
17 |
+
data: /checkpoint/abaevski/data/speech/libri/960h/wav2vec/raw
|
18 |
+
labels: ltr
|
19 |
+
normalize: true
|
20 |
+
|
21 |
+
dataset:
|
22 |
+
num_workers: 6
|
23 |
+
max_tokens: 1000000
|
24 |
+
skip_invalid_size_inputs_valid_test: true
|
25 |
+
validate_after_updates: 100
|
26 |
+
validate_interval: 1
|
27 |
+
valid_subset: dev_other
|
28 |
+
required_batch_size_multiple: 1
|
29 |
+
|
30 |
+
distributed_training:
|
31 |
+
ddp_backend: legacy_ddp
|
32 |
+
distributed_world_size: 16
|
33 |
+
|
34 |
+
criterion:
|
35 |
+
_name: ctc
|
36 |
+
zero_infinity: true
|
37 |
+
post_process: letter
|
38 |
+
wer_kenlm_model: /checkpoint/abaevski/data/speech/libri/4-gram.bin
|
39 |
+
wer_lexicon: /checkpoint/abaevski/data/speech/libri/10h/wav2vec/raw/lexicon_ltr2.lst
|
40 |
+
wer_lm_weight: 2.0
|
41 |
+
wer_word_score: -1.0
|
42 |
+
|
43 |
+
optimization:
|
44 |
+
max_update: 200000
|
45 |
+
lr: [1e-5]
|
46 |
+
# lr: [1e-5] # base 10h wer
|
47 |
+
sentence_avg: true
|
48 |
+
update_freq: [1] # base 10h we -> 2/4
|
49 |
+
|
50 |
+
optimizer:
|
51 |
+
_name: adam
|
52 |
+
adam_betas: (0.9,0.98)
|
53 |
+
adam_eps: 1e-08
|
54 |
+
|
55 |
+
lr_scheduler:
|
56 |
+
_name: tri_stage
|
57 |
+
phase_ratio: null
|
58 |
+
warmup_steps: 8000
|
59 |
+
hold_steps: 0
|
60 |
+
decay_steps: 200000
|
61 |
+
final_lr_scale: 0.05
|
62 |
+
|
63 |
+
model:
|
64 |
+
_name: wav2vec_ctc
|
65 |
+
w2v_path: ???
|
66 |
+
apply_mask: true
|
67 |
+
mask_prob: 0.4
|
68 |
+
mask_length: 5
|
69 |
+
# mask_prob: 0.65 # base 10h wer
|
70 |
+
mask_channel_prob: 0.1
|
71 |
+
# mask_channel_prob: 0.6 # base 10h wer
|
72 |
+
mask_channel_length: 64
|
73 |
+
layerdrop: 0.1
|
74 |
+
# layerdrop: 0.05 # base 10h wer
|
75 |
+
activation_dropout: 0.1
|
76 |
+
feature_grad_mult: 0.0
|
77 |
+
freeze_finetune_updates: 100
|
78 |
+
dropout: 0
|
79 |
+
final_dropout: 0
|
80 |
+
attention_dropout: 0
|
81 |
+
|
82 |
+
hydra:
|
83 |
+
job:
|
84 |
+
config:
|
85 |
+
override_dirname:
|
86 |
+
kv_sep: ':'
|
87 |
+
item_sep: '__'
|
88 |
+
exclude_keys:
|
89 |
+
- run_config
|
90 |
+
- distributed_training.distributed_port
|
91 |
+
sweep:
|
92 |
+
dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}/${hydra.job.override_dirname}
|
93 |
+
subdir: ${hydra.job.num}
|
94 |
+
launcher:
|
95 |
+
submitit_folder: ${hydra.sweep.dir}
|
96 |
+
timeout_min: 3000
|
97 |
+
cpus_per_task: 10
|
98 |
+
gpus_per_node: 4
|
99 |
+
tasks_per_node: 4
|
100 |
+
mem_gb: 250
|
101 |
+
nodes: 1
|
102 |
+
name: ${env:PREFIX}_${hydra.job.config_name}
|
103 |
+
partition: devlab,learnlab,learnfair,scavenge
|
104 |
+
constraint: volta32gb
|
105 |
+
max_num_timeout: 30
|
fairseq/examples/wav2vec/config/finetuning/vox_960h_2_aws.yaml
ADDED
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _group_
|
2 |
+
|
3 |
+
common:
|
4 |
+
fp16: true
|
5 |
+
log_format: json
|
6 |
+
log_interval: 200
|
7 |
+
user_dir: /data/home/abaevski/fairseq-py/examples/data2vec
|
8 |
+
# tensorboard_logdir: tb
|
9 |
+
|
10 |
+
checkpoint:
|
11 |
+
save_interval: 1
|
12 |
+
no_epoch_checkpoints: true
|
13 |
+
best_checkpoint_metric: wer
|
14 |
+
|
15 |
+
task:
|
16 |
+
_name: audio_finetuning
|
17 |
+
data: /fsx-wav2vec/abaevski/data/librispeech
|
18 |
+
labels: ltr
|
19 |
+
normalize: true
|
20 |
+
|
21 |
+
dataset:
|
22 |
+
num_workers: 6
|
23 |
+
max_tokens: 1280000
|
24 |
+
skip_invalid_size_inputs_valid_test: true
|
25 |
+
validate_after_updates: 100
|
26 |
+
validate_interval: 1
|
27 |
+
valid_subset: dev_other
|
28 |
+
required_batch_size_multiple: 1
|
29 |
+
|
30 |
+
distributed_training:
|
31 |
+
ddp_backend: legacy_ddp
|
32 |
+
distributed_world_size: 16
|
33 |
+
|
34 |
+
criterion:
|
35 |
+
_name: ctc
|
36 |
+
zero_infinity: true
|
37 |
+
post_process: letter
|
38 |
+
wer_kenlm_model: /fsx-wav2vec/abaevski/data/libri/4-gram.bin
|
39 |
+
wer_lexicon: /fsx-wav2vec/abaevski/data/libri/10h/wav2vec/raw/lexicon_ltr2.lst
|
40 |
+
wer_lm_weight: 1.5
|
41 |
+
wer_word_score: 0
|
42 |
+
wer_sil_weight: -1
|
43 |
+
|
44 |
+
optimization:
|
45 |
+
max_update: 200000
|
46 |
+
lr: [2e-5]
|
47 |
+
# lr: [1e-5] # base 10h wer
|
48 |
+
sentence_avg: true
|
49 |
+
update_freq: [1] # base 10h we -> 2/4
|
50 |
+
|
51 |
+
optimizer:
|
52 |
+
_name: adam
|
53 |
+
adam_betas: (0.9,0.98)
|
54 |
+
adam_eps: 1e-08
|
55 |
+
|
56 |
+
lr_scheduler:
|
57 |
+
_name: tri_stage
|
58 |
+
phase_ratio: null
|
59 |
+
warmup_steps: 8000
|
60 |
+
hold_steps: 0
|
61 |
+
decay_steps: 192000
|
62 |
+
final_lr_scale: 0.05
|
63 |
+
|
64 |
+
model:
|
65 |
+
_name: wav2vec_ctc
|
66 |
+
w2v_path: ???
|
67 |
+
apply_mask: true
|
68 |
+
mask_prob: 0.3
|
69 |
+
mask_length: 5
|
70 |
+
# mask_prob: 0.65 # base 10h wer
|
71 |
+
mask_channel_prob: 0.1
|
72 |
+
# mask_channel_prob: 0.6 # base 10h wer
|
73 |
+
mask_channel_length: 64
|
74 |
+
layerdrop: 0
|
75 |
+
# layerdrop: 0.05 # base 10h wer
|
76 |
+
activation_dropout: 0.1
|
77 |
+
feature_grad_mult: 0.0
|
78 |
+
freeze_finetune_updates: 100
|
79 |
+
dropout: 0
|
80 |
+
final_dropout: 0
|
81 |
+
attention_dropout: 0
|
82 |
+
|
fairseq/examples/wav2vec/config/finetuning/vox_960h_3.yaml
ADDED
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _group_
|
2 |
+
|
3 |
+
common:
|
4 |
+
fp16: true
|
5 |
+
log_format: json
|
6 |
+
log_interval: 200
|
7 |
+
user_dir: /private/home/abaevski/fairseq-py/examples/data2vec
|
8 |
+
# tensorboard_logdir: tb
|
9 |
+
|
10 |
+
checkpoint:
|
11 |
+
save_interval: 1
|
12 |
+
no_epoch_checkpoints: true
|
13 |
+
best_checkpoint_metric: wer
|
14 |
+
|
15 |
+
task:
|
16 |
+
_name: audio_finetuning
|
17 |
+
data: /checkpoint/abaevski/data/speech/libri/1h/wav2vec/raw
|
18 |
+
labels: ltr
|
19 |
+
normalize: true
|
20 |
+
|
21 |
+
dataset:
|
22 |
+
num_workers: 6
|
23 |
+
max_tokens: 1000000
|
24 |
+
skip_invalid_size_inputs_valid_test: true
|
25 |
+
validate_after_updates: 100
|
26 |
+
validate_interval: 1
|
27 |
+
valid_subset: dev_other
|
28 |
+
required_batch_size_multiple: 1
|
29 |
+
|
30 |
+
distributed_training:
|
31 |
+
ddp_backend: legacy_ddp
|
32 |
+
distributed_world_size: 16
|
33 |
+
|
34 |
+
criterion:
|
35 |
+
_name: ctc
|
36 |
+
zero_infinity: true
|
37 |
+
post_process: letter
|
38 |
+
wer_kenlm_model: /checkpoint/abaevski/data/speech/libri/4-gram.bin
|
39 |
+
wer_lexicon: /checkpoint/abaevski/data/speech/libri/10h/wav2vec/raw/lexicon_ltr2.lst
|
40 |
+
wer_lm_weight: 2.0
|
41 |
+
wer_word_score: -1.0
|
42 |
+
|
43 |
+
optimization:
|
44 |
+
max_update: 200000
|
45 |
+
lr: [1e-5]
|
46 |
+
# lr: [1e-5] # base 10h wer
|
47 |
+
sentence_avg: true
|
48 |
+
update_freq: [1] # base 10h we -> 2/4
|
49 |
+
|
50 |
+
optimizer:
|
51 |
+
_name: adam
|
52 |
+
adam_betas: (0.9,0.98)
|
53 |
+
adam_eps: 1e-08
|
54 |
+
|
55 |
+
lr_scheduler:
|
56 |
+
_name: cosine
|
57 |
+
warmup_updates: 8000
|
58 |
+
|
59 |
+
model:
|
60 |
+
_name: wav2vec_ctc
|
61 |
+
w2v_path: ???
|
62 |
+
apply_mask: true
|
63 |
+
mask_prob: 0.4
|
64 |
+
mask_length: 5
|
65 |
+
# mask_prob: 0.65 # base 10h wer
|
66 |
+
mask_channel_prob: 0.1
|
67 |
+
# mask_channel_prob: 0.6 # base 10h wer
|
68 |
+
mask_channel_length: 64
|
69 |
+
layerdrop: 0.1
|
70 |
+
# layerdrop: 0.05 # base 10h wer
|
71 |
+
activation_dropout: 0.1
|
72 |
+
feature_grad_mult: 0.0
|
73 |
+
freeze_finetune_updates: 100
|
74 |
+
dropout: 0
|
75 |
+
final_dropout: 0
|
76 |
+
attention_dropout: 0
|
77 |
+
|
78 |
+
hydra:
|
79 |
+
job:
|
80 |
+
config:
|
81 |
+
override_dirname:
|
82 |
+
kv_sep: ':'
|
83 |
+
item_sep: '__'
|
84 |
+
exclude_keys:
|
85 |
+
- run_config
|
86 |
+
- distributed_training.distributed_port
|
87 |
+
sweep:
|
88 |
+
dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}/${hydra.job.override_dirname}
|
89 |
+
subdir: ${hydra.job.num}
|
90 |
+
launcher:
|
91 |
+
submitit_folder: ${hydra.sweep.dir}
|
92 |
+
timeout_min: 3000
|
93 |
+
cpus_per_task: 10
|
94 |
+
gpus_per_node: 4
|
95 |
+
tasks_per_node: 4
|
96 |
+
mem_gb: 250
|
97 |
+
nodes: 1
|
98 |
+
name: ${env:PREFIX}_${hydra.job.config_name}
|
99 |
+
partition: devlab,learnlab,learnfair,scavenge
|
100 |
+
constraint: volta32gb
|
101 |
+
max_num_timeout: 30
|
fairseq/examples/wav2vec/config/pretraining/wav2vec2_base_librispeech.yaml
ADDED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _group_
|
2 |
+
|
3 |
+
common:
|
4 |
+
fp16: true
|
5 |
+
log_format: json
|
6 |
+
log_interval: 200
|
7 |
+
|
8 |
+
checkpoint:
|
9 |
+
save_interval_updates: 25000
|
10 |
+
keep_interval_updates: 1
|
11 |
+
no_epoch_checkpoints: true
|
12 |
+
|
13 |
+
task:
|
14 |
+
_name: audio_pretraining
|
15 |
+
data: ???
|
16 |
+
max_sample_size: 250000
|
17 |
+
min_sample_size: 32000
|
18 |
+
normalize: false
|
19 |
+
|
20 |
+
dataset:
|
21 |
+
num_workers: 6
|
22 |
+
max_tokens: 1400000
|
23 |
+
skip_invalid_size_inputs_valid_test: true
|
24 |
+
|
25 |
+
distributed_training:
|
26 |
+
distributed_world_size: 64
|
27 |
+
ddp_backend: legacy_ddp
|
28 |
+
|
29 |
+
criterion:
|
30 |
+
_name: wav2vec
|
31 |
+
infonce: true
|
32 |
+
log_keys: ["prob_perplexity","code_perplexity","temp"]
|
33 |
+
loss_weights: [0.1, 10]
|
34 |
+
|
35 |
+
optimization:
|
36 |
+
max_update: 400000
|
37 |
+
lr: [0.0005]
|
38 |
+
|
39 |
+
optimizer:
|
40 |
+
_name: adam
|
41 |
+
adam_betas: (0.9,0.98)
|
42 |
+
adam_eps: 1e-06
|
43 |
+
weight_decay: 0.01
|
44 |
+
|
45 |
+
lr_scheduler:
|
46 |
+
_name: polynomial_decay
|
47 |
+
warmup_updates: 32000
|
48 |
+
|
49 |
+
model:
|
50 |
+
_name: wav2vec2
|
51 |
+
quantize_targets: true
|
52 |
+
final_dim: 256
|
53 |
+
encoder_layerdrop: 0.05
|
54 |
+
dropout_input: 0.1
|
55 |
+
dropout_features: 0.1
|
56 |
+
feature_grad_mult: 0.1
|
57 |
+
encoder_embed_dim: 768
|
fairseq/examples/wav2vec/config/pretraining/wav2vec2_conformer_base_librispeech.yaml
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _group_
|
2 |
+
|
3 |
+
common:
|
4 |
+
fp16: true
|
5 |
+
log_format: json
|
6 |
+
log_interval: 200
|
7 |
+
|
8 |
+
checkpoint:
|
9 |
+
save_interval_updates: 25000
|
10 |
+
keep_interval_updates: 1
|
11 |
+
no_epoch_checkpoints: true
|
12 |
+
|
13 |
+
task:
|
14 |
+
_name: audio_pretraining
|
15 |
+
data: ???
|
16 |
+
max_sample_size: 250000
|
17 |
+
min_sample_size: 32000
|
18 |
+
normalize: false
|
19 |
+
|
20 |
+
dataset:
|
21 |
+
num_workers: 6
|
22 |
+
max_tokens: 1400000
|
23 |
+
skip_invalid_size_inputs_valid_test: true
|
24 |
+
|
25 |
+
distributed_training:
|
26 |
+
distributed_world_size: 64
|
27 |
+
ddp_backend: legacy_ddp
|
28 |
+
|
29 |
+
criterion:
|
30 |
+
_name: wav2vec
|
31 |
+
infonce: true
|
32 |
+
log_keys: ["prob_perplexity","code_perplexity","temp"]
|
33 |
+
loss_weights: [0.1, 10]
|
34 |
+
|
35 |
+
optimization:
|
36 |
+
max_update: 400000
|
37 |
+
lr: [0.0005]
|
38 |
+
|
39 |
+
optimizer:
|
40 |
+
_name: adam
|
41 |
+
adam_betas: (0.9,0.98)
|
42 |
+
adam_eps: 1e-06
|
43 |
+
weight_decay: 0.01
|
44 |
+
|
45 |
+
lr_scheduler:
|
46 |
+
_name: polynomial_decay
|
47 |
+
warmup_updates: 32000
|
48 |
+
|
49 |
+
model:
|
50 |
+
_name: wav2vec2
|
51 |
+
quantize_targets: true
|
52 |
+
final_dim: 256
|
53 |
+
encoder_layerdrop: 0.05
|
54 |
+
dropout_input: 0.1
|
55 |
+
dropout_features: 0.1
|
56 |
+
feature_grad_mult: 0.1
|
57 |
+
encoder_embed_dim: 768
|
58 |
+
layer_type: conformer
|
59 |
+
attn_type: espnet
|
60 |
+
pos_enc_type: rel_pos
|
fairseq/examples/wav2vec/config/pretraining/wav2vec2_conformer_large_librivox.yaml
ADDED
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _group_
|
2 |
+
|
3 |
+
common:
|
4 |
+
fp16: true
|
5 |
+
log_format: json
|
6 |
+
log_interval: 200
|
7 |
+
|
8 |
+
checkpoint:
|
9 |
+
save_interval_updates: 25000
|
10 |
+
keep_interval_updates: 1
|
11 |
+
no_epoch_checkpoints: true
|
12 |
+
|
13 |
+
task:
|
14 |
+
_name: audio_pretraining
|
15 |
+
data: ???
|
16 |
+
max_sample_size: 320000
|
17 |
+
min_sample_size: 32000
|
18 |
+
normalize: true
|
19 |
+
|
20 |
+
dataset:
|
21 |
+
num_workers: 6
|
22 |
+
max_tokens: 1200000
|
23 |
+
skip_invalid_size_inputs_valid_test: true
|
24 |
+
|
25 |
+
distributed_training:
|
26 |
+
distributed_world_size: 128
|
27 |
+
ddp_backend: legacy_ddp
|
28 |
+
|
29 |
+
criterion:
|
30 |
+
_name: wav2vec
|
31 |
+
infonce: true
|
32 |
+
log_keys: ["prob_perplexity","code_perplexity","temp"]
|
33 |
+
loss_weights: [0.1, 0]
|
34 |
+
|
35 |
+
optimization:
|
36 |
+
max_update: 1000000
|
37 |
+
lr: [0.005]
|
38 |
+
|
39 |
+
optimizer:
|
40 |
+
_name: adam
|
41 |
+
adam_betas: (0.9,0.98)
|
42 |
+
adam_eps: 1e-06
|
43 |
+
weight_decay: 0.01
|
44 |
+
|
45 |
+
lr_scheduler:
|
46 |
+
_name: polynomial_decay
|
47 |
+
warmup_updates: 32000
|
48 |
+
|
49 |
+
model:
|
50 |
+
_name: wav2vec2
|
51 |
+
quantize_targets: true
|
52 |
+
extractor_mode: layer_norm
|
53 |
+
layer_norm_first: true
|
54 |
+
final_dim: 768
|
55 |
+
latent_temp: [2.0,0.1,0.999995]
|
56 |
+
encoder_layerdrop: 0.00
|
57 |
+
dropout_input: 0.0
|
58 |
+
dropout_features: 0.0
|
59 |
+
dropout: 0.0
|
60 |
+
attention_dropout: 0.0
|
61 |
+
conv_bias: true
|
62 |
+
|
63 |
+
encoder_layers: 24
|
64 |
+
encoder_embed_dim: 1024
|
65 |
+
encoder_ffn_embed_dim: 4096
|
66 |
+
encoder_attention_heads: 16
|
67 |
+
|
68 |
+
feature_grad_mult: 1.0
|
69 |
+
|
70 |
+
layer_type: conformer
|
71 |
+
attn_type: espnet
|
72 |
+
pos_enc_type: rel_pos
|
fairseq/examples/wav2vec/config/pretraining/wav2vec2_large_librivox.yaml
ADDED
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _group_
|
2 |
+
|
3 |
+
common:
|
4 |
+
fp16: true
|
5 |
+
log_format: json
|
6 |
+
log_interval: 200
|
7 |
+
|
8 |
+
checkpoint:
|
9 |
+
save_interval_updates: 25000
|
10 |
+
keep_interval_updates: 1
|
11 |
+
no_epoch_checkpoints: true
|
12 |
+
|
13 |
+
task:
|
14 |
+
_name: audio_pretraining
|
15 |
+
data: ???
|
16 |
+
max_sample_size: 320000
|
17 |
+
min_sample_size: 32000
|
18 |
+
normalize: true
|
19 |
+
|
20 |
+
dataset:
|
21 |
+
batch_size: 4
|
22 |
+
num_workers: 6
|
23 |
+
max_tokens: 1200000
|
24 |
+
skip_invalid_size_inputs_valid_test: true
|
25 |
+
|
26 |
+
distributed_training:
|
27 |
+
distributed_world_size: 128
|
28 |
+
ddp_backend: legacy_ddp
|
29 |
+
|
30 |
+
criterion:
|
31 |
+
_name: wav2vec
|
32 |
+
infonce: true
|
33 |
+
log_keys: ["prob_perplexity","code_perplexity","temp"]
|
34 |
+
loss_weights: [0.1, 0]
|
35 |
+
|
36 |
+
optimization:
|
37 |
+
max_update: 1000000
|
38 |
+
lr: [0.005]
|
39 |
+
|
40 |
+
optimizer:
|
41 |
+
_name: adam
|
42 |
+
adam_betas: (0.9,0.98)
|
43 |
+
adam_eps: 1e-06
|
44 |
+
weight_decay: 0.01
|
45 |
+
|
46 |
+
lr_scheduler:
|
47 |
+
_name: polynomial_decay
|
48 |
+
warmup_updates: 32000
|
49 |
+
|
50 |
+
model:
|
51 |
+
_name: wav2vec2
|
52 |
+
quantize_targets: true
|
53 |
+
extractor_mode: layer_norm
|
54 |
+
layer_norm_first: true
|
55 |
+
final_dim: 768
|
56 |
+
latent_temp: [2.0,0.1,0.999995]
|
57 |
+
encoder_layerdrop: 0.00
|
58 |
+
dropout_input: 0.0
|
59 |
+
dropout_features: 0.0
|
60 |
+
dropout: 0.0
|
61 |
+
attention_dropout: 0.0
|
62 |
+
conv_bias: true
|
63 |
+
|
64 |
+
encoder_layers: 24
|
65 |
+
encoder_embed_dim: 1024
|
66 |
+
encoder_ffn_embed_dim: 4096
|
67 |
+
encoder_attention_heads: 16
|
68 |
+
|
69 |
+
feature_grad_mult: 1.0
|
70 |
+
|
fairseq/examples/wav2vec/config/pretraining/wav2vec2_large_librivox_tpu-pod.yaml
ADDED
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _group_
|
2 |
+
|
3 |
+
common:
|
4 |
+
tpu: true
|
5 |
+
fp16: false
|
6 |
+
log_format: json
|
7 |
+
log_interval: 10
|
8 |
+
|
9 |
+
checkpoint:
|
10 |
+
save_interval_updates: 25000
|
11 |
+
keep_interval_updates: 1
|
12 |
+
no_epoch_checkpoints: true
|
13 |
+
|
14 |
+
task:
|
15 |
+
_name: audio_pretraining
|
16 |
+
data: ???
|
17 |
+
max_sample_size: 250000
|
18 |
+
min_sample_size: 32000
|
19 |
+
normalize: true
|
20 |
+
num_batch_buckets: 3
|
21 |
+
precompute_mask_indices: true
|
22 |
+
enable_padding: true
|
23 |
+
|
24 |
+
dataset:
|
25 |
+
num_workers: 6
|
26 |
+
max_tokens: 1200000
|
27 |
+
skip_invalid_size_inputs_valid_test: true
|
28 |
+
|
29 |
+
distributed_training:
|
30 |
+
distributed_world_size: 128
|
31 |
+
ddp_backend: legacy_ddp
|
32 |
+
|
33 |
+
criterion:
|
34 |
+
_name: wav2vec
|
35 |
+
infonce: true
|
36 |
+
log_keys: ["prob_perplexity","code_perplexity","temp"]
|
37 |
+
loss_weights: [0.1, 0]
|
38 |
+
|
39 |
+
optimization:
|
40 |
+
max_update: 1000000
|
41 |
+
lr: [0.005]
|
42 |
+
|
43 |
+
optimizer:
|
44 |
+
_name: adam
|
45 |
+
adam_betas: (0.9,0.98)
|
46 |
+
adam_eps: 1e-06
|
47 |
+
weight_decay: 0.01
|
48 |
+
|
49 |
+
lr_scheduler:
|
50 |
+
_name: polynomial_decay
|
51 |
+
warmup_updates: 32000
|
52 |
+
|
53 |
+
model:
|
54 |
+
_name: wav2vec2
|
55 |
+
quantize_targets: true
|
56 |
+
extractor_mode: layer_norm
|
57 |
+
layer_norm_first: true
|
58 |
+
final_dim: 768
|
59 |
+
latent_temp: [2.0,0.1,0.999995]
|
60 |
+
encoder_layerdrop: 0.00
|
61 |
+
dropout_input: 0.0
|
62 |
+
dropout_features: 0.0
|
63 |
+
dropout: 0.0
|
64 |
+
attention_dropout: 0.0
|
65 |
+
conv_bias: true
|
66 |
+
|
67 |
+
encoder_layers: 24
|
68 |
+
encoder_embed_dim: 1024
|
69 |
+
encoder_ffn_embed_dim: 4096
|
70 |
+
encoder_attention_heads: 16
|
71 |
+
|
72 |
+
feature_grad_mult: 1.0
|
fairseq/examples/wav2vec/config/pretraining/wav2vec2_large_librivox_tpu.yaml
ADDED
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _group_
|
2 |
+
|
3 |
+
common:
|
4 |
+
tpu: true
|
5 |
+
fp16: false
|
6 |
+
log_format: json
|
7 |
+
log_interval: 10
|
8 |
+
|
9 |
+
checkpoint:
|
10 |
+
save_interval_updates: 25000
|
11 |
+
keep_interval_updates: 1
|
12 |
+
no_epoch_checkpoints: true
|
13 |
+
|
14 |
+
task:
|
15 |
+
_name: audio_pretraining
|
16 |
+
data: ???
|
17 |
+
max_sample_size: 250000
|
18 |
+
min_sample_size: 32000
|
19 |
+
normalize: true
|
20 |
+
num_batch_buckets: 3
|
21 |
+
precompute_mask_indices: true
|
22 |
+
enable_padding: true
|
23 |
+
inferred_w2v_config:
|
24 |
+
mask_prob: 0.65
|
25 |
+
mask_selection: 'static'
|
26 |
+
mask_other: 0
|
27 |
+
mask_channel_prob: 0.1
|
28 |
+
|
29 |
+
dataset:
|
30 |
+
num_workers: 6
|
31 |
+
max_tokens: 1200000
|
32 |
+
skip_invalid_size_inputs_valid_test: true
|
33 |
+
|
34 |
+
distributed_training:
|
35 |
+
distributed_world_size: 8
|
36 |
+
ddp_backend: legacy_ddp
|
37 |
+
|
38 |
+
criterion:
|
39 |
+
_name: wav2vec
|
40 |
+
infonce: true
|
41 |
+
log_keys: ["prob_perplexity","code_perplexity","temp"]
|
42 |
+
loss_weights: [0.1, 0]
|
43 |
+
|
44 |
+
optimization:
|
45 |
+
max_update: 1000000
|
46 |
+
lr: [0.005]
|
47 |
+
|
48 |
+
optimizer:
|
49 |
+
_name: adam
|
50 |
+
adam_betas: (0.9,0.98)
|
51 |
+
adam_eps: 1e-06
|
52 |
+
weight_decay: 0.01
|
53 |
+
|
54 |
+
lr_scheduler:
|
55 |
+
_name: polynomial_decay
|
56 |
+
warmup_updates: 32000
|
57 |
+
|
58 |
+
model:
|
59 |
+
_name: wav2vec2
|
60 |
+
quantize_targets: true
|
61 |
+
extractor_mode: layer_norm
|
62 |
+
layer_norm_first: true
|
63 |
+
final_dim: 768
|
64 |
+
latent_temp: [2.0,0.1,0.999995]
|
65 |
+
encoder_layerdrop: 0.00
|
66 |
+
dropout_input: 0.0
|
67 |
+
dropout_features: 0.0
|
68 |
+
dropout: 0.0
|
69 |
+
attention_dropout: 0.0
|
70 |
+
conv_bias: true
|
71 |
+
|
72 |
+
encoder_layers: 24
|
73 |
+
encoder_embed_dim: 1024
|
74 |
+
encoder_ffn_embed_dim: 4096
|
75 |
+
encoder_attention_heads: 16
|
76 |
+
|
77 |
+
feature_grad_mult: 1.0
|