PyTorch
ssl-aasist
custom_code
ash56 commited on
Commit
66a0dab
·
verified ·
1 Parent(s): eaa8a4e

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. fairseq/examples/textless_nlp/pgslm/sample/sample.py +612 -0
  2. fairseq/examples/textless_nlp/pgslm/scripts/join_units_manifest.py +48 -0
  3. fairseq/examples/translation/prepare-iwslt14.sh +115 -0
  4. fairseq/examples/translation/prepare-wmt14en2fr.sh +136 -0
  5. fairseq/examples/translation_moe/README.md +89 -0
  6. fairseq/examples/translation_moe/score.py +197 -0
  7. fairseq/examples/translation_moe/translation_moe_src/__init__.py +6 -0
  8. fairseq/examples/translation_moe/translation_moe_src/logsumexp_moe.py +26 -0
  9. fairseq/examples/translation_moe/translation_moe_src/mean_pool_gating_network.py +50 -0
  10. fairseq/examples/translation_moe/translation_moe_src/translation_moe.py +259 -0
  11. fairseq/examples/truncated_bptt/README.md +70 -0
  12. fairseq/examples/truncated_bptt/__init__.py +6 -0
  13. fairseq/examples/truncated_bptt/transformer_xl_model.py +143 -0
  14. fairseq/examples/truncated_bptt/truncated_bptt_lm_task.py +285 -0
  15. fairseq/examples/unsupervised_quality_estimation/aggregate_scores.py +41 -0
  16. fairseq/examples/unsupervised_quality_estimation/meteor.py +109 -0
  17. fairseq/examples/unsupervised_quality_estimation/repeat_lines.py +28 -0
  18. fairseq/examples/wav2vec/__init__.py +0 -0
  19. fairseq/examples/wav2vec/config/finetuning/base_10m.yaml +63 -0
  20. fairseq/examples/wav2vec/config/finetuning/base_1h.yaml +63 -0
  21. fairseq/examples/wav2vec/config/finetuning/run_config/slurm_1.yaml +26 -0
  22. fairseq/examples/wav2vec/config/finetuning/run_config/slurm_16.yaml +27 -0
  23. fairseq/examples/wav2vec/config/finetuning/run_config/slurm_1_aws.yaml +37 -0
  24. fairseq/examples/wav2vec/config/finetuning/run_config/slurm_1_old.yaml +27 -0
  25. fairseq/examples/wav2vec/config/finetuning/run_config/slurm_2.yaml +27 -0
  26. fairseq/examples/wav2vec/config/finetuning/run_config/slurm_2_aws.yaml +37 -0
  27. fairseq/examples/wav2vec/config/finetuning/run_config/slurm_2g.yaml +26 -0
  28. fairseq/examples/wav2vec/config/finetuning/run_config/slurm_3.yaml +27 -0
  29. fairseq/examples/wav2vec/config/finetuning/run_config/slurm_4g.yaml +26 -0
  30. fairseq/examples/wav2vec/config/finetuning/run_config/slurm_4g_aws.yaml +37 -0
  31. fairseq/examples/wav2vec/config/finetuning/run_config/slurm_8.yaml +26 -0
  32. fairseq/examples/wav2vec/config/finetuning/vox_10h_2_aws.yaml +81 -0
  33. fairseq/examples/wav2vec/config/finetuning/vox_10h_aws.yaml +104 -0
  34. fairseq/examples/wav2vec/config/finetuning/vox_10m_2.yaml +114 -0
  35. fairseq/examples/wav2vec/config/finetuning/vox_10m_2_aws.yaml +114 -0
  36. fairseq/examples/wav2vec/config/finetuning/vox_10m_3.yaml +105 -0
  37. fairseq/examples/wav2vec/config/finetuning/vox_1h.yaml +63 -0
  38. fairseq/examples/wav2vec/config/finetuning/vox_1h_2.yaml +104 -0
  39. fairseq/examples/wav2vec/config/finetuning/vox_1h_2_aws.yaml +114 -0
  40. fairseq/examples/wav2vec/config/finetuning/vox_1h_aws.yaml +80 -0
  41. fairseq/examples/wav2vec/config/finetuning/vox_960h.yaml +57 -0
  42. fairseq/examples/wav2vec/config/finetuning/vox_960h_2.yaml +105 -0
  43. fairseq/examples/wav2vec/config/finetuning/vox_960h_2_aws.yaml +82 -0
  44. fairseq/examples/wav2vec/config/finetuning/vox_960h_3.yaml +101 -0
  45. fairseq/examples/wav2vec/config/pretraining/wav2vec2_base_librispeech.yaml +57 -0
  46. fairseq/examples/wav2vec/config/pretraining/wav2vec2_conformer_base_librispeech.yaml +60 -0
  47. fairseq/examples/wav2vec/config/pretraining/wav2vec2_conformer_large_librivox.yaml +72 -0
  48. fairseq/examples/wav2vec/config/pretraining/wav2vec2_large_librivox.yaml +70 -0
  49. fairseq/examples/wav2vec/config/pretraining/wav2vec2_large_librivox_tpu-pod.yaml +72 -0
  50. fairseq/examples/wav2vec/config/pretraining/wav2vec2_large_librivox_tpu.yaml +77 -0
fairseq/examples/textless_nlp/pgslm/sample/sample.py ADDED
@@ -0,0 +1,612 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import os
7
+ import torch.multiprocessing as mp
8
+ import numpy as np
9
+ import json
10
+
11
+ import torch
12
+ from torch.distributions.categorical import Categorical
13
+
14
+ from fairseq import checkpoint_utils, options, utils
15
+ from fairseq.data.codedataset import CodeDataset, ExpressiveCodeDataConfig
16
+ from fairseq.dataclass.utils import convert_namespace_to_omegaconf
17
+ from torch.utils.data import DataLoader, DistributedSampler
18
+ from fairseq.utils import move_to_cuda
19
+
20
+ import tqdm
21
+ import random
22
+ import pathlib
23
+
24
+ import sys, pathlib
25
+
26
+ sys.path.append(str(pathlib.Path(__file__).parent.parent))
27
+ from inference_dataset import InferenceDataset, explode_batch
28
+ from naive_decoder import Naive_F0_Decoder
29
+ from truncated_laplace import truncated_laplace
30
+
31
+ CODETYPE_TO_FRAMETIME = {"cpc_km100": 0.01, "hubert": 0.02} # 10ms # 20ms
32
+
33
+
34
+ class TemperatureDecoder:
35
+ def __init__(self, Ts, discrete_dur=False, discrete_f0=False):
36
+ self.T_token, self.T_dur, self.T_f0 = Ts
37
+ self.discrete_dur = discrete_dur
38
+ self.discrete_f0 = discrete_f0
39
+
40
+ def __call__(self, output):
41
+ def sample_multinomial(key, T):
42
+ logits = output[key][:, -1, :].float()
43
+ return Categorical(logits=logits / T).sample().unsqueeze(-1)
44
+
45
+ def sample_laplace(key, T, truncate_at_zero):
46
+ mean = output[key][:, -1, :].float()
47
+ return truncated_laplace(mean=mean, T=T, truncate_by_zero=truncate_at_zero)
48
+
49
+ if self.T_token > 0:
50
+ new_tokens = sample_multinomial("token", self.T_token)
51
+ else:
52
+ new_tokens = output["token"][:, -1, :].argmax(dim=-1, keepdim=True)
53
+
54
+ if not self.discrete_dur and self.T_dur == 0:
55
+ new_durations = output["duration"][:, -1].round().int()
56
+ elif not self.discrete_dur and self.T_dur > 0:
57
+ new_durations = (
58
+ sample_laplace("duration", self.T_dur, truncate_at_zero=True)
59
+ .round()
60
+ .int()
61
+ )
62
+ elif self.discrete_dur and self.T_dur > 0:
63
+ new_durations = sample_multinomial("duration", self.T_dur)
64
+ elif self.discrete_dur and self.T_dur == 0:
65
+ new_durations = output["duration"][:, -1, :].argmax(dim=-1, keepdim=True)
66
+ else:
67
+ assert False
68
+
69
+ if not self.discrete_f0 and self.T_f0 == 0:
70
+ new_f0 = output["f0"][:, -1]
71
+ elif not self.discrete_f0 and self.T_f0 > 0:
72
+ new_f0 = sample_laplace("f0", self.T_f0, truncate_at_zero=False)
73
+ elif self.discrete_f0 and self.T_f0 > 0:
74
+ new_f0 = sample_multinomial("f0", self.T_f0)
75
+ elif self.discrete_f0 and self.T_f0 == 0:
76
+ new_f0 = output["f0"][:, -1, :].argmax(dim=-1, keepdim=True)
77
+ else:
78
+ assert False
79
+
80
+ return new_tokens, new_durations, new_f0
81
+
82
+
83
+ class FilterNamesDataset:
84
+ def __init__(self, dataset, fnames_path):
85
+ self.dataset = dataset
86
+
87
+ with open(fnames_path, "r") as fin:
88
+ fnames = set((eval(line)["audio"] for line in fin))
89
+ print(f"# will retrict the dataset for {len(fnames)} files")
90
+
91
+ self.indexes = []
92
+
93
+ for i, datapoint in enumerate(dataset):
94
+ if datapoint["filename"] in fnames:
95
+ self.indexes.append(i)
96
+ assert len(self.indexes) == len(fnames), f"{len(self.indexes)} {len(fnames)}"
97
+
98
+ self.collater = self.dataset.collater
99
+ self.discrete_dur = self.dataset.discrete_dur
100
+ self.discrete_f0 = self.dataset.discrete_f0
101
+
102
+ def __len__(self):
103
+ return len(self.indexes)
104
+
105
+ def __getitem__(self, k):
106
+ k = self.indexes[k]
107
+ return self.dataset[k]
108
+
109
+ def size(self, k):
110
+ k = self.indexes[k]
111
+ return self.dataset.size(k)
112
+
113
+
114
+ @torch.no_grad()
115
+ def do_sampling(
116
+ model,
117
+ batch,
118
+ eos_token,
119
+ decoder,
120
+ autoregressive_steps=100,
121
+ teacher_force_tokens=False,
122
+ teacher_force_duration=False,
123
+ teacher_force_f0=False,
124
+ match_duration=False,
125
+ ):
126
+ def autoregressive_step_(output, autoregressive_steps):
127
+ new_tokens, new_durations, new_f0 = decoder(output)
128
+
129
+ n = output["token"].size(1) if output["token"].ndim == 3 else 1
130
+
131
+ if teacher_force_tokens:
132
+ new_tokens = batch["target"][:, n - 1].unsqueeze(-1)
133
+ if teacher_force_duration:
134
+ new_durations = batch["dur_target"][:, n - 1].unsqueeze(-1)
135
+ if teacher_force_f0:
136
+ new_f0 = batch["f0_target"][:, n - 1].unsqueeze(-1)
137
+
138
+ batch["net_input"]["src_tokens"] = torch.cat(
139
+ [batch["net_input"]["src_tokens"], new_tokens], dim=1
140
+ )
141
+ batch["net_input"]["dur_src"] = torch.cat(
142
+ [batch["net_input"]["dur_src"], new_durations], dim=1
143
+ )
144
+ batch["net_input"]["f0_src"] = torch.cat(
145
+ [batch["net_input"]["f0_src"], new_f0], dim=1
146
+ )
147
+
148
+ outputs = []
149
+
150
+ if teacher_force_tokens or teacher_force_duration or teacher_force_f0:
151
+ max_time = batch["target"].size(1)
152
+ prefix_time = batch["net_input"]["src_tokens"].size(1)
153
+
154
+ autoregressive_steps = max_time - prefix_time + 1 # should be 0
155
+
156
+ for _ in range(autoregressive_steps):
157
+ output = model(**batch["net_input"])
158
+
159
+ last_steps = (
160
+ output["token"][:, -1, ...],
161
+ output["duration"][:, -1, ...],
162
+ output["f0"][:, -1, ...],
163
+ )
164
+ outputs.append(last_steps)
165
+
166
+ autoregressive_step_(output, autoregressive_steps)
167
+ tokens, duration, f0 = (
168
+ batch["net_input"]["src_tokens"],
169
+ batch["net_input"]["dur_src"],
170
+ batch["net_input"]["f0_src"],
171
+ )
172
+
173
+ if (
174
+ match_duration
175
+ and (batch["dur_target"].sum(dim=-1) < duration.sum(dim=-1)).all()
176
+ ):
177
+ break
178
+
179
+ return tokens, duration, f0, outputs
180
+
181
+
182
+ def unroll_duration(token_stream, duration_stream):
183
+ assert len(token_stream) == len(
184
+ duration_stream
185
+ ), f"{len(token_stream)} != {len(duration_stream)}"
186
+ non_positive_durations = sum(d <= 0 for d in duration_stream)
187
+ if non_positive_durations > 0:
188
+ print(
189
+ f"# {non_positive_durations} durations are non-positive, they will be capped to 1"
190
+ )
191
+
192
+ result = []
193
+
194
+ duration_stream_rounded_capped = [max(1, int(round(x))) for x in duration_stream]
195
+ for t, d in zip(token_stream, duration_stream_rounded_capped):
196
+ result.extend([t] * d)
197
+
198
+ return result
199
+
200
+
201
+ def realign_shifted_streams(tokens, durations, F0s, shifts):
202
+ """
203
+ Durations are shifted by 1, F0 by 2
204
+ >>> tokens = ["<s>", "t1", "t2", "t3", "</s>", "x", "x"]
205
+ >>> durations = ["<0>", "<0>", "d1", "d2", "d3", "<0>", "x"]
206
+ >>> F0s = ["<0>", "<0>", "<0>", "f1", "f2", "f3", "<0>"]
207
+ >>> shifts = [1,2]
208
+ >>> realign_shifted_streams(tokens, durations, F0s, shifts)
209
+ (['<s>', 't1', 't2', 't3', '</s>'], ['<0>', 'd1', 'd2', 'd3', '<0>'], ['<0>', 'f1', 'f2', 'f3', '<0>'])
210
+ """
211
+ max_shift = max(shifts)
212
+ if max_shift > 0:
213
+ shift_durations, shift_F0s = shifts
214
+
215
+ tokens = tokens[:-max_shift]
216
+ durations = durations[shift_durations:]
217
+ if shift_durations < max_shift:
218
+ durations = durations[: -(max_shift - shift_durations)]
219
+
220
+ if F0s is not None:
221
+ F0s = F0s[shift_F0s:]
222
+ if shift_F0s < max_shift:
223
+ F0s = F0s[: -(max_shift - shift_F0s)]
224
+
225
+ assert len(tokens) == len(durations), f"{len(tokens)} =! {len(durations)}"
226
+ if F0s is not None:
227
+ assert len(tokens) == len(F0s), f"{len(tokens)} =! {len(F0s)}"
228
+
229
+ return tokens, durations, F0s
230
+
231
+
232
+ def maybe_cut_eos(produced_tokens, produced_duration, produced_f0, eos_idx):
233
+ if eos_idx in produced_tokens:
234
+ eos_index = produced_tokens.index(eos_idx)
235
+ produced_tokens = produced_tokens[:eos_index]
236
+ produced_duration = produced_duration[:eos_index]
237
+ produced_f0 = produced_f0[:eos_index]
238
+ return produced_tokens, produced_duration, produced_f0
239
+
240
+
241
+ def maybe_filter_pad(produced_tokens, produced_duration, produced_f0, pad_idx):
242
+ if pad_idx not in produced_tokens:
243
+ return produced_tokens, produced_duration, produced_f0
244
+
245
+ assert len(produced_tokens) == len(produced_duration) == len(produced_f0)
246
+
247
+ print("<pad> is detected in the output!")
248
+ filtered_tokens, filtered_duration, filtered_f0 = [], [], []
249
+
250
+ for t, d, f in zip(produced_tokens, produced_duration, produced_f0):
251
+ if t != pad_idx:
252
+ filtered_tokens.append(t)
253
+ filtered_duration.append(d)
254
+ filtered_f0.append(f)
255
+ return filtered_tokens, filtered_duration, filtered_f0
256
+
257
+
258
+ def match_duration(produced_tokens, produced_duration, produced_f0, target_duration):
259
+ """
260
+ >>> tokens = ['t'] * 4
261
+ >>> F0s = ['f0'] * 4
262
+ >>> produced_duration = [1, 10, 10, 10]
263
+ >>> match_duration(tokens, produced_duration, F0s, target_duration=100)
264
+ (['t', 't', 't', 't'], [1, 10, 10, 10], ['f0', 'f0', 'f0', 'f0'])
265
+ >>> match_duration(tokens, produced_duration, F0s, target_duration=5)
266
+ (['t', 't'], [1, 4], ['f0', 'f0'])
267
+ """
268
+ if sum(produced_duration) <= target_duration:
269
+ return produced_tokens, produced_duration, produced_f0
270
+
271
+ running_duration = 0
272
+ filtered_duration = []
273
+
274
+ for next_tok_duration in produced_duration:
275
+ if running_duration + next_tok_duration < target_duration:
276
+ filtered_duration.append(next_tok_duration)
277
+ running_duration += next_tok_duration
278
+ else:
279
+ to_add = target_duration - running_duration
280
+ assert to_add <= next_tok_duration
281
+ filtered_duration.append(to_add)
282
+ break
283
+
284
+ produced_duration = filtered_duration
285
+ assert sum(produced_duration) == target_duration
286
+
287
+ n_tok = len(filtered_duration)
288
+
289
+ return produced_tokens[:n_tok], produced_duration, produced_f0[:n_tok]
290
+
291
+
292
+ def main(rank, world_size, args):
293
+ if world_size > 1:
294
+ torch.distributed.init_process_group(
295
+ backend="gloo", init_method="env://", world_size=world_size, rank=rank
296
+ )
297
+ torch.cuda.set_device(rank)
298
+
299
+ raw_args = args
300
+ args = convert_namespace_to_omegaconf(args)
301
+ if args.common.seed is not None:
302
+ random.seed(args.common.seed)
303
+ np.random.seed(args.common.seed)
304
+ utils.set_torch_seed(args.common.seed)
305
+
306
+ models, model_args, task = checkpoint_utils.load_model_ensemble_and_task(
307
+ [raw_args.path], arg_overrides={"data": args.task.data}
308
+ )
309
+ tgt_dict = task.target_dictionary
310
+
311
+ for model in models:
312
+ model.prepare_for_inference_(args)
313
+ model.cuda().eval()
314
+ if raw_args.fp16:
315
+ model = model.half()
316
+ model = models[0]
317
+
318
+ config = ExpressiveCodeDataConfig(args.task.data)
319
+
320
+ dataset = CodeDataset(
321
+ manifest=config.manifests[raw_args.subset],
322
+ dictionary=task.source_dictionary,
323
+ dur_dictionary=task.source_duration_dictionary,
324
+ f0_dictionary=task.source_f0_dictionary,
325
+ config=config,
326
+ discrete_dur=task.cfg.discrete_duration,
327
+ discrete_f0=task.cfg.discrete_f0,
328
+ log_f0=task.cfg.log_f0,
329
+ normalize_f0_mean=task.cfg.normalize_f0_mean,
330
+ normalize_f0_std=task.cfg.normalize_f0_std,
331
+ interpolate_f0=task.cfg.interpolate_f0,
332
+ shifts=task.cfg.stream_shifts,
333
+ return_filename=True,
334
+ strip_filename=False,
335
+ )
336
+ tgt_dict = task.target_dictionary
337
+ shifts = dataset.shifts.dur, dataset.shifts.f0
338
+ max_shift = max(shifts)
339
+
340
+ fname = raw_args.output
341
+ if world_size > 1:
342
+ fname += f"_{rank}"
343
+ output_file = open(fname, "w")
344
+
345
+ if raw_args.filter_names:
346
+ dataset = FilterNamesDataset(dataset, raw_args.filter_names)
347
+
348
+ dataset = InferenceDataset(dataset, raw_args.prefix_length, filter_short=True)
349
+ print(f"Dataset size {len(dataset)}")
350
+ sampler = (
351
+ None
352
+ if world_size == 1
353
+ else DistributedSampler(
354
+ dataset, num_replicas=world_size, rank=rank, shuffle=False
355
+ )
356
+ )
357
+ dataloader = DataLoader(
358
+ dataset,
359
+ batch_size=1,
360
+ shuffle=False,
361
+ collate_fn=dataset.collater,
362
+ sampler=sampler,
363
+ )
364
+
365
+ Ts = raw_args.T_token, raw_args.T_duration, raw_args.T_f0
366
+ decoder = TemperatureDecoder(
367
+ Ts, discrete_dur=task.cfg.discrete_duration, discrete_f0=task.cfg.discrete_f0
368
+ )
369
+
370
+ dataset_size = len(dataset)
371
+
372
+ f0_decoder = None
373
+ if raw_args.f0_discretization_bounds:
374
+ assert task.cfg.discrete_f0
375
+ f0_decoder = Naive_F0_Decoder(raw_args.f0_discretization_bounds).cuda()
376
+
377
+ pbar = (
378
+ tqdm.tqdm(
379
+ total=dataset_size
380
+ if raw_args.max_samples is None
381
+ else min(raw_args.max_samples, dataset_size)
382
+ )
383
+ if world_size == 1
384
+ else None
385
+ )
386
+
387
+ samples_produced = 0
388
+
389
+ for batch in dataloader:
390
+ if (
391
+ raw_args.max_samples is not None
392
+ and samples_produced >= raw_args.max_samples
393
+ ):
394
+ break
395
+
396
+ prefix = batch["prefix"][0]
397
+
398
+ batch = explode_batch(batch, raw_args.batch_explosion_rate)
399
+ batch = move_to_cuda(batch)
400
+
401
+ if not raw_args.short_curcuit:
402
+ produced_tokens, produced_durations, produced_f0, _ = do_sampling(
403
+ models[0],
404
+ batch,
405
+ tgt_dict.eos(),
406
+ decoder,
407
+ autoregressive_steps=raw_args.max_length - prefix + max_shift,
408
+ teacher_force_tokens=raw_args.teacher_force_tokens,
409
+ match_duration=raw_args.match_duration,
410
+ teacher_force_duration=raw_args.teacher_force_duration,
411
+ teacher_force_f0=raw_args.teacher_force_f0,
412
+ )
413
+
414
+ # stip entries corresponding to <s>
415
+ produced_tokens = produced_tokens[:, 1:]
416
+ produced_durations = produced_durations[:, 1:]
417
+ produced_f0 = produced_f0[:, 1:]
418
+
419
+ else:
420
+ max_length = raw_args.max_length + max_shift
421
+ produced_tokens, produced_durations, produced_f0 = (
422
+ batch["target"][:, :max_length],
423
+ batch["dur_target"][:, :max_length],
424
+ batch["f0_target"][:, :max_length],
425
+ )
426
+
427
+ if f0_decoder is not None:
428
+ produced_f0 = f0_decoder(produced_f0)
429
+
430
+ produced_tokens, produced_durations, produced_f0 = (
431
+ produced_tokens.cpu().tolist(),
432
+ produced_durations.cpu().tolist(),
433
+ produced_f0.cpu().tolist(),
434
+ )
435
+
436
+ bsz = batch["target"].size(0)
437
+ assert bsz == raw_args.batch_explosion_rate
438
+
439
+ for i in range(bsz):
440
+ if (
441
+ raw_args.max_samples is not None
442
+ and samples_produced >= raw_args.max_samples
443
+ ):
444
+ break
445
+
446
+ produced_tokens_i = produced_tokens[i]
447
+ produced_durations_i = produced_durations[i]
448
+ produced_f0_i = produced_f0[i]
449
+
450
+ (
451
+ produced_tokens_i,
452
+ produced_durations_i,
453
+ produced_f0_i,
454
+ ) = realign_shifted_streams(
455
+ produced_tokens_i, produced_durations_i, produced_f0_i, shifts
456
+ )
457
+
458
+ produced_tokens_i, produced_durations_i, produced_f0_i = maybe_cut_eos(
459
+ produced_tokens_i, produced_durations_i, produced_f0_i, tgt_dict.eos()
460
+ )
461
+
462
+ produced_tokens_i, produced_durations_i, produced_f0_i = maybe_filter_pad(
463
+ produced_tokens_i, produced_durations_i, produced_f0_i, tgt_dict.pad()
464
+ )
465
+
466
+ if raw_args.match_duration:
467
+ # NB: here we cheat a bit and use that padding has duration 0
468
+ # so no need to re-align and remove padding
469
+ dur_target_i = batch["dur_target"][i, :].sum().item()
470
+ produced_tokens_i, produced_durations_i, produced_f0_i = match_duration(
471
+ produced_tokens_i, produced_durations_i, produced_f0_i, dur_target_i
472
+ )
473
+
474
+ if raw_args.cut_prompt:
475
+ produced_tokens_i, produced_durations_i, produced_f0_i = (
476
+ produced_tokens_i[prefix:],
477
+ produced_durations_i[prefix:],
478
+ produced_f0_i[prefix:],
479
+ )
480
+
481
+ prompt_fname = batch["filename"][0]
482
+ fname = str(pathlib.Path(prompt_fname).with_suffix("")) + f"__{i}.wav"
483
+
484
+ token_stream = unroll_duration(produced_tokens_i, produced_durations_i)
485
+ f0_stream = unroll_duration(produced_f0_i, produced_durations_i)
486
+ output_line = json.dumps(
487
+ {
488
+ "audio": fname,
489
+ "prompt": prompt_fname,
490
+ raw_args.code_type: " ".join(map(str, token_stream)),
491
+ "duration": round(
492
+ sum(produced_durations_i)
493
+ * CODETYPE_TO_FRAMETIME[raw_args.code_type],
494
+ 3,
495
+ ),
496
+ "raw_duration": produced_durations_i,
497
+ "raw_f0": produced_f0_i,
498
+ "f0": [round(f0, 3) for f0 in f0_stream],
499
+ }
500
+ )
501
+ print(output_line, file=output_file)
502
+
503
+ if pbar:
504
+ pbar.update(1)
505
+ samples_produced += 1
506
+
507
+ if raw_args.debug:
508
+ break
509
+
510
+ output_file.close()
511
+
512
+ if world_size > 1:
513
+ # important that everything is flushed before aggregating
514
+ torch.distributed.barrier()
515
+
516
+ if world_size > 1 and rank == 0:
517
+ with open(raw_args.output, "w") as fout:
518
+ for i in range(world_size):
519
+ f = raw_args.output + f"_{i}"
520
+ with open(f, "r") as fin:
521
+ fout.write(fin.read())
522
+ os.remove(f)
523
+
524
+
525
+ def cli_main():
526
+ parser = options.get_interactive_generation_parser()
527
+ parser.add_argument(
528
+ "--prefix-length",
529
+ type=int,
530
+ default=1,
531
+ help="Prompt prefix length (including <s>)",
532
+ )
533
+ parser.add_argument("--output", type=str, default=None, required=True)
534
+ parser.add_argument(
535
+ "--debug", action="store_true", help="Process only the first batch"
536
+ )
537
+ parser.add_argument(
538
+ "--ignore-durations",
539
+ action="store_true",
540
+ help="If set, the duration stream is ignored",
541
+ )
542
+ parser.add_argument(
543
+ "--max-length", type=int, default=200, help="Maximal produced length"
544
+ )
545
+ parser.add_argument(
546
+ "--code-type", choices=["cpc_km100", "hubert"], default="cpc_km100"
547
+ )
548
+ parser.add_argument("--max-samples", type=int, default=None)
549
+ parser.add_argument("--prompt-duration-scaler", type=float, default=1.0)
550
+ parser.add_argument("--teacher-force-tokens", action="store_true", default=False)
551
+ parser.add_argument("--teacher-force-duration", action="store_true", default=False)
552
+ parser.add_argument("--teacher-force-f0", action="store_true", default=False)
553
+ parser.add_argument("--filter-names", type=str, default=None)
554
+ parser.add_argument(
555
+ "--match-duration",
556
+ action="store_true",
557
+ help="Do not produce sequences longer that ground-truth",
558
+ )
559
+ parser.add_argument(
560
+ "--cut-prompt",
561
+ action="store_true",
562
+ help="Remove prompt from the produced audio",
563
+ )
564
+ parser.add_argument(
565
+ "--short-curcuit", action="store_true", help="Use 'target' as a sample"
566
+ )
567
+ parser.add_argument("--f0-discretization-bounds", type=str, default=None)
568
+
569
+ parser.add_argument("--batch-explosion-rate", type=int, default=1)
570
+
571
+ parser.add_argument("--T-token", type=float, default=1.0)
572
+ parser.add_argument("--T-duration", type=float, default=1.0)
573
+ parser.add_argument("--T-f0", type=float, default=1.0)
574
+
575
+ parser.add_argument(
576
+ "--subset", type=str, default="valid", choices=["test", "valid"]
577
+ )
578
+
579
+ args = options.parse_args_and_arch(parser)
580
+
581
+ assert (
582
+ args.prefix_length >= 1
583
+ ), "Prefix length includes bos token <s>, hence the minimum is 1."
584
+ assert all(
585
+ t >= 0 for t in [args.T_token, args.T_f0, args.T_duration]
586
+ ), "T must be non-negative!"
587
+
588
+ world_size = torch.cuda.device_count()
589
+ if world_size > 1:
590
+ import random
591
+
592
+ mp.set_start_method("spawn", force=True)
593
+ os.environ["MASTER_ADDR"] = "localhost"
594
+ os.environ["MASTER_PORT"] = str(random.randint(10_000, 50_000))
595
+
596
+ print(f"Using {world_size} devices, master port {os.environ['MASTER_PORT']}")
597
+
598
+ mp.spawn(
599
+ main,
600
+ nprocs=world_size,
601
+ args=(
602
+ world_size,
603
+ args,
604
+ ),
605
+ join=True,
606
+ )
607
+ else:
608
+ main(rank=0, world_size=world_size, args=args)
609
+
610
+
611
+ if __name__ == "__main__":
612
+ cli_main()
fairseq/examples/textless_nlp/pgslm/scripts/join_units_manifest.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import json
7
+ import argparse
8
+ import pathlib
9
+
10
+
11
+ def main():
12
+ parser = argparse.ArgumentParser()
13
+ parser.add_argument("--manifest", required=True)
14
+ parser.add_argument("--units", required=True)
15
+ parser.add_argument("--output", required=True)
16
+ parser.add_argument("--sample_rate", type=int, default=16_000)
17
+
18
+ args = parser.parse_args()
19
+
20
+ with open(args.manifest, "r") as manifest, open(args.units, "r") as units, open(
21
+ args.output, "w"
22
+ ) as outp:
23
+ root = manifest.readline().strip()
24
+ root = pathlib.Path(root)
25
+
26
+ for manifest_line, unit_line in zip(manifest.readlines(), units.readlines()):
27
+ path, frames = manifest_line.split()
28
+ duration = int(frames) / float(args.sample_rate)
29
+ fname = root / path
30
+ speaker = fname.parent.parent.name
31
+
32
+ units = unit_line.split("|")[1]
33
+
34
+ print(
35
+ json.dumps(
36
+ dict(
37
+ audio=str(root / path),
38
+ duration=duration,
39
+ hubert_km100=units.strip(),
40
+ speaker=speaker,
41
+ )
42
+ ),
43
+ file=outp,
44
+ )
45
+
46
+
47
+ if __name__ == "__main__":
48
+ main()
fairseq/examples/translation/prepare-iwslt14.sh ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ #
3
+ # Adapted from https://github.com/facebookresearch/MIXER/blob/master/prepareData.sh
4
+
5
+ echo 'Cloning Moses github repository (for tokenization scripts)...'
6
+ git clone https://github.com/moses-smt/mosesdecoder.git
7
+
8
+ echo 'Cloning Subword NMT repository (for BPE pre-processing)...'
9
+ git clone https://github.com/rsennrich/subword-nmt.git
10
+
11
+ SCRIPTS=mosesdecoder/scripts
12
+ TOKENIZER=$SCRIPTS/tokenizer/tokenizer.perl
13
+ LC=$SCRIPTS/tokenizer/lowercase.perl
14
+ CLEAN=$SCRIPTS/training/clean-corpus-n.perl
15
+ BPEROOT=subword-nmt/subword_nmt
16
+ BPE_TOKENS=10000
17
+
18
+ URL="http://dl.fbaipublicfiles.com/fairseq/data/iwslt14/de-en.tgz"
19
+ GZ=de-en.tgz
20
+
21
+ if [ ! -d "$SCRIPTS" ]; then
22
+ echo "Please set SCRIPTS variable correctly to point to Moses scripts."
23
+ exit
24
+ fi
25
+
26
+ src=de
27
+ tgt=en
28
+ lang=de-en
29
+ prep=iwslt14.tokenized.de-en
30
+ tmp=$prep/tmp
31
+ orig=orig
32
+
33
+ mkdir -p $orig $tmp $prep
34
+
35
+ echo "Downloading data from ${URL}..."
36
+ cd $orig
37
+ wget "$URL"
38
+
39
+ if [ -f $GZ ]; then
40
+ echo "Data successfully downloaded."
41
+ else
42
+ echo "Data not successfully downloaded."
43
+ exit
44
+ fi
45
+
46
+ tar zxvf $GZ
47
+ cd ..
48
+
49
+ echo "pre-processing train data..."
50
+ for l in $src $tgt; do
51
+ f=train.tags.$lang.$l
52
+ tok=train.tags.$lang.tok.$l
53
+
54
+ cat $orig/$lang/$f | \
55
+ grep -v '<url>' | \
56
+ grep -v '<talkid>' | \
57
+ grep -v '<keywords>' | \
58
+ sed -e 's/<title>//g' | \
59
+ sed -e 's/<\/title>//g' | \
60
+ sed -e 's/<description>//g' | \
61
+ sed -e 's/<\/description>//g' | \
62
+ perl $TOKENIZER -threads 8 -l $l > $tmp/$tok
63
+ echo ""
64
+ done
65
+ perl $CLEAN -ratio 1.5 $tmp/train.tags.$lang.tok $src $tgt $tmp/train.tags.$lang.clean 1 175
66
+ for l in $src $tgt; do
67
+ perl $LC < $tmp/train.tags.$lang.clean.$l > $tmp/train.tags.$lang.$l
68
+ done
69
+
70
+ echo "pre-processing valid/test data..."
71
+ for l in $src $tgt; do
72
+ for o in `ls $orig/$lang/IWSLT14.TED*.$l.xml`; do
73
+ fname=${o##*/}
74
+ f=$tmp/${fname%.*}
75
+ echo $o $f
76
+ grep '<seg id' $o | \
77
+ sed -e 's/<seg id="[0-9]*">\s*//g' | \
78
+ sed -e 's/\s*<\/seg>\s*//g' | \
79
+ sed -e "s/\’/\'/g" | \
80
+ perl $TOKENIZER -threads 8 -l $l | \
81
+ perl $LC > $f
82
+ echo ""
83
+ done
84
+ done
85
+
86
+
87
+ echo "creating train, valid, test..."
88
+ for l in $src $tgt; do
89
+ awk '{if (NR%23 == 0) print $0; }' $tmp/train.tags.de-en.$l > $tmp/valid.$l
90
+ awk '{if (NR%23 != 0) print $0; }' $tmp/train.tags.de-en.$l > $tmp/train.$l
91
+
92
+ cat $tmp/IWSLT14.TED.dev2010.de-en.$l \
93
+ $tmp/IWSLT14.TEDX.dev2012.de-en.$l \
94
+ $tmp/IWSLT14.TED.tst2010.de-en.$l \
95
+ $tmp/IWSLT14.TED.tst2011.de-en.$l \
96
+ $tmp/IWSLT14.TED.tst2012.de-en.$l \
97
+ > $tmp/test.$l
98
+ done
99
+
100
+ TRAIN=$tmp/train.en-de
101
+ BPE_CODE=$prep/code
102
+ rm -f $TRAIN
103
+ for l in $src $tgt; do
104
+ cat $tmp/train.$l >> $TRAIN
105
+ done
106
+
107
+ echo "learn_bpe.py on ${TRAIN}..."
108
+ python $BPEROOT/learn_bpe.py -s $BPE_TOKENS < $TRAIN > $BPE_CODE
109
+
110
+ for L in $src $tgt; do
111
+ for f in train.$L valid.$L test.$L; do
112
+ echo "apply_bpe.py to ${f}..."
113
+ python $BPEROOT/apply_bpe.py -c $BPE_CODE < $tmp/$f > $prep/$f
114
+ done
115
+ done
fairseq/examples/translation/prepare-wmt14en2fr.sh ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ # Adapted from https://github.com/facebookresearch/MIXER/blob/master/prepareData.sh
3
+
4
+ echo 'Cloning Moses github repository (for tokenization scripts)...'
5
+ git clone https://github.com/moses-smt/mosesdecoder.git
6
+
7
+ echo 'Cloning Subword NMT repository (for BPE pre-processing)...'
8
+ git clone https://github.com/rsennrich/subword-nmt.git
9
+
10
+ SCRIPTS=mosesdecoder/scripts
11
+ TOKENIZER=$SCRIPTS/tokenizer/tokenizer.perl
12
+ CLEAN=$SCRIPTS/training/clean-corpus-n.perl
13
+ NORM_PUNC=$SCRIPTS/tokenizer/normalize-punctuation.perl
14
+ REM_NON_PRINT_CHAR=$SCRIPTS/tokenizer/remove-non-printing-char.perl
15
+ BPEROOT=subword-nmt/subword_nmt
16
+ BPE_TOKENS=40000
17
+
18
+ URLS=(
19
+ "http://statmt.org/wmt13/training-parallel-europarl-v7.tgz"
20
+ "http://statmt.org/wmt13/training-parallel-commoncrawl.tgz"
21
+ "http://statmt.org/wmt13/training-parallel-un.tgz"
22
+ "http://statmt.org/wmt14/training-parallel-nc-v9.tgz"
23
+ "http://statmt.org/wmt10/training-giga-fren.tar"
24
+ "http://statmt.org/wmt14/test-full.tgz"
25
+ )
26
+ FILES=(
27
+ "training-parallel-europarl-v7.tgz"
28
+ "training-parallel-commoncrawl.tgz"
29
+ "training-parallel-un.tgz"
30
+ "training-parallel-nc-v9.tgz"
31
+ "training-giga-fren.tar"
32
+ "test-full.tgz"
33
+ )
34
+ CORPORA=(
35
+ "training/europarl-v7.fr-en"
36
+ "commoncrawl.fr-en"
37
+ "un/undoc.2000.fr-en"
38
+ "training/news-commentary-v9.fr-en"
39
+ "giga-fren.release2.fixed"
40
+ )
41
+
42
+ if [ ! -d "$SCRIPTS" ]; then
43
+ echo "Please set SCRIPTS variable correctly to point to Moses scripts."
44
+ exit
45
+ fi
46
+
47
+ src=en
48
+ tgt=fr
49
+ lang=en-fr
50
+ prep=wmt14_en_fr
51
+ tmp=$prep/tmp
52
+ orig=orig
53
+
54
+ mkdir -p $orig $tmp $prep
55
+
56
+ cd $orig
57
+
58
+ for ((i=0;i<${#URLS[@]};++i)); do
59
+ file=${FILES[i]}
60
+ if [ -f $file ]; then
61
+ echo "$file already exists, skipping download"
62
+ else
63
+ url=${URLS[i]}
64
+ wget "$url"
65
+ if [ -f $file ]; then
66
+ echo "$url successfully downloaded."
67
+ else
68
+ echo "$url not successfully downloaded."
69
+ exit -1
70
+ fi
71
+ if [ ${file: -4} == ".tgz" ]; then
72
+ tar zxvf $file
73
+ elif [ ${file: -4} == ".tar" ]; then
74
+ tar xvf $file
75
+ fi
76
+ fi
77
+ done
78
+
79
+ gunzip giga-fren.release2.fixed.*.gz
80
+ cd ..
81
+
82
+ echo "pre-processing train data..."
83
+ for l in $src $tgt; do
84
+ rm $tmp/train.tags.$lang.tok.$l
85
+ for f in "${CORPORA[@]}"; do
86
+ cat $orig/$f.$l | \
87
+ perl $NORM_PUNC $l | \
88
+ perl $REM_NON_PRINT_CHAR | \
89
+ perl $TOKENIZER -threads 8 -a -l $l >> $tmp/train.tags.$lang.tok.$l
90
+ done
91
+ done
92
+
93
+ echo "pre-processing test data..."
94
+ for l in $src $tgt; do
95
+ if [ "$l" == "$src" ]; then
96
+ t="src"
97
+ else
98
+ t="ref"
99
+ fi
100
+ grep '<seg id' $orig/test-full/newstest2014-fren-$t.$l.sgm | \
101
+ sed -e 's/<seg id="[0-9]*">\s*//g' | \
102
+ sed -e 's/\s*<\/seg>\s*//g' | \
103
+ sed -e "s/\’/\'/g" | \
104
+ perl $TOKENIZER -threads 8 -a -l $l > $tmp/test.$l
105
+ echo ""
106
+ done
107
+
108
+ echo "splitting train and valid..."
109
+ for l in $src $tgt; do
110
+ awk '{if (NR%1333 == 0) print $0; }' $tmp/train.tags.$lang.tok.$l > $tmp/valid.$l
111
+ awk '{if (NR%1333 != 0) print $0; }' $tmp/train.tags.$lang.tok.$l > $tmp/train.$l
112
+ done
113
+
114
+ TRAIN=$tmp/train.fr-en
115
+ BPE_CODE=$prep/code
116
+ rm -f $TRAIN
117
+ for l in $src $tgt; do
118
+ cat $tmp/train.$l >> $TRAIN
119
+ done
120
+
121
+ echo "learn_bpe.py on ${TRAIN}..."
122
+ python $BPEROOT/learn_bpe.py -s $BPE_TOKENS < $TRAIN > $BPE_CODE
123
+
124
+ for L in $src $tgt; do
125
+ for f in train.$L valid.$L test.$L; do
126
+ echo "apply_bpe.py to ${f}..."
127
+ python $BPEROOT/apply_bpe.py -c $BPE_CODE < $tmp/$f > $tmp/bpe.$f
128
+ done
129
+ done
130
+
131
+ perl $CLEAN -ratio 1.5 $tmp/bpe.train $src $tgt $prep/train 1 250
132
+ perl $CLEAN -ratio 1.5 $tmp/bpe.valid $src $tgt $prep/valid 1 250
133
+
134
+ for L in $src $tgt; do
135
+ cp $tmp/bpe.test.$L $prep/test.$L
136
+ done
fairseq/examples/translation_moe/README.md ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Mixture Models for Diverse Machine Translation: Tricks of the Trade (Shen et al., 2019)
2
+
3
+ This page includes instructions for reproducing results from the paper [Mixture Models for Diverse Machine Translation: Tricks of the Trade (Shen et al., 2019)](https://arxiv.org/abs/1902.07816).
4
+
5
+ ## Download data
6
+
7
+ First, follow the [instructions to download and preprocess the WMT'17 En-De dataset](../translation#prepare-wmt14en2desh).
8
+ Make sure to learn a joint vocabulary by passing the `--joined-dictionary` option to `fairseq-preprocess`.
9
+
10
+ ## Train a model
11
+
12
+ Then we can train a mixture of experts model using the `translation_moe` task.
13
+ Use the `--method` flag to choose the MoE variant; we support hard mixtures with a learned or uniform prior (`--method hMoElp` and `hMoEup`, respectively) and soft mixures (`--method sMoElp` and `sMoEup`).
14
+ The model is trained with online responsibility assignment and shared parameterization.
15
+
16
+ The following command will train a `hMoElp` model with `3` experts:
17
+ ```bash
18
+ fairseq-train --ddp-backend='legacy_ddp' \
19
+ data-bin/wmt17_en_de \
20
+ --max-update 100000 \
21
+ --task translation_moe --user-dir examples/translation_moe/translation_moe_src \
22
+ --method hMoElp --mean-pool-gating-network \
23
+ --num-experts 3 \
24
+ --arch transformer_wmt_en_de --share-all-embeddings \
25
+ --optimizer adam --adam-betas '(0.9, 0.98)' --clip-norm 0.0 \
26
+ --lr-scheduler inverse_sqrt --warmup-init-lr 1e-07 --warmup-updates 4000 \
27
+ --lr 0.0007 \
28
+ --dropout 0.1 --weight-decay 0.0 --criterion cross_entropy \
29
+ --max-tokens 3584
30
+ ```
31
+
32
+ ## Translate
33
+
34
+ Once a model is trained, we can generate translations from different experts using the `--gen-expert` option.
35
+ For example, to generate from expert 0:
36
+ ```bash
37
+ fairseq-generate data-bin/wmt17_en_de \
38
+ --path checkpoints/checkpoint_best.pt \
39
+ --beam 1 --remove-bpe \
40
+ --task translation_moe --user-dir examples/translation_moe/translation_moe_src \
41
+ --method hMoElp --mean-pool-gating-network \
42
+ --num-experts 3 \
43
+ --gen-expert 0
44
+ ```
45
+
46
+ ## Evaluate
47
+
48
+ First download a tokenized version of the WMT'14 En-De test set with multiple references:
49
+ ```bash
50
+ wget dl.fbaipublicfiles.com/fairseq/data/wmt14-en-de.extra_refs.tok
51
+ ```
52
+
53
+ Next apply BPE on the fly and run generation for each expert:
54
+ ```bash
55
+ BPE_CODE=examples/translation/wmt17_en_de/code
56
+ for EXPERT in $(seq 0 2); do \
57
+ cat wmt14-en-de.extra_refs.tok \
58
+ | grep ^S | cut -f 2 \
59
+ | fairseq-interactive data-bin/wmt17_en_de \
60
+ --path checkpoints/checkpoint_best.pt \
61
+ --beam 1 \
62
+ --bpe subword_nmt --bpe-codes $BPE_CODE \
63
+ --buffer-size 500 --max-tokens 6000 \
64
+ --task translation_moe --user-dir examples/translation_moe/translation_moe_src \
65
+ --method hMoElp --mean-pool-gating-network \
66
+ --num-experts 3 \
67
+ --gen-expert $EXPERT ; \
68
+ done > wmt14-en-de.extra_refs.tok.gen.3experts
69
+ ```
70
+
71
+ Finally use `score_moe.py` to compute pairwise BLUE and average oracle BLEU:
72
+ ```bash
73
+ python examples/translation_moe/score.py --sys wmt14-en-de.extra_refs.tok.gen.3experts --ref wmt14-en-de.extra_refs.tok
74
+ # pairwise BLEU: 48.26
75
+ # #refs covered: 2.11
76
+ # multi-reference BLEU (leave-one-out): 59.46
77
+ ```
78
+ This matches row 3 from Table 7 in the paper.
79
+
80
+ ## Citation
81
+
82
+ ```bibtex
83
+ @article{shen2019mixture,
84
+ title = {Mixture Models for Diverse Machine Translation: Tricks of the Trade},
85
+ author = {Tianxiao Shen and Myle Ott and Michael Auli and Marc'Aurelio Ranzato},
86
+ journal = {International Conference on Machine Learning},
87
+ year = 2019,
88
+ }
89
+ ```
fairseq/examples/translation_moe/score.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Copyright (c) Facebook, Inc. and its affiliates.
3
+ #
4
+ # This source code is licensed under the MIT license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ """
7
+ Scoring script for computing pairwise BLEU and multi-ref BLEU over a set of
8
+ candidate hypotheses.
9
+
10
+ See `"Mixture Models for Diverse Machine Translation: Tricks of the Trade"
11
+ (Shen et al., 2019) <https://arxiv.org/abs/1902.07816>`_.
12
+ """
13
+
14
+ import argparse
15
+ import random
16
+ import sys
17
+ from itertools import chain
18
+
19
+ import numpy as np
20
+ import sacrebleu
21
+ from sacrebleu import corpus_bleu as _corpus_bleu
22
+
23
+ def main():
24
+ parser = argparse.ArgumentParser(sys.argv[0])
25
+ parser.add_argument(
26
+ "--sys", nargs="*", default="", metavar="FILE", help="path to system output"
27
+ )
28
+ parser.add_argument("--ref", default="", metavar="FILE", help="path to references")
29
+ parser.add_argument(
30
+ "--output",
31
+ default="",
32
+ metavar="FILE",
33
+ help="print outputs into a pretty format",
34
+ )
35
+ args = parser.parse_args()
36
+
37
+ if args.sys:
38
+ src, tgt, hypos, log_probs = load_sys(args.sys)
39
+ print("pairwise BLEU: %.2f" % pairwise(hypos))
40
+ if args.output:
41
+ merge(src, tgt, hypos, log_probs, args.output)
42
+
43
+ if args.ref:
44
+ _, _, refs = load_ref(args.ref)
45
+ if args.sys:
46
+ multi_ref(refs, hypos)
47
+ else:
48
+ intra_ref(refs)
49
+
50
+
51
+ def dictolist(d):
52
+ a = sorted(d.items(), key=lambda i: i[0])
53
+ return [i[1] for i in a]
54
+
55
+
56
+ def load_sys(paths):
57
+ src, tgt, hypos, log_probs = {}, {}, {}, {}
58
+ for path in paths:
59
+ with open(path) as f:
60
+ for line in f:
61
+ line = line.rstrip()
62
+ # S: source
63
+ # T: target
64
+ # D: detokenized system output
65
+ if line.startswith(("S-", "T-", "D-")):
66
+ i = int(line[line.find("-") + 1 : line.find("\t")])
67
+ if line.startswith("S-"):
68
+ src[i] = line.split("\t")[1]
69
+ if line.startswith("T-"):
70
+ tgt[i] = line.split("\t")[1]
71
+ if line.startswith("D-"):
72
+ if i not in hypos:
73
+ hypos[i] = []
74
+ log_probs[i] = []
75
+ hypos[i].append(line.split("\t")[2])
76
+ log_probs[i].append(float(line.split("\t")[1]))
77
+ return dictolist(src), dictolist(tgt), dictolist(hypos), dictolist(log_probs)
78
+
79
+
80
+ def load_ref(path):
81
+ with open(path) as f:
82
+ lines = f.readlines()
83
+ src, tgt, refs = [], [], []
84
+ i = 0
85
+ while i < len(lines):
86
+ if lines[i].startswith("S-"):
87
+ src.append(lines[i].split("\t")[1].rstrip())
88
+ i += 1
89
+ elif lines[i].startswith("T-"):
90
+ tgt.append(lines[i].split("\t")[1].rstrip())
91
+ i += 1
92
+ else:
93
+ a = []
94
+ while i < len(lines) and lines[i].startswith("R"):
95
+ a.append(lines[i].split("\t")[1].rstrip())
96
+ i += 1
97
+ refs.append(a)
98
+ return src, tgt, refs
99
+
100
+
101
+ def merge(src, tgt, hypos, log_probs, path):
102
+ with open(path, "w") as f:
103
+ for s, t, hs, lps in zip(src, tgt, hypos, log_probs):
104
+ f.write(s + "\n")
105
+ f.write(t + "\n")
106
+ f.write("\n")
107
+ for h, lp in zip(hs, lps):
108
+ f.write("\t%f\t%s\n" % (lp, h.strip()))
109
+ f.write("------------------------------------------------------\n")
110
+
111
+
112
+ def corpus_bleu(sys_stream, ref_streams):
113
+ bleu = _corpus_bleu(sys_stream, ref_streams, tokenize="none")
114
+ return bleu.score
115
+
116
+
117
+ def sentence_bleu(hypothesis, reference):
118
+ bleu = _corpus_bleu(hypothesis, reference)
119
+ for i in range(1, 4):
120
+ bleu.counts[i] += 1
121
+ bleu.totals[i] += 1
122
+ bleu = sacrebleu.BLEU.compute_bleu(
123
+ bleu.counts,
124
+ bleu.totals,
125
+ bleu.sys_len,
126
+ bleu.ref_len,
127
+ smooth_method="exp",
128
+ )
129
+ return bleu.score
130
+
131
+
132
+ def pairwise(sents):
133
+ _ref, _hypo = [], []
134
+ for s in sents:
135
+ for i in range(len(s)):
136
+ for j in range(len(s)):
137
+ if i != j:
138
+ _ref.append(s[i])
139
+ _hypo.append(s[j])
140
+ return corpus_bleu(_hypo, [_ref])
141
+
142
+
143
+ def multi_ref(refs, hypos):
144
+ _ref, _hypo = [], []
145
+ ref_cnt = 0
146
+ assert len(refs) == len(hypos)
147
+
148
+ # count number of refs covered
149
+ for rs, hs in zip(refs, hypos):
150
+ a = set()
151
+ for h in hs:
152
+ s = [sentence_bleu(h, r) for r in rs]
153
+ j = np.argmax(s)
154
+ _ref.append(rs[j])
155
+ _hypo.append(h)
156
+ best = [k for k in range(len(rs)) if s[k] == s[j]]
157
+ a.add(random.choice(best))
158
+ ref_cnt += len(a)
159
+ print("#refs covered: %.2f" % (ref_cnt / len(refs)))
160
+
161
+ # transpose refs and hypos
162
+ refs = list(zip(*refs))
163
+ hypos = list(zip(*hypos))
164
+
165
+ # compute multi-ref corpus BLEU (leave-one-out to be comparable to intra_ref)
166
+ k = len(hypos)
167
+ m = len(refs)
168
+ flat_hypos = [hypos[j][i] for i in range(len(hypos[0])) for j in range(k)]
169
+ duplicated_refs = [[ref for ref in refs_i for _ in range(k)] for refs_i in refs]
170
+ loo_bleus = []
171
+ for held_out_ref in range(m):
172
+ remaining_refs = (
173
+ duplicated_refs[:held_out_ref] + duplicated_refs[held_out_ref + 1 :]
174
+ )
175
+ assert len(remaining_refs) == m - 1
176
+ loo_bleus.append(corpus_bleu(flat_hypos, remaining_refs))
177
+ print("average multi-reference BLEU (leave-one-out): %.2f" % np.mean(loo_bleus))
178
+
179
+
180
+ def intra_ref(refs):
181
+ print("ref pairwise BLEU: %.2f" % pairwise(refs))
182
+ refs = list(zip(*refs))
183
+ m = len(refs)
184
+ concat_h = []
185
+ concat_rest = [[] for j in range(m - 1)]
186
+ for i, h in enumerate(refs):
187
+ rest = refs[:i] + refs[i + 1 :]
188
+ concat_h.append(h)
189
+ for j in range(m - 1):
190
+ concat_rest[j].extend(rest[j])
191
+ concat_h = list(chain.from_iterable(concat_h))
192
+ bleu = corpus_bleu(concat_h, concat_rest)
193
+ print("multi-reference BLEU (leave-one-out): %.2f" % bleu)
194
+
195
+
196
+ if __name__ == "__main__":
197
+ main()
fairseq/examples/translation_moe/translation_moe_src/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from . import translation_moe # noqa
fairseq/examples/translation_moe/translation_moe_src/logsumexp_moe.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import torch
7
+
8
+
9
+ class LogSumExpMoE(torch.autograd.Function):
10
+ """Standard LogSumExp forward pass, but use *posterior* for the backward.
11
+
12
+ See `"Mixture Models for Diverse Machine Translation: Tricks of the Trade"
13
+ (Shen et al., 2019) <https://arxiv.org/abs/1902.07816>`_.
14
+ """
15
+
16
+ @staticmethod
17
+ def forward(ctx, logp, posterior, dim=-1):
18
+ ctx.save_for_backward(posterior)
19
+ ctx.dim = dim
20
+ return torch.logsumexp(logp, dim=dim)
21
+
22
+ @staticmethod
23
+ def backward(ctx, grad_output):
24
+ (posterior,) = ctx.saved_tensors
25
+ grad_logp = grad_output.unsqueeze(ctx.dim) * posterior
26
+ return grad_logp, None, None
fairseq/examples/translation_moe/translation_moe_src/mean_pool_gating_network.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import torch
7
+ import torch.nn.functional as F
8
+
9
+
10
+ class MeanPoolGatingNetwork(torch.nn.Module):
11
+ """A simple mean-pooling gating network for selecting experts.
12
+
13
+ This module applies mean pooling over an encoder's output and returns
14
+ reponsibilities for each expert. The encoder format is expected to match
15
+ :class:`fairseq.models.transformer.TransformerEncoder`.
16
+ """
17
+
18
+ def __init__(self, embed_dim, num_experts, dropout=None):
19
+ super().__init__()
20
+ self.embed_dim = embed_dim
21
+ self.num_experts = num_experts
22
+
23
+ self.fc1 = torch.nn.Linear(embed_dim, embed_dim)
24
+ self.dropout = torch.nn.Dropout(dropout) if dropout is not None else None
25
+ self.fc2 = torch.nn.Linear(embed_dim, num_experts)
26
+
27
+ def forward(self, encoder_out):
28
+ if not (
29
+ "encoder_out" in encoder_out
30
+ and "encoder_padding_mask" in encoder_out
31
+ and encoder_out["encoder_out"][0].size(2) == self.embed_dim
32
+ ):
33
+ raise ValueError("Unexpected format for encoder_out")
34
+
35
+ # mean pooling over time
36
+ encoder_padding_mask = encoder_out["encoder_padding_mask"][0] # B x T
37
+ encoder_out = encoder_out["encoder_out"][0].transpose(0, 1) # B x T x C
38
+ if encoder_padding_mask is not None:
39
+ encoder_out = encoder_out.clone() # required because of transpose above
40
+ encoder_out[encoder_padding_mask] = 0
41
+ ntokens = torch.sum(~encoder_padding_mask, dim=1, keepdim=True)
42
+ x = torch.sum(encoder_out, dim=1) / ntokens.type_as(encoder_out)
43
+ else:
44
+ x = torch.mean(encoder_out, dim=1)
45
+
46
+ x = torch.tanh(self.fc1(x))
47
+ if self.dropout is not None:
48
+ x = self.dropout(x)
49
+ x = self.fc2(x)
50
+ return F.log_softmax(x, dim=-1, dtype=torch.float32).type_as(x)
fairseq/examples/translation_moe/translation_moe_src/translation_moe.py ADDED
@@ -0,0 +1,259 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from dataclasses import dataclass, field
7
+ import torch
8
+ from omegaconf import II
9
+
10
+ from fairseq import utils
11
+ from fairseq.logging import metrics
12
+ from fairseq.dataclass import ChoiceEnum
13
+ from fairseq.tasks import register_task
14
+ from fairseq.tasks.translation import TranslationConfig, TranslationTask
15
+
16
+ from .logsumexp_moe import LogSumExpMoE
17
+ from .mean_pool_gating_network import MeanPoolGatingNetwork
18
+
19
+
20
+ METHOD_CHOICES = ChoiceEnum(["sMoElp", "sMoEup", "hMoElp", "hMoEup"])
21
+
22
+
23
+ @dataclass
24
+ class TranslationMoEConfig(TranslationConfig):
25
+ method: METHOD_CHOICES = field(
26
+ default="hMoEup",
27
+ metadata={"help": "MoE method"},
28
+ )
29
+ num_experts: int = field(
30
+ default=3,
31
+ metadata={"help": "number of experts"},
32
+ )
33
+ mean_pool_gating_network: bool = field(
34
+ default=False,
35
+ metadata={"help": "use a simple mean-pooling gating network"},
36
+ )
37
+ mean_pool_gating_network_dropout: float = field(
38
+ default=0,
39
+ metadata={"help": "dropout for mean-pooling gating network"},
40
+ )
41
+ mean_pool_gating_network_encoder_dim: int = field(
42
+ default=0,
43
+ metadata={"help": "encoder output dim for mean-pooling gating network"},
44
+ )
45
+ gen_expert: int = field(
46
+ default=0,
47
+ metadata={"help": "which expert to use for generation"},
48
+ )
49
+ sentence_avg: bool = II("optimization.sentence_avg")
50
+
51
+
52
+ @register_task("translation_moe", dataclass=TranslationMoEConfig)
53
+ class TranslationMoETask(TranslationTask):
54
+ """
55
+ Translation task for Mixture of Experts (MoE) models.
56
+
57
+ See `"Mixture Models for Diverse Machine Translation: Tricks of the Trade"
58
+ (Shen et al., 2019) <https://arxiv.org/abs/1902.07816>`_.
59
+
60
+ Args:
61
+ src_dict (~fairseq.data.Dictionary): dictionary for the source language
62
+ tgt_dict (~fairseq.data.Dictionary): dictionary for the target language
63
+
64
+ .. note::
65
+
66
+ The translation task is compatible with :mod:`fairseq-train`,
67
+ :mod:`fairseq-generate` and :mod:`fairseq-interactive`.
68
+
69
+ The translation task provides the following additional command-line
70
+ arguments:
71
+
72
+ .. argparse::
73
+ :ref: fairseq.tasks.translation_parser
74
+ :prog:
75
+ """
76
+
77
+ cfg: TranslationMoEConfig
78
+
79
+ def __init__(self, cfg: TranslationMoEConfig, src_dict, tgt_dict):
80
+ if cfg.method == "sMoElp":
81
+ # soft MoE with learned prior
82
+ self.uniform_prior = False
83
+ self.hard_selection = False
84
+ elif cfg.method == "sMoEup":
85
+ # soft MoE with uniform prior
86
+ self.uniform_prior = True
87
+ self.hard_selection = False
88
+ elif cfg.method == "hMoElp":
89
+ # hard MoE with learned prior
90
+ self.uniform_prior = False
91
+ self.hard_selection = True
92
+ elif cfg.method == "hMoEup":
93
+ # hard MoE with uniform prior
94
+ self.uniform_prior = True
95
+ self.hard_selection = True
96
+
97
+ # add indicator tokens for each expert
98
+ for i in range(cfg.num_experts):
99
+ # add to both dictionaries in case we're sharing embeddings
100
+ src_dict.add_symbol("<expert_{}>".format(i))
101
+ tgt_dict.add_symbol("<expert_{}>".format(i))
102
+
103
+ super().__init__(cfg, src_dict, tgt_dict)
104
+
105
+ def build_model(self, cfg, from_checkpoint=False):
106
+ from fairseq import models
107
+
108
+ model = models.build_model(cfg, self)
109
+ if not self.uniform_prior and not hasattr(model, "gating_network"):
110
+ if self.cfg.mean_pool_gating_network:
111
+ if self.cfg.mean_pool_gating_network_encoder_dim > 0:
112
+ encoder_dim = self.cfg.mean_pool_gating_network_encoder_dim
113
+ elif getattr(cfg, "encoder_embed_dim", None):
114
+ # assume that encoder_embed_dim is the encoder's output dimension
115
+ encoder_dim = cfg.encoder_embed_dim
116
+ else:
117
+ raise ValueError(
118
+ "Must specify --mean-pool-gating-network-encoder-dim"
119
+ )
120
+
121
+ if self.cfg.mean_pool_gating_network_dropout > 0:
122
+ dropout = self.cfg.mean_pool_gating_network_dropout
123
+ elif getattr(cfg, "dropout", None):
124
+ dropout = cfg.dropout
125
+ else:
126
+ raise ValueError("Must specify task.mean_pool_gating_network_dropout")
127
+
128
+ model.gating_network = MeanPoolGatingNetwork(
129
+ encoder_dim,
130
+ self.cfg.num_experts,
131
+ dropout,
132
+ )
133
+ else:
134
+ raise ValueError(
135
+ "translation_moe task with learned prior requires the model to "
136
+ "have a gating network; try using --mean-pool-gating-network"
137
+ )
138
+ return model
139
+
140
+ def expert_index(self, i):
141
+ return i + self.tgt_dict.index("<expert_0>")
142
+
143
+ def _get_loss(self, sample, model, criterion):
144
+ assert hasattr(
145
+ criterion, "compute_loss"
146
+ ), "translation_moe task requires the criterion to implement the compute_loss() method"
147
+
148
+ k = self.cfg.num_experts
149
+ bsz = sample["target"].size(0)
150
+
151
+ def get_lprob_y(encoder_out, prev_output_tokens_k):
152
+ net_output = model.decoder(
153
+ prev_output_tokens=prev_output_tokens_k,
154
+ encoder_out=encoder_out,
155
+ )
156
+ loss, _ = criterion.compute_loss(model, net_output, sample, reduce=False)
157
+ loss = loss.view(bsz, -1)
158
+ return -loss.sum(dim=1, keepdim=True) # -> B x 1
159
+
160
+ def get_lprob_yz(winners=None):
161
+ encoder_out = model.encoder(
162
+ src_tokens=sample["net_input"]["src_tokens"],
163
+ src_lengths=sample["net_input"]["src_lengths"],
164
+ )
165
+
166
+ if winners is None:
167
+ lprob_y = []
168
+ for i in range(k):
169
+ prev_output_tokens_k = sample["net_input"][
170
+ "prev_output_tokens"
171
+ ].clone()
172
+ assert not prev_output_tokens_k.requires_grad
173
+ prev_output_tokens_k[:, 0] = self.expert_index(i)
174
+ lprob_y.append(get_lprob_y(encoder_out, prev_output_tokens_k))
175
+ lprob_y = torch.cat(lprob_y, dim=1) # -> B x K
176
+ else:
177
+ prev_output_tokens_k = sample["net_input"]["prev_output_tokens"].clone()
178
+ prev_output_tokens_k[:, 0] = self.expert_index(winners)
179
+ lprob_y = get_lprob_y(encoder_out, prev_output_tokens_k) # -> B
180
+
181
+ if self.uniform_prior:
182
+ lprob_yz = lprob_y
183
+ else:
184
+ lprob_z = model.gating_network(encoder_out) # B x K
185
+ if winners is not None:
186
+ lprob_z = lprob_z.gather(dim=1, index=winners.unsqueeze(-1))
187
+ lprob_yz = lprob_y + lprob_z.type_as(lprob_y) # B x K
188
+
189
+ return lprob_yz
190
+
191
+ # compute responsibilities without dropout
192
+ with utils.model_eval(model): # disable dropout
193
+ with torch.no_grad(): # disable autograd
194
+ lprob_yz = get_lprob_yz() # B x K
195
+ prob_z_xy = torch.nn.functional.softmax(lprob_yz, dim=1)
196
+ assert not prob_z_xy.requires_grad
197
+
198
+ # compute loss with dropout
199
+ if self.hard_selection:
200
+ winners = prob_z_xy.max(dim=1)[1]
201
+ loss = -get_lprob_yz(winners)
202
+ else:
203
+ lprob_yz = get_lprob_yz() # B x K
204
+ loss = -LogSumExpMoE.apply(lprob_yz, prob_z_xy, 1)
205
+
206
+ loss = loss.sum()
207
+ sample_size = (
208
+ sample["target"].size(0) if self.cfg.sentence_avg else sample["ntokens"]
209
+ )
210
+ logging_output = {
211
+ "loss": utils.item(loss.data),
212
+ "ntokens": sample["ntokens"],
213
+ "nsentences": bsz,
214
+ "sample_size": sample_size,
215
+ "posterior": prob_z_xy.float().sum(dim=0).cpu(),
216
+ }
217
+ return loss, sample_size, logging_output
218
+
219
+ def train_step(
220
+ self, sample, model, criterion, optimizer, update_num, ignore_grad=False
221
+ ):
222
+ model.train()
223
+ loss, sample_size, logging_output = self._get_loss(sample, model, criterion)
224
+ if ignore_grad:
225
+ loss *= 0
226
+ optimizer.backward(loss)
227
+ return loss, sample_size, logging_output
228
+
229
+ def valid_step(self, sample, model, criterion):
230
+ model.eval()
231
+ with torch.no_grad():
232
+ loss, sample_size, logging_output = self._get_loss(sample, model, criterion)
233
+ return loss, sample_size, logging_output
234
+
235
+ def inference_step(
236
+ self,
237
+ generator,
238
+ models,
239
+ sample,
240
+ prefix_tokens=None,
241
+ expert=None,
242
+ constraints=None,
243
+ ):
244
+ expert = expert or self.cfg.gen_expert
245
+ with torch.no_grad():
246
+ return generator.generate(
247
+ models,
248
+ sample,
249
+ prefix_tokens=prefix_tokens,
250
+ constraints=constraints,
251
+ bos_token=self.expert_index(expert),
252
+ )
253
+
254
+ def reduce_metrics(self, logging_outputs, criterion):
255
+ super().reduce_metrics(logging_outputs, criterion)
256
+ metrics.log_scalar(
257
+ "posterior",
258
+ sum(log["posterior"] for log in logging_outputs if "posterior" in log),
259
+ )
fairseq/examples/truncated_bptt/README.md ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Truncated Backpropagation Through Time (BPTT)
2
+
3
+ Truncated BPTT is a useful technique for training language models on very long
4
+ sequences. Typically a long sequences is split into chunks and a language model
5
+ is trained over the chunks sequentially. The LM may condition on previous
6
+ chunks, but gradients only flow through the current chunk. This technique was
7
+ the basis for the paper: [Transformer-XL: Attentive Language Models Beyond a
8
+ Fixed-Length Context](https://arxiv.org/abs/1901.02860), which achieved
9
+ state-of-the-art language modeling results at the time of publication.
10
+
11
+ It is slightly tricky to implement Truncated BPTT efficiently in fairseq, since
12
+ we need to iterate over the data sequentially and disable any batch shuffling
13
+ logic. The code provided in this example illustrates how to implement Truncated
14
+ BPTT in fairseq by overriding ``FairseqTask::get_batch_iterator`` to iterate
15
+ over the data sequentially. Crucially, this example supports batching and
16
+ multi-GPU (data parallel) training.
17
+
18
+ ##### 0. Setup
19
+
20
+ First, see the general [language modeling README](README.md) for instructions on
21
+ preprocessing the WikiText-103 data.
22
+
23
+ ##### 1. Train a Transformer-XL model on WikiText-103
24
+
25
+ We will train a 16-layer Transformer-XL model following the [hyperparameters
26
+ used in the original
27
+ paper](https://github.com/kimiyoung/transformer-xl/blob/master/pytorch/run_wt103_base.sh).
28
+
29
+ The following command assumes 4 GPUs, so that the total batch size is 60
30
+ sequences (15 x 4). Training should take ~24 hours on 4 V100 GPUs:
31
+ ```bash
32
+ CUDA_VISIBLE_DEVICES=0,1,2,3 fairseq-train \
33
+ --user-dir examples/truncated_bptt \
34
+ data-bin/wikitext-103/ \
35
+ --task truncated_bptt_lm --tokens-per-sample 150 \
36
+ --batch-size 15 --max-update 200000 \
37
+ --arch transformer_xl --n-layer 16 --d-model 410 --n-head 10 \
38
+ --d-head 41 --d-inner 2100 --dropout 0.1 --dropatt 0.0 --mem-len 150 \
39
+ --optimizer adam --clip-norm 0.25 \
40
+ --lr-scheduler cosine --warmup-updates 0 --min-lr 0.0 --lr 0.00025 \
41
+ --log-format json --log-interval 25 \
42
+ --fp16
43
+ ```
44
+
45
+ If training on a single GPU, set `--update-freq=4` to accumulate 4x gradients
46
+ and simulate training on 4 GPUs.
47
+
48
+ ##### 2. Evaluate
49
+
50
+ ```bash
51
+ fairseq-eval-lm data-bin/wikitext-103/ \
52
+ --path checkpoints/checkpoint_best.pt \
53
+ --user-dir examples/truncated_bptt/ \
54
+ --task truncated_bptt_lm \
55
+ --batch-size 1 --required-batch-size-multiple 1 \
56
+ --model-overrides '{"mem_len":640,"clamp_len":400,"same_length":True}' \
57
+ --tokens-per-sample 64
58
+ # ... | INFO | fairseq_cli.eval_lm | num. model params: 151123537
59
+ # ... | INFO | fairseq_cli.eval_lm | Evaluated 245569 tokens in 83.1s (2956.82 tokens/s)
60
+ # ... | INFO | fairseq_cli.eval_lm | Loss (base 2): 4.5668, Perplexity: 23.70
61
+ # Compare to 24.0 test perplexity from the paper
62
+ ```
63
+
64
+ *Note:* During training the model saw 150 tokens of context
65
+ (``--tokens-per-sample=150``) and 150 extra memory tokens (``--mem-len=150``).
66
+ During evaluation we measure perplexity on sequences of 64 tokens
67
+ (``--tokens-per-sample=64``) and increase the memory length
68
+ (``--model-overrides='{"mem_len":640}'``). These settings match the evaluation
69
+ settings from [the original
70
+ paper](https://github.com/kimiyoung/transformer-xl/blob/master/pytorch/run_wt103_base.sh).
fairseq/examples/truncated_bptt/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from . import transformer_xl_model, truncated_bptt_lm_task # noqa
fairseq/examples/truncated_bptt/transformer_xl_model.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import logging
7
+ from dataclasses import dataclass, field
8
+ from typing import Dict, List, Optional
9
+
10
+ import torch
11
+ from fairseq.dataclass import FairseqDataclass
12
+ from fairseq.models import (
13
+ FairseqIncrementalDecoder,
14
+ FairseqLanguageModel,
15
+ register_model,
16
+ )
17
+ from fairseq.modules.checkpoint_activations import checkpoint_wrapper
18
+ from omegaconf import II
19
+
20
+
21
+ logger = logging.getLogger(__name__)
22
+
23
+
24
+ @dataclass
25
+ class TransformerXLConfig(FairseqDataclass):
26
+ # defaults come from the original Transformer-XL code
27
+ cutoffs: List[int] = field(default_factory=lambda: [20000, 40000, 200000])
28
+ d_model: int = 500
29
+ n_head: int = 10
30
+ d_head: int = 50
31
+ d_inner: int = 1000
32
+ div_val: int = 1
33
+ n_layer: int = 12
34
+ mem_len: int = 0
35
+ clamp_len: int = -1
36
+ same_length: bool = False
37
+ dropout: float = 0.0
38
+ dropatt: float = 0.0
39
+ checkpoint_activations: bool = False
40
+ offload_activations: bool = False
41
+ max_target_positions: int = II("task.max_target_positions")
42
+
43
+
44
+ @register_model("transformer_xl", dataclass=TransformerXLConfig)
45
+ class TransformerXLLanguageModel(FairseqLanguageModel):
46
+ @classmethod
47
+ def build_model(cls, cfg: TransformerXLConfig, task):
48
+ return cls(TransformerXLDecoder(cfg, task))
49
+
50
+
51
+ class TransformerXLDecoder(FairseqIncrementalDecoder):
52
+ def __init__(self, cfg, task):
53
+ try:
54
+ from transformers.models.transfo_xl import (
55
+ TransfoXLConfig,
56
+ TransfoXLLMHeadModel,
57
+ )
58
+ except ImportError:
59
+ from transformers.configuration_transfo_xl import TransfoXLConfig
60
+ from transformers.modeling_transfo_xl import TransfoXLLMHeadModel
61
+
62
+ super().__init__(task.target_dictionary)
63
+ self.cfg = cfg
64
+
65
+ # remove any cutoffs larger than the vocab size
66
+ cutoffs = [
67
+ cutoff for cutoff in cfg.cutoffs if cutoff < len(task.target_dictionary)
68
+ ]
69
+
70
+ config = TransfoXLConfig(
71
+ vocab_size=len(task.target_dictionary),
72
+ cutoffs=cutoffs,
73
+ d_model=cfg.d_model,
74
+ d_embed=cfg.d_model,
75
+ n_head=cfg.n_head,
76
+ d_head=cfg.d_head,
77
+ d_inner=cfg.d_inner,
78
+ div_val=cfg.div_val,
79
+ n_layer=cfg.n_layer,
80
+ mem_len=cfg.mem_len,
81
+ clamp_len=cfg.clamp_len,
82
+ same_length=cfg.same_length,
83
+ dropout=cfg.dropout,
84
+ dropatt=cfg.dropatt,
85
+ )
86
+ logger.info(config)
87
+ self.model = TransfoXLLMHeadModel(config)
88
+
89
+ if cfg.checkpoint_activations or cfg.offload_activations:
90
+ for i in range(len(self.model.transformer.layers)):
91
+ self.model.transformer.layers[i] = checkpoint_wrapper(
92
+ self.model.transformer.layers[i],
93
+ offload_to_cpu=cfg.offload_activations,
94
+ )
95
+ # TODO: may save mem to wrap(layer.pos_ff.CoreNet[3])
96
+
97
+ self._mems = None
98
+
99
+ def forward(
100
+ self,
101
+ src_tokens,
102
+ src_lengths=None, # unused
103
+ incremental_state: Optional[Dict[str, List[torch.Tensor]]] = None,
104
+ encoder_out=None,
105
+ ):
106
+ if incremental_state is not None: # used during inference
107
+ mems = self.get_incremental_state(incremental_state, "mems")
108
+ src_tokens = src_tokens[:, -1:] # only keep the most recent token
109
+ else:
110
+ mems = self._mems
111
+
112
+ output = self.model(
113
+ input_ids=src_tokens,
114
+ mems=mems,
115
+ return_dict=False,
116
+ )
117
+
118
+ if len(output) >= 2:
119
+ if incremental_state is not None:
120
+ self.set_incremental_state(incremental_state, "mems", output[1])
121
+ else:
122
+ self._mems = output[1]
123
+
124
+ return (output[0],)
125
+
126
+ def max_positions(self):
127
+ return self.cfg.max_target_positions
128
+
129
+ def reorder_incremental_state(
130
+ self,
131
+ incremental_state: Dict[str, Dict[str, Optional[torch.Tensor]]],
132
+ new_order: torch.Tensor,
133
+ ):
134
+ """Reorder incremental state.
135
+
136
+ This will be called when the order of the input has changed from the
137
+ previous time step. A typical use case is beam search, where the input
138
+ order changes between time steps based on the selection of beams.
139
+ """
140
+ mems = self.get_incremental_state(incremental_state, "mems")
141
+ if mems is not None:
142
+ new_mems = [mems_i.index_select(1, new_order) for mems_i in mems]
143
+ self.set_incremental_state(incremental_state, "mems", new_mems)
fairseq/examples/truncated_bptt/truncated_bptt_lm_task.py ADDED
@@ -0,0 +1,285 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import logging
7
+ import os
8
+ from dataclasses import dataclass, field
9
+ from typing import List, Optional, Tuple
10
+
11
+ import torch
12
+ from fairseq import utils
13
+ from fairseq.data import (
14
+ Dictionary,
15
+ TokenBlockDataset,
16
+ data_utils,
17
+ iterators,
18
+ )
19
+ from fairseq.dataclass import FairseqDataclass
20
+ from fairseq.distributed import utils as dist_utils
21
+ from fairseq.tasks import FairseqTask, register_task
22
+ from omegaconf import II
23
+
24
+
25
+ logger = logging.getLogger(__name__)
26
+
27
+
28
+ @dataclass
29
+ class TruncatedBPTTLMConfig(FairseqDataclass):
30
+ data: str = field(default="???", metadata={"help": "path to data directory"})
31
+ tokens_per_sample: int = field(
32
+ default=1024, metadata={"help": "max number of tokens per sequence"},
33
+ )
34
+ batch_size: int = II("dataset.batch_size")
35
+ # Some models use *max_target_positions* to know how many positional
36
+ # embeddings to learn. We use II(...) to make it default to
37
+ # *tokens_per_sample*, but in principle there could be more positional
38
+ # embeddings than tokens in a single batch. This may also be irrelevant for
39
+ # custom model implementations.
40
+ max_target_positions: int = II("task.tokens_per_sample")
41
+ # these will be populated automatically if not provided
42
+ data_parallel_rank: Optional[int] = None
43
+ data_parallel_size: Optional[int] = None
44
+
45
+
46
+ @register_task("truncated_bptt_lm", dataclass=TruncatedBPTTLMConfig)
47
+ class TruncatedBPTTLMTask(FairseqTask):
48
+ def __init__(self, cfg: TruncatedBPTTLMConfig):
49
+ super().__init__(cfg)
50
+
51
+ if cfg.data_parallel_rank is None or cfg.data_parallel_size is None:
52
+ if torch.distributed.is_initialized():
53
+ cfg.data_parallel_rank = dist_utils.get_data_parallel_rank()
54
+ cfg.data_parallel_size = dist_utils.get_data_parallel_world_size()
55
+ else:
56
+ cfg.data_parallel_rank = 0
57
+ cfg.data_parallel_size = 1
58
+
59
+ # load the dictionary
60
+ paths = utils.split_paths(cfg.data)
61
+ assert len(paths) > 0
62
+ self.dictionary = Dictionary.load(os.path.join(paths[0], "dict.txt"))
63
+ logger.info("dictionary: {} types".format(len(self.dictionary)))
64
+
65
+ def load_dataset(self, split, epoch=1, combine=False, **kwargs):
66
+ """Load a given dataset split (e.g., train, valid, test)"""
67
+
68
+ # support sharded datasets
69
+ paths = utils.split_paths(self.cfg.data)
70
+ assert len(paths) > 0
71
+ data_path = paths[(epoch - 1) % len(paths)]
72
+ split_path = os.path.join(data_path, split)
73
+
74
+ # each element of *data* will be a tensorized line from the original
75
+ # text dataset, similar to ``open(split_path).readlines()``
76
+ data = data_utils.load_indexed_dataset(
77
+ split_path, self.dictionary, combine=combine
78
+ )
79
+ if data is None:
80
+ raise FileNotFoundError(
81
+ "Dataset not found: {} ({})".format(split, split_path)
82
+ )
83
+
84
+ # this is similar to ``data.view(-1).split(tokens_per_sample)``
85
+ data = TokenBlockDataset(
86
+ data,
87
+ data.sizes,
88
+ block_size=self.cfg.tokens_per_sample,
89
+ pad=None, # unused
90
+ eos=None, # unused
91
+ break_mode="none",
92
+ )
93
+
94
+ self.datasets[split] = TruncatedBPTTDataset(
95
+ data=data,
96
+ bsz_per_shard=self.cfg.batch_size,
97
+ shard_id=self.cfg.data_parallel_rank,
98
+ num_shards=self.cfg.data_parallel_size,
99
+ )
100
+
101
+ def dataset(self, split):
102
+ return self.datasets[split]
103
+
104
+ def get_batch_iterator(
105
+ self,
106
+ dataset,
107
+ num_workers=0,
108
+ epoch=1,
109
+ data_buffer_size=0,
110
+ skip_remainder_batch=False,
111
+ **kwargs
112
+ ):
113
+ return iterators.EpochBatchIterator(
114
+ dataset=dataset,
115
+ collate_fn=self._collate_fn,
116
+ num_workers=num_workers,
117
+ epoch=epoch,
118
+ buffer_size=data_buffer_size,
119
+ # we don't use the batching functionality from EpochBatchIterator;
120
+ # instead every item in *dataset* is a whole batch
121
+ batch_sampler=[[i] for i in range(len(dataset))],
122
+ disable_shuffling=True,
123
+ skip_remainder_batch=skip_remainder_batch,
124
+ )
125
+
126
+ def _collate_fn(self, items: List[List[torch.Tensor]]):
127
+ # we don't use fairseq's batching functionality, so we expect a single
128
+ # Tensor of type List[torch.Tensor]
129
+ assert len(items) == 1
130
+
131
+ # item will have shape B x T (the last batch may have length < T)
132
+ id, item = items[0]
133
+ item = data_utils.collate_tokens(item, pad_idx=self.source_dictionary.pad())
134
+ B, T = item.size()
135
+
136
+ # shift item one position over and append a padding token for the target
137
+ target = torch.nn.functional.pad(
138
+ item[:, 1:], (0, 1, 0, 0), value=self.target_dictionary.pad()
139
+ )
140
+
141
+ # fairseq expects batches to have the following structure
142
+ return {
143
+ "id": torch.tensor([id] * item.size(0)),
144
+ "net_input": {"src_tokens": item,},
145
+ "target": target,
146
+ "nsentences": item.size(0),
147
+ "ntokens": item.numel(),
148
+ }
149
+
150
+ def build_dataset_for_inference(
151
+ self, src_tokens: List[torch.Tensor], src_lengths: List[int], **kwargs
152
+ ) -> torch.utils.data.Dataset:
153
+ eos = self.source_dictionary.eos()
154
+ dataset = TokenBlockDataset(
155
+ src_tokens,
156
+ src_lengths,
157
+ block_size=None, # ignored for "eos" break mode
158
+ pad=self.source_dictionary.pad(),
159
+ eos=eos,
160
+ break_mode="eos",
161
+ )
162
+
163
+ class Dataset(torch.utils.data.Dataset):
164
+ def __getitem__(self, i):
165
+ item = dataset[i]
166
+ if item[-1] == eos:
167
+ # remove eos to support generating with a prefix
168
+ item = item[:-1]
169
+ return (i, [item])
170
+
171
+ def __len__(self):
172
+ return len(dataset)
173
+
174
+ return Dataset()
175
+
176
+ def inference_step(
177
+ self, generator, models, sample, prefix_tokens=None, constraints=None
178
+ ):
179
+ with torch.no_grad():
180
+ if constraints is not None:
181
+ raise NotImplementedError
182
+
183
+ # SequenceGenerator doesn't use *src_tokens* directly, we need to
184
+ # pass the *prefix_tokens* argument instead.
185
+ if prefix_tokens is None and sample["net_input"]["src_tokens"].nelement():
186
+ prefix_tokens = sample["net_input"]["src_tokens"]
187
+
188
+ # begin generation with the end-of-sentence token
189
+ bos_token = self.source_dictionary.eos()
190
+
191
+ return generator.generate(
192
+ models, sample, prefix_tokens=prefix_tokens, bos_token=bos_token
193
+ )
194
+
195
+ def eval_lm_dataloader(
196
+ self,
197
+ dataset,
198
+ max_tokens: Optional[int] = 36000,
199
+ batch_size: Optional[int] = None,
200
+ max_positions: Optional[int] = None,
201
+ num_shards: int = 1,
202
+ shard_id: int = 0,
203
+ num_workers: int = 1,
204
+ data_buffer_size: int = 10,
205
+ context_window: int = 0,
206
+ ):
207
+ if context_window > 0:
208
+ raise NotImplementedError(
209
+ "Transformer-XL doesn't need --context-window, try "
210
+ "--model-overrides '{\"mem_len\":42}' instead "
211
+ )
212
+ return self.get_batch_iterator(
213
+ dataset=dataset,
214
+ max_tokens=max_tokens,
215
+ max_sentences=batch_size,
216
+ max_positions=max_positions,
217
+ ignore_invalid_inputs=True,
218
+ num_shards=num_shards,
219
+ shard_id=shard_id,
220
+ num_workers=num_workers,
221
+ data_buffer_size=data_buffer_size,
222
+ ).next_epoch_itr(shuffle=False)
223
+
224
+ @property
225
+ def source_dictionary(self):
226
+ return self.dictionary
227
+
228
+ @property
229
+ def target_dictionary(self):
230
+ return self.dictionary
231
+
232
+
233
+ class TruncatedBPTTDataset(torch.utils.data.Dataset):
234
+ def __init__(
235
+ self,
236
+ data: List[torch.Tensor], # ordered list of items
237
+ bsz_per_shard, # number of items processed per GPUs per forward
238
+ shard_id, # current GPU ID
239
+ num_shards, # number of GPUs
240
+ ):
241
+ super().__init__()
242
+ self.data = data
243
+
244
+ def batchify(data, bsz):
245
+ # Work out how cleanly we can divide the dataset into bsz parts.
246
+ nbatch = data.size(0) // bsz
247
+ # Trim off any extra elements that wouldn't cleanly fit (remainders).
248
+ data = data.narrow(0, 0, nbatch * bsz)
249
+ # Evenly divide the data across the bsz batches.
250
+ data = data.view(bsz, -1).contiguous()
251
+ return data
252
+
253
+ # total number of sequences processed by all GPUs in each forward pass
254
+ global_batch_size = bsz_per_shard * num_shards
255
+
256
+ """
257
+ With a 16 item dataset, bsz_per_shard=2 and num_shards=3,
258
+ *indices* might look like:
259
+
260
+ indices = [[0, 1],
261
+ [2, 3],
262
+ [4, 5],
263
+ [6, 7],
264
+ [8, 9],
265
+ [10, 11]]
266
+
267
+ The size of the TruncatedBPTTDataset instance will be 2,
268
+ and shard 1 will see items:
269
+
270
+ [(0, [data[4], data[6]]),
271
+ (1, [data[5], data[7]])]
272
+ """
273
+ indices = batchify(torch.arange(len(data)), global_batch_size)
274
+ assert indices.size(0) == global_batch_size
275
+
276
+ self.my_indices = indices[
277
+ shard_id * bsz_per_shard : (shard_id + 1) * bsz_per_shard
278
+ ]
279
+ assert self.my_indices.size(0) == bsz_per_shard
280
+
281
+ def __len__(self):
282
+ return self.my_indices.size(1)
283
+
284
+ def __getitem__(self, i) -> Tuple[int, List[torch.Tensor]]:
285
+ return (i, [self.data[idx] for idx in self.my_indices[:, i]])
fairseq/examples/unsupervised_quality_estimation/aggregate_scores.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import argparse
7
+ import sys
8
+
9
+ import numpy as np
10
+
11
+
12
+ aggregate_funcs = {
13
+ "std": np.std,
14
+ "var": np.var,
15
+ "median": np.median,
16
+ "mean": np.mean,
17
+ "min": np.min,
18
+ "max": np.max,
19
+ }
20
+
21
+
22
+ def main():
23
+ parser = argparse.ArgumentParser()
24
+ parser.add_argument("-i", "--input_file", required=True, type=str)
25
+ parser.add_argument("-n", "--repeat_times", required=True, type=int)
26
+ parser.add_argument("-o", "--output_file", required=False)
27
+ parser.add_argument("-f", "--func", required=False, default="mean")
28
+ args = parser.parse_args()
29
+
30
+ stream = open(args.output_file, "w") if args.output_file else sys.stdout
31
+
32
+ segment_scores = []
33
+ for line in open(args.input_file):
34
+ segment_scores.append(float(line.strip()))
35
+ if len(segment_scores) == args.repeat_times:
36
+ stream.write("{}\n".format(aggregate_funcs[args.func](segment_scores)))
37
+ segment_scores = []
38
+
39
+
40
+ if __name__ == "__main__":
41
+ main()
fairseq/examples/unsupervised_quality_estimation/meteor.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import argparse
7
+ import math
8
+ import os
9
+ import subprocess
10
+ import sys
11
+ import tempfile
12
+ from collections import defaultdict
13
+ from itertools import combinations
14
+
15
+
16
+ def read_translations(path, n_repeats):
17
+ segment_counter = 0
18
+ segment_translations = []
19
+ translations = defaultdict(list)
20
+ for line in open(path):
21
+ segment_translations.append(" ".join(line.split()))
22
+ if len(segment_translations) == n_repeats:
23
+ translations[segment_counter] = segment_translations
24
+ segment_translations = []
25
+ segment_counter += 1
26
+ return translations
27
+
28
+
29
+ def generate_input(translations, n_repeats):
30
+ _, ref_path = tempfile.mkstemp()
31
+ _, mt_path = tempfile.mkstemp()
32
+ ref_fh = open(ref_path, "w")
33
+ mt_fh = open(mt_path, "w")
34
+ for segid in sorted(translations.keys()):
35
+ assert len(translations[segid]) == n_repeats
36
+ indexes = combinations(range(n_repeats), 2)
37
+ for idx1, idx2 in indexes:
38
+ mt_fh.write(translations[segid][idx1].strip() + "\n")
39
+ ref_fh.write(translations[segid][idx2].strip() + "\n")
40
+ sys.stderr.write("\nSaved translations to %s and %s" % (ref_path, mt_path))
41
+ return ref_path, mt_path
42
+
43
+
44
+ def run_meteor(ref_path, mt_path, metric_path, lang="en"):
45
+ _, out_path = tempfile.mkstemp()
46
+ subprocess.call(
47
+ [
48
+ "java",
49
+ "-Xmx2G",
50
+ "-jar",
51
+ metric_path,
52
+ mt_path,
53
+ ref_path,
54
+ "-p",
55
+ "0.5 0.2 0.6 0.75", # default parameters, only changed alpha to give equal weight to P and R
56
+ "-norm",
57
+ "-l",
58
+ lang,
59
+ ],
60
+ stdout=open(out_path, "w"),
61
+ )
62
+ os.remove(ref_path)
63
+ os.remove(mt_path)
64
+ sys.stderr.write("\nSaved Meteor output to %s" % out_path)
65
+ return out_path
66
+
67
+
68
+ def read_output(meteor_output_path, n_repeats):
69
+ n_combinations = math.factorial(n_repeats) / (
70
+ math.factorial(2) * math.factorial(n_repeats - 2)
71
+ )
72
+ raw_scores = []
73
+ average_scores = []
74
+ for line in open(meteor_output_path):
75
+ if not line.startswith("Segment "):
76
+ continue
77
+ score = float(line.strip().split("\t")[1])
78
+ raw_scores.append(score)
79
+ if len(raw_scores) == n_combinations:
80
+ average_scores.append(sum(raw_scores) / n_combinations)
81
+ raw_scores = []
82
+ os.remove(meteor_output_path)
83
+ return average_scores
84
+
85
+
86
+ def main():
87
+ parser = argparse.ArgumentParser()
88
+ parser.add_argument("-i", "--infile")
89
+ parser.add_argument("-n", "--repeat_times", type=int)
90
+ parser.add_argument("-m", "--meteor")
91
+ parser.add_argument("-o", "--output")
92
+ args = parser.parse_args()
93
+
94
+ translations = read_translations(args.infile, args.repeat_times)
95
+ sys.stderr.write("\nGenerating input for Meteor...")
96
+ ref_path, mt_path = generate_input(translations, args.repeat_times)
97
+ sys.stderr.write("\nRunning Meteor...")
98
+ out_path = run_meteor(ref_path, mt_path, args.meteor)
99
+ sys.stderr.write("\nReading output...")
100
+ scores = read_output(out_path, args.repeat_times)
101
+ sys.stderr.write("\nWriting results...")
102
+ with open(args.output, "w") as o:
103
+ for scr in scores:
104
+ o.write("{}\n".format(scr))
105
+ o.close()
106
+
107
+
108
+ if __name__ == "__main__":
109
+ main()
fairseq/examples/unsupervised_quality_estimation/repeat_lines.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import argparse
7
+ import sys
8
+
9
+
10
+ def _normalize_spaces(line):
11
+ return " ".join(line.split())
12
+
13
+
14
+ def main():
15
+ parser = argparse.ArgumentParser()
16
+ parser.add_argument("-i", "--input_file", required=True, type=str)
17
+ parser.add_argument("-n", "--repeat_times", required=True, type=int)
18
+ parser.add_argument("-o", "--output_file", required=False, type=str)
19
+ args = parser.parse_args()
20
+ stream = open(args.output_file, "w") if args.output_file else sys.stdout
21
+
22
+ for line in open(args.input_file):
23
+ for _ in range(args.repeat_times):
24
+ stream.write(_normalize_spaces(line) + "\n")
25
+
26
+
27
+ if __name__ == "__main__":
28
+ main()
fairseq/examples/wav2vec/__init__.py ADDED
File without changes
fairseq/examples/wav2vec/config/finetuning/base_10m.yaml ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _group_
2
+
3
+ common:
4
+ fp16: true
5
+ log_format: json
6
+ log_interval: 200
7
+
8
+ checkpoint:
9
+ save_interval: 1000
10
+ save_interval_updates: 50
11
+ keep_interval_updates: 1
12
+ no_epoch_checkpoints: true
13
+ best_checkpoint_metric: wer
14
+
15
+ task:
16
+ _name: audio_finetuning
17
+ data: ???
18
+ normalize: false
19
+ labels: ltr
20
+
21
+ dataset:
22
+ num_workers: 6
23
+ max_tokens: 3200000
24
+ skip_invalid_size_inputs_valid_test: true
25
+ validate_after_updates: 10000
26
+ validate_interval: 1000
27
+ valid_subset: dev_other
28
+
29
+ distributed_training:
30
+ ddp_backend: legacy_ddp
31
+ distributed_world_size: 2
32
+
33
+ criterion:
34
+ _name: ctc
35
+ zero_infinity: true
36
+
37
+ optimization:
38
+ max_update: 13000
39
+ lr: [0.00005]
40
+ sentence_avg: true
41
+ update_freq: [4]
42
+
43
+ optimizer:
44
+ _name: adam
45
+ adam_betas: (0.9,0.98)
46
+ adam_eps: 1e-08
47
+
48
+ lr_scheduler:
49
+ _name: tri_stage
50
+ phase_ratio: [0.1, 0.4, 0.5]
51
+ final_lr_scale: 0.05
52
+
53
+ model:
54
+ _name: wav2vec_ctc
55
+ w2v_path: ???
56
+ apply_mask: true
57
+ mask_prob: 0.65
58
+ mask_channel_prob: 0.25
59
+ mask_channel_length: 64
60
+ layerdrop: 0.1
61
+ activation_dropout: 0.1
62
+ feature_grad_mult: 0.0
63
+ freeze_finetune_updates: 10000
fairseq/examples/wav2vec/config/finetuning/base_1h.yaml ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _group_
2
+
3
+ common:
4
+ fp16: true
5
+ log_format: json
6
+ log_interval: 200
7
+
8
+ checkpoint:
9
+ save_interval: 50
10
+ save_interval_updates: 1000
11
+ keep_interval_updates: 1
12
+ no_epoch_checkpoints: true
13
+ best_checkpoint_metric: wer
14
+
15
+ task:
16
+ _name: audio_finetuning
17
+ data: ???
18
+ normalize: false
19
+ labels: ltr
20
+
21
+ dataset:
22
+ num_workers: 6
23
+ max_tokens: 3200000
24
+ skip_invalid_size_inputs_valid_test: true
25
+ validate_after_updates: 10000
26
+ validate_interval: 1000
27
+ valid_subset: dev_other
28
+
29
+ distributed_training:
30
+ ddp_backend: legacy_ddp
31
+ distributed_world_size: 2
32
+
33
+ criterion:
34
+ _name: ctc
35
+ zero_infinity: true
36
+
37
+ optimization:
38
+ max_update: 13000
39
+ lr: [0.00005]
40
+ sentence_avg: true
41
+ update_freq: [4]
42
+
43
+ optimizer:
44
+ _name: adam
45
+ adam_betas: (0.9,0.98)
46
+ adam_eps: 1e-08
47
+
48
+ lr_scheduler:
49
+ _name: tri_stage
50
+ phase_ratio: [0.1, 0.4, 0.5]
51
+ final_lr_scale: 0.05
52
+
53
+ model:
54
+ _name: wav2vec_ctc
55
+ w2v_path: ???
56
+ apply_mask: true
57
+ mask_prob: 0.65
58
+ mask_channel_prob: 0.25
59
+ mask_channel_length: 64
60
+ layerdrop: 0.1
61
+ activation_dropout: 0.1
62
+ feature_grad_mult: 0.0
63
+ freeze_finetune_updates: 10000
fairseq/examples/wav2vec/config/finetuning/run_config/slurm_1.yaml ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ hydra:
4
+ job:
5
+ config:
6
+ override_dirname:
7
+ kv_sep: ':'
8
+ item_sep: '__'
9
+ exclude_keys:
10
+ - run_config
11
+ - distributed_training.distributed_port
12
+ sweep:
13
+ dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
14
+ subdir: ${hydra.job.num}
15
+ launcher:
16
+ submitit_folder: ${hydra.sweep.dir}
17
+ timeout_min: 4320
18
+ cpus_per_task: 10
19
+ gpus_per_node: 8
20
+ tasks_per_node: 8
21
+ mem_gb: 450
22
+ nodes: 1
23
+ name: ${env:PREFIX}_${hydra.job.config_name}
24
+ partition: devlab,learnlab,learnfair,scavenge
25
+ constraint: volta32gb
26
+ max_num_timeout: 30
fairseq/examples/wav2vec/config/finetuning/run_config/slurm_16.yaml ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ hydra:
4
+ job:
5
+ config:
6
+ override_dirname:
7
+ kv_sep: ':'
8
+ item_sep: '__'
9
+ exclude_keys:
10
+ - run_config
11
+ - distributed_training.distributed_port
12
+ sweep:
13
+ dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
14
+ subdir: ${hydra.job.num}
15
+ launcher:
16
+ submitit_folder: ${hydra.sweep.dir}
17
+ timeout_min: 4320
18
+ cpus_per_task: 80
19
+ gpus_per_node: 8
20
+ tasks_per_node: 1
21
+ mem_gb: 450
22
+ nodes: 16
23
+ name: ${env:PREFIX}_${hydra.job.config_name}
24
+ partition: learnlab,learnfair,scavenge
25
+ constraint: volta32gb
26
+ max_num_timeout: 30
27
+ exclude: learnfair1381,learnfair5192,learnfair2304
fairseq/examples/wav2vec/config/finetuning/run_config/slurm_1_aws.yaml ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ hydra:
4
+ job:
5
+ config:
6
+ override_dirname:
7
+ kv_sep: ':'
8
+ item_sep: '/'
9
+ exclude_keys:
10
+ - run_config
11
+ - distributed_training.distributed_port
12
+ - distributed_training.distributed_world_size
13
+ - model.pretrained_model_path
14
+ - model.target_network_path
15
+ - next_script
16
+ - task.cache_in_scratch
17
+ - task.local_cache_path
18
+ - task.data
19
+ - checkpoint.save_interval_updates
20
+ - checkpoint.keep_interval_updates
21
+ - checkpoint.save_on_overflow
22
+ - common.log_interval
23
+ - common.user_dir
24
+ sweep:
25
+ dir: /fsx-wav2vec/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
26
+ subdir: ''
27
+ launcher:
28
+ submitit_folder: ${hydra.sweep.dir}
29
+ timeout_min: 4320
30
+ cpus_per_task: 80
31
+ gpus_per_node: 8
32
+ tasks_per_node: 1
33
+ mem_gb: 0
34
+ nodes: 1
35
+ name: ${env:PREFIX}_${hydra.job.config_name}
36
+ partition: wav2vec,learnlab,learnfair
37
+ max_num_timeout: 30
fairseq/examples/wav2vec/config/finetuning/run_config/slurm_1_old.yaml ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ hydra:
4
+ job:
5
+ config:
6
+ override_dirname:
7
+ kv_sep: ':'
8
+ item_sep: '__'
9
+ exclude_keys:
10
+ - run_config
11
+ - distributed_training.distributed_port
12
+ sweep:
13
+ dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
14
+ subdir: ${hydra.job.num}
15
+ launcher:
16
+ submitit_folder: ${hydra.sweep.dir}
17
+ timeout_min: 4320
18
+ cpus_per_task: 80
19
+ gpus_per_node: 8
20
+ tasks_per_node: 1
21
+ mem_gb: 450
22
+ nodes: 1
23
+ name: ${env:PREFIX}_wav2vec3_small_librispeech
24
+ partition: devlab,learnlab,learnfair,scavenge
25
+ constraint: volta32gb
26
+ max_num_timeout: 30
27
+ exclude: learnfair1381
fairseq/examples/wav2vec/config/finetuning/run_config/slurm_2.yaml ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ hydra:
4
+ job:
5
+ config:
6
+ override_dirname:
7
+ kv_sep: ':'
8
+ item_sep: '__'
9
+ exclude_keys:
10
+ - run_config
11
+ - distributed_training.distributed_port
12
+ sweep:
13
+ dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
14
+ subdir: ${hydra.job.num}
15
+ launcher:
16
+ submitit_folder: ${hydra.sweep.dir}
17
+ timeout_min: 4320
18
+ cpus_per_task: 10
19
+ gpus_per_node: 8
20
+ tasks_per_node: 8
21
+ mem_gb: 450
22
+ nodes: 2
23
+ name: ${env:PREFIX}_${hydra.job.config_name}
24
+ partition: devlab,learnlab,learnfair,scavenge
25
+ constraint: volta32gb
26
+ max_num_timeout: 30
27
+ exclude: learnfair7491,learnfair7477,learnfair7487
fairseq/examples/wav2vec/config/finetuning/run_config/slurm_2_aws.yaml ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ hydra:
4
+ job:
5
+ config:
6
+ override_dirname:
7
+ kv_sep: ':'
8
+ item_sep: '/'
9
+ exclude_keys:
10
+ - run_config
11
+ - distributed_training.distributed_port
12
+ - distributed_training.distributed_world_size
13
+ - model.pretrained_model_path
14
+ - model.target_network_path
15
+ - next_script
16
+ - task.cache_in_scratch
17
+ - task.local_cache_path
18
+ - task.data
19
+ - checkpoint.save_interval_updates
20
+ - checkpoint.keep_interval_updates
21
+ - checkpoint.save_on_overflow
22
+ - common.log_interval
23
+ - common.user_dir
24
+ sweep:
25
+ dir: /fsx-wav2vec/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
26
+ subdir: ''
27
+ launcher:
28
+ submitit_folder: ${hydra.sweep.dir}
29
+ timeout_min: 4320
30
+ cpus_per_task: 80
31
+ gpus_per_node: 8
32
+ tasks_per_node: 1
33
+ mem_gb: 0
34
+ nodes: 2
35
+ name: ${env:PREFIX}_${hydra.job.config_name}
36
+ partition: wav2vec,learnlab,learnfair
37
+ max_num_timeout: 30
fairseq/examples/wav2vec/config/finetuning/run_config/slurm_2g.yaml ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ hydra:
4
+ job:
5
+ config:
6
+ override_dirname:
7
+ kv_sep: ':'
8
+ item_sep: '__'
9
+ exclude_keys:
10
+ - run_config
11
+ - distributed_training.distributed_port
12
+ sweep:
13
+ dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
14
+ subdir: ${hydra.job.num}
15
+ launcher:
16
+ submitit_folder: ${hydra.sweep.dir}
17
+ timeout_min: 4320
18
+ cpus_per_task: 10
19
+ gpus_per_node: 2
20
+ tasks_per_node: 2
21
+ mem_gb: 200
22
+ nodes: 1
23
+ name: ${env:PREFIX}_${hydra.job.config_name}
24
+ partition: devlab,learnlab,learnfair,scavenge
25
+ constraint: volta32gb
26
+ max_num_timeout: 30
fairseq/examples/wav2vec/config/finetuning/run_config/slurm_3.yaml ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ hydra:
4
+ job:
5
+ config:
6
+ override_dirname:
7
+ kv_sep: ':'
8
+ item_sep: '__'
9
+ exclude_keys:
10
+ - run_config
11
+ - distributed_training.distributed_port
12
+ sweep:
13
+ dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
14
+ subdir: ${hydra.job.num}
15
+ launcher:
16
+ submitit_folder: ${hydra.sweep.dir}
17
+ timeout_min: 4320
18
+ cpus_per_task: 10
19
+ gpus_per_node: 8
20
+ tasks_per_node: 8
21
+ mem_gb: 450
22
+ nodes: 3
23
+ name: ${env:PREFIX}_${hydra.job.config_name}
24
+ partition: devlab,learnlab,learnfair,scavenge
25
+ constraint: volta32gb
26
+ max_num_timeout: 30
27
+ exclude: learnfair7491,learnfair7477,learnfair7487
fairseq/examples/wav2vec/config/finetuning/run_config/slurm_4g.yaml ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ hydra:
4
+ job:
5
+ config:
6
+ override_dirname:
7
+ kv_sep: ':'
8
+ item_sep: '__'
9
+ exclude_keys:
10
+ - run_config
11
+ - distributed_training.distributed_port
12
+ sweep:
13
+ dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
14
+ subdir: ${hydra.job.num}
15
+ launcher:
16
+ submitit_folder: ${hydra.sweep.dir}
17
+ timeout_min: 4320
18
+ cpus_per_task: 10
19
+ gpus_per_node: 4
20
+ tasks_per_node: 4
21
+ mem_gb: 200
22
+ nodes: 1
23
+ name: ${env:PREFIX}_${hydra.job.config_name}
24
+ partition: devlab,learnlab,learnfair,scavenge
25
+ constraint: volta32gb
26
+ max_num_timeout: 30
fairseq/examples/wav2vec/config/finetuning/run_config/slurm_4g_aws.yaml ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ hydra:
4
+ job:
5
+ config:
6
+ override_dirname:
7
+ kv_sep: ':'
8
+ item_sep: '/'
9
+ exclude_keys:
10
+ - run_config
11
+ - distributed_training.distributed_port
12
+ - distributed_training.distributed_world_size
13
+ - model.pretrained_model_path
14
+ - model.target_network_path
15
+ - next_script
16
+ - task.cache_in_scratch
17
+ - task.local_cache_path
18
+ - task.data
19
+ - checkpoint.save_interval_updates
20
+ - checkpoint.keep_interval_updates
21
+ - checkpoint.save_on_overflow
22
+ - common.log_interval
23
+ - common.user_dir
24
+ sweep:
25
+ dir: /fsx-wav2vec/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
26
+ subdir: ''
27
+ launcher:
28
+ submitit_folder: ${hydra.sweep.dir}
29
+ timeout_min: 4320
30
+ cpus_per_task: 80
31
+ gpus_per_node: 4
32
+ tasks_per_node: 1
33
+ mem_gb: 0
34
+ nodes: 1
35
+ name: ${env:PREFIX}_${hydra.job.config_name}
36
+ partition: wav2vec,learnlab,learnfair
37
+ max_num_timeout: 30
fairseq/examples/wav2vec/config/finetuning/run_config/slurm_8.yaml ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ hydra:
4
+ job:
5
+ config:
6
+ override_dirname:
7
+ kv_sep: ':'
8
+ item_sep: '__'
9
+ exclude_keys:
10
+ - run_config
11
+ - distributed_training.distributed_port
12
+ sweep:
13
+ dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
14
+ subdir: ${hydra.job.num}
15
+ launcher:
16
+ submitit_folder: ${hydra.sweep.dir}
17
+ timeout_min: 4320
18
+ cpus_per_task: 10
19
+ gpus_per_node: 8
20
+ tasks_per_node: 8
21
+ mem_gb: 400
22
+ nodes: 8
23
+ name: ${env:PREFIX}_${hydra.job.config_name}
24
+ partition: devlab,learnlab,learnfair,scavenge
25
+ constraint: volta32gb
26
+ max_num_timeout: 30
fairseq/examples/wav2vec/config/finetuning/vox_10h_2_aws.yaml ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _group_
2
+
3
+ common:
4
+ fp16: true
5
+ log_format: json
6
+ log_interval: 200
7
+ user_dir: /data/home/abaevski/fairseq-py/examples/data2vec
8
+ # tensorboard_logdir: tb
9
+
10
+ checkpoint:
11
+ save_interval: 10
12
+ no_epoch_checkpoints: true
13
+ best_checkpoint_metric: wer
14
+
15
+ task:
16
+ _name: audio_finetuning
17
+ data: /fsx-wav2vec/abaevski/data/libri/10h/wav2vec/raw
18
+ labels: ltr
19
+ normalize: true
20
+
21
+ dataset:
22
+ num_workers: 6
23
+ max_tokens: 1280000
24
+ skip_invalid_size_inputs_valid_test: true
25
+ validate_after_updates: 100
26
+ validate_interval: 10
27
+ valid_subset: dev_other
28
+ required_batch_size_multiple: 1
29
+
30
+ distributed_training:
31
+ ddp_backend: legacy_ddp
32
+ distributed_world_size: 4
33
+
34
+ criterion:
35
+ _name: ctc
36
+ zero_infinity: true
37
+ post_process: letter
38
+ wer_kenlm_model: /fsx-wav2vec/abaevski/data/libri/4-gram.bin
39
+ wer_lexicon: /fsx-wav2vec/abaevski/data/libri/10h/wav2vec/raw/lexicon_ltr2.lst
40
+ wer_lm_weight: 2.0
41
+ wer_word_score: 4
42
+ wer_sil_weight: -5
43
+
44
+ optimization:
45
+ max_update: 60000
46
+ lr: [1e-5]
47
+ # lr: [1e-5] # base 10h wer
48
+ sentence_avg: true
49
+ update_freq: [1] # base 10h we -> 2/4
50
+
51
+ optimizer:
52
+ _name: adam
53
+ adam_betas: (0.9,0.98)
54
+ adam_eps: 1e-08
55
+
56
+ lr_scheduler:
57
+ _name: tri_stage
58
+ phase_ratio: null
59
+ warmup_steps: 8000
60
+ hold_steps: 0
61
+ decay_steps: 72000
62
+ final_lr_scale: 0.05
63
+
64
+ model:
65
+ _name: wav2vec_ctc
66
+ w2v_path: ???
67
+ apply_mask: true
68
+ mask_prob: 0.75
69
+ mask_length: 5
70
+ # mask_prob: 0.65 # base 10h wer
71
+ mask_channel_prob: 0.1
72
+ # mask_channel_prob: 0.6 # base 10h wer
73
+ mask_channel_length: 64
74
+ layerdrop: 0
75
+ # layerdrop: 0.05 # base 10h wer
76
+ activation_dropout: 0.1
77
+ feature_grad_mult: 0.0
78
+ freeze_finetune_updates: 100
79
+ dropout: 0
80
+ final_dropout: 0
81
+ attention_dropout: 0
fairseq/examples/wav2vec/config/finetuning/vox_10h_aws.yaml ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _group_
2
+
3
+ common:
4
+ fp16: true
5
+ log_format: json
6
+ log_interval: 200
7
+ user_dir: /data/home/abaevski/fairseq-py/examples/data2vec
8
+ # tensorboard_logdir: tb
9
+
10
+ checkpoint:
11
+ save_interval: 10
12
+ no_epoch_checkpoints: true
13
+ best_checkpoint_metric: wer
14
+
15
+ task:
16
+ _name: audio_finetuning
17
+ data: /fsx-wav2vec/abaevski/data/libri/10h/wav2vec/raw
18
+ labels: ltr
19
+ normalize: true
20
+
21
+ dataset:
22
+ num_workers: 6
23
+ max_tokens: 1280000
24
+ skip_invalid_size_inputs_valid_test: true
25
+ validate_after_updates: 100
26
+ validate_interval: 10
27
+ valid_subset: dev_other
28
+ required_batch_size_multiple: 1
29
+
30
+ distributed_training:
31
+ ddp_backend: legacy_ddp
32
+ distributed_world_size: 4
33
+
34
+ criterion:
35
+ _name: ctc
36
+ zero_infinity: true
37
+ post_process: letter
38
+ # wer_kenlm_model: /fsx-wav2vec/abaevski/data/libri/4-gram.bin
39
+ # wer_lexicon: /fsx-wav2vec/abaevski/data/libri/10h/wav2vec/raw/lexicon_ltr2.lst
40
+ # wer_lm_weight: 2.0
41
+ # wer_word_score: -1.0
42
+
43
+ optimization:
44
+ max_update: 60000
45
+ lr: [2e-5]
46
+ # lr: [1e-5] # base 10h wer
47
+ sentence_avg: true
48
+ update_freq: [1] # base 10h we -> 2/4
49
+
50
+ optimizer:
51
+ _name: adam
52
+ adam_betas: (0.9,0.98)
53
+ adam_eps: 1e-08
54
+
55
+ lr_scheduler:
56
+ _name: tri_stage
57
+ phase_ratio: null
58
+ warmup_steps: 8000
59
+ hold_steps: 0
60
+ decay_steps: 72000
61
+ final_lr_scale: 0.05
62
+
63
+ model:
64
+ _name: wav2vec_ctc
65
+ w2v_path: ???
66
+ apply_mask: true
67
+ mask_prob: 0.4
68
+ mask_length: 5
69
+ # mask_prob: 0.65 # base 10h wer
70
+ mask_channel_prob: 0.1
71
+ # mask_channel_prob: 0.6 # base 10h wer
72
+ mask_channel_length: 64
73
+ layerdrop: 0.1
74
+ # layerdrop: 0.05 # base 10h wer
75
+ activation_dropout: 0.1
76
+ feature_grad_mult: 0.0
77
+ freeze_finetune_updates: 100
78
+ dropout: 0
79
+ final_dropout: 0
80
+ attention_dropout: 0
81
+
82
+ hydra:
83
+ job:
84
+ config:
85
+ override_dirname:
86
+ kv_sep: ':'
87
+ item_sep: '__'
88
+ exclude_keys:
89
+ - run_config
90
+ - distributed_training.distributed_port
91
+ sweep:
92
+ dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}/${hydra.job.override_dirname}
93
+ subdir: ${hydra.job.num}
94
+ launcher:
95
+ submitit_folder: ${hydra.sweep.dir}
96
+ timeout_min: 3000
97
+ cpus_per_task: 10
98
+ gpus_per_node: 4
99
+ tasks_per_node: 4
100
+ mem_gb: 0
101
+ nodes: 1
102
+ name: ${env:PREFIX}_${hydra.job.config_name}
103
+ partition: wav2vec,learnlab
104
+ max_num_timeout: 30
fairseq/examples/wav2vec/config/finetuning/vox_10m_2.yaml ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _group_
2
+
3
+ common:
4
+ fp16: true
5
+ fp16_no_flatten_grads: true
6
+ log_format: json
7
+ log_interval: 200
8
+ user_dir: /private/home/abaevski/fairseq-py/examples/data2vec
9
+ # tensorboard_logdir: tb
10
+
11
+ checkpoint:
12
+ save_interval: 500
13
+ save_interval_updates: 500
14
+ keep_interval_updates: 1
15
+ no_epoch_checkpoints: true
16
+ best_checkpoint_metric: wer
17
+
18
+ task:
19
+ _name: audio_finetuning
20
+ data: /checkpoint/abaevski/data/speech/libri/10m/wav2vec/raw
21
+ labels: ltr
22
+ normalize: true
23
+
24
+ dataset:
25
+ num_workers: 6
26
+ max_tokens: 1000000
27
+ skip_invalid_size_inputs_valid_test: true
28
+ validate_after_updates: 100
29
+ validate_interval: 500
30
+ valid_subset: dev_other
31
+ required_batch_size_multiple: 1
32
+
33
+ distributed_training:
34
+ ddp_backend: legacy_ddp
35
+ distributed_world_size: 4
36
+
37
+ criterion:
38
+ _name: ctc
39
+ zero_infinity: true
40
+ post_process: letter
41
+ wer_kenlm_model: /checkpoint/abaevski/data/speech/libri/4-gram.bin
42
+ wer_lexicon: /checkpoint/abaevski/data/speech/libri/10h/wav2vec/raw/lexicon_ltr2.lst
43
+ wer_lm_weight: 5
44
+ wer_word_score: 2
45
+ wer_sil_weight: -2
46
+
47
+ optimization:
48
+ max_update: 10000
49
+ lr: [2e-6]
50
+ # lr: [1e-5] # base 10h wer
51
+ sentence_avg: true
52
+ update_freq: [4] # base 10h we -> 2/4
53
+
54
+ optimizer:
55
+ _name: composite
56
+ dynamic_groups: true
57
+ groups:
58
+ default:
59
+ lr_float: 2e-6
60
+ optimizer:
61
+ _name: adam
62
+ adam_betas: [0.9,0.95]
63
+ lr_scheduler:
64
+ _name: cosine
65
+ warmup_updates: 1000
66
+
67
+ lr_scheduler: pass_through
68
+
69
+ model:
70
+ _name: wav2vec_ctc
71
+ w2v_path: ???
72
+ apply_mask: true
73
+ mask_prob: 0.4
74
+ mask_length: 3
75
+ # mask_prob: 0.65 # base 10h wer
76
+ mask_channel_prob: 0.25
77
+ # mask_channel_prob: 0.6 # base 10h wer
78
+ mask_channel_length: 64
79
+ layerdrop: 0.1
80
+ # layerdrop: 0.05 # base 10h wer
81
+ freeze_finetune_updates: 100
82
+
83
+ zero_mask: true
84
+ feature_grad_mult: 0.0
85
+ activation_dropout: 0.1
86
+ dropout: 0
87
+ final_dropout: 0
88
+ attention_dropout: 0
89
+ update_alibi: false
90
+
91
+ #hydra:
92
+ # job:
93
+ # config:
94
+ # override_dirname:
95
+ # kv_sep: ':'
96
+ # item_sep: '__'
97
+ # exclude_keys:
98
+ # - run_config
99
+ # - distributed_training.distributed_port
100
+ # sweep:
101
+ # dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}/${hydra.job.override_dirname}
102
+ # subdir: ${hydra.job.num}
103
+ # launcher:
104
+ # submitit_folder: ${hydra.sweep.dir}
105
+ # timeout_min: 3000
106
+ # cpus_per_task: 10
107
+ # gpus_per_node: 4
108
+ # tasks_per_node: 4
109
+ # mem_gb: 250
110
+ # nodes: 1
111
+ # name: ${env:PREFIX}_${hydra.job.config_name}
112
+ # partition: devlab,learnlab,learnfair,scavenge
113
+ # constraint: volta32gb
114
+ # max_num_timeout: 30
fairseq/examples/wav2vec/config/finetuning/vox_10m_2_aws.yaml ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _group_
2
+
3
+ common:
4
+ fp16: true
5
+ fp16_no_flatten_grads: true
6
+ log_format: json
7
+ log_interval: 200
8
+ user_dir: /data/home/abaevski/fairseq-py/examples/data2vec
9
+ # tensorboard_logdir: tb
10
+
11
+ checkpoint:
12
+ save_interval: 500
13
+ save_interval_updates: 500
14
+ keep_interval_updates: 1
15
+ no_epoch_checkpoints: true
16
+ best_checkpoint_metric: wer
17
+
18
+ task:
19
+ _name: audio_finetuning
20
+ data: /fsx-wav2vec/abaevski/data/libri/10m/wav2vec/raw
21
+ labels: ltr
22
+ normalize: true
23
+
24
+ dataset:
25
+ num_workers: 6
26
+ max_tokens: 1000000
27
+ skip_invalid_size_inputs_valid_test: true
28
+ validate_after_updates: 100
29
+ validate_interval: 500
30
+ valid_subset: dev_other
31
+ required_batch_size_multiple: 1
32
+
33
+ distributed_training:
34
+ ddp_backend: legacy_ddp
35
+ distributed_world_size: 4
36
+
37
+ criterion:
38
+ _name: ctc
39
+ zero_infinity: true
40
+ post_process: letter
41
+ wer_kenlm_model: /fsx-wav2vec/abaevski/data/libri/4-gram.bin
42
+ wer_lexicon: /fsx-wav2vec/abaevski/data/libri/10h/wav2vec/raw/lexicon_ltr2.lst
43
+ wer_lm_weight: 5
44
+ wer_word_score: 2
45
+ wer_sil_weight: -2
46
+
47
+ optimization:
48
+ max_update: 10000
49
+ lr: [2e-6]
50
+ # lr: [1e-5] # base 10h wer
51
+ sentence_avg: true
52
+ update_freq: [4] # base 10h we -> 2/4
53
+
54
+ optimizer:
55
+ _name: composite
56
+ dynamic_groups: true
57
+ groups:
58
+ default:
59
+ lr_float: 2e-6
60
+ optimizer:
61
+ _name: adam
62
+ adam_betas: [0.9,0.95]
63
+ lr_scheduler:
64
+ _name: cosine
65
+ warmup_updates: 1000
66
+
67
+ lr_scheduler: pass_through
68
+
69
+ model:
70
+ _name: wav2vec_ctc
71
+ w2v_path: ???
72
+ apply_mask: true
73
+ mask_prob: 0.4
74
+ mask_length: 3
75
+ # mask_prob: 0.65 # base 10h wer
76
+ mask_channel_prob: 0.25
77
+ # mask_channel_prob: 0.6 # base 10h wer
78
+ mask_channel_length: 64
79
+ layerdrop: 0.1
80
+ # layerdrop: 0.05 # base 10h wer
81
+ freeze_finetune_updates: 100
82
+
83
+ zero_mask: true
84
+ feature_grad_mult: 0.0
85
+ activation_dropout: 0.1
86
+ dropout: 0
87
+ final_dropout: 0
88
+ attention_dropout: 0
89
+ update_alibi: false
90
+
91
+ #hydra:
92
+ # job:
93
+ # config:
94
+ # override_dirname:
95
+ # kv_sep: ':'
96
+ # item_sep: '__'
97
+ # exclude_keys:
98
+ # - run_config
99
+ # - distributed_training.distributed_port
100
+ # sweep:
101
+ # dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}/${hydra.job.override_dirname}
102
+ # subdir: ${hydra.job.num}
103
+ # launcher:
104
+ # submitit_folder: ${hydra.sweep.dir}
105
+ # timeout_min: 3000
106
+ # cpus_per_task: 10
107
+ # gpus_per_node: 4
108
+ # tasks_per_node: 4
109
+ # mem_gb: 250
110
+ # nodes: 1
111
+ # name: ${env:PREFIX}_${hydra.job.config_name}
112
+ # partition: devlab,learnlab,learnfair,scavenge
113
+ # constraint: volta32gb
114
+ # max_num_timeout: 30
fairseq/examples/wav2vec/config/finetuning/vox_10m_3.yaml ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _group_
2
+
3
+ common:
4
+ fp16: true
5
+ log_format: json
6
+ log_interval: 200
7
+ user_dir: /private/home/abaevski/fairseq-py/examples/data2vec
8
+ # tensorboard_logdir: tb
9
+
10
+ checkpoint:
11
+ save_interval: 1000
12
+ save_interval_updates: 100
13
+ keep_interval_updates: 1
14
+ no_epoch_checkpoints: true
15
+ best_checkpoint_metric: wer
16
+
17
+ task:
18
+ _name: audio_finetuning
19
+ data: /checkpoint/abaevski/data/speech/libri/10m/wav2vec/raw
20
+ labels: ltr
21
+ normalize: true
22
+
23
+ dataset:
24
+ num_workers: 6
25
+ max_tokens: 1280000
26
+ skip_invalid_size_inputs_valid_test: true
27
+ validate_after_updates: 10000
28
+ validate_interval: 500
29
+ valid_subset: dev_other
30
+ required_batch_size_multiple: 8
31
+
32
+ distributed_training:
33
+ ddp_backend: legacy_ddp
34
+ distributed_world_size: 4
35
+
36
+ criterion:
37
+ _name: ctc
38
+ zero_infinity: true
39
+ post_process: letter
40
+ wer_kenlm_model: /checkpoint/abaevski/data/speech/libri/4-gram.bin
41
+ wer_lexicon: /checkpoint/abaevski/data/speech/libri/10h/wav2vec/raw/lexicon_ltr2.lst
42
+ wer_lm_weight: 8
43
+ wer_word_score: 5.8
44
+ wer_sil_weight: -8
45
+
46
+ optimization:
47
+ max_update: 13000
48
+ lr: [2e-5]
49
+ # lr: [1e-5] # base 10h wer
50
+ sentence_avg: true
51
+ update_freq: [5] # base 10h we -> 2/4
52
+
53
+ optimizer:
54
+ _name: adam
55
+ adam_betas: (0.9,0.98)
56
+ adam_eps: 1e-08
57
+
58
+ lr_scheduler:
59
+ _name: tri_stage
60
+ phase_ratio: [0.1, 0.4, 0.5]
61
+ final_lr_scale: 0.05
62
+
63
+ model:
64
+ _name: wav2vec_ctc
65
+ w2v_path: ???
66
+ apply_mask: true
67
+ mask_prob: 0.65
68
+ mask_length: 10
69
+ # mask_prob: 0.65 # base 10h wer
70
+ mask_channel_prob: 0.25
71
+ # mask_channel_prob: 0.6 # base 10h wer
72
+ mask_channel_length: 64
73
+ layerdrop: 0.1
74
+ # layerdrop: 0.05 # base 10h wer
75
+ activation_dropout: 0.1
76
+ feature_grad_mult: 0.0
77
+ freeze_finetune_updates: 10000
78
+ dropout: 0
79
+ final_dropout: 0
80
+ attention_dropout: 0
81
+
82
+ hydra:
83
+ job:
84
+ config:
85
+ override_dirname:
86
+ kv_sep: ':'
87
+ item_sep: '__'
88
+ exclude_keys:
89
+ - run_config
90
+ - distributed_training.distributed_port
91
+ sweep:
92
+ dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}/${hydra.job.override_dirname}
93
+ subdir: ${hydra.job.num}
94
+ launcher:
95
+ submitit_folder: ${hydra.sweep.dir}
96
+ timeout_min: 3000
97
+ cpus_per_task: 10
98
+ gpus_per_node: 4
99
+ tasks_per_node: 4
100
+ mem_gb: 250
101
+ nodes: 1
102
+ name: ${env:PREFIX}_${hydra.job.config_name}
103
+ partition: devlab,learnlab,learnfair,scavenge
104
+ constraint: volta32gb
105
+ max_num_timeout: 30
fairseq/examples/wav2vec/config/finetuning/vox_1h.yaml ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _group_
2
+
3
+ common:
4
+ fp16: true
5
+ log_format: json
6
+ log_interval: 200
7
+
8
+ checkpoint:
9
+ save_interval: 1000
10
+ save_interval_updates: 50
11
+ keep_interval_updates: 1
12
+ no_epoch_checkpoints: true
13
+ best_checkpoint_metric: wer
14
+
15
+ task:
16
+ _name: audio_finetuning
17
+ data: ???
18
+ normalize: true
19
+ labels: ltr
20
+
21
+ dataset:
22
+ num_workers: 6
23
+ max_tokens: 1280000
24
+ skip_invalid_size_inputs_valid_test: true
25
+ validate_after_updates: 10000
26
+ validate_interval: 1000
27
+ valid_subset: dev_other
28
+
29
+ distributed_training:
30
+ ddp_backend: legacy_ddp
31
+ distributed_world_size: 4
32
+
33
+ criterion:
34
+ _name: ctc
35
+ zero_infinity: true
36
+
37
+ optimization:
38
+ max_update: 13000
39
+ lr: [0.0003]
40
+ sentence_avg: true
41
+ update_freq: [5]
42
+
43
+ optimizer:
44
+ _name: adam
45
+ adam_betas: (0.9,0.98)
46
+ adam_eps: 1e-08
47
+
48
+ lr_scheduler:
49
+ _name: tri_stage
50
+ phase_ratio: [0.1, 0.4, 0.5]
51
+ final_lr_scale: 0.05
52
+
53
+ model:
54
+ _name: wav2vec_ctc
55
+ w2v_path: ???
56
+ apply_mask: true
57
+ mask_prob: 0.75
58
+ mask_channel_prob: 0.25
59
+ mask_channel_length: 64
60
+ layerdrop: 0.1
61
+ activation_dropout: 0.1
62
+ feature_grad_mult: 0.0
63
+ freeze_finetune_updates: 10000
fairseq/examples/wav2vec/config/finetuning/vox_1h_2.yaml ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _group_
2
+
3
+ common:
4
+ fp16: true
5
+ log_format: json
6
+ log_interval: 200
7
+ user_dir: /private/home/abaevski/fairseq-py/examples/data2vec
8
+ # tensorboard_logdir: tb
9
+
10
+ checkpoint:
11
+ save_interval: 100
12
+ save_interval_updates: 500
13
+ keep_interval_updates: 1
14
+ no_epoch_checkpoints: true
15
+ best_checkpoint_metric: wer
16
+
17
+ task:
18
+ _name: audio_finetuning
19
+ data: /checkpoint/abaevski/data/speech/libri/1h/wav2vec/raw
20
+ labels: ltr
21
+ normalize: true
22
+
23
+ dataset:
24
+ num_workers: 6
25
+ max_tokens: 1000000
26
+ skip_invalid_size_inputs_valid_test: true
27
+ validate_after_updates: 100
28
+ validate_interval: 100
29
+ valid_subset: dev_other
30
+ required_batch_size_multiple: 1
31
+
32
+ distributed_training:
33
+ ddp_backend: legacy_ddp
34
+ distributed_world_size: 8
35
+
36
+ criterion:
37
+ _name: ctc
38
+ zero_infinity: true
39
+ post_process: letter
40
+ wer_kenlm_model: /checkpoint/abaevski/data/speech/libri/4-gram.bin
41
+ wer_lexicon: /checkpoint/abaevski/data/speech/libri/10h/wav2vec/raw/lexicon_ltr2.lst
42
+ wer_lm_weight: 6
43
+ wer_word_score: -0.1
44
+ wer_sil_weight: -4.7
45
+
46
+ optimization:
47
+ max_update: 60000
48
+ lr: [1e-5]
49
+ # lr: [1e-5] # base 10h wer
50
+ sentence_avg: true
51
+ update_freq: [1] # base 10h we -> 2/4
52
+
53
+ optimizer:
54
+ _name: adam
55
+ adam_betas: (0.9,0.98)
56
+ adam_eps: 1e-08
57
+
58
+ lr_scheduler:
59
+ _name: cosine
60
+ warmup_updates: 4000
61
+
62
+ model:
63
+ _name: wav2vec_ctc
64
+ w2v_path: ???
65
+ apply_mask: true
66
+ mask_prob: 0.65
67
+ mask_length: 5
68
+ # mask_prob: 0.65 # base 10h wer
69
+ mask_channel_prob: 0.25
70
+ # mask_channel_prob: 0.6 # base 10h wer
71
+ mask_channel_length: 64
72
+ layerdrop: 0.1
73
+ # layerdrop: 0.05 # base 10h wer
74
+ activation_dropout: 0.1
75
+ feature_grad_mult: 0.0
76
+ freeze_finetune_updates: 100
77
+ dropout: 0
78
+ final_dropout: 0
79
+ attention_dropout: 0
80
+
81
+ hydra:
82
+ job:
83
+ config:
84
+ override_dirname:
85
+ kv_sep: ':'
86
+ item_sep: '__'
87
+ exclude_keys:
88
+ - run_config
89
+ - distributed_training.distributed_port
90
+ sweep:
91
+ dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}/${hydra.job.override_dirname}
92
+ subdir: ${hydra.job.num}
93
+ launcher:
94
+ submitit_folder: ${hydra.sweep.dir}
95
+ timeout_min: 3000
96
+ cpus_per_task: 10
97
+ gpus_per_node: 4
98
+ tasks_per_node: 4
99
+ mem_gb: 250
100
+ nodes: 1
101
+ name: ${env:PREFIX}_${hydra.job.config_name}
102
+ partition: devlab,learnlab,learnfair,scavenge
103
+ constraint: volta32gb
104
+ max_num_timeout: 30
fairseq/examples/wav2vec/config/finetuning/vox_1h_2_aws.yaml ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _group_
2
+
3
+ common:
4
+ fp16: true
5
+ fp16_no_flatten_grads: true
6
+ log_format: json
7
+ log_interval: 200
8
+ user_dir: /data/home/abaevski/fairseq-py/examples/data2vec
9
+ # tensorboard_logdir: tb
10
+
11
+ checkpoint:
12
+ save_interval: 100
13
+ save_interval_updates: 500
14
+ keep_interval_updates: 1
15
+ no_epoch_checkpoints: true
16
+ best_checkpoint_metric: wer
17
+
18
+ task:
19
+ _name: audio_finetuning
20
+ data: /fsx-wav2vec/abaevski/data/libri/1h/wav2vec/raw
21
+ labels: ltr
22
+ normalize: true
23
+
24
+ dataset:
25
+ num_workers: 6
26
+ max_tokens: 1000000
27
+ skip_invalid_size_inputs_valid_test: true
28
+ validate_after_updates: 100
29
+ validate_interval: 500
30
+ valid_subset: dev_other
31
+ required_batch_size_multiple: 1
32
+
33
+ distributed_training:
34
+ ddp_backend: legacy_ddp
35
+ distributed_world_size: 4
36
+
37
+ criterion:
38
+ _name: ctc
39
+ zero_infinity: true
40
+ post_process: letter
41
+ wer_kenlm_model: /fsx-wav2vec/abaevski/data/libri/4-gram.bin
42
+ wer_lexicon: /fsx-wav2vec/abaevski/data/libri/10h/wav2vec/raw/lexicon_ltr2.lst
43
+ wer_lm_weight: 5
44
+ wer_word_score: 0
45
+ wer_sil_weight: -4
46
+
47
+ optimization:
48
+ max_update: 10000
49
+ lr: [2e-6]
50
+ # lr: [1e-5] # base 10h wer
51
+ sentence_avg: true
52
+ update_freq: [4] # base 10h we -> 2/4
53
+
54
+ optimizer:
55
+ _name: composite
56
+ dynamic_groups: true
57
+ groups:
58
+ default:
59
+ lr_float: 2e-6
60
+ optimizer:
61
+ _name: adam
62
+ adam_betas: [0.9,0.95]
63
+ lr_scheduler:
64
+ _name: cosine
65
+ warmup_updates: 1000
66
+
67
+ lr_scheduler: pass_through
68
+
69
+ model:
70
+ _name: wav2vec_ctc
71
+ w2v_path: ???
72
+ apply_mask: true
73
+ mask_prob: 0.4
74
+ mask_length: 3
75
+ # mask_prob: 0.65 # base 10h wer
76
+ mask_channel_prob: 0.25
77
+ # mask_channel_prob: 0.6 # base 10h wer
78
+ mask_channel_length: 64
79
+ layerdrop: 0.1
80
+ # layerdrop: 0.05 # base 10h wer
81
+ freeze_finetune_updates: 100
82
+
83
+ zero_mask: true
84
+ feature_grad_mult: 0.0
85
+ activation_dropout: 0.1
86
+ dropout: 0
87
+ final_dropout: 0
88
+ attention_dropout: 0
89
+ update_alibi: false
90
+
91
+ #hydra:
92
+ # job:
93
+ # config:
94
+ # override_dirname:
95
+ # kv_sep: ':'
96
+ # item_sep: '__'
97
+ # exclude_keys:
98
+ # - run_config
99
+ # - distributed_training.distributed_port
100
+ # sweep:
101
+ # dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}/${hydra.job.override_dirname}
102
+ # subdir: ${hydra.job.num}
103
+ # launcher:
104
+ # submitit_folder: ${hydra.sweep.dir}
105
+ # timeout_min: 3000
106
+ # cpus_per_task: 10
107
+ # gpus_per_node: 4
108
+ # tasks_per_node: 4
109
+ # mem_gb: 250
110
+ # nodes: 1
111
+ # name: ${env:PREFIX}_${hydra.job.config_name}
112
+ # partition: devlab,learnlab,learnfair,scavenge
113
+ # constraint: volta32gb
114
+ # max_num_timeout: 30
fairseq/examples/wav2vec/config/finetuning/vox_1h_aws.yaml ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _group_
2
+
3
+ common:
4
+ fp16: true
5
+ log_format: json
6
+ log_interval: 200
7
+ user_dir: /data/home/abaevski/fairseq-py/examples/data2vec
8
+ # tensorboard_logdir: tb
9
+
10
+ checkpoint:
11
+ save_interval: 100
12
+ save_interval_updates: 500
13
+ keep_interval_updates: 1
14
+ no_epoch_checkpoints: true
15
+ best_checkpoint_metric: wer
16
+
17
+ task:
18
+ _name: audio_finetuning
19
+ data: /fsx-wav2vec/abaevski/data/libri/10m/wav2vec/raw
20
+ labels: ltr
21
+ normalize: true
22
+
23
+ dataset:
24
+ num_workers: 6
25
+ max_tokens: 1000000
26
+ skip_invalid_size_inputs_valid_test: true
27
+ validate_after_updates: 10000
28
+ validate_interval: 100
29
+ valid_subset: dev_other
30
+ required_batch_size_multiple: 8
31
+
32
+ distributed_training:
33
+ ddp_backend: legacy_ddp
34
+ distributed_world_size: 8
35
+
36
+ criterion:
37
+ _name: ctc
38
+ zero_infinity: true
39
+ post_process: letter
40
+ wer_kenlm_model: /fsx-wav2vec/abaevski/data/libri/4-gram.bin
41
+ wer_lexicon: /fsx-wav2vec/abaevski/data/libri/10h/wav2vec/raw/lexicon_ltr2.lst
42
+ wer_lm_weight: 5
43
+ wer_word_score: -0.1
44
+ wer_sil_weight: -4.7
45
+
46
+ optimization:
47
+ max_update: 13000
48
+ lr: [6e-5]
49
+ # lr: [1e-5] # base 10h wer
50
+ sentence_avg: true
51
+ update_freq: [5] # base 10h we -> 2/4
52
+
53
+ optimizer:
54
+ _name: adam
55
+ adam_betas: (0.9,0.98)
56
+ adam_eps: 1e-08
57
+
58
+ lr_scheduler:
59
+ _name: cosine
60
+ warmup_updates: 4000
61
+
62
+ model:
63
+ _name: wav2vec_ctc
64
+ w2v_path: ???
65
+ apply_mask: true
66
+ mask_prob: 0.3
67
+ mask_length: 3
68
+ # mask_prob: 0.65 # base 10h wer
69
+ mask_channel_prob: 0.25
70
+ # mask_channel_prob: 0.6 # base 10h wer
71
+ mask_channel_length: 64
72
+ layerdrop: 0.1
73
+ # layerdrop: 0.05 # base 10h wer
74
+ activation_dropout: 0.1
75
+ feature_grad_mult: 0.0
76
+ freeze_finetune_updates: 10000
77
+ dropout: 0
78
+ final_dropout: 0
79
+ attention_dropout: 0
80
+ update_alibi: false
fairseq/examples/wav2vec/config/finetuning/vox_960h.yaml ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _group_
2
+
3
+ common:
4
+ fp16: true
5
+ log_format: json
6
+ log_interval: 200
7
+
8
+ checkpoint:
9
+ no_epoch_checkpoints: true
10
+ best_checkpoint_metric: wer
11
+
12
+ task:
13
+ _name: audio_finetuning
14
+ data: ???
15
+ normalize: true
16
+ labels: ltr
17
+
18
+ dataset:
19
+ num_workers: 6
20
+ max_tokens: 1280000
21
+ skip_invalid_size_inputs_valid_test: true
22
+ valid_subset: dev_other
23
+
24
+ distributed_training:
25
+ ddp_backend: legacy_ddp
26
+ distributed_world_size: 24
27
+
28
+ criterion:
29
+ _name: ctc
30
+ zero_infinity: true
31
+
32
+ optimization:
33
+ max_update: 320000
34
+ lr: [0.00003]
35
+ sentence_avg: true
36
+
37
+ optimizer:
38
+ _name: adam
39
+ adam_betas: (0.9,0.98)
40
+ adam_eps: 1e-08
41
+
42
+ lr_scheduler:
43
+ _name: tri_stage
44
+ phase_ratio: [0.1, 0.4, 0.5]
45
+ final_lr_scale: 0.05
46
+
47
+ model:
48
+ _name: wav2vec_ctc
49
+ w2v_path: ???
50
+ apply_mask: true
51
+ mask_prob: 0.5
52
+ mask_channel_prob: 0.25
53
+ mask_channel_length: 64
54
+ layerdrop: 0.1
55
+ activation_dropout: 0.1
56
+ feature_grad_mult: 0.0
57
+ freeze_finetune_updates: 10000
fairseq/examples/wav2vec/config/finetuning/vox_960h_2.yaml ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _group_
2
+
3
+ common:
4
+ fp16: true
5
+ log_format: json
6
+ log_interval: 200
7
+ user_dir: /private/home/abaevski/fairseq-py/examples/data2vec
8
+ # tensorboard_logdir: tb
9
+
10
+ checkpoint:
11
+ save_interval: 1
12
+ no_epoch_checkpoints: true
13
+ best_checkpoint_metric: wer
14
+
15
+ task:
16
+ _name: audio_finetuning
17
+ data: /checkpoint/abaevski/data/speech/libri/960h/wav2vec/raw
18
+ labels: ltr
19
+ normalize: true
20
+
21
+ dataset:
22
+ num_workers: 6
23
+ max_tokens: 1000000
24
+ skip_invalid_size_inputs_valid_test: true
25
+ validate_after_updates: 100
26
+ validate_interval: 1
27
+ valid_subset: dev_other
28
+ required_batch_size_multiple: 1
29
+
30
+ distributed_training:
31
+ ddp_backend: legacy_ddp
32
+ distributed_world_size: 16
33
+
34
+ criterion:
35
+ _name: ctc
36
+ zero_infinity: true
37
+ post_process: letter
38
+ wer_kenlm_model: /checkpoint/abaevski/data/speech/libri/4-gram.bin
39
+ wer_lexicon: /checkpoint/abaevski/data/speech/libri/10h/wav2vec/raw/lexicon_ltr2.lst
40
+ wer_lm_weight: 2.0
41
+ wer_word_score: -1.0
42
+
43
+ optimization:
44
+ max_update: 200000
45
+ lr: [1e-5]
46
+ # lr: [1e-5] # base 10h wer
47
+ sentence_avg: true
48
+ update_freq: [1] # base 10h we -> 2/4
49
+
50
+ optimizer:
51
+ _name: adam
52
+ adam_betas: (0.9,0.98)
53
+ adam_eps: 1e-08
54
+
55
+ lr_scheduler:
56
+ _name: tri_stage
57
+ phase_ratio: null
58
+ warmup_steps: 8000
59
+ hold_steps: 0
60
+ decay_steps: 200000
61
+ final_lr_scale: 0.05
62
+
63
+ model:
64
+ _name: wav2vec_ctc
65
+ w2v_path: ???
66
+ apply_mask: true
67
+ mask_prob: 0.4
68
+ mask_length: 5
69
+ # mask_prob: 0.65 # base 10h wer
70
+ mask_channel_prob: 0.1
71
+ # mask_channel_prob: 0.6 # base 10h wer
72
+ mask_channel_length: 64
73
+ layerdrop: 0.1
74
+ # layerdrop: 0.05 # base 10h wer
75
+ activation_dropout: 0.1
76
+ feature_grad_mult: 0.0
77
+ freeze_finetune_updates: 100
78
+ dropout: 0
79
+ final_dropout: 0
80
+ attention_dropout: 0
81
+
82
+ hydra:
83
+ job:
84
+ config:
85
+ override_dirname:
86
+ kv_sep: ':'
87
+ item_sep: '__'
88
+ exclude_keys:
89
+ - run_config
90
+ - distributed_training.distributed_port
91
+ sweep:
92
+ dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}/${hydra.job.override_dirname}
93
+ subdir: ${hydra.job.num}
94
+ launcher:
95
+ submitit_folder: ${hydra.sweep.dir}
96
+ timeout_min: 3000
97
+ cpus_per_task: 10
98
+ gpus_per_node: 4
99
+ tasks_per_node: 4
100
+ mem_gb: 250
101
+ nodes: 1
102
+ name: ${env:PREFIX}_${hydra.job.config_name}
103
+ partition: devlab,learnlab,learnfair,scavenge
104
+ constraint: volta32gb
105
+ max_num_timeout: 30
fairseq/examples/wav2vec/config/finetuning/vox_960h_2_aws.yaml ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _group_
2
+
3
+ common:
4
+ fp16: true
5
+ log_format: json
6
+ log_interval: 200
7
+ user_dir: /data/home/abaevski/fairseq-py/examples/data2vec
8
+ # tensorboard_logdir: tb
9
+
10
+ checkpoint:
11
+ save_interval: 1
12
+ no_epoch_checkpoints: true
13
+ best_checkpoint_metric: wer
14
+
15
+ task:
16
+ _name: audio_finetuning
17
+ data: /fsx-wav2vec/abaevski/data/librispeech
18
+ labels: ltr
19
+ normalize: true
20
+
21
+ dataset:
22
+ num_workers: 6
23
+ max_tokens: 1280000
24
+ skip_invalid_size_inputs_valid_test: true
25
+ validate_after_updates: 100
26
+ validate_interval: 1
27
+ valid_subset: dev_other
28
+ required_batch_size_multiple: 1
29
+
30
+ distributed_training:
31
+ ddp_backend: legacy_ddp
32
+ distributed_world_size: 16
33
+
34
+ criterion:
35
+ _name: ctc
36
+ zero_infinity: true
37
+ post_process: letter
38
+ wer_kenlm_model: /fsx-wav2vec/abaevski/data/libri/4-gram.bin
39
+ wer_lexicon: /fsx-wav2vec/abaevski/data/libri/10h/wav2vec/raw/lexicon_ltr2.lst
40
+ wer_lm_weight: 1.5
41
+ wer_word_score: 0
42
+ wer_sil_weight: -1
43
+
44
+ optimization:
45
+ max_update: 200000
46
+ lr: [2e-5]
47
+ # lr: [1e-5] # base 10h wer
48
+ sentence_avg: true
49
+ update_freq: [1] # base 10h we -> 2/4
50
+
51
+ optimizer:
52
+ _name: adam
53
+ adam_betas: (0.9,0.98)
54
+ adam_eps: 1e-08
55
+
56
+ lr_scheduler:
57
+ _name: tri_stage
58
+ phase_ratio: null
59
+ warmup_steps: 8000
60
+ hold_steps: 0
61
+ decay_steps: 192000
62
+ final_lr_scale: 0.05
63
+
64
+ model:
65
+ _name: wav2vec_ctc
66
+ w2v_path: ???
67
+ apply_mask: true
68
+ mask_prob: 0.3
69
+ mask_length: 5
70
+ # mask_prob: 0.65 # base 10h wer
71
+ mask_channel_prob: 0.1
72
+ # mask_channel_prob: 0.6 # base 10h wer
73
+ mask_channel_length: 64
74
+ layerdrop: 0
75
+ # layerdrop: 0.05 # base 10h wer
76
+ activation_dropout: 0.1
77
+ feature_grad_mult: 0.0
78
+ freeze_finetune_updates: 100
79
+ dropout: 0
80
+ final_dropout: 0
81
+ attention_dropout: 0
82
+
fairseq/examples/wav2vec/config/finetuning/vox_960h_3.yaml ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _group_
2
+
3
+ common:
4
+ fp16: true
5
+ log_format: json
6
+ log_interval: 200
7
+ user_dir: /private/home/abaevski/fairseq-py/examples/data2vec
8
+ # tensorboard_logdir: tb
9
+
10
+ checkpoint:
11
+ save_interval: 1
12
+ no_epoch_checkpoints: true
13
+ best_checkpoint_metric: wer
14
+
15
+ task:
16
+ _name: audio_finetuning
17
+ data: /checkpoint/abaevski/data/speech/libri/1h/wav2vec/raw
18
+ labels: ltr
19
+ normalize: true
20
+
21
+ dataset:
22
+ num_workers: 6
23
+ max_tokens: 1000000
24
+ skip_invalid_size_inputs_valid_test: true
25
+ validate_after_updates: 100
26
+ validate_interval: 1
27
+ valid_subset: dev_other
28
+ required_batch_size_multiple: 1
29
+
30
+ distributed_training:
31
+ ddp_backend: legacy_ddp
32
+ distributed_world_size: 16
33
+
34
+ criterion:
35
+ _name: ctc
36
+ zero_infinity: true
37
+ post_process: letter
38
+ wer_kenlm_model: /checkpoint/abaevski/data/speech/libri/4-gram.bin
39
+ wer_lexicon: /checkpoint/abaevski/data/speech/libri/10h/wav2vec/raw/lexicon_ltr2.lst
40
+ wer_lm_weight: 2.0
41
+ wer_word_score: -1.0
42
+
43
+ optimization:
44
+ max_update: 200000
45
+ lr: [1e-5]
46
+ # lr: [1e-5] # base 10h wer
47
+ sentence_avg: true
48
+ update_freq: [1] # base 10h we -> 2/4
49
+
50
+ optimizer:
51
+ _name: adam
52
+ adam_betas: (0.9,0.98)
53
+ adam_eps: 1e-08
54
+
55
+ lr_scheduler:
56
+ _name: cosine
57
+ warmup_updates: 8000
58
+
59
+ model:
60
+ _name: wav2vec_ctc
61
+ w2v_path: ???
62
+ apply_mask: true
63
+ mask_prob: 0.4
64
+ mask_length: 5
65
+ # mask_prob: 0.65 # base 10h wer
66
+ mask_channel_prob: 0.1
67
+ # mask_channel_prob: 0.6 # base 10h wer
68
+ mask_channel_length: 64
69
+ layerdrop: 0.1
70
+ # layerdrop: 0.05 # base 10h wer
71
+ activation_dropout: 0.1
72
+ feature_grad_mult: 0.0
73
+ freeze_finetune_updates: 100
74
+ dropout: 0
75
+ final_dropout: 0
76
+ attention_dropout: 0
77
+
78
+ hydra:
79
+ job:
80
+ config:
81
+ override_dirname:
82
+ kv_sep: ':'
83
+ item_sep: '__'
84
+ exclude_keys:
85
+ - run_config
86
+ - distributed_training.distributed_port
87
+ sweep:
88
+ dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}/${hydra.job.override_dirname}
89
+ subdir: ${hydra.job.num}
90
+ launcher:
91
+ submitit_folder: ${hydra.sweep.dir}
92
+ timeout_min: 3000
93
+ cpus_per_task: 10
94
+ gpus_per_node: 4
95
+ tasks_per_node: 4
96
+ mem_gb: 250
97
+ nodes: 1
98
+ name: ${env:PREFIX}_${hydra.job.config_name}
99
+ partition: devlab,learnlab,learnfair,scavenge
100
+ constraint: volta32gb
101
+ max_num_timeout: 30
fairseq/examples/wav2vec/config/pretraining/wav2vec2_base_librispeech.yaml ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _group_
2
+
3
+ common:
4
+ fp16: true
5
+ log_format: json
6
+ log_interval: 200
7
+
8
+ checkpoint:
9
+ save_interval_updates: 25000
10
+ keep_interval_updates: 1
11
+ no_epoch_checkpoints: true
12
+
13
+ task:
14
+ _name: audio_pretraining
15
+ data: ???
16
+ max_sample_size: 250000
17
+ min_sample_size: 32000
18
+ normalize: false
19
+
20
+ dataset:
21
+ num_workers: 6
22
+ max_tokens: 1400000
23
+ skip_invalid_size_inputs_valid_test: true
24
+
25
+ distributed_training:
26
+ distributed_world_size: 64
27
+ ddp_backend: legacy_ddp
28
+
29
+ criterion:
30
+ _name: wav2vec
31
+ infonce: true
32
+ log_keys: ["prob_perplexity","code_perplexity","temp"]
33
+ loss_weights: [0.1, 10]
34
+
35
+ optimization:
36
+ max_update: 400000
37
+ lr: [0.0005]
38
+
39
+ optimizer:
40
+ _name: adam
41
+ adam_betas: (0.9,0.98)
42
+ adam_eps: 1e-06
43
+ weight_decay: 0.01
44
+
45
+ lr_scheduler:
46
+ _name: polynomial_decay
47
+ warmup_updates: 32000
48
+
49
+ model:
50
+ _name: wav2vec2
51
+ quantize_targets: true
52
+ final_dim: 256
53
+ encoder_layerdrop: 0.05
54
+ dropout_input: 0.1
55
+ dropout_features: 0.1
56
+ feature_grad_mult: 0.1
57
+ encoder_embed_dim: 768
fairseq/examples/wav2vec/config/pretraining/wav2vec2_conformer_base_librispeech.yaml ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _group_
2
+
3
+ common:
4
+ fp16: true
5
+ log_format: json
6
+ log_interval: 200
7
+
8
+ checkpoint:
9
+ save_interval_updates: 25000
10
+ keep_interval_updates: 1
11
+ no_epoch_checkpoints: true
12
+
13
+ task:
14
+ _name: audio_pretraining
15
+ data: ???
16
+ max_sample_size: 250000
17
+ min_sample_size: 32000
18
+ normalize: false
19
+
20
+ dataset:
21
+ num_workers: 6
22
+ max_tokens: 1400000
23
+ skip_invalid_size_inputs_valid_test: true
24
+
25
+ distributed_training:
26
+ distributed_world_size: 64
27
+ ddp_backend: legacy_ddp
28
+
29
+ criterion:
30
+ _name: wav2vec
31
+ infonce: true
32
+ log_keys: ["prob_perplexity","code_perplexity","temp"]
33
+ loss_weights: [0.1, 10]
34
+
35
+ optimization:
36
+ max_update: 400000
37
+ lr: [0.0005]
38
+
39
+ optimizer:
40
+ _name: adam
41
+ adam_betas: (0.9,0.98)
42
+ adam_eps: 1e-06
43
+ weight_decay: 0.01
44
+
45
+ lr_scheduler:
46
+ _name: polynomial_decay
47
+ warmup_updates: 32000
48
+
49
+ model:
50
+ _name: wav2vec2
51
+ quantize_targets: true
52
+ final_dim: 256
53
+ encoder_layerdrop: 0.05
54
+ dropout_input: 0.1
55
+ dropout_features: 0.1
56
+ feature_grad_mult: 0.1
57
+ encoder_embed_dim: 768
58
+ layer_type: conformer
59
+ attn_type: espnet
60
+ pos_enc_type: rel_pos
fairseq/examples/wav2vec/config/pretraining/wav2vec2_conformer_large_librivox.yaml ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _group_
2
+
3
+ common:
4
+ fp16: true
5
+ log_format: json
6
+ log_interval: 200
7
+
8
+ checkpoint:
9
+ save_interval_updates: 25000
10
+ keep_interval_updates: 1
11
+ no_epoch_checkpoints: true
12
+
13
+ task:
14
+ _name: audio_pretraining
15
+ data: ???
16
+ max_sample_size: 320000
17
+ min_sample_size: 32000
18
+ normalize: true
19
+
20
+ dataset:
21
+ num_workers: 6
22
+ max_tokens: 1200000
23
+ skip_invalid_size_inputs_valid_test: true
24
+
25
+ distributed_training:
26
+ distributed_world_size: 128
27
+ ddp_backend: legacy_ddp
28
+
29
+ criterion:
30
+ _name: wav2vec
31
+ infonce: true
32
+ log_keys: ["prob_perplexity","code_perplexity","temp"]
33
+ loss_weights: [0.1, 0]
34
+
35
+ optimization:
36
+ max_update: 1000000
37
+ lr: [0.005]
38
+
39
+ optimizer:
40
+ _name: adam
41
+ adam_betas: (0.9,0.98)
42
+ adam_eps: 1e-06
43
+ weight_decay: 0.01
44
+
45
+ lr_scheduler:
46
+ _name: polynomial_decay
47
+ warmup_updates: 32000
48
+
49
+ model:
50
+ _name: wav2vec2
51
+ quantize_targets: true
52
+ extractor_mode: layer_norm
53
+ layer_norm_first: true
54
+ final_dim: 768
55
+ latent_temp: [2.0,0.1,0.999995]
56
+ encoder_layerdrop: 0.00
57
+ dropout_input: 0.0
58
+ dropout_features: 0.0
59
+ dropout: 0.0
60
+ attention_dropout: 0.0
61
+ conv_bias: true
62
+
63
+ encoder_layers: 24
64
+ encoder_embed_dim: 1024
65
+ encoder_ffn_embed_dim: 4096
66
+ encoder_attention_heads: 16
67
+
68
+ feature_grad_mult: 1.0
69
+
70
+ layer_type: conformer
71
+ attn_type: espnet
72
+ pos_enc_type: rel_pos
fairseq/examples/wav2vec/config/pretraining/wav2vec2_large_librivox.yaml ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _group_
2
+
3
+ common:
4
+ fp16: true
5
+ log_format: json
6
+ log_interval: 200
7
+
8
+ checkpoint:
9
+ save_interval_updates: 25000
10
+ keep_interval_updates: 1
11
+ no_epoch_checkpoints: true
12
+
13
+ task:
14
+ _name: audio_pretraining
15
+ data: ???
16
+ max_sample_size: 320000
17
+ min_sample_size: 32000
18
+ normalize: true
19
+
20
+ dataset:
21
+ batch_size: 4
22
+ num_workers: 6
23
+ max_tokens: 1200000
24
+ skip_invalid_size_inputs_valid_test: true
25
+
26
+ distributed_training:
27
+ distributed_world_size: 128
28
+ ddp_backend: legacy_ddp
29
+
30
+ criterion:
31
+ _name: wav2vec
32
+ infonce: true
33
+ log_keys: ["prob_perplexity","code_perplexity","temp"]
34
+ loss_weights: [0.1, 0]
35
+
36
+ optimization:
37
+ max_update: 1000000
38
+ lr: [0.005]
39
+
40
+ optimizer:
41
+ _name: adam
42
+ adam_betas: (0.9,0.98)
43
+ adam_eps: 1e-06
44
+ weight_decay: 0.01
45
+
46
+ lr_scheduler:
47
+ _name: polynomial_decay
48
+ warmup_updates: 32000
49
+
50
+ model:
51
+ _name: wav2vec2
52
+ quantize_targets: true
53
+ extractor_mode: layer_norm
54
+ layer_norm_first: true
55
+ final_dim: 768
56
+ latent_temp: [2.0,0.1,0.999995]
57
+ encoder_layerdrop: 0.00
58
+ dropout_input: 0.0
59
+ dropout_features: 0.0
60
+ dropout: 0.0
61
+ attention_dropout: 0.0
62
+ conv_bias: true
63
+
64
+ encoder_layers: 24
65
+ encoder_embed_dim: 1024
66
+ encoder_ffn_embed_dim: 4096
67
+ encoder_attention_heads: 16
68
+
69
+ feature_grad_mult: 1.0
70
+
fairseq/examples/wav2vec/config/pretraining/wav2vec2_large_librivox_tpu-pod.yaml ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _group_
2
+
3
+ common:
4
+ tpu: true
5
+ fp16: false
6
+ log_format: json
7
+ log_interval: 10
8
+
9
+ checkpoint:
10
+ save_interval_updates: 25000
11
+ keep_interval_updates: 1
12
+ no_epoch_checkpoints: true
13
+
14
+ task:
15
+ _name: audio_pretraining
16
+ data: ???
17
+ max_sample_size: 250000
18
+ min_sample_size: 32000
19
+ normalize: true
20
+ num_batch_buckets: 3
21
+ precompute_mask_indices: true
22
+ enable_padding: true
23
+
24
+ dataset:
25
+ num_workers: 6
26
+ max_tokens: 1200000
27
+ skip_invalid_size_inputs_valid_test: true
28
+
29
+ distributed_training:
30
+ distributed_world_size: 128
31
+ ddp_backend: legacy_ddp
32
+
33
+ criterion:
34
+ _name: wav2vec
35
+ infonce: true
36
+ log_keys: ["prob_perplexity","code_perplexity","temp"]
37
+ loss_weights: [0.1, 0]
38
+
39
+ optimization:
40
+ max_update: 1000000
41
+ lr: [0.005]
42
+
43
+ optimizer:
44
+ _name: adam
45
+ adam_betas: (0.9,0.98)
46
+ adam_eps: 1e-06
47
+ weight_decay: 0.01
48
+
49
+ lr_scheduler:
50
+ _name: polynomial_decay
51
+ warmup_updates: 32000
52
+
53
+ model:
54
+ _name: wav2vec2
55
+ quantize_targets: true
56
+ extractor_mode: layer_norm
57
+ layer_norm_first: true
58
+ final_dim: 768
59
+ latent_temp: [2.0,0.1,0.999995]
60
+ encoder_layerdrop: 0.00
61
+ dropout_input: 0.0
62
+ dropout_features: 0.0
63
+ dropout: 0.0
64
+ attention_dropout: 0.0
65
+ conv_bias: true
66
+
67
+ encoder_layers: 24
68
+ encoder_embed_dim: 1024
69
+ encoder_ffn_embed_dim: 4096
70
+ encoder_attention_heads: 16
71
+
72
+ feature_grad_mult: 1.0
fairseq/examples/wav2vec/config/pretraining/wav2vec2_large_librivox_tpu.yaml ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _group_
2
+
3
+ common:
4
+ tpu: true
5
+ fp16: false
6
+ log_format: json
7
+ log_interval: 10
8
+
9
+ checkpoint:
10
+ save_interval_updates: 25000
11
+ keep_interval_updates: 1
12
+ no_epoch_checkpoints: true
13
+
14
+ task:
15
+ _name: audio_pretraining
16
+ data: ???
17
+ max_sample_size: 250000
18
+ min_sample_size: 32000
19
+ normalize: true
20
+ num_batch_buckets: 3
21
+ precompute_mask_indices: true
22
+ enable_padding: true
23
+ inferred_w2v_config:
24
+ mask_prob: 0.65
25
+ mask_selection: 'static'
26
+ mask_other: 0
27
+ mask_channel_prob: 0.1
28
+
29
+ dataset:
30
+ num_workers: 6
31
+ max_tokens: 1200000
32
+ skip_invalid_size_inputs_valid_test: true
33
+
34
+ distributed_training:
35
+ distributed_world_size: 8
36
+ ddp_backend: legacy_ddp
37
+
38
+ criterion:
39
+ _name: wav2vec
40
+ infonce: true
41
+ log_keys: ["prob_perplexity","code_perplexity","temp"]
42
+ loss_weights: [0.1, 0]
43
+
44
+ optimization:
45
+ max_update: 1000000
46
+ lr: [0.005]
47
+
48
+ optimizer:
49
+ _name: adam
50
+ adam_betas: (0.9,0.98)
51
+ adam_eps: 1e-06
52
+ weight_decay: 0.01
53
+
54
+ lr_scheduler:
55
+ _name: polynomial_decay
56
+ warmup_updates: 32000
57
+
58
+ model:
59
+ _name: wav2vec2
60
+ quantize_targets: true
61
+ extractor_mode: layer_norm
62
+ layer_norm_first: true
63
+ final_dim: 768
64
+ latent_temp: [2.0,0.1,0.999995]
65
+ encoder_layerdrop: 0.00
66
+ dropout_input: 0.0
67
+ dropout_features: 0.0
68
+ dropout: 0.0
69
+ attention_dropout: 0.0
70
+ conv_bias: true
71
+
72
+ encoder_layers: 24
73
+ encoder_embed_dim: 1024
74
+ encoder_ffn_embed_dim: 4096
75
+ encoder_attention_heads: 16
76
+
77
+ feature_grad_mult: 1.0