Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +1 -0
- fairseq/docs/fairseq.gif +3 -0
- fairseq/examples/data2vec/scripts/text/finetune_all_large_fair_aws_local_lr.sh +17 -0
- fairseq/examples/data2vec/scripts/text/finetune_all_large_fair_local_lr.sh +26 -0
- fairseq/examples/data2vec/scripts/text/finetune_all_large_fair_nodep_aws_local_lr.sh +15 -0
- fairseq/examples/data2vec/scripts/text/finetune_sst2_qnli_sweep_fair_nodep.sh +20 -0
- fairseq/examples/data2vec/scripts/text/glue.py +34 -0
- fairseq/examples/data2vec/scripts/text/glue_lr.py +143 -0
- fairseq/examples/data2vec/tasks/audio_classification.py +167 -0
- fairseq/examples/data2vec/tasks/image_classification.py +129 -0
- fairseq/examples/data2vec/tasks/image_pretraining.py +110 -0
- fairseq/examples/data2vec/tasks/mae_image_pretraining.py +119 -0
- fairseq/examples/emotion_conversion/emotion_models/__init__.py +0 -0
- fairseq/examples/emotion_conversion/emotion_models/duration_predictor.py +243 -0
- fairseq/examples/emotion_conversion/emotion_models/duration_predictor.yaml +48 -0
- fairseq/examples/emotion_conversion/emotion_models/pitch_predictor.py +559 -0
- fairseq/examples/emotion_conversion/emotion_models/utils.py +78 -0
- fairseq/examples/emotion_conversion/fairseq_models/__init__.py +226 -0
- fairseq/examples/emotion_conversion/preprocess/__init__.py +0 -0
- fairseq/examples/emotion_conversion/preprocess/build_hifigan_manifest.py +38 -0
- fairseq/examples/emotion_conversion/preprocess/build_translation_manifests.py +258 -0
- fairseq/examples/emotion_conversion/preprocess/create_core_manifest.py +91 -0
- fairseq/examples/emotion_conversion/preprocess/extract_f0.py +57 -0
- fairseq/examples/emotion_conversion/preprocess/process_km.py +40 -0
- fairseq/examples/emotion_conversion/preprocess/split_emov_km_tsv_by_uttid.py +70 -0
- fairseq/examples/emotion_conversion/preprocess/split_km.py +50 -0
- fairseq/examples/emotion_conversion/preprocess/split_km_tsv.py +65 -0
- fairseq/examples/fast_noisy_channel/README.md +345 -0
- fairseq/examples/fast_noisy_channel/__init__.py +8 -0
- fairseq/examples/fast_noisy_channel/noisy_channel_beam_search.py +71 -0
- fairseq/examples/fast_noisy_channel/noisy_channel_sequence_generator.py +842 -0
- fairseq/examples/fast_noisy_channel/noisy_channel_translation.py +127 -0
- fairseq/examples/flores101/README.md +223 -0
- fairseq/examples/flores101/flores_logo.png +0 -0
- fairseq/examples/fully_sharded_data_parallel/README.md +177 -0
- fairseq/examples/gottbert/README.md +64 -0
- fairseq/examples/hubert/README.md +116 -0
- fairseq/examples/hubert/config/decode/ax_sweep/ngram.yaml +33 -0
- fairseq/examples/hubert/config/decode/ax_sweep/transformer.yaml +33 -0
- fairseq/examples/hubert/config/decode/infer_fsqlm.yaml +36 -0
- fairseq/examples/hubert/config/decode/infer_kenlm.yaml +36 -0
- fairseq/examples/hubert/config/decode/infer_viterbi.yaml +29 -0
- fairseq/examples/hubert/config/decode/run/submitit_slurm.yaml +17 -0
- fairseq/examples/hubert/config/decode/run/submitit_slurm_8gpu.yaml +17 -0
- fairseq/examples/hubert/config/finetune/base_10h.yaml +100 -0
- fairseq/examples/hubert/config/finetune/ckpt/it1.yaml +7 -0
- fairseq/examples/hubert/config/finetune/lm/ls_4gram.yaml +7 -0
- fairseq/examples/hubert/config/finetune/run/submitit_reg.yaml +20 -0
- fairseq/examples/hubert/config/pretrain/data/iter1.yaml +8 -0
- fairseq/examples/hubert/config/pretrain/data/iter2.yaml +8 -0
.gitattributes
CHANGED
@@ -37,3 +37,4 @@ fairseq/examples/MMPT/vlm.png filter=lfs diff=lfs merge=lfs -text
|
|
37 |
fairseq/examples/MMPT/videoclip.png filter=lfs diff=lfs merge=lfs -text
|
38 |
fairseq/alignment_train_cuda_binding.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
|
39 |
fairseq/alignment_train_cpu_binding.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
|
|
|
|
37 |
fairseq/examples/MMPT/videoclip.png filter=lfs diff=lfs merge=lfs -text
|
38 |
fairseq/alignment_train_cuda_binding.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
|
39 |
fairseq/alignment_train_cpu_binding.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
|
40 |
+
fairseq/docs/fairseq.gif filter=lfs diff=lfs merge=lfs -text
|
fairseq/docs/fairseq.gif
ADDED
![]() |
Git LFS Details
|
fairseq/examples/data2vec/scripts/text/finetune_all_large_fair_aws_local_lr.sh
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
|
3 |
+
set -eu
|
4 |
+
|
5 |
+
job_id="$1"
|
6 |
+
task_id="$2"
|
7 |
+
dir="$3"
|
8 |
+
|
9 |
+
echo "job_id: $job_id, task_id: $task_id, dir: $dir"
|
10 |
+
|
11 |
+
mkdir -p "$dir/log"
|
12 |
+
sbatch_args="-p wav2vec --nodes=1 --ntasks-per-node=1"
|
13 |
+
sbatch_args="$sbatch_args --gpus-per-node=1 --cpus-per-task=8 --mem=0 --time=24:00:00"
|
14 |
+
sbatch_args="$sbatch_args -d afterok:$job_id -o $dir/log/decode_sweep_%A.out"
|
15 |
+
sbatch_args="$sbatch_args -e $dir/log/decode_sweep_%A.err"
|
16 |
+
|
17 |
+
sbatch $sbatch_args examples/data2vec/scripts/text/finetune_all_large_fair_local_lr.sh $dir
|
fairseq/examples/data2vec/scripts/text/finetune_all_large_fair_local_lr.sh
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env zsh
|
2 |
+
|
3 |
+
dir="$1"
|
4 |
+
cp="$dir/checkpoints/checkpoint_last.pt"
|
5 |
+
|
6 |
+
echo "dir: $dir"
|
7 |
+
|
8 |
+
declare -A tasks
|
9 |
+
tasks[cola]="/fsx-wav2vec/abaevski/data/nlp/GLUE/CoLA-bin"
|
10 |
+
tasks[qnli]="/fsx-wav2vec/abaevski/data/nlp/GLUE/QNLI-bin"
|
11 |
+
tasks[mrpc]="/fsx-wav2vec/abaevski/data/nlp/GLUE/MRPC-bin"
|
12 |
+
tasks[rte]="/fsx-wav2vec/abaevski/data/nlp/GLUE/RTE-bin"
|
13 |
+
tasks[sst_2]="/fsx-wav2vec/abaevski/data/nlp/GLUE/SST-2-bin"
|
14 |
+
|
15 |
+
lrs=(5e-6 8e-6 1e-5 2e-5)
|
16 |
+
|
17 |
+
for task data_path in ${(kv)tasks}; do
|
18 |
+
for lr in $lrs; do
|
19 |
+
echo $lr $task
|
20 |
+
PYTHONPATH=. PREFIX="${PREFIX}" SUFFIX="" \
|
21 |
+
python fairseq_cli/hydra_train.py -m --config-dir examples/roberta/config/finetuning \
|
22 |
+
--config-name $task +run_config=local task.data="$data_path" common.log_interval=200 dataset.num_workers=1 \
|
23 |
+
checkpoint.restore_file="$cp" hydra.sweep.dir="$dir/finetune_lr/$task/$lr" "optimization.lr=[${lr}]" \
|
24 |
+
model._name=roberta_large
|
25 |
+
done
|
26 |
+
done
|
fairseq/examples/data2vec/scripts/text/finetune_all_large_fair_nodep_aws_local_lr.sh
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
|
3 |
+
set -eu
|
4 |
+
|
5 |
+
dir="$1"
|
6 |
+
|
7 |
+
echo "dir: $dir"
|
8 |
+
|
9 |
+
mkdir -p "$dir/log"
|
10 |
+
sbatch_args="-p wav2vec --nodes=1 --ntasks-per-node=1"
|
11 |
+
sbatch_args="$sbatch_args --gpus-per-node=1 --cpus-per-task=8 --mem=0 --time=24:00:00"
|
12 |
+
sbatch_args="$sbatch_args -o $dir/log/decode_sweep_%A.out"
|
13 |
+
sbatch_args="$sbatch_args -e $dir/log/decode_sweep_%A.err"
|
14 |
+
|
15 |
+
sbatch $sbatch_args examples/data2vec/scripts/text/finetune_all_large_fair_local_lr.sh $dir
|
fairseq/examples/data2vec/scripts/text/finetune_sst2_qnli_sweep_fair_nodep.sh
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env zsh
|
2 |
+
|
3 |
+
dir="$1"
|
4 |
+
cp="$dir/checkpoints/checkpoint_last.pt"
|
5 |
+
|
6 |
+
echo "dir: $dir"
|
7 |
+
|
8 |
+
declare -A tasks
|
9 |
+
tasks[qnli]="/private/home/jgu/data/GLUE/QNLI-bin"
|
10 |
+
tasks[sst_2]="/private/home/jgu/data/GLUE/SST-2-bin"
|
11 |
+
|
12 |
+
lrs="5e-6 1e-5 2e-5 5e-5 1e-4 2e-4 5e-4 1e-3"
|
13 |
+
|
14 |
+
for task data_path in ${(kv)tasks}; do
|
15 |
+
for lr in $(echo "$lrs"); do
|
16 |
+
PYTHONPATH=. PREFIX="${PREFIX}" SUFFIX="" nohup python fairseq_cli/hydra_train.py -m --config-dir examples/roberta/config/finetuning \
|
17 |
+
--config-name $task hydra/launcher=submitit_slurm +run_config=slurm_1g task.data="$data_path" hydra.launcher.name=finetune_${task}_${PREFIX} \
|
18 |
+
checkpoint.restore_file="$cp" hydra.sweep.dir="$dir/finetune_sweep/$task/lr_$lr" "optimization.lr=[${lr}]" &
|
19 |
+
done
|
20 |
+
done
|
fairseq/examples/data2vec/scripts/text/glue.py
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from valids import parser, main as valids_main
|
2 |
+
import os.path as osp
|
3 |
+
|
4 |
+
|
5 |
+
args = parser.parse_args()
|
6 |
+
args.target = "valid_accuracy"
|
7 |
+
args.best_biggest = True
|
8 |
+
args.best = True
|
9 |
+
args.last = 0
|
10 |
+
args.path_contains = None
|
11 |
+
|
12 |
+
res = valids_main(args, print_output=False)
|
13 |
+
|
14 |
+
grouped = {}
|
15 |
+
for k, v in res.items():
|
16 |
+
k = osp.dirname(k)
|
17 |
+
run = osp.dirname(k)
|
18 |
+
task = osp.basename(k)
|
19 |
+
val = v["valid_accuracy"]
|
20 |
+
|
21 |
+
if run not in grouped:
|
22 |
+
grouped[run] = {}
|
23 |
+
|
24 |
+
grouped[run][task] = val
|
25 |
+
|
26 |
+
for run, tasks in grouped.items():
|
27 |
+
print(run)
|
28 |
+
avg = sum(float(v) for v in tasks.values()) / len(tasks)
|
29 |
+
avg_norte = sum(float(v) for k,v in tasks.items() if k != 'rte') / (len(tasks) -1)
|
30 |
+
try:
|
31 |
+
print(f"{tasks['cola']}\t{tasks['qnli']}\t{tasks['mrpc']}\t{tasks['rte']}\t{tasks['sst_2']}\t{avg:.2f}\t{avg_norte:.2f}")
|
32 |
+
except:
|
33 |
+
print(tasks)
|
34 |
+
print()
|
fairseq/examples/data2vec/scripts/text/glue_lr.py
ADDED
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os.path as osp
|
2 |
+
import re
|
3 |
+
from collections import defaultdict
|
4 |
+
|
5 |
+
from valids import parser, main as valids_main
|
6 |
+
|
7 |
+
|
8 |
+
TASK_TO_METRIC = {
|
9 |
+
"cola": "mcc",
|
10 |
+
"qnli": "accuracy",
|
11 |
+
"mrpc": "acc_and_f1",
|
12 |
+
"rte": "accuracy",
|
13 |
+
"sst_2": "accuracy",
|
14 |
+
"mnli": "accuracy",
|
15 |
+
"qqp": "acc_and_f1",
|
16 |
+
"sts_b": "pearson_and_spearman",
|
17 |
+
}
|
18 |
+
TASKS = ["cola", "qnli", "mrpc", "rte", "sst_2", "mnli", "qqp", "sts_b"]
|
19 |
+
|
20 |
+
|
21 |
+
def get_best_stat_str(task_vals, show_subdir):
|
22 |
+
task_to_best_val = {}
|
23 |
+
task_to_best_dir = {}
|
24 |
+
for task, subdir_to_val in task_vals.items():
|
25 |
+
task_to_best_val[task] = max(subdir_to_val.values())
|
26 |
+
task_to_best_dir[task] = max(subdir_to_val.keys(), key=lambda x: subdir_to_val[x])
|
27 |
+
|
28 |
+
# import pdb; pdb.set_trace()
|
29 |
+
N1 = len(task_to_best_val)
|
30 |
+
N2 = len([k for k in task_to_best_val if k != "rte"])
|
31 |
+
avg1 = sum(task_to_best_val.values()) / N1
|
32 |
+
avg2 = sum(v for task, v in task_to_best_val.items() if task != "rte") / N2
|
33 |
+
|
34 |
+
try:
|
35 |
+
msg = ""
|
36 |
+
for task in TASKS:
|
37 |
+
dir = task_to_best_dir.get(task, 'null')
|
38 |
+
val = task_to_best_val.get(task, -100)
|
39 |
+
msg += f"({dir}, {val})\t" if show_subdir else f"{val}\t"
|
40 |
+
msg += f"{avg1:.2f}\t{avg2:.2f}"
|
41 |
+
except Exception as e:
|
42 |
+
msg = str(e)
|
43 |
+
msg += str(sorted(task_vals.items()))
|
44 |
+
return msg
|
45 |
+
|
46 |
+
def get_all_stat_str(task_vals):
|
47 |
+
msg = ""
|
48 |
+
for task in [task for task in TASKS if task in task_vals]:
|
49 |
+
msg += f"=== {task}\n"
|
50 |
+
for subdir in sorted(task_vals[task].keys()):
|
51 |
+
msg += f"\t{subdir}\t{task_vals[task][subdir]}\n"
|
52 |
+
return msg
|
53 |
+
|
54 |
+
def get_tabular_stat_str(task_vals):
|
55 |
+
"""assume subdir is <param>/run_*/0"""
|
56 |
+
msg = ""
|
57 |
+
for task in [task for task in TASKS if task in task_vals]:
|
58 |
+
msg += f"=== {task}\n"
|
59 |
+
param_to_runs = defaultdict(dict)
|
60 |
+
for subdir in task_vals[task]:
|
61 |
+
match = re.match("(.*)/(run_.*)/0", subdir)
|
62 |
+
assert match, "subdir"
|
63 |
+
param, run = match.groups()
|
64 |
+
param_to_runs[param][run] = task_vals[task][subdir]
|
65 |
+
params = sorted(param_to_runs, key=lambda x: float(x))
|
66 |
+
runs = sorted(set(run for runs in param_to_runs.values() for run in runs))
|
67 |
+
msg += ("runs:" + "\t".join(runs) + "\n")
|
68 |
+
msg += ("params:" + "\t".join(params) + "\n")
|
69 |
+
for param in params:
|
70 |
+
msg += "\t".join([str(param_to_runs[param].get(run, None)) for run in runs])
|
71 |
+
msg += "\n"
|
72 |
+
# for subdir in sorted(task_vals[task].keys()):
|
73 |
+
# msg += f"\t{subdir}\t{task_vals[task][subdir]}\n"
|
74 |
+
return msg
|
75 |
+
|
76 |
+
|
77 |
+
|
78 |
+
def main():
|
79 |
+
parser.add_argument("--show_glue", action="store_true", help="show glue metric for each task instead of accuracy")
|
80 |
+
parser.add_argument("--print_mode", default="best", help="best|all|tabular")
|
81 |
+
parser.add_argument("--show_subdir", action="store_true", help="print the subdir that has the best results for each run")
|
82 |
+
parser.add_argument("--override_target", default="valid_accuracy", help="override target")
|
83 |
+
|
84 |
+
args = parser.parse_args()
|
85 |
+
args.target = args.override_target
|
86 |
+
args.best_biggest = True
|
87 |
+
args.best = True
|
88 |
+
args.last = 0
|
89 |
+
args.path_contains = None
|
90 |
+
|
91 |
+
res = valids_main(args, print_output=False)
|
92 |
+
grouped_acc = {}
|
93 |
+
grouped_met = {} # use official metric for each task
|
94 |
+
for path, v in res.items():
|
95 |
+
path = "/".join([args.base, path])
|
96 |
+
path = re.sub("//*", "/", path)
|
97 |
+
match = re.match("(.*)finetune[^/]*/([^/]*)/(.*)", path)
|
98 |
+
if not match:
|
99 |
+
continue
|
100 |
+
run, task, subdir = match.groups()
|
101 |
+
|
102 |
+
if run not in grouped_acc:
|
103 |
+
grouped_acc[run] = {}
|
104 |
+
grouped_met[run] = {}
|
105 |
+
if task not in grouped_acc[run]:
|
106 |
+
grouped_acc[run][task] = {}
|
107 |
+
grouped_met[run][task] = {}
|
108 |
+
|
109 |
+
if v is not None:
|
110 |
+
grouped_acc[run][task][subdir] = float(v.get("valid_accuracy", -100))
|
111 |
+
grouped_met[run][task][subdir] = float(v.get(f"valid_{TASK_TO_METRIC[task]}", -100))
|
112 |
+
else:
|
113 |
+
print(f"{path} has None return")
|
114 |
+
|
115 |
+
header = "\t".join(TASKS)
|
116 |
+
for run in sorted(grouped_acc):
|
117 |
+
print(run)
|
118 |
+
if args.print_mode == "all":
|
119 |
+
if args.show_glue:
|
120 |
+
print("===== GLUE =====")
|
121 |
+
print(get_all_stat_str(grouped_met[run]))
|
122 |
+
else:
|
123 |
+
print("===== ACC =====")
|
124 |
+
print(get_all_stat_str(grouped_acc[run]))
|
125 |
+
elif args.print_mode == "best":
|
126 |
+
print(f" {header}")
|
127 |
+
if args.show_glue:
|
128 |
+
print(f"GLEU: {get_best_stat_str(grouped_met[run], args.show_subdir)}")
|
129 |
+
else:
|
130 |
+
print(f"ACC: {get_best_stat_str(grouped_acc[run], args.show_subdir)}")
|
131 |
+
elif args.print_mode == "tabular":
|
132 |
+
if args.show_glue:
|
133 |
+
print("===== GLUE =====")
|
134 |
+
print(get_tabular_stat_str(grouped_met[run]))
|
135 |
+
else:
|
136 |
+
print("===== ACC =====")
|
137 |
+
print(get_tabular_stat_str(grouped_acc[run]))
|
138 |
+
else:
|
139 |
+
raise ValueError(args.print_mode)
|
140 |
+
print()
|
141 |
+
|
142 |
+
if __name__ == "__main__":
|
143 |
+
main()
|
fairseq/examples/data2vec/tasks/audio_classification.py
ADDED
@@ -0,0 +1,167 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2017-present, Facebook, Inc.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the LICENSE file in
|
5 |
+
# the root directory of this source tree. An additional grant of patent rights
|
6 |
+
# can be found in the PATENTS file in the same directory.
|
7 |
+
|
8 |
+
import logging
|
9 |
+
import os
|
10 |
+
import numpy as np
|
11 |
+
import math
|
12 |
+
import torch
|
13 |
+
|
14 |
+
from sklearn import metrics as sklearn_metrics
|
15 |
+
from dataclasses import dataclass
|
16 |
+
|
17 |
+
from fairseq.tasks.audio_pretraining import AudioPretrainingTask, AudioPretrainingConfig
|
18 |
+
from fairseq.tasks import register_task
|
19 |
+
from fairseq.logging import metrics
|
20 |
+
|
21 |
+
from ..data.add_class_target_dataset import AddClassTargetDataset
|
22 |
+
|
23 |
+
|
24 |
+
logger = logging.getLogger(__name__)
|
25 |
+
|
26 |
+
|
27 |
+
@dataclass
|
28 |
+
class AudioClassificationConfig(AudioPretrainingConfig):
|
29 |
+
label_descriptors: str = "label_descriptors.csv"
|
30 |
+
labels: str = "lbl"
|
31 |
+
|
32 |
+
|
33 |
+
@register_task("audio_classification", dataclass=AudioClassificationConfig)
|
34 |
+
class AudioClassificationTask(AudioPretrainingTask):
|
35 |
+
""" """
|
36 |
+
|
37 |
+
cfg: AudioClassificationConfig
|
38 |
+
|
39 |
+
def __init__(
|
40 |
+
self,
|
41 |
+
cfg: AudioClassificationConfig,
|
42 |
+
):
|
43 |
+
super().__init__(cfg)
|
44 |
+
|
45 |
+
self.state.add_factory("labels", self.load_labels)
|
46 |
+
|
47 |
+
def load_labels(self):
|
48 |
+
labels = {}
|
49 |
+
path = os.path.join(self.cfg.data, self.cfg.label_descriptors)
|
50 |
+
with open(path, "r") as ldf:
|
51 |
+
for line in ldf:
|
52 |
+
if line.strip() == "":
|
53 |
+
continue
|
54 |
+
items = line.split(",")
|
55 |
+
idx = items[0]
|
56 |
+
lbl = items[1]
|
57 |
+
assert lbl not in labels, lbl
|
58 |
+
labels[lbl] = idx
|
59 |
+
return labels
|
60 |
+
|
61 |
+
@property
|
62 |
+
def labels(self):
|
63 |
+
return self.state.labels
|
64 |
+
|
65 |
+
def load_dataset(
|
66 |
+
self, split: str, task_cfg: AudioClassificationConfig = None, **kwargs
|
67 |
+
):
|
68 |
+
super().load_dataset(split, task_cfg, **kwargs)
|
69 |
+
|
70 |
+
task_cfg = task_cfg or self.cfg
|
71 |
+
|
72 |
+
data_path = self.cfg.data
|
73 |
+
label_path = os.path.join(data_path, f"{split}.{task_cfg.labels}")
|
74 |
+
skipped_indices = getattr(self.datasets[split], "skipped_indices", set())
|
75 |
+
labels = []
|
76 |
+
with open(label_path, "r") as f:
|
77 |
+
for i, line in enumerate(f):
|
78 |
+
if i not in skipped_indices:
|
79 |
+
lbl_items = line.rstrip().split("\t")
|
80 |
+
labels.append([int(x) for x in lbl_items[2].split(",")])
|
81 |
+
|
82 |
+
assert len(labels) == len(self.datasets[split]), (
|
83 |
+
f"labels length ({len(labels)}) and dataset length "
|
84 |
+
f"({len(self.datasets[split])}) do not match"
|
85 |
+
)
|
86 |
+
|
87 |
+
self.datasets[split] = AddClassTargetDataset(
|
88 |
+
self.datasets[split],
|
89 |
+
labels,
|
90 |
+
multi_class=True,
|
91 |
+
add_to_input=True,
|
92 |
+
num_classes=len(self.labels),
|
93 |
+
)
|
94 |
+
|
95 |
+
def calculate_stats(self, output, target):
|
96 |
+
|
97 |
+
classes_num = target.shape[-1]
|
98 |
+
stats = []
|
99 |
+
|
100 |
+
# Accuracy, only used for single-label classification such as esc-50, not for multiple label one such as AudioSet
|
101 |
+
# acc = sklearn_metrics.accuracy_score(np.argmax(target, 1), np.argmax(output, 1))
|
102 |
+
|
103 |
+
# Class-wise statistics
|
104 |
+
for k in range(classes_num):
|
105 |
+
# Average precision
|
106 |
+
avg_precision = sklearn_metrics.average_precision_score(
|
107 |
+
target[:, k], output[:, k], average=None
|
108 |
+
)
|
109 |
+
|
110 |
+
dict = {
|
111 |
+
"AP": avg_precision,
|
112 |
+
}
|
113 |
+
|
114 |
+
# # AUC
|
115 |
+
# try:
|
116 |
+
# auc = sklearn_metrics.roc_auc_score(target[:, k], output[:, k], average=None)
|
117 |
+
# except:
|
118 |
+
# auc = 0
|
119 |
+
#
|
120 |
+
# # Precisions, recalls
|
121 |
+
# (precisions, recalls, thresholds) = sklearn_metrics.precision_recall_curve(
|
122 |
+
# target[:, k], output[:, k]
|
123 |
+
# )
|
124 |
+
#
|
125 |
+
# # FPR, TPR
|
126 |
+
# (fpr, tpr, thresholds) = sklearn_metrics.roc_curve(target[:, k], output[:, k])
|
127 |
+
#
|
128 |
+
# save_every_steps = 1000 # Sample statistics to reduce size
|
129 |
+
# dict = {
|
130 |
+
# "precisions": precisions[0::save_every_steps],
|
131 |
+
# "recalls": recalls[0::save_every_steps],
|
132 |
+
# "AP": avg_precision,
|
133 |
+
# "fpr": fpr[0::save_every_steps],
|
134 |
+
# "fnr": 1.0 - tpr[0::save_every_steps],
|
135 |
+
# "auc": auc,
|
136 |
+
# # note acc is not class-wise, this is just to keep consistent with other metrics
|
137 |
+
# "acc": acc,
|
138 |
+
# }
|
139 |
+
stats.append(dict)
|
140 |
+
|
141 |
+
return stats
|
142 |
+
|
143 |
+
def valid_step(self, sample, model, criterion):
|
144 |
+
loss, sample_size, logging_output = super().valid_step(sample, model, criterion)
|
145 |
+
return loss, sample_size, logging_output
|
146 |
+
|
147 |
+
def reduce_metrics(self, logging_outputs, criterion):
|
148 |
+
super().reduce_metrics(logging_outputs, criterion)
|
149 |
+
if "_predictions" in logging_outputs[0]:
|
150 |
+
metrics.log_concat_tensor(
|
151 |
+
"_predictions",
|
152 |
+
torch.cat([l["_predictions"].cpu() for l in logging_outputs], dim=0),
|
153 |
+
)
|
154 |
+
metrics.log_concat_tensor(
|
155 |
+
"_targets",
|
156 |
+
torch.cat([l["_targets"].cpu() for l in logging_outputs], dim=0),
|
157 |
+
)
|
158 |
+
|
159 |
+
def compute_stats(meters):
|
160 |
+
if meters["_predictions"].tensor.shape[0] < 100:
|
161 |
+
return 0
|
162 |
+
stats = self.calculate_stats(
|
163 |
+
meters["_predictions"].tensor, meters["_targets"].tensor
|
164 |
+
)
|
165 |
+
return np.nanmean([stat["AP"] for stat in stats])
|
166 |
+
|
167 |
+
metrics.log_derived("mAP", compute_stats)
|
fairseq/examples/data2vec/tasks/image_classification.py
ADDED
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2017-present, Facebook, Inc.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the LICENSE file in
|
5 |
+
# the root directory of this source tree. An additional grant of patent rights
|
6 |
+
# can be found in the PATENTS file in the same directory.
|
7 |
+
|
8 |
+
import os.path as osp
|
9 |
+
import logging
|
10 |
+
|
11 |
+
from dataclasses import dataclass
|
12 |
+
import torch
|
13 |
+
from torchvision import transforms
|
14 |
+
|
15 |
+
from fairseq.dataclass import FairseqDataclass
|
16 |
+
from fairseq.tasks import register_task
|
17 |
+
from fairseq.logging import metrics
|
18 |
+
|
19 |
+
try:
|
20 |
+
from ..data import ImageDataset
|
21 |
+
except:
|
22 |
+
import sys
|
23 |
+
|
24 |
+
sys.path.append("..")
|
25 |
+
from data import ImageDataset
|
26 |
+
|
27 |
+
from .image_pretraining import (
|
28 |
+
ImagePretrainingConfig,
|
29 |
+
ImagePretrainingTask,
|
30 |
+
IMG_EXTENSIONS,
|
31 |
+
)
|
32 |
+
|
33 |
+
logger = logging.getLogger(__name__)
|
34 |
+
|
35 |
+
|
36 |
+
@dataclass
|
37 |
+
class ImageClassificationConfig(ImagePretrainingConfig):
|
38 |
+
pass
|
39 |
+
|
40 |
+
|
41 |
+
@register_task("image_classification", dataclass=ImageClassificationConfig)
|
42 |
+
class ImageClassificationTask(ImagePretrainingTask):
|
43 |
+
|
44 |
+
cfg: ImageClassificationConfig
|
45 |
+
|
46 |
+
@classmethod
|
47 |
+
def setup_task(cls, cfg: ImageClassificationConfig, **kwargs):
|
48 |
+
return cls(cfg)
|
49 |
+
|
50 |
+
def load_dataset(self, split: str, task_cfg: FairseqDataclass = None, **kwargs):
|
51 |
+
data_path = self.cfg.data
|
52 |
+
cfg = task_cfg or self.cfg
|
53 |
+
|
54 |
+
path_with_split = osp.join(data_path, split)
|
55 |
+
if osp.exists(path_with_split):
|
56 |
+
data_path = path_with_split
|
57 |
+
|
58 |
+
from timm.data import create_transform
|
59 |
+
|
60 |
+
if split == "train":
|
61 |
+
# this should always dispatch to transforms_imagenet_train
|
62 |
+
transform = create_transform(
|
63 |
+
input_size=cfg.input_size,
|
64 |
+
is_training=True,
|
65 |
+
auto_augment="rand-m9-mstd0.5-inc1",
|
66 |
+
interpolation="bicubic",
|
67 |
+
re_prob=0.25,
|
68 |
+
re_mode="pixel",
|
69 |
+
re_count=1,
|
70 |
+
mean=cfg.normalization_mean,
|
71 |
+
std=cfg.normalization_std,
|
72 |
+
)
|
73 |
+
if not cfg.input_size > 32:
|
74 |
+
transform.transforms[0] = transforms.RandomCrop(
|
75 |
+
cfg.input_size, padding=4
|
76 |
+
)
|
77 |
+
else:
|
78 |
+
t = []
|
79 |
+
if cfg.input_size > 32:
|
80 |
+
crop_pct = 1
|
81 |
+
if cfg.input_size < 384:
|
82 |
+
crop_pct = 224 / 256
|
83 |
+
size = int(cfg.input_size / crop_pct)
|
84 |
+
t.append(
|
85 |
+
transforms.Resize(
|
86 |
+
size, interpolation=3
|
87 |
+
), # to maintain same ratio w.r.t. 224 images
|
88 |
+
)
|
89 |
+
t.append(transforms.CenterCrop(cfg.input_size))
|
90 |
+
|
91 |
+
t.append(transforms.ToTensor())
|
92 |
+
t.append(
|
93 |
+
transforms.Normalize(cfg.normalization_mean, cfg.normalization_std)
|
94 |
+
)
|
95 |
+
transform = transforms.Compose(t)
|
96 |
+
logger.info(transform)
|
97 |
+
|
98 |
+
self.datasets[split] = ImageDataset(
|
99 |
+
root=data_path,
|
100 |
+
extensions=IMG_EXTENSIONS,
|
101 |
+
load_classes=True,
|
102 |
+
transform=transform,
|
103 |
+
)
|
104 |
+
for k in self.datasets.keys():
|
105 |
+
if k != split:
|
106 |
+
assert self.datasets[k].classes == self.datasets[split].classes
|
107 |
+
|
108 |
+
def build_model(self, model_cfg: FairseqDataclass, from_checkpoint=False):
|
109 |
+
model = super().build_model(model_cfg, from_checkpoint)
|
110 |
+
|
111 |
+
actualized_cfg = getattr(model, "cfg", None)
|
112 |
+
if actualized_cfg is not None:
|
113 |
+
if hasattr(actualized_cfg, "pretrained_model_args"):
|
114 |
+
model_cfg.pretrained_model_args = actualized_cfg.pretrained_model_args
|
115 |
+
|
116 |
+
return model
|
117 |
+
|
118 |
+
def reduce_metrics(self, logging_outputs, criterion):
|
119 |
+
super().reduce_metrics(logging_outputs, criterion)
|
120 |
+
|
121 |
+
if "correct" in logging_outputs[0]:
|
122 |
+
zero = torch.scalar_tensor(0.0)
|
123 |
+
correct = sum(log.get("correct", zero) for log in logging_outputs)
|
124 |
+
metrics.log_scalar_sum("_correct", correct)
|
125 |
+
|
126 |
+
metrics.log_derived(
|
127 |
+
"accuracy",
|
128 |
+
lambda meters: 100 * meters["_correct"].sum / meters["sample_size"].sum,
|
129 |
+
)
|
fairseq/examples/data2vec/tasks/image_pretraining.py
ADDED
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2017-present, Facebook, Inc.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the LICENSE file in
|
5 |
+
# the root directory of this source tree. An additional grant of patent rights
|
6 |
+
# can be found in the PATENTS file in the same directory.
|
7 |
+
|
8 |
+
import logging
|
9 |
+
import sys
|
10 |
+
import os.path as osp
|
11 |
+
|
12 |
+
from dataclasses import dataclass, field
|
13 |
+
from typing import List
|
14 |
+
from omegaconf import MISSING
|
15 |
+
|
16 |
+
import torch
|
17 |
+
from torchvision import transforms
|
18 |
+
|
19 |
+
from fairseq.dataclass import FairseqDataclass
|
20 |
+
from fairseq.tasks import FairseqTask, register_task
|
21 |
+
|
22 |
+
try:
|
23 |
+
from ..data import ImageDataset
|
24 |
+
except:
|
25 |
+
sys.path.append("..")
|
26 |
+
from data import ImageDataset
|
27 |
+
|
28 |
+
logger = logging.getLogger(__name__)
|
29 |
+
|
30 |
+
IMG_EXTENSIONS = {
|
31 |
+
".jpg",
|
32 |
+
".jpeg",
|
33 |
+
".png",
|
34 |
+
".ppm",
|
35 |
+
".bmp",
|
36 |
+
".pgm",
|
37 |
+
".tif",
|
38 |
+
".tiff",
|
39 |
+
".webp",
|
40 |
+
}
|
41 |
+
|
42 |
+
|
43 |
+
@dataclass
|
44 |
+
class ImagePretrainingConfig(FairseqDataclass):
|
45 |
+
data: str = field(default=MISSING, metadata={"help": "path to data directory"})
|
46 |
+
input_size: int = 224
|
47 |
+
normalization_mean: List[float] = (0.485, 0.456, 0.406)
|
48 |
+
normalization_std: List[float] = (0.229, 0.224, 0.225)
|
49 |
+
|
50 |
+
|
51 |
+
@register_task("image_pretraining", dataclass=ImagePretrainingConfig)
|
52 |
+
class ImagePretrainingTask(FairseqTask):
|
53 |
+
""" """
|
54 |
+
|
55 |
+
cfg: ImagePretrainingConfig
|
56 |
+
|
57 |
+
@classmethod
|
58 |
+
def setup_task(cls, cfg: ImagePretrainingConfig, **kwargs):
|
59 |
+
"""Setup the task (e.g., load dictionaries).
|
60 |
+
|
61 |
+
Args:
|
62 |
+
cfg (AudioPretrainingConfig): configuration of this task
|
63 |
+
"""
|
64 |
+
|
65 |
+
return cls(cfg)
|
66 |
+
|
67 |
+
def load_dataset(self, split: str, task_cfg: FairseqDataclass = None, **kwargs):
|
68 |
+
data_path = self.cfg.data
|
69 |
+
cfg = task_cfg or self.cfg
|
70 |
+
|
71 |
+
path_with_split = osp.join(data_path, split)
|
72 |
+
if osp.exists(path_with_split):
|
73 |
+
data_path = path_with_split
|
74 |
+
|
75 |
+
transform = transforms.Compose(
|
76 |
+
[
|
77 |
+
transforms.ColorJitter(0.4, 0.4, 0.4),
|
78 |
+
transforms.RandomHorizontalFlip(p=0.5),
|
79 |
+
transforms.RandomResizedCrop(
|
80 |
+
size=cfg.input_size,
|
81 |
+
interpolation=transforms.InterpolationMode.BICUBIC,
|
82 |
+
),
|
83 |
+
transforms.ToTensor(),
|
84 |
+
transforms.Normalize(
|
85 |
+
mean=torch.tensor(cfg.normalization_mean),
|
86 |
+
std=torch.tensor(cfg.normalization_std),
|
87 |
+
),
|
88 |
+
]
|
89 |
+
)
|
90 |
+
|
91 |
+
logger.info(transform)
|
92 |
+
|
93 |
+
self.datasets[split] = ImageDataset(
|
94 |
+
root=data_path,
|
95 |
+
extensions=IMG_EXTENSIONS,
|
96 |
+
load_classes=False,
|
97 |
+
transform=transform,
|
98 |
+
)
|
99 |
+
|
100 |
+
@property
|
101 |
+
def source_dictionary(self):
|
102 |
+
return None
|
103 |
+
|
104 |
+
@property
|
105 |
+
def target_dictionary(self):
|
106 |
+
return None
|
107 |
+
|
108 |
+
def max_positions(self):
|
109 |
+
"""Maximum input length supported by the encoder."""
|
110 |
+
return sys.maxsize, sys.maxsize
|
fairseq/examples/data2vec/tasks/mae_image_pretraining.py
ADDED
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2017-present, Facebook, Inc.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the LICENSE file in
|
5 |
+
# the root directory of this source tree. An additional grant of patent rights
|
6 |
+
# can be found in the PATENTS file in the same directory.
|
7 |
+
|
8 |
+
import logging
|
9 |
+
import sys
|
10 |
+
|
11 |
+
from typing import Optional, List
|
12 |
+
from dataclasses import dataclass, field
|
13 |
+
from omegaconf import MISSING, II
|
14 |
+
|
15 |
+
from fairseq.data import SubsampleDataset
|
16 |
+
from fairseq.dataclass import FairseqDataclass
|
17 |
+
from fairseq.tasks import FairseqTask, register_task
|
18 |
+
|
19 |
+
try:
|
20 |
+
from ..data import MaeImageDataset
|
21 |
+
except:
|
22 |
+
sys.path.append("..")
|
23 |
+
from data import MaeImageDataset
|
24 |
+
|
25 |
+
logger = logging.getLogger(__name__)
|
26 |
+
|
27 |
+
|
28 |
+
@dataclass
|
29 |
+
class ImageMaskingConfig:
|
30 |
+
patch_size: int = II("model.modalities.image.patch_size")
|
31 |
+
mask_prob: float = II("model.modalities.image.mask_prob")
|
32 |
+
mask_prob_adjust: float = II("model.modalities.image.mask_prob_adjust")
|
33 |
+
mask_length: int = II("model.modalities.image.mask_length")
|
34 |
+
inverse_mask: bool = II("model.modalities.image.inverse_mask")
|
35 |
+
mask_dropout: float = II("model.modalities.image.mask_dropout")
|
36 |
+
clone_batch: int = II("model.clone_batch")
|
37 |
+
expand_adjacent: bool = False
|
38 |
+
non_overlapping: bool = False
|
39 |
+
|
40 |
+
|
41 |
+
@dataclass
|
42 |
+
class MaeImagePretrainingConfig(FairseqDataclass):
|
43 |
+
data: str = field(default=MISSING, metadata={"help": "path to data directory"})
|
44 |
+
multi_data: Optional[List[str]] = None
|
45 |
+
input_size: int = 224
|
46 |
+
local_cache_path: Optional[str] = None
|
47 |
+
key: str = "imgs"
|
48 |
+
|
49 |
+
beit_transforms: bool = False
|
50 |
+
target_transform: bool = False
|
51 |
+
no_transform: bool = False
|
52 |
+
|
53 |
+
rebuild_batches: bool = True
|
54 |
+
|
55 |
+
precompute_mask_config: Optional[ImageMaskingConfig] = None
|
56 |
+
|
57 |
+
subsample: float = 1
|
58 |
+
seed: int = II("common.seed")
|
59 |
+
dataset_type: str = "imagefolder"
|
60 |
+
|
61 |
+
|
62 |
+
@register_task("mae_image_pretraining", dataclass=MaeImagePretrainingConfig)
|
63 |
+
class MaeImagePretrainingTask(FairseqTask):
|
64 |
+
""" """
|
65 |
+
|
66 |
+
cfg: MaeImagePretrainingConfig
|
67 |
+
|
68 |
+
@classmethod
|
69 |
+
def setup_task(cls, cfg: MaeImagePretrainingConfig, **kwargs):
|
70 |
+
"""Setup the task (e.g., load dictionaries).
|
71 |
+
|
72 |
+
Args:
|
73 |
+
cfg (AudioPretrainingConfig): configuration of this task
|
74 |
+
"""
|
75 |
+
|
76 |
+
return cls(cfg)
|
77 |
+
|
78 |
+
def load_dataset(self, split: str, task_cfg: FairseqDataclass = None, **kwargs):
|
79 |
+
data_path = self.cfg.data
|
80 |
+
cfg = task_cfg or self.cfg
|
81 |
+
|
82 |
+
compute_mask = cfg.precompute_mask_config is not None
|
83 |
+
mask_args = {}
|
84 |
+
if compute_mask:
|
85 |
+
mask_args = cfg.precompute_mask_config
|
86 |
+
|
87 |
+
self.datasets[split] = MaeImageDataset(
|
88 |
+
root=data_path if cfg.multi_data is None else cfg.multi_data,
|
89 |
+
split=split,
|
90 |
+
input_size=cfg.input_size,
|
91 |
+
local_cache_path=cfg.local_cache_path,
|
92 |
+
key=cfg.key,
|
93 |
+
beit_transforms=cfg.beit_transforms,
|
94 |
+
target_transform=cfg.target_transform,
|
95 |
+
no_transform=cfg.no_transform,
|
96 |
+
compute_mask=compute_mask,
|
97 |
+
dataset_type=cfg.dataset_type,
|
98 |
+
**mask_args,
|
99 |
+
)
|
100 |
+
|
101 |
+
if cfg.subsample < 1:
|
102 |
+
self.datasets[split] = SubsampleDataset(
|
103 |
+
self.datasets[split],
|
104 |
+
cfg.subsample,
|
105 |
+
shuffle=True,
|
106 |
+
seed=cfg.seed,
|
107 |
+
)
|
108 |
+
|
109 |
+
@property
|
110 |
+
def source_dictionary(self):
|
111 |
+
return None
|
112 |
+
|
113 |
+
@property
|
114 |
+
def target_dictionary(self):
|
115 |
+
return None
|
116 |
+
|
117 |
+
def max_positions(self):
|
118 |
+
"""Maximum input length supported by the encoder."""
|
119 |
+
return sys.maxsize, sys.maxsize
|
fairseq/examples/emotion_conversion/emotion_models/__init__.py
ADDED
File without changes
|
fairseq/examples/emotion_conversion/emotion_models/duration_predictor.py
ADDED
@@ -0,0 +1,243 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import os
|
3 |
+
|
4 |
+
import hydra
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
import torch.nn.functional as F
|
8 |
+
from einops.layers.torch import Rearrange
|
9 |
+
from torch.utils.data import DataLoader, Dataset
|
10 |
+
|
11 |
+
from .utils import Accuracy
|
12 |
+
|
13 |
+
logger = logging.getLogger(__name__)
|
14 |
+
|
15 |
+
|
16 |
+
def save_ckpt(model, path, model_class):
|
17 |
+
ckpt = {
|
18 |
+
"state_dict": model.state_dict(),
|
19 |
+
"padding_token": model.padding_token,
|
20 |
+
"model_class": model_class,
|
21 |
+
}
|
22 |
+
torch.save(ckpt, path)
|
23 |
+
|
24 |
+
|
25 |
+
def load_ckpt(path):
|
26 |
+
ckpt = torch.load(path)
|
27 |
+
ckpt["model_class"]["_target_"] = "emotion_models.duration_predictor.CnnPredictor"
|
28 |
+
model = hydra.utils.instantiate(ckpt["model_class"])
|
29 |
+
model.load_state_dict(ckpt["state_dict"])
|
30 |
+
model.padding_token = ckpt["padding_token"]
|
31 |
+
model = model.cpu()
|
32 |
+
model.eval()
|
33 |
+
return model
|
34 |
+
|
35 |
+
|
36 |
+
class Collator:
|
37 |
+
def __init__(self, padding_idx):
|
38 |
+
self.padding_idx = padding_idx
|
39 |
+
|
40 |
+
def __call__(self, batch):
|
41 |
+
x = [item[0] for item in batch]
|
42 |
+
lengths = [len(item) for item in x]
|
43 |
+
x = torch.nn.utils.rnn.pad_sequence(x, batch_first=True, padding_value=self.padding_idx)
|
44 |
+
y = [item[1] for item in batch]
|
45 |
+
y = torch.nn.utils.rnn.pad_sequence(y, batch_first=True, padding_value=self.padding_idx)
|
46 |
+
mask = (x != self.padding_idx)
|
47 |
+
return x, y, mask, lengths
|
48 |
+
|
49 |
+
|
50 |
+
class Predictor(nn.Module):
|
51 |
+
def __init__(self, n_tokens, emb_dim):
|
52 |
+
super(Predictor, self).__init__()
|
53 |
+
self.n_tokens = n_tokens
|
54 |
+
self.emb_dim = emb_dim
|
55 |
+
self.padding_token = n_tokens
|
56 |
+
# add 1 extra embedding for padding token, set the padding index to be the last token
|
57 |
+
# (tokens from the clustering start at index 0)
|
58 |
+
self.emb = nn.Embedding(n_tokens + 1, emb_dim, padding_idx=self.padding_token)
|
59 |
+
|
60 |
+
def inflate_input(self, batch):
|
61 |
+
""" get a sequence of tokens, predict their durations
|
62 |
+
and inflate them accordingly """
|
63 |
+
batch_durs = self.forward(batch)
|
64 |
+
batch_durs = torch.exp(batch_durs) - 1
|
65 |
+
batch_durs = batch_durs.round()
|
66 |
+
output = []
|
67 |
+
for seq, durs in zip(batch, batch_durs):
|
68 |
+
inflated_seq = []
|
69 |
+
for token, n in zip(seq, durs):
|
70 |
+
if token == self.padding_token:
|
71 |
+
break
|
72 |
+
n = int(n.item())
|
73 |
+
token = int(token.item())
|
74 |
+
inflated_seq.extend([token for _ in range(n)])
|
75 |
+
output.append(inflated_seq)
|
76 |
+
output = torch.LongTensor(output)
|
77 |
+
return output
|
78 |
+
|
79 |
+
|
80 |
+
class CnnPredictor(Predictor):
|
81 |
+
def __init__(self, n_tokens, emb_dim, channels, kernel, output_dim, dropout, n_layers):
|
82 |
+
super(CnnPredictor, self).__init__(n_tokens=n_tokens, emb_dim=emb_dim)
|
83 |
+
layers = [
|
84 |
+
Rearrange("b t c -> b c t"),
|
85 |
+
nn.Conv1d(emb_dim, channels, kernel_size=kernel, padding=(kernel - 1) // 2),
|
86 |
+
Rearrange("b c t -> b t c"),
|
87 |
+
nn.ReLU(),
|
88 |
+
nn.LayerNorm(channels),
|
89 |
+
nn.Dropout(dropout),
|
90 |
+
]
|
91 |
+
for _ in range(n_layers-1):
|
92 |
+
layers += [
|
93 |
+
Rearrange("b t c -> b c t"),
|
94 |
+
nn.Conv1d(channels, channels, kernel_size=kernel, padding=(kernel - 1) // 2),
|
95 |
+
Rearrange("b c t -> b t c"),
|
96 |
+
nn.ReLU(),
|
97 |
+
nn.LayerNorm(channels),
|
98 |
+
nn.Dropout(dropout),
|
99 |
+
]
|
100 |
+
self.conv_layer = nn.Sequential(*layers)
|
101 |
+
self.proj = nn.Linear(channels, output_dim)
|
102 |
+
|
103 |
+
def forward(self, x):
|
104 |
+
x = self.emb(x)
|
105 |
+
x = self.conv_layer(x)
|
106 |
+
x = self.proj(x)
|
107 |
+
x = x.squeeze(-1)
|
108 |
+
return x
|
109 |
+
|
110 |
+
|
111 |
+
def l2_log_loss(input, target):
|
112 |
+
return F.mse_loss(
|
113 |
+
input=input.float(),
|
114 |
+
target=torch.log(target.float() + 1),
|
115 |
+
reduce=False
|
116 |
+
)
|
117 |
+
|
118 |
+
|
119 |
+
class DurationDataset(Dataset):
|
120 |
+
def __init__(self, tsv_path, km_path, substring=""):
|
121 |
+
lines = open(tsv_path, "r").readlines()
|
122 |
+
self.root, self.tsv = lines[0], lines[1:]
|
123 |
+
self.km = open(km_path, "r").readlines()
|
124 |
+
logger.info(f"loaded {len(self.km)} files")
|
125 |
+
|
126 |
+
if substring != "":
|
127 |
+
tsv, km = [], []
|
128 |
+
for tsv_line, km_line in zip(self.tsv, self.km):
|
129 |
+
if substring.lower() in tsv_line.lower():
|
130 |
+
tsv.append(tsv_line)
|
131 |
+
km.append(km_line)
|
132 |
+
self.tsv, self.km = tsv, km
|
133 |
+
logger.info(f"after filtering: {len(self.km)} files")
|
134 |
+
|
135 |
+
def __len__(self):
|
136 |
+
return len(self.km)
|
137 |
+
|
138 |
+
def __getitem__(self, i):
|
139 |
+
x = self.km[i]
|
140 |
+
x = x.split(" ")
|
141 |
+
x = list(map(int, x))
|
142 |
+
|
143 |
+
y = []
|
144 |
+
xd = []
|
145 |
+
count = 1
|
146 |
+
for x1, x2 in zip(x[:-1], x[1:]):
|
147 |
+
if x1 == x2:
|
148 |
+
count += 1
|
149 |
+
continue
|
150 |
+
else:
|
151 |
+
y.append(count)
|
152 |
+
xd.append(x1)
|
153 |
+
count = 1
|
154 |
+
|
155 |
+
xd = torch.LongTensor(xd)
|
156 |
+
y = torch.LongTensor(y)
|
157 |
+
return xd, y
|
158 |
+
|
159 |
+
|
160 |
+
def train(cfg):
|
161 |
+
device = "cuda:0"
|
162 |
+
model = hydra.utils.instantiate(cfg[cfg.model]).to(device)
|
163 |
+
optimizer = hydra.utils.instantiate(cfg.optimizer, model.parameters())
|
164 |
+
# add 1 extra embedding for padding token, set the padding index to be the last token
|
165 |
+
# (tokens from the clustering start at index 0)
|
166 |
+
collate_fn = Collator(padding_idx=model.padding_token)
|
167 |
+
logger.info(f"data: {cfg.train_tsv}")
|
168 |
+
train_ds = DurationDataset(cfg.train_tsv, cfg.train_km, substring=cfg.substring)
|
169 |
+
valid_ds = DurationDataset(cfg.valid_tsv, cfg.valid_km, substring=cfg.substring)
|
170 |
+
train_dl = DataLoader(train_ds, batch_size=32, shuffle=True, collate_fn=collate_fn)
|
171 |
+
valid_dl = DataLoader(valid_ds, batch_size=32, shuffle=False, collate_fn=collate_fn)
|
172 |
+
|
173 |
+
best_loss = float("inf")
|
174 |
+
for epoch in range(cfg.epochs):
|
175 |
+
train_loss, train_loss_scaled = train_epoch(model, train_dl, l2_log_loss, optimizer, device)
|
176 |
+
valid_loss, valid_loss_scaled, *acc = valid_epoch(model, valid_dl, l2_log_loss, device)
|
177 |
+
acc0, acc1, acc2, acc3 = acc
|
178 |
+
if valid_loss_scaled < best_loss:
|
179 |
+
path = f"{os.getcwd()}/{cfg.substring}.ckpt"
|
180 |
+
save_ckpt(model, path, cfg[cfg.model])
|
181 |
+
best_loss = valid_loss_scaled
|
182 |
+
logger.info(f"saved checkpoint: {path}")
|
183 |
+
logger.info(f"[epoch {epoch}] train loss: {train_loss:.3f}, train scaled: {train_loss_scaled:.3f}")
|
184 |
+
logger.info(f"[epoch {epoch}] valid loss: {valid_loss:.3f}, valid scaled: {valid_loss_scaled:.3f}")
|
185 |
+
logger.info(f"acc: {acc0,acc1,acc2,acc3}")
|
186 |
+
|
187 |
+
|
188 |
+
def train_epoch(model, loader, criterion, optimizer, device):
|
189 |
+
model.train()
|
190 |
+
epoch_loss = 0
|
191 |
+
epoch_loss_scaled = 0
|
192 |
+
for x, y, mask, _ in loader:
|
193 |
+
x, y, mask = x.to(device), y.to(device), mask.to(device)
|
194 |
+
yhat = model(x)
|
195 |
+
loss = criterion(yhat, y) * mask
|
196 |
+
loss = torch.mean(loss)
|
197 |
+
loss.backward()
|
198 |
+
nn.utils.clip_grad_norm_(model.parameters(), 1.0)
|
199 |
+
optimizer.step()
|
200 |
+
epoch_loss += loss.item()
|
201 |
+
# get normal scale loss
|
202 |
+
yhat_scaled = torch.exp(yhat) - 1
|
203 |
+
yhat_scaled = torch.round(yhat_scaled)
|
204 |
+
scaled_loss = torch.mean(torch.abs(yhat_scaled - y) * mask)
|
205 |
+
epoch_loss_scaled += scaled_loss.item()
|
206 |
+
return epoch_loss / len(loader), epoch_loss_scaled / len(loader)
|
207 |
+
|
208 |
+
|
209 |
+
def valid_epoch(model, loader, criterion, device):
|
210 |
+
model.eval()
|
211 |
+
epoch_loss = 0
|
212 |
+
epoch_loss_scaled = 0
|
213 |
+
acc = Accuracy()
|
214 |
+
for x, y, mask, _ in loader:
|
215 |
+
x, y, mask = x.to(device), y.to(device), mask.to(device)
|
216 |
+
yhat = model(x)
|
217 |
+
loss = criterion(yhat, y) * mask
|
218 |
+
loss = torch.mean(loss)
|
219 |
+
epoch_loss += loss.item()
|
220 |
+
# get normal scale loss
|
221 |
+
yhat_scaled = torch.exp(yhat) - 1
|
222 |
+
yhat_scaled = torch.round(yhat_scaled)
|
223 |
+
scaled_loss = torch.sum(torch.abs(yhat_scaled - y) * mask) / mask.sum()
|
224 |
+
acc.update(yhat_scaled[mask].view(-1).float(), y[mask].view(-1).float())
|
225 |
+
epoch_loss_scaled += scaled_loss.item()
|
226 |
+
logger.info(f"example y: {y[0, :10].tolist()}")
|
227 |
+
logger.info(f"example yhat: {yhat_scaled[0, :10].tolist()}")
|
228 |
+
acc0 = acc.acc(tol=0)
|
229 |
+
acc1 = acc.acc(tol=1)
|
230 |
+
acc2 = acc.acc(tol=2)
|
231 |
+
acc3 = acc.acc(tol=3)
|
232 |
+
logger.info(f"accs: {acc0,acc1,acc2,acc3}")
|
233 |
+
return epoch_loss / len(loader), epoch_loss_scaled / len(loader), acc0, acc1, acc2, acc3
|
234 |
+
|
235 |
+
|
236 |
+
@hydra.main(config_path=".", config_name="duration_predictor.yaml")
|
237 |
+
def main(cfg):
|
238 |
+
logger.info(f"{cfg}")
|
239 |
+
train(cfg)
|
240 |
+
|
241 |
+
|
242 |
+
if __name__ == "__main__":
|
243 |
+
main()
|
fairseq/examples/emotion_conversion/emotion_models/duration_predictor.yaml
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
train_tsv: "<your-processed-data>/denoising/emov/train.tsv"
|
2 |
+
train_km: "<your-processed-data>/denoising/emov/train.km"
|
3 |
+
valid_tsv: "<your-processed-data>/denoising/emov/valid.tsv"
|
4 |
+
valid_km: "<your-processed-data>/denoising/emov/valid.km"
|
5 |
+
|
6 |
+
n_tokens: 200
|
7 |
+
batch_size: 32
|
8 |
+
lr: 0.0001
|
9 |
+
epochs: 300
|
10 |
+
model: "cnn"
|
11 |
+
substring: ""
|
12 |
+
|
13 |
+
rnn:
|
14 |
+
_target_: emotion_models.duration_predictor.RnnPredictor
|
15 |
+
n_tokens: ${n_tokens}
|
16 |
+
emb_dim: 128
|
17 |
+
rnn_hidden: 128
|
18 |
+
output_dim: 1
|
19 |
+
dropout: 0
|
20 |
+
n_layers: 1
|
21 |
+
|
22 |
+
optimizer:
|
23 |
+
_target_: torch.optim.Adam
|
24 |
+
lr: ${lr}
|
25 |
+
betas: [0.9, 0.98]
|
26 |
+
eps: 0.000000001
|
27 |
+
weight_decay: 0
|
28 |
+
|
29 |
+
cnn:
|
30 |
+
_target_: emotion_models.duration_predictor.CnnPredictor
|
31 |
+
n_tokens: ${n_tokens}
|
32 |
+
emb_dim: 128
|
33 |
+
channels: 256
|
34 |
+
kernel: 3
|
35 |
+
output_dim: 1
|
36 |
+
dropout: 0.5
|
37 |
+
n_layers: 1
|
38 |
+
|
39 |
+
hydra:
|
40 |
+
run:
|
41 |
+
dir: /checkpoint/felixkreuk/experiments/duration_predictor/${hydra.job.override_dirname}
|
42 |
+
job:
|
43 |
+
config:
|
44 |
+
# configuration for the ${hydra.job.override_dirname} runtime variable
|
45 |
+
override_dirname:
|
46 |
+
kv_sep: '='
|
47 |
+
item_sep: ','
|
48 |
+
exclude_keys: ['train_tsv', 'train_km', 'valid_tsv', 'valid_km']
|
fairseq/examples/emotion_conversion/emotion_models/pitch_predictor.py
ADDED
@@ -0,0 +1,559 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import os
|
3 |
+
import random
|
4 |
+
import sys
|
5 |
+
from collections import defaultdict
|
6 |
+
|
7 |
+
import hydra
|
8 |
+
import numpy as np
|
9 |
+
import torch
|
10 |
+
import torch.nn as nn
|
11 |
+
import torch.nn.functional as F
|
12 |
+
from einops import rearrange
|
13 |
+
from einops.layers.torch import Rearrange
|
14 |
+
from scipy.io.wavfile import read
|
15 |
+
from scipy.ndimage import gaussian_filter1d
|
16 |
+
from torch.utils.data import DataLoader, Dataset
|
17 |
+
from tqdm import tqdm
|
18 |
+
|
19 |
+
dir_path = os.path.dirname(__file__)
|
20 |
+
resynth_path = os.path.dirname(dir_path) + "/speech-resynthesis"
|
21 |
+
sys.path.append(resynth_path)
|
22 |
+
from dataset import parse_speaker, parse_style
|
23 |
+
from .utils import F0Stat
|
24 |
+
|
25 |
+
MAX_WAV_VALUE = 32768.0
|
26 |
+
logger = logging.getLogger(__name__)
|
27 |
+
|
28 |
+
|
29 |
+
def quantize_f0(speaker_to_f0, nbins, normalize, log):
|
30 |
+
f0_all = []
|
31 |
+
for speaker, f0 in speaker_to_f0.items():
|
32 |
+
f0 = f0.raw_data
|
33 |
+
if log:
|
34 |
+
f0 = f0.log()
|
35 |
+
mean = speaker_to_f0[speaker].mean_log if log else speaker_to_f0[speaker].mean
|
36 |
+
std = speaker_to_f0[speaker].std_log if log else speaker_to_f0[speaker].std
|
37 |
+
if normalize == "mean":
|
38 |
+
f0 = f0 - mean
|
39 |
+
elif normalize == "meanstd":
|
40 |
+
f0 = (f0 - mean) / std
|
41 |
+
f0_all.extend(f0.tolist())
|
42 |
+
|
43 |
+
hist, bin_x = np.histogram(f0_all, 100000)
|
44 |
+
cum_hist = np.cumsum(hist) / len(f0_all) * 100
|
45 |
+
|
46 |
+
bin_offset = []
|
47 |
+
bin_size = 100 / nbins
|
48 |
+
threshold = bin_size
|
49 |
+
for i in range(nbins - 1):
|
50 |
+
index = (np.abs(cum_hist - threshold)).argmin()
|
51 |
+
bin_offset.append(bin_x[index])
|
52 |
+
threshold += bin_size
|
53 |
+
bins = np.array(bin_offset)
|
54 |
+
bins = torch.FloatTensor(bins)
|
55 |
+
|
56 |
+
return bins
|
57 |
+
|
58 |
+
|
59 |
+
def save_ckpt(model, path, model_class, f0_min, f0_max, f0_bins, speaker_stats):
|
60 |
+
ckpt = {
|
61 |
+
"state_dict": model.state_dict(),
|
62 |
+
"padding_token": model.padding_token,
|
63 |
+
"model_class": model_class,
|
64 |
+
"speaker_stats": speaker_stats,
|
65 |
+
"f0_min": f0_min,
|
66 |
+
"f0_max": f0_max,
|
67 |
+
"f0_bins": f0_bins,
|
68 |
+
}
|
69 |
+
torch.save(ckpt, path)
|
70 |
+
|
71 |
+
|
72 |
+
def load_ckpt(path):
|
73 |
+
ckpt = torch.load(path)
|
74 |
+
ckpt["model_class"]["_target_"] = "emotion_models.pitch_predictor.CnnPredictor"
|
75 |
+
model = hydra.utils.instantiate(ckpt["model_class"])
|
76 |
+
model.load_state_dict(ckpt["state_dict"])
|
77 |
+
model.setup_f0_stats(
|
78 |
+
ckpt["f0_min"],
|
79 |
+
ckpt["f0_max"],
|
80 |
+
ckpt["f0_bins"],
|
81 |
+
ckpt["speaker_stats"],
|
82 |
+
)
|
83 |
+
return model
|
84 |
+
|
85 |
+
|
86 |
+
def freq2bin(f0, f0_min, f0_max, bins):
|
87 |
+
f0 = f0.clone()
|
88 |
+
f0[f0 < f0_min] = f0_min
|
89 |
+
f0[f0 > f0_max] = f0_max
|
90 |
+
f0 = torch.bucketize(f0, bins)
|
91 |
+
return f0
|
92 |
+
|
93 |
+
|
94 |
+
def bin2freq(x, f0_min, f0_max, bins, mode):
|
95 |
+
n_bins = len(bins) + 1
|
96 |
+
assert x.shape[-1] == n_bins
|
97 |
+
bins = torch.cat([torch.tensor([f0_min]), bins]).to(x.device)
|
98 |
+
if mode == "mean":
|
99 |
+
f0 = (x * bins).sum(-1, keepdims=True) / x.sum(-1, keepdims=True)
|
100 |
+
elif mode == "argmax":
|
101 |
+
idx = F.one_hot(x.argmax(-1), num_classes=n_bins)
|
102 |
+
f0 = (idx * bins).sum(-1, keepdims=True)
|
103 |
+
else:
|
104 |
+
raise NotImplementedError()
|
105 |
+
return f0[..., 0]
|
106 |
+
|
107 |
+
|
108 |
+
def load_wav(full_path):
|
109 |
+
sampling_rate, data = read(full_path)
|
110 |
+
return data, sampling_rate
|
111 |
+
|
112 |
+
|
113 |
+
def l1_loss(input, target):
|
114 |
+
return F.l1_loss(input=input.float(), target=target.float(), reduce=False)
|
115 |
+
|
116 |
+
|
117 |
+
def l2_loss(input, target):
|
118 |
+
return F.mse_loss(input=input.float(), target=target.float(), reduce=False)
|
119 |
+
|
120 |
+
|
121 |
+
class Collator:
|
122 |
+
def __init__(self, padding_idx):
|
123 |
+
self.padding_idx = padding_idx
|
124 |
+
|
125 |
+
def __call__(self, batch):
|
126 |
+
tokens = [item[0] for item in batch]
|
127 |
+
lengths = [len(item) for item in tokens]
|
128 |
+
tokens = torch.nn.utils.rnn.pad_sequence(
|
129 |
+
tokens, batch_first=True, padding_value=self.padding_idx
|
130 |
+
)
|
131 |
+
f0 = [item[1] for item in batch]
|
132 |
+
f0 = torch.nn.utils.rnn.pad_sequence(
|
133 |
+
f0, batch_first=True, padding_value=self.padding_idx
|
134 |
+
)
|
135 |
+
f0_raw = [item[2] for item in batch]
|
136 |
+
f0_raw = torch.nn.utils.rnn.pad_sequence(
|
137 |
+
f0_raw, batch_first=True, padding_value=self.padding_idx
|
138 |
+
)
|
139 |
+
spk = [item[3] for item in batch]
|
140 |
+
spk = torch.LongTensor(spk)
|
141 |
+
gst = [item[4] for item in batch]
|
142 |
+
gst = torch.LongTensor(gst)
|
143 |
+
mask = tokens != self.padding_idx
|
144 |
+
return tokens, f0, f0_raw, spk, gst, mask, lengths
|
145 |
+
|
146 |
+
|
147 |
+
class CnnPredictor(nn.Module):
|
148 |
+
def __init__(
|
149 |
+
self,
|
150 |
+
n_tokens,
|
151 |
+
emb_dim,
|
152 |
+
channels,
|
153 |
+
kernel,
|
154 |
+
dropout,
|
155 |
+
n_layers,
|
156 |
+
spk_emb,
|
157 |
+
gst_emb,
|
158 |
+
n_bins,
|
159 |
+
f0_pred,
|
160 |
+
f0_log,
|
161 |
+
f0_norm,
|
162 |
+
):
|
163 |
+
super(CnnPredictor, self).__init__()
|
164 |
+
self.n_tokens = n_tokens
|
165 |
+
self.emb_dim = emb_dim
|
166 |
+
self.f0_log = f0_log
|
167 |
+
self.f0_pred = f0_pred
|
168 |
+
self.padding_token = n_tokens
|
169 |
+
self.f0_norm = f0_norm
|
170 |
+
# add 1 extra embedding for padding token, set the padding index to be the last token
|
171 |
+
# (tokens from the clustering start at index 0)
|
172 |
+
self.token_emb = nn.Embedding(
|
173 |
+
n_tokens + 1, emb_dim, padding_idx=self.padding_token
|
174 |
+
)
|
175 |
+
|
176 |
+
self.spk_emb = spk_emb
|
177 |
+
self.gst_emb = nn.Embedding(20, gst_emb)
|
178 |
+
self.setup = False
|
179 |
+
|
180 |
+
feats = emb_dim + gst_emb
|
181 |
+
# feats = emb_dim + gst_emb + (256 if spk_emb else 0)
|
182 |
+
layers = [
|
183 |
+
nn.Sequential(
|
184 |
+
Rearrange("b t c -> b c t"),
|
185 |
+
nn.Conv1d(
|
186 |
+
feats, channels, kernel_size=kernel, padding=(kernel - 1) // 2
|
187 |
+
),
|
188 |
+
Rearrange("b c t -> b t c"),
|
189 |
+
nn.ReLU(),
|
190 |
+
nn.LayerNorm(channels),
|
191 |
+
nn.Dropout(dropout),
|
192 |
+
)
|
193 |
+
]
|
194 |
+
for _ in range(n_layers - 1):
|
195 |
+
layers += [
|
196 |
+
nn.Sequential(
|
197 |
+
Rearrange("b t c -> b c t"),
|
198 |
+
nn.Conv1d(
|
199 |
+
channels,
|
200 |
+
channels,
|
201 |
+
kernel_size=kernel,
|
202 |
+
padding=(kernel - 1) // 2,
|
203 |
+
),
|
204 |
+
Rearrange("b c t -> b t c"),
|
205 |
+
nn.ReLU(),
|
206 |
+
nn.LayerNorm(channels),
|
207 |
+
nn.Dropout(dropout),
|
208 |
+
)
|
209 |
+
]
|
210 |
+
self.conv_layer = nn.ModuleList(layers)
|
211 |
+
self.proj = nn.Linear(channels, n_bins)
|
212 |
+
|
213 |
+
def forward(self, x, gst=None):
|
214 |
+
x = self.token_emb(x)
|
215 |
+
feats = [x]
|
216 |
+
|
217 |
+
if gst is not None:
|
218 |
+
gst = self.gst_emb(gst)
|
219 |
+
gst = rearrange(gst, "b c -> b c 1")
|
220 |
+
gst = F.interpolate(gst, x.shape[1])
|
221 |
+
gst = rearrange(gst, "b c t -> b t c")
|
222 |
+
feats.append(gst)
|
223 |
+
|
224 |
+
x = torch.cat(feats, dim=-1)
|
225 |
+
|
226 |
+
for i, conv in enumerate(self.conv_layer):
|
227 |
+
if i != 0:
|
228 |
+
x = conv(x) + x
|
229 |
+
else:
|
230 |
+
x = conv(x)
|
231 |
+
|
232 |
+
x = self.proj(x)
|
233 |
+
x = x.squeeze(-1)
|
234 |
+
|
235 |
+
if self.f0_pred == "mean":
|
236 |
+
x = torch.sigmoid(x)
|
237 |
+
elif self.f0_pred == "argmax":
|
238 |
+
x = torch.softmax(x, dim=-1)
|
239 |
+
else:
|
240 |
+
raise NotImplementedError
|
241 |
+
return x
|
242 |
+
|
243 |
+
def setup_f0_stats(self, f0_min, f0_max, f0_bins, speaker_stats):
|
244 |
+
self.f0_min = f0_min
|
245 |
+
self.f0_max = f0_max
|
246 |
+
self.f0_bins = f0_bins
|
247 |
+
self.speaker_stats = speaker_stats
|
248 |
+
self.setup = True
|
249 |
+
|
250 |
+
def inference(self, x, spk_id=None, gst=None):
|
251 |
+
assert (
|
252 |
+
self.setup == True
|
253 |
+
), "make sure that `setup_f0_stats` was called before inference!"
|
254 |
+
probs = self(x, gst)
|
255 |
+
f0 = bin2freq(probs, self.f0_min, self.f0_max, self.f0_bins, self.f0_pred)
|
256 |
+
for i in range(f0.shape[0]):
|
257 |
+
mean = (
|
258 |
+
self.speaker_stats[spk_id[i].item()].mean_log
|
259 |
+
if self.f0_log
|
260 |
+
else self.speaker_stats[spk_id[i].item()].mean
|
261 |
+
)
|
262 |
+
std = (
|
263 |
+
self.speaker_stats[spk_id[i].item()].std_log
|
264 |
+
if self.f0_log
|
265 |
+
else self.speaker_stats[spk_id[i].item()].std
|
266 |
+
)
|
267 |
+
if self.f0_norm == "mean":
|
268 |
+
f0[i] = f0[i] + mean
|
269 |
+
if self.f0_norm == "meanstd":
|
270 |
+
f0[i] = (f0[i] * std) + mean
|
271 |
+
if self.f0_log:
|
272 |
+
f0 = f0.exp()
|
273 |
+
return f0
|
274 |
+
|
275 |
+
|
276 |
+
class PitchDataset(Dataset):
|
277 |
+
def __init__(
|
278 |
+
self,
|
279 |
+
tsv_path,
|
280 |
+
km_path,
|
281 |
+
substring,
|
282 |
+
spk,
|
283 |
+
spk2id,
|
284 |
+
gst,
|
285 |
+
gst2id,
|
286 |
+
f0_bins,
|
287 |
+
f0_bin_type,
|
288 |
+
f0_smoothing,
|
289 |
+
f0_norm,
|
290 |
+
f0_log,
|
291 |
+
):
|
292 |
+
lines = open(tsv_path, "r").readlines()
|
293 |
+
self.root, self.tsv = lines[0], lines[1:]
|
294 |
+
self.root = self.root.strip()
|
295 |
+
self.km = open(km_path, "r").readlines()
|
296 |
+
print(f"loaded {len(self.km)} files")
|
297 |
+
|
298 |
+
self.spk = spk
|
299 |
+
self.spk2id = spk2id
|
300 |
+
self.gst = gst
|
301 |
+
self.gst2id = gst2id
|
302 |
+
|
303 |
+
self.f0_bins = f0_bins
|
304 |
+
self.f0_smoothing = f0_smoothing
|
305 |
+
self.f0_norm = f0_norm
|
306 |
+
self.f0_log = f0_log
|
307 |
+
|
308 |
+
if substring != "":
|
309 |
+
tsv, km = [], []
|
310 |
+
for tsv_line, km_line in zip(self.tsv, self.km):
|
311 |
+
if substring.lower() in tsv_line.lower():
|
312 |
+
tsv.append(tsv_line)
|
313 |
+
km.append(km_line)
|
314 |
+
self.tsv, self.km = tsv, km
|
315 |
+
print(f"after filtering: {len(self.km)} files")
|
316 |
+
|
317 |
+
self.speaker_stats = self._compute_f0_stats()
|
318 |
+
self.f0_min, self.f0_max = self._compute_f0_minmax()
|
319 |
+
if f0_bin_type == "adaptive":
|
320 |
+
self.f0_bins = quantize_f0(
|
321 |
+
self.speaker_stats, self.f0_bins, self.f0_norm, self.f0_log
|
322 |
+
)
|
323 |
+
elif f0_bin_type == "uniform":
|
324 |
+
self.f0_bins = torch.linspace(self.f0_min, self.f0_max, self.f0_bins + 1)[
|
325 |
+
1:-1
|
326 |
+
]
|
327 |
+
else:
|
328 |
+
raise NotImplementedError
|
329 |
+
print(f"f0 min: {self.f0_min}, f0 max: {self.f0_max}")
|
330 |
+
print(f"bins: {self.f0_bins} (shape: {self.f0_bins.shape})")
|
331 |
+
|
332 |
+
def __len__(self):
|
333 |
+
return len(self.km)
|
334 |
+
|
335 |
+
def _load_f0(self, tsv_line):
|
336 |
+
tsv_line = tsv_line.split("\t")[0]
|
337 |
+
f0 = self.root + "/" + tsv_line.replace(".wav", ".yaapt.f0.npy")
|
338 |
+
f0 = np.load(f0)
|
339 |
+
f0 = torch.FloatTensor(f0)
|
340 |
+
return f0
|
341 |
+
|
342 |
+
def _preprocess_f0(self, f0, spk):
|
343 |
+
mask = f0 != -999999 # process all frames
|
344 |
+
# mask = (f0 != 0) # only process voiced frames
|
345 |
+
mean = (
|
346 |
+
self.speaker_stats[spk].mean_log
|
347 |
+
if self.f0_log
|
348 |
+
else self.speaker_stats[spk].mean
|
349 |
+
)
|
350 |
+
std = (
|
351 |
+
self.speaker_stats[spk].std_log
|
352 |
+
if self.f0_log
|
353 |
+
else self.speaker_stats[spk].std
|
354 |
+
)
|
355 |
+
if self.f0_log:
|
356 |
+
f0[f0 == 0] = 1e-5
|
357 |
+
f0[mask] = f0[mask].log()
|
358 |
+
if self.f0_norm == "mean":
|
359 |
+
f0[mask] = f0[mask] - mean
|
360 |
+
if self.f0_norm == "meanstd":
|
361 |
+
f0[mask] = (f0[mask] - mean) / std
|
362 |
+
return f0
|
363 |
+
|
364 |
+
def _compute_f0_minmax(self):
|
365 |
+
f0_min, f0_max = float("inf"), -float("inf")
|
366 |
+
for tsv_line in tqdm(self.tsv, desc="computing f0 minmax"):
|
367 |
+
spk = self.spk2id[parse_speaker(tsv_line, self.spk)]
|
368 |
+
f0 = self._load_f0(tsv_line)
|
369 |
+
f0 = self._preprocess_f0(f0, spk)
|
370 |
+
f0_min = min(f0_min, f0.min().item())
|
371 |
+
f0_max = max(f0_max, f0.max().item())
|
372 |
+
return f0_min, f0_max
|
373 |
+
|
374 |
+
def _compute_f0_stats(self):
|
375 |
+
from functools import partial
|
376 |
+
|
377 |
+
speaker_stats = defaultdict(partial(F0Stat, True))
|
378 |
+
for tsv_line in tqdm(self.tsv, desc="computing speaker stats"):
|
379 |
+
spk = self.spk2id[parse_speaker(tsv_line, self.spk)]
|
380 |
+
f0 = self._load_f0(tsv_line)
|
381 |
+
mask = f0 != 0
|
382 |
+
f0 = f0[mask] # compute stats only on voiced parts
|
383 |
+
speaker_stats[spk].update(f0)
|
384 |
+
return speaker_stats
|
385 |
+
|
386 |
+
def __getitem__(self, i):
|
387 |
+
x = self.km[i]
|
388 |
+
x = x.split(" ")
|
389 |
+
x = list(map(int, x))
|
390 |
+
x = torch.LongTensor(x)
|
391 |
+
|
392 |
+
gst = parse_style(self.tsv[i], self.gst)
|
393 |
+
gst = self.gst2id[gst]
|
394 |
+
spk = parse_speaker(self.tsv[i], self.spk)
|
395 |
+
spk = self.spk2id[spk]
|
396 |
+
|
397 |
+
f0_raw = self._load_f0(self.tsv[i])
|
398 |
+
f0 = self._preprocess_f0(f0_raw.clone(), spk)
|
399 |
+
|
400 |
+
f0 = F.interpolate(f0.unsqueeze(0).unsqueeze(0), x.shape[0])[0, 0]
|
401 |
+
f0_raw = F.interpolate(f0_raw.unsqueeze(0).unsqueeze(0), x.shape[0])[0, 0]
|
402 |
+
|
403 |
+
f0 = freq2bin(f0, f0_min=self.f0_min, f0_max=self.f0_max, bins=self.f0_bins)
|
404 |
+
f0 = F.one_hot(f0.long(), num_classes=len(self.f0_bins) + 1).float()
|
405 |
+
if self.f0_smoothing > 0:
|
406 |
+
f0 = torch.tensor(
|
407 |
+
gaussian_filter1d(f0.float().numpy(), sigma=self.f0_smoothing)
|
408 |
+
)
|
409 |
+
return x, f0, f0_raw, spk, gst
|
410 |
+
|
411 |
+
|
412 |
+
def train(cfg):
|
413 |
+
device = "cuda:0"
|
414 |
+
# add 1 extra embedding for padding token, set the padding index to be the last token
|
415 |
+
# (tokens from the clustering start at index 0)
|
416 |
+
padding_token = cfg.n_tokens
|
417 |
+
collate_fn = Collator(padding_idx=padding_token)
|
418 |
+
train_ds = PitchDataset(
|
419 |
+
cfg.train_tsv,
|
420 |
+
cfg.train_km,
|
421 |
+
substring=cfg.substring,
|
422 |
+
spk=cfg.spk,
|
423 |
+
spk2id=cfg.spk2id,
|
424 |
+
gst=cfg.gst,
|
425 |
+
gst2id=cfg.gst2id,
|
426 |
+
f0_bins=cfg.f0_bins,
|
427 |
+
f0_bin_type=cfg.f0_bin_type,
|
428 |
+
f0_smoothing=cfg.f0_smoothing,
|
429 |
+
f0_norm=cfg.f0_norm,
|
430 |
+
f0_log=cfg.f0_log,
|
431 |
+
)
|
432 |
+
valid_ds = PitchDataset(
|
433 |
+
cfg.valid_tsv,
|
434 |
+
cfg.valid_km,
|
435 |
+
substring=cfg.substring,
|
436 |
+
spk=cfg.spk,
|
437 |
+
spk2id=cfg.spk2id,
|
438 |
+
gst=cfg.gst,
|
439 |
+
gst2id=cfg.gst2id,
|
440 |
+
f0_bins=cfg.f0_bins,
|
441 |
+
f0_bin_type=cfg.f0_bin_type,
|
442 |
+
f0_smoothing=cfg.f0_smoothing,
|
443 |
+
f0_norm=cfg.f0_norm,
|
444 |
+
f0_log=cfg.f0_log,
|
445 |
+
)
|
446 |
+
train_dl = DataLoader(
|
447 |
+
train_ds,
|
448 |
+
num_workers=0,
|
449 |
+
batch_size=cfg.batch_size,
|
450 |
+
shuffle=True,
|
451 |
+
collate_fn=collate_fn,
|
452 |
+
)
|
453 |
+
valid_dl = DataLoader(
|
454 |
+
valid_ds, num_workers=0, batch_size=16, shuffle=False, collate_fn=collate_fn
|
455 |
+
)
|
456 |
+
|
457 |
+
f0_min = train_ds.f0_min
|
458 |
+
f0_max = train_ds.f0_max
|
459 |
+
f0_bins = train_ds.f0_bins
|
460 |
+
speaker_stats = train_ds.speaker_stats
|
461 |
+
|
462 |
+
model = hydra.utils.instantiate(cfg["model"]).to(device)
|
463 |
+
model.setup_f0_stats(f0_min, f0_max, f0_bins, speaker_stats)
|
464 |
+
|
465 |
+
optimizer = hydra.utils.instantiate(cfg.optimizer, model.parameters())
|
466 |
+
|
467 |
+
best_loss = float("inf")
|
468 |
+
for epoch in range(cfg.epochs):
|
469 |
+
train_loss, train_l2_loss, train_l2_voiced_loss = run_epoch(
|
470 |
+
model, train_dl, optimizer, device, cfg, mode="train"
|
471 |
+
)
|
472 |
+
valid_loss, valid_l2_loss, valid_l2_voiced_loss = run_epoch(
|
473 |
+
model, valid_dl, None, device, cfg, mode="valid"
|
474 |
+
)
|
475 |
+
print(
|
476 |
+
f"[epoch {epoch}] train loss: {train_loss:.3f}, l2 loss: {train_l2_loss:.3f}, l2 voiced loss: {train_l2_voiced_loss:.3f}"
|
477 |
+
)
|
478 |
+
print(
|
479 |
+
f"[epoch {epoch}] valid loss: {valid_loss:.3f}, l2 loss: {valid_l2_loss:.3f}, l2 voiced loss: {valid_l2_voiced_loss:.3f}"
|
480 |
+
)
|
481 |
+
if valid_l2_voiced_loss < best_loss:
|
482 |
+
path = f"{os.getcwd()}/pitch_predictor.ckpt"
|
483 |
+
save_ckpt(model, path, cfg["model"], f0_min, f0_max, f0_bins, speaker_stats)
|
484 |
+
best_loss = valid_l2_voiced_loss
|
485 |
+
print(f"saved checkpoint: {path}")
|
486 |
+
print(f"[epoch {epoch}] best loss: {best_loss:.3f}")
|
487 |
+
|
488 |
+
|
489 |
+
def run_epoch(model, loader, optimizer, device, cfg, mode):
|
490 |
+
if mode == "train":
|
491 |
+
model.train()
|
492 |
+
else:
|
493 |
+
model.eval()
|
494 |
+
|
495 |
+
epoch_loss = 0
|
496 |
+
l1 = 0
|
497 |
+
l1_voiced = 0
|
498 |
+
for x, f0_bin, f0_raw, spk_id, gst, mask, _ in tqdm(loader):
|
499 |
+
x, f0_bin, f0_raw, spk_id, gst, mask = (
|
500 |
+
x.to(device),
|
501 |
+
f0_bin.to(device),
|
502 |
+
f0_raw.to(device),
|
503 |
+
spk_id.to(device),
|
504 |
+
gst.to(device),
|
505 |
+
mask.to(device),
|
506 |
+
)
|
507 |
+
b, t, n_bins = f0_bin.shape
|
508 |
+
yhat = model(x, gst)
|
509 |
+
nonzero_mask = (f0_raw != 0).logical_and(mask)
|
510 |
+
yhat_raw = model.inference(x, spk_id, gst)
|
511 |
+
expanded_mask = mask.unsqueeze(-1).expand(-1, -1, n_bins)
|
512 |
+
if cfg.f0_pred == "mean":
|
513 |
+
loss = F.binary_cross_entropy(
|
514 |
+
yhat[expanded_mask], f0_bin[expanded_mask]
|
515 |
+
).mean()
|
516 |
+
elif cfg.f0_pred == "argmax":
|
517 |
+
loss = F.cross_entropy(
|
518 |
+
rearrange(yhat, "b t d -> (b t) d"),
|
519 |
+
rearrange(f0_bin.argmax(-1), "b t -> (b t)"),
|
520 |
+
reduce=False,
|
521 |
+
)
|
522 |
+
loss = rearrange(loss, "(b t) -> b t", b=b, t=t)
|
523 |
+
loss = (loss * mask).sum() / mask.float().sum()
|
524 |
+
else:
|
525 |
+
raise NotImplementedError
|
526 |
+
l1 += F.l1_loss(yhat_raw[mask], f0_raw[mask]).item()
|
527 |
+
l1_voiced += F.l1_loss(yhat_raw[nonzero_mask], f0_raw[nonzero_mask]).item()
|
528 |
+
epoch_loss += loss.item()
|
529 |
+
|
530 |
+
if mode == "train":
|
531 |
+
loss.backward()
|
532 |
+
nn.utils.clip_grad_norm_(model.parameters(), 1.0)
|
533 |
+
optimizer.step()
|
534 |
+
|
535 |
+
print(f"{mode} example y: {f0_bin.argmax(-1)[0, 50:60].tolist()}")
|
536 |
+
print(f"{mode} example yhat: {yhat.argmax(-1)[0, 50:60].tolist()}")
|
537 |
+
print(f"{mode} example y: {f0_raw[0, 50:60].round().tolist()}")
|
538 |
+
print(f"{mode} example yhat: {yhat_raw[0, 50:60].round().tolist()}")
|
539 |
+
return epoch_loss / len(loader), l1 / len(loader), l1_voiced / len(loader)
|
540 |
+
|
541 |
+
|
542 |
+
@hydra.main(config_path=dir_path, config_name="pitch_predictor.yaml")
|
543 |
+
def main(cfg):
|
544 |
+
np.random.seed(1)
|
545 |
+
random.seed(1)
|
546 |
+
torch.manual_seed(1)
|
547 |
+
from hydra.core.hydra_config import HydraConfig
|
548 |
+
|
549 |
+
overrides = {
|
550 |
+
x.split("=")[0]: x.split("=")[1]
|
551 |
+
for x in HydraConfig.get().overrides.task
|
552 |
+
if "/" not in x
|
553 |
+
}
|
554 |
+
print(f"{cfg}")
|
555 |
+
train(cfg)
|
556 |
+
|
557 |
+
|
558 |
+
if __name__ == "__main__":
|
559 |
+
main()
|
fairseq/examples/emotion_conversion/emotion_models/utils.py
ADDED
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
|
4 |
+
class Stat:
|
5 |
+
def __init__(self, keep_raw=False):
|
6 |
+
self.x = 0.0
|
7 |
+
self.x2 = 0.0
|
8 |
+
self.z = 0.0 # z = logx
|
9 |
+
self.z2 = 0.0
|
10 |
+
self.n = 0.0
|
11 |
+
self.u = 0.0
|
12 |
+
self.keep_raw = keep_raw
|
13 |
+
self.raw = []
|
14 |
+
|
15 |
+
def update(self, new_x):
|
16 |
+
new_z = new_x.log()
|
17 |
+
|
18 |
+
self.x += new_x.sum()
|
19 |
+
self.x2 += (new_x**2).sum()
|
20 |
+
self.z += new_z.sum()
|
21 |
+
self.z2 += (new_z**2).sum()
|
22 |
+
self.n += len(new_x)
|
23 |
+
self.u += 1
|
24 |
+
|
25 |
+
if self.keep_raw:
|
26 |
+
self.raw.append(new_x)
|
27 |
+
|
28 |
+
@property
|
29 |
+
def mean(self):
|
30 |
+
return self.x / self.n
|
31 |
+
|
32 |
+
@property
|
33 |
+
def std(self):
|
34 |
+
return (self.x2 / self.n - self.mean**2) ** 0.5
|
35 |
+
|
36 |
+
@property
|
37 |
+
def mean_log(self):
|
38 |
+
return self.z / self.n
|
39 |
+
|
40 |
+
@property
|
41 |
+
def std_log(self):
|
42 |
+
return (self.z2 / self.n - self.mean_log**2) ** 0.5
|
43 |
+
|
44 |
+
@property
|
45 |
+
def n_frms(self):
|
46 |
+
return self.n
|
47 |
+
|
48 |
+
@property
|
49 |
+
def n_utts(self):
|
50 |
+
return self.u
|
51 |
+
|
52 |
+
@property
|
53 |
+
def raw_data(self):
|
54 |
+
assert self.keep_raw, "does not support storing raw data!"
|
55 |
+
return torch.cat(self.raw)
|
56 |
+
|
57 |
+
|
58 |
+
class F0Stat(Stat):
|
59 |
+
def update(self, new_x):
|
60 |
+
# assume unvoiced frames are 0 and consider only voiced frames
|
61 |
+
if new_x is not None:
|
62 |
+
super().update(new_x[new_x != 0])
|
63 |
+
|
64 |
+
|
65 |
+
class Accuracy:
|
66 |
+
def __init__(self):
|
67 |
+
self.y, self.yhat = [], []
|
68 |
+
|
69 |
+
def update(self, yhat, y):
|
70 |
+
self.yhat.append(yhat)
|
71 |
+
self.y.append(y)
|
72 |
+
|
73 |
+
def acc(self, tol):
|
74 |
+
yhat = torch.cat(self.yhat)
|
75 |
+
y = torch.cat(self.y)
|
76 |
+
acc = torch.abs(yhat - y) <= tol
|
77 |
+
acc = acc.float().mean().item()
|
78 |
+
return acc
|
fairseq/examples/emotion_conversion/fairseq_models/__init__.py
ADDED
@@ -0,0 +1,226 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
from fairseq import utils
|
7 |
+
from fairseq.models import (
|
8 |
+
FairseqMultiModel,
|
9 |
+
register_model,
|
10 |
+
register_model_architecture,
|
11 |
+
)
|
12 |
+
from fairseq.models.transformer import (
|
13 |
+
Embedding,
|
14 |
+
base_architecture,
|
15 |
+
)
|
16 |
+
from fairseq.models.multilingual_transformer import (
|
17 |
+
MultilingualTransformerModel,
|
18 |
+
base_multilingual_architecture,
|
19 |
+
)
|
20 |
+
from fairseq.utils import safe_hasattr
|
21 |
+
from collections import OrderedDict
|
22 |
+
|
23 |
+
|
24 |
+
@register_model("multilingual_transformer_from_mbart")
|
25 |
+
class MultilingualTransformerModelFromMbart(MultilingualTransformerModel):
|
26 |
+
@classmethod
|
27 |
+
def build_model(cls, args, task):
|
28 |
+
"""Build a new model instance."""
|
29 |
+
from fairseq.tasks.multilingual_translation import MultilingualTranslationTask
|
30 |
+
|
31 |
+
assert isinstance(task, MultilingualTranslationTask)
|
32 |
+
|
33 |
+
# make sure all arguments are present in older models
|
34 |
+
base_multilingual_architecture(args)
|
35 |
+
|
36 |
+
if not safe_hasattr(args, "max_source_positions"):
|
37 |
+
args.max_source_positions = 1024
|
38 |
+
if not safe_hasattr(args, "max_target_positions"):
|
39 |
+
args.max_target_positions = 1024
|
40 |
+
|
41 |
+
src_langs = [lang_pair.split("-")[0] for lang_pair in task.model_lang_pairs]
|
42 |
+
tgt_langs = [lang_pair.split("-")[1] for lang_pair in task.model_lang_pairs]
|
43 |
+
|
44 |
+
if args.share_encoders:
|
45 |
+
args.share_encoder_embeddings = True
|
46 |
+
if args.share_decoders:
|
47 |
+
args.share_decoder_embeddings = True
|
48 |
+
|
49 |
+
def build_embedding(dictionary, embed_dim, path=None):
|
50 |
+
num_embeddings = len(dictionary)
|
51 |
+
padding_idx = dictionary.pad()
|
52 |
+
emb = Embedding(num_embeddings, embed_dim, padding_idx)
|
53 |
+
# if provided, load from preloaded dictionaries
|
54 |
+
if path:
|
55 |
+
embed_dict = utils.parse_embedding(path)
|
56 |
+
utils.load_embedding(embed_dict, dictionary, emb)
|
57 |
+
return emb
|
58 |
+
|
59 |
+
# build shared embeddings (if applicable)
|
60 |
+
shared_encoder_embed_tokens, shared_decoder_embed_tokens = None, None
|
61 |
+
if args.share_all_embeddings:
|
62 |
+
if args.encoder_embed_dim != args.decoder_embed_dim:
|
63 |
+
raise ValueError(
|
64 |
+
"--share-all-embeddings requires --encoder-embed-dim to match --decoder-embed-dim"
|
65 |
+
)
|
66 |
+
if args.decoder_embed_path and (
|
67 |
+
args.decoder_embed_path != args.encoder_embed_path
|
68 |
+
):
|
69 |
+
raise ValueError(
|
70 |
+
"--share-all-embeddings not compatible with --decoder-embed-path"
|
71 |
+
)
|
72 |
+
shared_encoder_embed_tokens = FairseqMultiModel.build_shared_embeddings(
|
73 |
+
dicts=task.dicts,
|
74 |
+
langs=task.langs,
|
75 |
+
embed_dim=args.encoder_embed_dim,
|
76 |
+
build_embedding=build_embedding,
|
77 |
+
pretrained_embed_path=args.encoder_embed_path,
|
78 |
+
)
|
79 |
+
shared_decoder_embed_tokens = shared_encoder_embed_tokens
|
80 |
+
args.share_decoder_input_output_embed = True
|
81 |
+
else:
|
82 |
+
if args.share_encoder_embeddings:
|
83 |
+
shared_encoder_embed_tokens = FairseqMultiModel.build_shared_embeddings(
|
84 |
+
dicts=task.dicts,
|
85 |
+
langs=src_langs,
|
86 |
+
embed_dim=args.encoder_embed_dim,
|
87 |
+
build_embedding=build_embedding,
|
88 |
+
pretrained_embed_path=args.encoder_embed_path,
|
89 |
+
)
|
90 |
+
if args.share_decoder_embeddings:
|
91 |
+
shared_decoder_embed_tokens = FairseqMultiModel.build_shared_embeddings(
|
92 |
+
dicts=task.dicts,
|
93 |
+
langs=tgt_langs,
|
94 |
+
embed_dim=args.decoder_embed_dim,
|
95 |
+
build_embedding=build_embedding,
|
96 |
+
pretrained_embed_path=args.decoder_embed_path,
|
97 |
+
)
|
98 |
+
|
99 |
+
# encoders/decoders for each language
|
100 |
+
lang_encoders, lang_decoders = {}, {}
|
101 |
+
|
102 |
+
def get_encoder(lang):
|
103 |
+
if lang not in lang_encoders:
|
104 |
+
if shared_encoder_embed_tokens is not None:
|
105 |
+
encoder_embed_tokens = shared_encoder_embed_tokens
|
106 |
+
else:
|
107 |
+
encoder_embed_tokens = build_embedding(
|
108 |
+
task.dicts[lang],
|
109 |
+
args.encoder_embed_dim,
|
110 |
+
args.encoder_embed_path,
|
111 |
+
)
|
112 |
+
lang_encoders[lang] = MultilingualTransformerModel._get_module_class(
|
113 |
+
True, args, task.dicts[lang], encoder_embed_tokens, src_langs
|
114 |
+
)
|
115 |
+
return lang_encoders[lang]
|
116 |
+
|
117 |
+
def get_decoder(lang):
|
118 |
+
if lang not in lang_decoders:
|
119 |
+
if shared_decoder_embed_tokens is not None:
|
120 |
+
decoder_embed_tokens = shared_decoder_embed_tokens
|
121 |
+
else:
|
122 |
+
decoder_embed_tokens = build_embedding(
|
123 |
+
task.dicts[lang],
|
124 |
+
args.decoder_embed_dim,
|
125 |
+
args.decoder_embed_path,
|
126 |
+
)
|
127 |
+
lang_decoders[lang] = MultilingualTransformerModel._get_module_class(
|
128 |
+
False, args, task.dicts[lang], decoder_embed_tokens, tgt_langs
|
129 |
+
)
|
130 |
+
return lang_decoders[lang]
|
131 |
+
|
132 |
+
# shared encoders/decoders (if applicable)
|
133 |
+
shared_encoder, shared_decoder = None, None
|
134 |
+
if args.share_encoders:
|
135 |
+
shared_encoder = get_encoder(src_langs[0])
|
136 |
+
if args.share_decoders:
|
137 |
+
shared_decoder = get_decoder(tgt_langs[0])
|
138 |
+
|
139 |
+
encoders, decoders = OrderedDict(), OrderedDict()
|
140 |
+
for lang_pair, src, tgt in zip(task.model_lang_pairs, src_langs, tgt_langs):
|
141 |
+
encoders[lang_pair] = (
|
142 |
+
shared_encoder if shared_encoder is not None else get_encoder(src)
|
143 |
+
)
|
144 |
+
decoders[lang_pair] = (
|
145 |
+
shared_decoder if shared_decoder is not None else get_decoder(tgt)
|
146 |
+
)
|
147 |
+
|
148 |
+
return MultilingualTransformerModelFromMbart(encoders, decoders)
|
149 |
+
|
150 |
+
def load_state_dict(self, state_dict, strict=True, model_cfg=None):
|
151 |
+
state_dict_subset = state_dict.copy()
|
152 |
+
lang_pairs = set([x.split(".")[1] for x in state_dict.keys()])
|
153 |
+
finetune_mode = not any("neutral" in lp for lp in lang_pairs)
|
154 |
+
|
155 |
+
if finetune_mode:
|
156 |
+
# load a pre-trained mBART/BART model
|
157 |
+
# we need this code because mBART/BART are not of type FairseqMultiModel but FairseqModel
|
158 |
+
# so we hackishly load the weights by replicating them for all lang pairs
|
159 |
+
print("loading pre-trained BART")
|
160 |
+
self_state_dict = self.state_dict()
|
161 |
+
for k, v in state_dict.items():
|
162 |
+
for lang_pair in self.models:
|
163 |
+
new_key = k if "models." in k else f"models.{lang_pair}.{k}"
|
164 |
+
# print(new_key)
|
165 |
+
if self_state_dict[new_key].shape == v.shape:
|
166 |
+
state_dict_subset[new_key] = v
|
167 |
+
elif any(
|
168 |
+
w in k
|
169 |
+
for w in [
|
170 |
+
"encoder.embed_tokens.weight",
|
171 |
+
"decoder.embed_tokens.weight",
|
172 |
+
"decoder.output_projection.weight",
|
173 |
+
]
|
174 |
+
):
|
175 |
+
# why vocab_size - 5? because there are `vocab_size` tokens from the language
|
176 |
+
# and 5 additional tokens in the denoising task: eos,bos,pad,unk,mask.
|
177 |
+
# but in the translation task there are only `vocab_size` + 4 (no mask).
|
178 |
+
print(
|
179 |
+
f"{k}: {self_state_dict[new_key].shape} != {v.shape}",
|
180 |
+
end="",
|
181 |
+
flush=True,
|
182 |
+
)
|
183 |
+
vocab_size = v.shape[0] - 5
|
184 |
+
state_dict_subset[new_key] = self_state_dict[new_key]
|
185 |
+
state_dict_subset[new_key] = v[: vocab_size + 4]
|
186 |
+
print(f" => fixed by using first {vocab_size + 4} dims")
|
187 |
+
else:
|
188 |
+
raise ValueError("unable to load model due to mimatched dims!")
|
189 |
+
del state_dict_subset[k]
|
190 |
+
else:
|
191 |
+
print("loading pre-trained emotion translation model")
|
192 |
+
for k, _ in state_dict.items():
|
193 |
+
assert k.startswith("models.")
|
194 |
+
lang_pair = k.split(".")[1]
|
195 |
+
if lang_pair not in self.models:
|
196 |
+
del state_dict_subset[k]
|
197 |
+
|
198 |
+
super().load_state_dict(state_dict_subset, strict=strict, model_cfg=model_cfg)
|
199 |
+
|
200 |
+
|
201 |
+
@register_model_architecture("transformer", "transformer_small")
|
202 |
+
def transformer_small(args):
|
203 |
+
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512)
|
204 |
+
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 512)
|
205 |
+
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 4)
|
206 |
+
args.encoder_layers = getattr(args, "encoder_layers", 3)
|
207 |
+
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 512)
|
208 |
+
args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 512)
|
209 |
+
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 4)
|
210 |
+
args.decoder_layers = getattr(args, "decoder_layers", 3)
|
211 |
+
base_architecture(args)
|
212 |
+
|
213 |
+
|
214 |
+
@register_model_architecture(
|
215 |
+
"multilingual_transformer_from_mbart", "multilingual_small"
|
216 |
+
)
|
217 |
+
def multilingual_small(args):
|
218 |
+
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512)
|
219 |
+
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 512)
|
220 |
+
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 4)
|
221 |
+
args.encoder_layers = getattr(args, "encoder_layers", 3)
|
222 |
+
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 512)
|
223 |
+
args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 512)
|
224 |
+
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 4)
|
225 |
+
args.decoder_layers = getattr(args, "decoder_layers", 3)
|
226 |
+
base_multilingual_architecture(args)
|
fairseq/examples/emotion_conversion/preprocess/__init__.py
ADDED
File without changes
|
fairseq/examples/emotion_conversion/preprocess/build_hifigan_manifest.py
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torchaudio
|
2 |
+
import argparse
|
3 |
+
import json
|
4 |
+
|
5 |
+
def main():
|
6 |
+
parser = argparse.ArgumentParser(description="example: python create_hifigan_manifest.py --tsv /checkpoint/felixkreuk/datasets/vctk/splits/vctk_16khz/train.tsv --km /checkpoint/felixkreuk/experiments/hubert/hubert_feats/vctk_16khz_km_100/train.km --km_type hubert_100km > ~/tmp/tmp_mani.txt")
|
7 |
+
parser.add_argument("--tsv", required=True, help="path to fairseq tsv file")
|
8 |
+
parser.add_argument("--km", required=True, help="path to a km file generated by HuBERT clustering")
|
9 |
+
parser.add_argument("--km_type", required=True, help="name of the codes in the output json (for example: 'cpc_100km')")
|
10 |
+
args = parser.parse_args()
|
11 |
+
|
12 |
+
km_lines = open(args.km, "r").readlines()
|
13 |
+
tsv_lines = open(args.tsv, "r").readlines()
|
14 |
+
assert len(km_lines) == len(tsv_lines) - 1, "tsv and km files are not of the same length!"
|
15 |
+
|
16 |
+
wav_root = tsv_lines[0].strip()
|
17 |
+
tsv_lines = tsv_lines[1:]
|
18 |
+
|
19 |
+
for tsv_line, km_line in zip(tsv_lines, km_lines):
|
20 |
+
tsv_line, km_line = tsv_line.strip(), km_line.strip()
|
21 |
+
wav_basename, wav_num_frames = tsv_line.split("\t")
|
22 |
+
wav_path = wav_root + "/" + wav_basename
|
23 |
+
wav_info = torchaudio.info(wav_path)
|
24 |
+
assert int(wav_num_frames) == wav_info.num_frames, "tsv duration and actual duration don't match!"
|
25 |
+
wav_duration = wav_info.num_frames / wav_info.sample_rate
|
26 |
+
manifest_line = {"audio": wav_path, "duration": wav_duration, args.km_type: km_line}
|
27 |
+
print(json.dumps(manifest_line))
|
28 |
+
|
29 |
+
if __name__ == "__main__":
|
30 |
+
"""
|
31 |
+
usage:
|
32 |
+
python create_hifigan_manifest.py \
|
33 |
+
--tsv /checkpoint/felixkreuk/datasets/vctk/manifests/vctk_16khz/valid.tsv \
|
34 |
+
--km /checkpoint/felixkreuk/datasets/vctk/manifests/vctk_16khz/hubert_km_100/valid.km \
|
35 |
+
--km_type hubert \
|
36 |
+
> /checkpoint/felixkreuk/datasets/vctk/manifests/vctk_16khz/hubert_km_100/hifigan_valid_manifest.txt
|
37 |
+
"""
|
38 |
+
main()
|
fairseq/examples/emotion_conversion/preprocess/build_translation_manifests.py
ADDED
@@ -0,0 +1,258 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from glob import glob
|
2 |
+
import argparse
|
3 |
+
from collections import defaultdict, Counter
|
4 |
+
from itertools import combinations, product, groupby
|
5 |
+
from pathlib import Path
|
6 |
+
import os
|
7 |
+
from sklearn.utils import shuffle
|
8 |
+
import numpy as np
|
9 |
+
import random
|
10 |
+
from shutil import copy
|
11 |
+
from subprocess import check_call
|
12 |
+
|
13 |
+
np.random.seed(42)
|
14 |
+
random.seed(42)
|
15 |
+
|
16 |
+
|
17 |
+
def get_fname(s):
|
18 |
+
return s.split("\t")[0]
|
19 |
+
|
20 |
+
def get_emotion(s):
|
21 |
+
return get_fname(s).split("_")[0].split("/")[1].lower()
|
22 |
+
|
23 |
+
def get_utt_id(s):
|
24 |
+
return get_fname(s).split(".")[0].split("_")[-1]
|
25 |
+
|
26 |
+
def dedup(seq):
|
27 |
+
""" >> remove_repetitions("1 2 2 3 100 2 2 1")
|
28 |
+
'1 2 3 100 2 1' """
|
29 |
+
seq = seq.strip().split(" ")
|
30 |
+
result = seq[:1]
|
31 |
+
reps = []
|
32 |
+
rep_counter = 1
|
33 |
+
for k in seq[1:]:
|
34 |
+
if k != result[-1]:
|
35 |
+
result += [k]
|
36 |
+
reps += [rep_counter]
|
37 |
+
rep_counter = 1
|
38 |
+
else:
|
39 |
+
rep_counter += 1
|
40 |
+
reps += [rep_counter]
|
41 |
+
assert len(reps) == len(result) and sum(reps) == len(seq)
|
42 |
+
return " ".join(result) + "\n" #, reps
|
43 |
+
|
44 |
+
def remove_under_k(seq, k):
|
45 |
+
""" remove tokens that repeat less then k times in a row
|
46 |
+
>> remove_under_k("a a a a b c c c", 1) ==> a a a a c c c """
|
47 |
+
seq = seq.strip().split(" ")
|
48 |
+
result = []
|
49 |
+
|
50 |
+
freqs = [(k,len(list(g))) for k, g in groupby(seq)]
|
51 |
+
for c, f in freqs:
|
52 |
+
if f > k:
|
53 |
+
result += [c for _ in range(f)]
|
54 |
+
return " ".join(result) + "\n" #, reps
|
55 |
+
|
56 |
+
|
57 |
+
def call(cmd):
|
58 |
+
print(cmd)
|
59 |
+
check_call(cmd, shell=True)
|
60 |
+
|
61 |
+
|
62 |
+
def denoising_preprocess(path, lang, dict):
|
63 |
+
bin = 'fairseq-preprocess'
|
64 |
+
cmd = [
|
65 |
+
bin,
|
66 |
+
f'--trainpref {path}/train.{lang} --validpref {path}/valid.{lang} --testpref {path}/test.{lang}',
|
67 |
+
f'--destdir {path}/tokenized/{lang}',
|
68 |
+
'--only-source',
|
69 |
+
'--task multilingual_denoising',
|
70 |
+
'--workers 40',
|
71 |
+
]
|
72 |
+
if dict != "":
|
73 |
+
cmd += [f'--srcdict {dict}']
|
74 |
+
cmd = " ".join(cmd)
|
75 |
+
call(cmd)
|
76 |
+
|
77 |
+
|
78 |
+
def translation_preprocess(path, src_lang, trg_lang, dict, only_train=False):
|
79 |
+
bin = 'fairseq-preprocess'
|
80 |
+
cmd = [
|
81 |
+
bin,
|
82 |
+
f'--source-lang {src_lang} --target-lang {trg_lang}',
|
83 |
+
f'--trainpref {path}/train',
|
84 |
+
f'--destdir {path}/tokenized',
|
85 |
+
'--workers 40',
|
86 |
+
]
|
87 |
+
if not only_train:
|
88 |
+
cmd += [f'--validpref {path}/valid --testpref {path}/test']
|
89 |
+
if dict != "":
|
90 |
+
cmd += [
|
91 |
+
f'--srcdict {dict}',
|
92 |
+
f'--tgtdict {dict}',
|
93 |
+
]
|
94 |
+
cmd = " ".join(cmd)
|
95 |
+
call(cmd)
|
96 |
+
|
97 |
+
|
98 |
+
def load_tsv_km(tsv_path, km_path):
|
99 |
+
assert tsv_path.exists() and km_path.exists()
|
100 |
+
tsv_lines = open(tsv_path, "r").readlines()
|
101 |
+
root, tsv_lines = tsv_lines[0], tsv_lines[1:]
|
102 |
+
km_lines = open(km_path, "r").readlines()
|
103 |
+
assert len(tsv_lines) == len(km_lines), ".tsv and .km should be the same length!"
|
104 |
+
return root, tsv_lines, km_lines
|
105 |
+
|
106 |
+
|
107 |
+
def main():
|
108 |
+
desc = """
|
109 |
+
this script takes as input .tsv and .km files for EMOV dataset, and a pairs of emotions.
|
110 |
+
it generates parallel .tsv and .km files for these emotions. for exmaple:
|
111 |
+
❯ python build_emov_translation_manifests.py \
|
112 |
+
/checkpoint/felixkreuk/datasets/emov/manifests/emov_16khz/train.tsv \
|
113 |
+
/checkpoint/felixkreuk/datasets/emov/manifests/emov_16khz/emov_16khz_km_100/train.km \
|
114 |
+
~/tmp/emov_pairs \
|
115 |
+
--src-emotion amused --trg-emotion neutral \
|
116 |
+
--dedup --shuffle --cross-speaker --dry-run
|
117 |
+
"""
|
118 |
+
parser = argparse.ArgumentParser(description=desc)
|
119 |
+
parser.add_argument("data", type=Path, help="path to a dir containing .tsv and .km files containing emov dataset")
|
120 |
+
parser.add_argument("output_path", type=Path, help="output directory with the manifests will be created")
|
121 |
+
parser.add_argument("-cs", "--cross-speaker", action='store_true', help="if set then translation will occur also between speakers, meaning the same sentence can be translated between different speakers (default: false)")
|
122 |
+
parser.add_argument("-dd", "--dedup", action='store_true', help="remove repeated tokens (example: 'aaabc=>abc')")
|
123 |
+
parser.add_argument("-sh", "--shuffle", action='store_true', help="shuffle the data")
|
124 |
+
parser.add_argument("-ae", "--autoencode", action='store_true', help="include training pairs from the same emotion (this includes examples of the same sentence uttered by different people and examples where the src and trg are the exact same seq)")
|
125 |
+
parser.add_argument("-dr", "--dry-run", action='store_true', help="don't write anything to disk")
|
126 |
+
parser.add_argument("-zs", "--zero-shot", action='store_true', help="if true, the denoising task will train on the same splits as the translation task (split by utterance id). if false, the denoising task will train on randomly sampled splits (not split by utterance id)")
|
127 |
+
parser.add_argument("--km-ext", default="km", help="")
|
128 |
+
parser.add_argument("--dict", default="/checkpoint/felixkreuk/datasets/emov/manifests/emov_16khz/fairseq.dict.txt", help="")
|
129 |
+
args = parser.parse_args()
|
130 |
+
SPEAKERS = ["bea", "jenie", "josh", "sam", "SAME"]
|
131 |
+
EMOTIONS = ['neutral', 'amused', 'angry', 'disgusted', 'sleepy']
|
132 |
+
|
133 |
+
suffix = ""
|
134 |
+
if args.cross_speaker: suffix += "_cross-speaker"
|
135 |
+
if args.dedup: suffix += "_dedup"
|
136 |
+
translation_suffix = ""
|
137 |
+
if args.autoencode: translation_suffix += "_autoencode"
|
138 |
+
denoising_suffix = ""
|
139 |
+
denoising_suffix += "_zeroshot" if args.zero_shot else "_nonzeroshot"
|
140 |
+
|
141 |
+
translation_dir = Path(args.output_path) / ("emov_multilingual_translation" + suffix + translation_suffix)
|
142 |
+
os.makedirs(translation_dir, exist_ok=True)
|
143 |
+
denoising_dir = Path(args.output_path) / ("emov_multilingual_denoising" + suffix + denoising_suffix)
|
144 |
+
os.makedirs(denoising_dir, exist_ok=True)
|
145 |
+
|
146 |
+
denoising_data = [p.name for p in (args.data / "denoising").glob("*") if "emov" not in p.name]
|
147 |
+
|
148 |
+
for split in ["train", "valid", "test"]:
|
149 |
+
root, tsv_lines, km_lines = load_tsv_km(
|
150 |
+
tsv_path = args.data / "denoising" / "emov" / f"{split}.tsv",
|
151 |
+
km_path = args.data / "denoising" / "emov" / f"{split}.{args.km_ext}"
|
152 |
+
)
|
153 |
+
|
154 |
+
# generate data for the multilingual denoising task
|
155 |
+
for EMOTION in EMOTIONS:
|
156 |
+
print("---")
|
157 |
+
print(split)
|
158 |
+
print(f"denoising: {EMOTION}")
|
159 |
+
emotion_tsv, emotion_km = [], []
|
160 |
+
for tsv_line, km_line in zip(tsv_lines, km_lines):
|
161 |
+
if EMOTION.lower() in tsv_line.lower():
|
162 |
+
km_line = km_line if not args.dedup else dedup(km_line)
|
163 |
+
emotion_tsv.append(tsv_line)
|
164 |
+
emotion_km.append(km_line)
|
165 |
+
print(f"{len(emotion_km)} samples")
|
166 |
+
open(denoising_dir / f"files.{split}.{EMOTION}", "w").writelines([root] + emotion_tsv)
|
167 |
+
open(denoising_dir / f"{split}.{EMOTION}", "w").writelines(emotion_km)
|
168 |
+
|
169 |
+
for data in denoising_data:
|
170 |
+
with open(args.data / "denoising" / data / f"{split}.{args.km_ext}", "r") as f1:
|
171 |
+
with open(denoising_dir / f"{split}.{data}", "w") as f2:
|
172 |
+
f2.writelines([l if not args.dedup else dedup(l) for l in f1.readlines()])
|
173 |
+
|
174 |
+
# start of translation preprocessing
|
175 |
+
root, tsv_lines, km_lines = load_tsv_km(
|
176 |
+
tsv_path = args.data / "translation" / f"{split}.tsv",
|
177 |
+
km_path = args.data / "translation" / f"{split}.{args.km_ext}"
|
178 |
+
)
|
179 |
+
|
180 |
+
# generate data for the multilingual translation task
|
181 |
+
for SRC_EMOTION in EMOTIONS:
|
182 |
+
TRG_EMOTIONS = EMOTIONS if args.autoencode else set(EMOTIONS) - set([SRC_EMOTION])
|
183 |
+
for TRG_EMOTION in TRG_EMOTIONS:
|
184 |
+
# when translating back to the same emotion - we dont want these emotion
|
185 |
+
# pairs to be part of the validation/test sets (because its not really emotion conversino)
|
186 |
+
# if SRC_EMOTION == TRG_EMOTION and split in ["valid", "test"]: continue
|
187 |
+
print("---")
|
188 |
+
print(split)
|
189 |
+
print(f"src emotions: {SRC_EMOTION}\ntrg emotions: {TRG_EMOTION}")
|
190 |
+
|
191 |
+
# create a dictionary with the following structure:
|
192 |
+
# output[SPEAKER][UTT_ID] = list with indexes of line from the tsv file
|
193 |
+
# that match the speaker and utterance id. for exmaple:
|
194 |
+
# output = {'sam': {'0493': [875, 1608, 1822], ...}, ...}
|
195 |
+
# meaning, for speaker 'sam', utterance id '0493', the indexes in tsv_lines
|
196 |
+
# are 875, 1608, 1822
|
197 |
+
spkr2utts = defaultdict(lambda: defaultdict(list))
|
198 |
+
for i, tsv_line in enumerate(tsv_lines):
|
199 |
+
speaker = tsv_line.split("/")[0]
|
200 |
+
if args.cross_speaker: speaker = "SAME"
|
201 |
+
assert speaker in SPEAKERS, "unknown speaker! make sure the .tsv contains EMOV data"
|
202 |
+
utt_id = get_utt_id(tsv_line)
|
203 |
+
spkr2utts[speaker][utt_id].append(i)
|
204 |
+
|
205 |
+
# create a tsv and km files with all the combinations for translation
|
206 |
+
src_tsv, trg_tsv, src_km, trg_km = [], [], [], []
|
207 |
+
for speaker, utt_ids in spkr2utts.items():
|
208 |
+
for utt_id, indices in utt_ids.items():
|
209 |
+
# generate all pairs
|
210 |
+
pairs = [(x,y) for x in indices for y in indices]
|
211 |
+
# self-translation
|
212 |
+
if SRC_EMOTION == TRG_EMOTION:
|
213 |
+
pairs = [(x,y) for (x,y) in pairs if x == y]
|
214 |
+
# filter according to src and trg emotions
|
215 |
+
pairs = [(x,y) for (x,y) in pairs
|
216 |
+
if get_emotion(tsv_lines[x]) == SRC_EMOTION and get_emotion(tsv_lines[y]) == TRG_EMOTION]
|
217 |
+
|
218 |
+
for idx1, idx2 in pairs:
|
219 |
+
assert get_utt_id(tsv_lines[idx1]) == get_utt_id(tsv_lines[idx2])
|
220 |
+
src_tsv.append(tsv_lines[idx1])
|
221 |
+
trg_tsv.append(tsv_lines[idx2])
|
222 |
+
km_line_idx1 = km_lines[idx1]
|
223 |
+
km_line_idx2 = km_lines[idx2]
|
224 |
+
km_line_idx1 = km_line_idx1 if not args.dedup else dedup(km_line_idx1)
|
225 |
+
km_line_idx2 = km_line_idx2 if not args.dedup else dedup(km_line_idx2)
|
226 |
+
src_km.append(km_line_idx1)
|
227 |
+
trg_km.append(km_line_idx2)
|
228 |
+
assert len(src_tsv) == len(trg_tsv) == len(src_km) == len(trg_km)
|
229 |
+
print(f"{len(src_tsv)} pairs")
|
230 |
+
|
231 |
+
if len(src_tsv) == 0:
|
232 |
+
raise Exception("ERROR: generated 0 pairs!")
|
233 |
+
|
234 |
+
if args.dry_run: continue
|
235 |
+
|
236 |
+
# create files
|
237 |
+
os.makedirs(translation_dir / f"{SRC_EMOTION}-{TRG_EMOTION}", exist_ok=True)
|
238 |
+
open(translation_dir / f"{SRC_EMOTION}-{TRG_EMOTION}" / f"files.{split}.{SRC_EMOTION}", "w").writelines([root] + src_tsv)
|
239 |
+
open(translation_dir / f"{SRC_EMOTION}-{TRG_EMOTION}" / f"files.{split}.{TRG_EMOTION}", "w").writelines([root] + trg_tsv)
|
240 |
+
open(translation_dir / f"{SRC_EMOTION}-{TRG_EMOTION}" / f"{split}.{SRC_EMOTION}", "w").writelines(src_km)
|
241 |
+
open(translation_dir / f"{SRC_EMOTION}-{TRG_EMOTION}" / f"{split}.{TRG_EMOTION}", "w").writelines(trg_km)
|
242 |
+
|
243 |
+
|
244 |
+
# fairseq-preprocess the denoising data
|
245 |
+
for EMOTION in EMOTIONS + denoising_data:
|
246 |
+
denoising_preprocess(denoising_dir, EMOTION, args.dict)
|
247 |
+
os.system(f"cp {args.dict} {denoising_dir}/tokenized/dict.txt")
|
248 |
+
|
249 |
+
# fairseq-preprocess the translation data
|
250 |
+
os.makedirs(translation_dir / "tokenized", exist_ok=True)
|
251 |
+
for SRC_EMOTION in EMOTIONS:
|
252 |
+
TRG_EMOTIONS = EMOTIONS if args.autoencode else set(EMOTIONS) - set([SRC_EMOTION])
|
253 |
+
for TRG_EMOTION in TRG_EMOTIONS:
|
254 |
+
translation_preprocess(translation_dir / f"{SRC_EMOTION}-{TRG_EMOTION}", SRC_EMOTION, TRG_EMOTION, args.dict)#, only_train=SRC_EMOTION==TRG_EMOTION)
|
255 |
+
os.system(f"cp -rf {translation_dir}/**/tokenized/* {translation_dir}/tokenized")
|
256 |
+
|
257 |
+
if __name__ == "__main__":
|
258 |
+
main()
|
fairseq/examples/emotion_conversion/preprocess/create_core_manifest.py
ADDED
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pathlib import Path
|
2 |
+
import os
|
3 |
+
import sys
|
4 |
+
import subprocess
|
5 |
+
import argparse
|
6 |
+
from datetime import datetime
|
7 |
+
import logging
|
8 |
+
|
9 |
+
logging.basicConfig(
|
10 |
+
level=logging.INFO,
|
11 |
+
format='%(asctime)s [%(levelname)s] %(message)s',
|
12 |
+
handlers=[logging.FileHandler('debug.log'), logging.StreamHandler()]
|
13 |
+
)
|
14 |
+
logger = logging.getLogger(__name__)
|
15 |
+
|
16 |
+
|
17 |
+
def verify_dict_size(km, dict):
|
18 |
+
logger.info(f"verifying: {km}")
|
19 |
+
dict_size = len(open(dict, "r").readlines())
|
20 |
+
km_vocab = set(open(km, "r").read().replace("\n", " ").split(" "))
|
21 |
+
if "" in km_vocab: km_vocab.remove("")
|
22 |
+
km_vocab_size = len(km_vocab)
|
23 |
+
return dict_size == km_vocab_size
|
24 |
+
|
25 |
+
|
26 |
+
def verify_files_exist(l):
|
27 |
+
for f in l:
|
28 |
+
if not f.exists():
|
29 |
+
logging.error(f"{f} doesn't exist!")
|
30 |
+
return False
|
31 |
+
return True
|
32 |
+
|
33 |
+
|
34 |
+
def run_cmd(cmd, print_output=True):
|
35 |
+
try:
|
36 |
+
out = subprocess.check_output(cmd, stderr=subprocess.STDOUT, universal_newlines=True, shell=True)
|
37 |
+
if print_output:
|
38 |
+
logger.info(f"command output:\n{out}")
|
39 |
+
return out
|
40 |
+
except subprocess.CalledProcessError as grepexc:
|
41 |
+
logger.info(f"error executing command!:\n{cmd}")
|
42 |
+
logger.info(grepexc.output)
|
43 |
+
|
44 |
+
def main():
|
45 |
+
parser = argparse.ArgumentParser()
|
46 |
+
parser.add_argument("--tsv", default="/checkpoint/felixkreuk/datasets/emov/manifests/emov_16khz/data.tsv", type=Path)
|
47 |
+
parser.add_argument("--emov-km", required=True, type=Path)
|
48 |
+
parser.add_argument("--km", nargs='+', required=True, type=Path)
|
49 |
+
parser.add_argument("--seed", type=int, default=1)
|
50 |
+
parser.add_argument("--dict", default="/checkpoint/felixkreuk/datasets/emov/manifests/emov_16khz/fairseq.dict.txt")
|
51 |
+
parser.add_argument("--manifests-dir", type=Path, default="/checkpoint/felixkreuk/datasets/emov/manifests/emov_16khz")
|
52 |
+
args = parser.parse_args()
|
53 |
+
|
54 |
+
manifests_dir = args.manifests_dir
|
55 |
+
date = datetime.now().strftime('%d%m%y')
|
56 |
+
outdir = manifests_dir / f"{date}"
|
57 |
+
|
58 |
+
# verify input and create folders
|
59 |
+
all_kms = args.km + [args.emov_km]
|
60 |
+
assert verify_files_exist(all_kms), "make sure the km dir contains: train-clean-all.km, blizzard2013.km, data.km"
|
61 |
+
for codes in all_kms:
|
62 |
+
assert verify_dict_size(codes, args.dict), "dict argument doesn't match the vocabulary of the km file!"
|
63 |
+
assert not outdir.exists(), "data dir already exists!"
|
64 |
+
outdir.mkdir(parents=True, exist_ok=True)
|
65 |
+
|
66 |
+
logger.info("generating denoising split (emov)")
|
67 |
+
run_cmd(f"python preprocess/split_km_tsv.py {args.tsv} {args.emov_km} --destdir {outdir}/denoising/emov -sh --seed {args.seed}")
|
68 |
+
for codes in args.km:
|
69 |
+
codes_name = os.path.basename(codes)
|
70 |
+
run_cmd(f"python preprocess/split_km.py {codes} --destdir {outdir}/denoising/{codes_name} -sh --seed {args.seed}")
|
71 |
+
|
72 |
+
logger.info("generating translation split")
|
73 |
+
run_cmd(f"python preprocess/split_emov_km_tsv_by_uttid.py {args.tsv} {args.emov_km} --destdir {outdir}/translation --seed {args.seed}")
|
74 |
+
|
75 |
+
emov_code_name = os.path.basename(args.emov_km)
|
76 |
+
logger.info("generating hifigan split")
|
77 |
+
run_cmd(
|
78 |
+
f"mkdir -p {outdir}/hifigan &&"
|
79 |
+
f"python preprocess/build_hifigan_manifest.py --km_type hubert --tsv {outdir}/denoising/emov/train.tsv --km {outdir}/denoising/emov/train.km > {outdir}/hifigan/train.txt &&"
|
80 |
+
f"python preprocess/build_hifigan_manifest.py --km_type hubert --tsv {outdir}/denoising/emov/valid.tsv --km {outdir}/denoising/emov/valid.km > {outdir}/hifigan/valid.txt &&"
|
81 |
+
f"python preprocess/build_hifigan_manifest.py --km_type hubert --tsv {outdir}/denoising/emov/test.tsv --km {outdir}/denoising/emov/test.km > {outdir}/hifigan/test.txt"
|
82 |
+
)
|
83 |
+
|
84 |
+
logger.info("generating fairseq manifests")
|
85 |
+
run_cmd(f"python preprocess/build_translation_manifests.py {outdir} {outdir}/fairseq-data -dd -cs --dict {args.dict}")
|
86 |
+
|
87 |
+
logger.info(f"finished processing data at:\n{outdir}")
|
88 |
+
|
89 |
+
|
90 |
+
if __name__ == "__main__":
|
91 |
+
main()
|
fairseq/examples/emotion_conversion/preprocess/extract_f0.py
ADDED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
from tqdm import tqdm
|
3 |
+
from multiprocessing import Manager, Pool
|
4 |
+
|
5 |
+
from scipy.io.wavfile import read
|
6 |
+
from librosa.util import normalize
|
7 |
+
import numpy as np
|
8 |
+
import amfm_decompy.pYAAPT as pYAAPT
|
9 |
+
import amfm_decompy.basic_tools as basic
|
10 |
+
|
11 |
+
MAX_WAV_VALUE = 32768.0
|
12 |
+
|
13 |
+
parser = argparse.ArgumentParser(description="")
|
14 |
+
parser.add_argument("tsv", help="")
|
15 |
+
parser.add_argument("--extractor", choices=["crepe", "pyaapt"], default="pyaapt", help="")
|
16 |
+
parser.add_argument("--interp", action="store_true", help="")
|
17 |
+
parser.add_argument("--n_workers", type=int, default=40, help="")
|
18 |
+
args = parser.parse_args()
|
19 |
+
|
20 |
+
tsv_lines = open(args.tsv, "r").readlines()
|
21 |
+
root, tsv_lines = tsv_lines[0].strip(), tsv_lines[1:]
|
22 |
+
|
23 |
+
|
24 |
+
def extract_f0(tsv_line):
|
25 |
+
wav_path, _ = tsv_line.split("\t")
|
26 |
+
wav_path = root.strip() + "/" + wav_path
|
27 |
+
sr, wav = read(wav_path)
|
28 |
+
wav = wav / MAX_WAV_VALUE
|
29 |
+
wav = normalize(wav) * 0.95
|
30 |
+
|
31 |
+
if args.extractor == "pyaapt":
|
32 |
+
frame_length = 20.0
|
33 |
+
pad = int(frame_length / 1000 * sr) // 2
|
34 |
+
wav = np.pad(wav.squeeze(), (pad, pad), "constant", constant_values=0)
|
35 |
+
signal = basic.SignalObj(wav, sr)
|
36 |
+
pitch = pYAAPT.yaapt(
|
37 |
+
signal,
|
38 |
+
**{
|
39 |
+
'frame_length': frame_length,
|
40 |
+
'frame_space': 5.0,
|
41 |
+
'nccf_thresh1': 0.25,
|
42 |
+
'tda_frame_length': 25.0
|
43 |
+
})
|
44 |
+
pitch = pitch.samp_interp[None, None, :] if args.interp else pitch.samp_values[None, None, :]
|
45 |
+
pitch = pitch[0, 0]
|
46 |
+
f0_path = wav_path.replace(".wav", ".yaapt")
|
47 |
+
f0_path += ".interp.f0" if args.interp else ".f0"
|
48 |
+
np.save(f0_path, pitch)
|
49 |
+
|
50 |
+
|
51 |
+
def main():
|
52 |
+
with Pool(args.n_workers) as p:
|
53 |
+
r = list(tqdm(p.imap(extract_f0, tsv_lines), total=len(tsv_lines)))
|
54 |
+
|
55 |
+
|
56 |
+
if __name__ == "__main__":
|
57 |
+
main()
|
fairseq/examples/emotion_conversion/preprocess/process_km.py
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
import argparse
|
3 |
+
from tqdm import tqdm
|
4 |
+
from build_emov_translation_manifests import dedup, remove_under_k
|
5 |
+
|
6 |
+
|
7 |
+
if __name__ == "__main__":
|
8 |
+
"""
|
9 |
+
this is a standalone script to process a km file
|
10 |
+
specifically, to dedup or remove tokens that repeat less
|
11 |
+
than k times in a row
|
12 |
+
"""
|
13 |
+
parser = argparse.ArgumentParser(description="")
|
14 |
+
parser.add_argument("km", type=str, help="path to km file")
|
15 |
+
parser.add_argument("--dedup", action='store_true')
|
16 |
+
parser.add_argument("--remove-under-k", type=int, default=0)
|
17 |
+
parser.add_argument("--output", default=None)
|
18 |
+
args = parser.parse_args()
|
19 |
+
|
20 |
+
if not args.dedup and args.remove_under_k == 0:
|
21 |
+
print("nothing to do! quitting...")
|
22 |
+
sys.exit(0)
|
23 |
+
|
24 |
+
km = open(args.km, "r").readlines()
|
25 |
+
out = []
|
26 |
+
for line in tqdm(km):
|
27 |
+
if args.remove_under_k > 0:
|
28 |
+
line = remove_under_k(line, args.remove_under_k)
|
29 |
+
if args.dedup:
|
30 |
+
line = dedup(line)
|
31 |
+
out.append(line)
|
32 |
+
|
33 |
+
path = args.km if args.output is None else args.output
|
34 |
+
if args.remove_under_k > 0:
|
35 |
+
path = path.replace(".km", f"-k{args.remove_under_k}.km")
|
36 |
+
if args.dedup:
|
37 |
+
path = path.replace(".km", f"-deduped.km")
|
38 |
+
|
39 |
+
open(path, "w").writelines(out)
|
40 |
+
print(f"written to {path}")
|
fairseq/examples/emotion_conversion/preprocess/split_emov_km_tsv_by_uttid.py
ADDED
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pathlib import Path
|
2 |
+
import os
|
3 |
+
import sys
|
4 |
+
import argparse
|
5 |
+
import random
|
6 |
+
import numpy as np
|
7 |
+
from tqdm import tqdm
|
8 |
+
from sklearn.model_selection import train_test_split
|
9 |
+
from build_translation_manifests import get_utt_id
|
10 |
+
|
11 |
+
|
12 |
+
def train_val_test_split(tsv_lines, km_lines, valid_percent, test_percent, seed=42):
|
13 |
+
utt_ids = list(sorted(set([get_utt_id(x) for x in tsv_lines])))
|
14 |
+
utt_ids, valid_utt_ids, _, _ = train_test_split(utt_ids, utt_ids, test_size=valid_percent, shuffle=True, random_state=seed)
|
15 |
+
train_utt_ids, test_utt_ids, _, _ = train_test_split(utt_ids, utt_ids, test_size=test_percent, shuffle=True, random_state=seed)
|
16 |
+
|
17 |
+
train_idx = [i for i, line in enumerate(tsv_lines) if get_utt_id(line) in train_utt_ids]
|
18 |
+
valid_idx = [i for i, line in enumerate(tsv_lines) if get_utt_id(line) in valid_utt_ids]
|
19 |
+
test_idx = [i for i, line in enumerate(tsv_lines) if get_utt_id(line) in test_utt_ids]
|
20 |
+
|
21 |
+
train_tsv, train_km = [tsv_lines[i] for i in train_idx], [km_lines[i] for i in train_idx]
|
22 |
+
valid_tsv, valid_km = [tsv_lines[i] for i in valid_idx], [km_lines[i] for i in valid_idx]
|
23 |
+
test_tsv, test_km = [tsv_lines[i] for i in test_idx], [km_lines[i] for i in test_idx]
|
24 |
+
|
25 |
+
print(f"train {len(train_km)}")
|
26 |
+
print(f"valid {len(valid_km)}")
|
27 |
+
print(f"test {len(test_km)}")
|
28 |
+
|
29 |
+
return train_tsv, train_km, valid_tsv, valid_km, test_tsv, test_km
|
30 |
+
|
31 |
+
|
32 |
+
if __name__ == "__main__":
|
33 |
+
"""
|
34 |
+
this is a standalone script to process a km file
|
35 |
+
specifically, to dedup or remove tokens that repeat less
|
36 |
+
than k times in a row
|
37 |
+
"""
|
38 |
+
parser = argparse.ArgumentParser(description="")
|
39 |
+
parser.add_argument("tsv", type=str, help="path to tsv file")
|
40 |
+
parser.add_argument("km", type=str, help="path to km file")
|
41 |
+
parser.add_argument("--destdir", required=True, type=str)
|
42 |
+
parser.add_argument("--valid-percent", type=float, default=0.05, help="percent to allocate to validation set")
|
43 |
+
parser.add_argument("--test-percent", type=float, default=0.05, help="percent to allocate to test set")
|
44 |
+
parser.add_argument("--seed", type=int, default=42, help="")
|
45 |
+
args = parser.parse_args()
|
46 |
+
|
47 |
+
np.random.seed(args.seed)
|
48 |
+
random.seed(args.seed)
|
49 |
+
|
50 |
+
os.makedirs(args.destdir, exist_ok=True)
|
51 |
+
km = open(args.km, "r").readlines()
|
52 |
+
tsv = open(args.tsv, "r").readlines()
|
53 |
+
root, tsv = tsv[0], tsv[1:]
|
54 |
+
|
55 |
+
assert args.tsv.endswith(".tsv") and args.km.endswith(".km")
|
56 |
+
assert len(tsv) == len(km)
|
57 |
+
|
58 |
+
train_tsv, train_km, valid_tsv, valid_km, test_tsv, test_km = train_val_test_split(tsv, km, args.valid_percent, args.test_percent, args.seed)
|
59 |
+
|
60 |
+
assert len(train_tsv) + len(valid_tsv) + len(test_tsv) == len(tsv)
|
61 |
+
assert len(train_tsv) == len(train_km) and len(valid_tsv) == len(valid_km) and len(test_tsv) == len(test_km)
|
62 |
+
|
63 |
+
dir = Path(args.destdir)
|
64 |
+
open(dir / f"train.tsv", "w").writelines([root] + train_tsv)
|
65 |
+
open(dir / f"valid.tsv", "w").writelines([root] + valid_tsv)
|
66 |
+
open(dir / f"test.tsv", "w").writelines([root] + test_tsv)
|
67 |
+
open(dir / f"train.km", "w").writelines(train_km)
|
68 |
+
open(dir / f"valid.km", "w").writelines(valid_km)
|
69 |
+
open(dir / f"test.km", "w").writelines(test_km)
|
70 |
+
print("done")
|
fairseq/examples/emotion_conversion/preprocess/split_km.py
ADDED
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pathlib import Path
|
2 |
+
import os
|
3 |
+
import argparse
|
4 |
+
import random
|
5 |
+
import numpy as np
|
6 |
+
from sklearn.utils import shuffle
|
7 |
+
|
8 |
+
|
9 |
+
if __name__ == "__main__":
|
10 |
+
"""
|
11 |
+
this is a standalone script to process a km file
|
12 |
+
specifically, to dedup or remove tokens that repeat less
|
13 |
+
than k times in a row
|
14 |
+
"""
|
15 |
+
parser = argparse.ArgumentParser(description="")
|
16 |
+
parser.add_argument("km", type=str, help="path to km file")
|
17 |
+
parser.add_argument("--destdir", required=True, type=str)
|
18 |
+
parser.add_argument("--valid-percent", type=float, default=0.05, help="percent to allocate to validation set")
|
19 |
+
parser.add_argument("--test-percent", type=float, default=0.05, help="percent to allocate to test set")
|
20 |
+
parser.add_argument("-sh", "--shuffle", action="store_true", help="path to km file")
|
21 |
+
parser.add_argument("--seed", type=int, default=42, help="")
|
22 |
+
args = parser.parse_args()
|
23 |
+
|
24 |
+
np.random.seed(args.seed)
|
25 |
+
random.seed(args.seed)
|
26 |
+
|
27 |
+
os.makedirs(args.destdir, exist_ok=True)
|
28 |
+
km = open(args.km, "r").readlines()
|
29 |
+
|
30 |
+
if args.shuffle:
|
31 |
+
km = shuffle(km)
|
32 |
+
print(f"shuffled")
|
33 |
+
|
34 |
+
N = len(km)
|
35 |
+
N_tt = int(N * args.test_percent)
|
36 |
+
N_cv = int(N * args.valid_percent)
|
37 |
+
N_tr = N - N_tt - N_cv
|
38 |
+
|
39 |
+
train_km = km[:N_tr]
|
40 |
+
valid_km = km[N_tr:N_tr + N_cv]
|
41 |
+
test_km = km[N_tr + N_cv:]
|
42 |
+
|
43 |
+
dir = Path(args.destdir)
|
44 |
+
open(dir / f"train.km", "w").writelines(train_km)
|
45 |
+
open(dir / f"valid.km", "w").writelines(valid_km)
|
46 |
+
open(dir / f"test.km", "w").writelines(test_km)
|
47 |
+
print(f"train: {len(train_km)}")
|
48 |
+
print(f"valid: {len(valid_km)}")
|
49 |
+
print(f"test: {len(test_km)}")
|
50 |
+
print("done")
|
fairseq/examples/emotion_conversion/preprocess/split_km_tsv.py
ADDED
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pathlib import Path
|
2 |
+
import os
|
3 |
+
import argparse
|
4 |
+
import random
|
5 |
+
import numpy as np
|
6 |
+
from sklearn.utils import shuffle
|
7 |
+
|
8 |
+
|
9 |
+
if __name__ == "__main__":
|
10 |
+
"""
|
11 |
+
this is a standalone script to process a km file
|
12 |
+
specifically, to dedup or remove tokens that repeat less
|
13 |
+
than k times in a row
|
14 |
+
"""
|
15 |
+
parser = argparse.ArgumentParser(description="")
|
16 |
+
parser.add_argument("tsv", type=str, help="path to tsv file")
|
17 |
+
parser.add_argument("km", type=str, help="path to km file")
|
18 |
+
parser.add_argument("--destdir", required=True, type=str)
|
19 |
+
parser.add_argument("--valid-percent", type=float, default=0.05, help="percent to allocate to validation set")
|
20 |
+
parser.add_argument("--test-percent", type=float, default=0.05, help="percent to allocate to test set")
|
21 |
+
parser.add_argument("-sh", "--shuffle", action="store_true", help="path to km file")
|
22 |
+
parser.add_argument("--seed", type=int, default=42, help="")
|
23 |
+
args = parser.parse_args()
|
24 |
+
|
25 |
+
np.random.seed(args.seed)
|
26 |
+
random.seed(args.seed)
|
27 |
+
|
28 |
+
os.makedirs(args.destdir, exist_ok=True)
|
29 |
+
km = open(args.km, "r").readlines()
|
30 |
+
tsv = open(args.tsv, "r").readlines()
|
31 |
+
root, tsv = tsv[0], tsv[1:]
|
32 |
+
|
33 |
+
assert args.tsv.endswith(".tsv") and args.km.endswith(".km")
|
34 |
+
assert len(tsv) == len(km)
|
35 |
+
|
36 |
+
if args.shuffle:
|
37 |
+
tsv, km = shuffle(tsv, km)
|
38 |
+
print(f"shuffled")
|
39 |
+
|
40 |
+
N = len(tsv)
|
41 |
+
N_tt = int(N * args.test_percent)
|
42 |
+
N_cv = int(N * args.valid_percent)
|
43 |
+
N_tr = N - N_tt - N_cv
|
44 |
+
|
45 |
+
train_tsv = tsv[:N_tr]
|
46 |
+
valid_tsv = tsv[N_tr:N_tr + N_cv]
|
47 |
+
test_tsv = tsv[N_tr + N_cv:]
|
48 |
+
train_km = km[:N_tr]
|
49 |
+
valid_km = km[N_tr:N_tr + N_cv]
|
50 |
+
test_km = km[N_tr + N_cv:]
|
51 |
+
|
52 |
+
assert len(train_tsv) + len(valid_tsv) + len(test_tsv) == len(tsv)
|
53 |
+
assert len(train_tsv) == len(train_km) and len(valid_tsv) == len(valid_km) and len(test_tsv) == len(test_km)
|
54 |
+
|
55 |
+
dir = Path(args.destdir)
|
56 |
+
open(dir / f"train.tsv", "w").writelines([root] + train_tsv)
|
57 |
+
open(dir / f"valid.tsv", "w").writelines([root] + valid_tsv)
|
58 |
+
open(dir / f"test.tsv", "w").writelines([root] + test_tsv)
|
59 |
+
open(dir / f"train.km", "w").writelines(train_km)
|
60 |
+
open(dir / f"valid.km", "w").writelines(valid_km)
|
61 |
+
open(dir / f"test.km", "w").writelines(test_km)
|
62 |
+
print(f"train: {len(train_km)}")
|
63 |
+
print(f"valid: {len(valid_km)}")
|
64 |
+
print(f"test: {len(test_km)}")
|
65 |
+
print("done")
|
fairseq/examples/fast_noisy_channel/README.md
ADDED
@@ -0,0 +1,345 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Language Models not just for Pre-training: Fast Online Neural Noisy Channel Modeling
|
2 |
+
|
3 |
+
## Introduction
|
4 |
+
- [Yee et al. (2019)](https://www.aclweb.org/anthology/D19-1571.pdf) introduce a simple and effective noisy channel modeling approach for neural machine translation. However, the noisy channel online decoding approach introduced in this paper is too slow to be practical.
|
5 |
+
- To address this, [Bhosale et al. (2020)](http://www.statmt.org/wmt20/pdf/2020.wmt-1.68.pdf) introduces 3 simple approximations to make this approach very fast and practical without much loss in accuracy.
|
6 |
+
- This README provides intructions on how to run online decoding or generation with the noisy channel modeling approach, including ways to make it very fast without much loss in accuracy.
|
7 |
+
|
8 |
+
## Noisy Channel Modeling
|
9 |
+
|
10 |
+
[Yee et al. (2019)](https://www.aclweb.org/anthology/D19-1571.pdf) applies the Bayes Rule to predict `P(y|x)`, the probability of the target `y` given the source `x`.
|
11 |
+
```P(y|x) = P(x|y) * P(y) / P(x)```
|
12 |
+
- `P(x|y)` predicts the source `x` given the target `y` and is referred to as the **channel model**
|
13 |
+
- `P(y)` is a **language model** over the target `y`
|
14 |
+
- `P(x)` is generally not modeled since it is constant for all `y`.
|
15 |
+
|
16 |
+
We use Transformer models to parameterize the direct model `P(y|x)`, the channel model `P(x|y)` and the language model `P(y)`.
|
17 |
+
|
18 |
+
During online decoding with beam search, we generate the top `K2` candidates per beam and score them with the following linear combination of the channel model, the language model as well as the direct model scores.
|
19 |
+
|
20 |
+
```(1 / t) * log(P(y|x) + (1 / s) * ( λ1 * log(P(x|y)) + λ2 * log(P(y) ) )```
|
21 |
+
- `t` - Target Prefix Length
|
22 |
+
- `s` - Source Length
|
23 |
+
- `λ1` - Channel Model Weight
|
24 |
+
- `λ2` - Language Model Weight
|
25 |
+
|
26 |
+
The top `beam_size` candidates based on the above combined scores are chosen to continue the beams in beam search. In beam search with a direct model alone, the scores from the direct model `P(y|x)` are used to choose the top candidates in beam search.
|
27 |
+
|
28 |
+
This framework provides a great way to utlize strong target language models trained on large amounts of unlabeled data. Language models can prefer targets unrelated to the source, so we also need a channel model whose role is to ensure that the target preferred by the language model also translates back to the source.
|
29 |
+
|
30 |
+
### Training Translation Models and Language Models
|
31 |
+
|
32 |
+
For training Transformer models in fairseq for machine translation, refer to instructions [here](https://github.com/pytorch/fairseq/tree/main/examples/translation)
|
33 |
+
|
34 |
+
For training Transformer models in fairseq for language modeling, refer to instructions [here](https://github.com/pytorch/fairseq/tree/main/examples/language_model)
|
35 |
+
|
36 |
+
### Generation with Language Model for German-English translation with fairseq
|
37 |
+
|
38 |
+
Here are instructions to generate using a direct model and a target-side language model.
|
39 |
+
|
40 |
+
Note:
|
41 |
+
- Download and install fairseq as per instructions [here](https://github.com/pytorch/fairseq)
|
42 |
+
- Preprocess and binarize the dataset as per instructions in section [Test Data Preprocessing](#test-data-preprocessing)
|
43 |
+
|
44 |
+
```sh
|
45 |
+
binarized_data=data_dir/binarized
|
46 |
+
direct_model=de_en_seed4.pt
|
47 |
+
lm_model=en_lm.pt
|
48 |
+
lm_data=lm_data
|
49 |
+
wget https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/direct_models/seed4.pt -O ${direct_model}
|
50 |
+
wget https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/lm_model/transformer_lm.pt -O ${lm_model}
|
51 |
+
mkdir -p ${lm_data}
|
52 |
+
wget https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/lm_model/lm_dict/dict.txt -O ${lm_data}/dict.txt
|
53 |
+
|
54 |
+
k2=10
|
55 |
+
lenpen=0.16
|
56 |
+
lm_wt=0.14
|
57 |
+
fairseq-generate ${binarized_data} \
|
58 |
+
--user-dir examples/fast_noisy_channel \
|
59 |
+
--beam 5 \
|
60 |
+
--path ${direct_model} \
|
61 |
+
--lm-model ${lm_model} \
|
62 |
+
--lm-data ${lm_data} \
|
63 |
+
--k2 ${k2} \
|
64 |
+
--combine-method lm_only \
|
65 |
+
--task noisy_channel_translation \
|
66 |
+
--lenpen ${lenpen} \
|
67 |
+
--lm-wt ${lm_wt} \
|
68 |
+
--gen-subset valid \
|
69 |
+
--remove-bpe \
|
70 |
+
--fp16 \
|
71 |
+
--batch-size 10
|
72 |
+
```
|
73 |
+
### Noisy Channel Generation for German-English translation with fairseq
|
74 |
+
|
75 |
+
Here are instructions for noisy channel generation with a direct model, channel model and language model as explained in section [Noisy Channel Modeling](#noisy-channel-modeling).
|
76 |
+
|
77 |
+
Note:
|
78 |
+
- Download and install fairseq as per instructions [here](https://github.com/pytorch/fairseq)
|
79 |
+
- Preprocess and binarize the dataset as per instructions in section [Test Data Preprocessing](#test-data-preprocessing)
|
80 |
+
|
81 |
+
```sh
|
82 |
+
binarized_data=data_dir/binarized
|
83 |
+
direct_model=de_en_seed4.pt
|
84 |
+
lm_model=en_lm.pt
|
85 |
+
lm_data=lm_data
|
86 |
+
ch_model=en_de.big.seed4.pt
|
87 |
+
wget https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/direct_models/seed4.pt -O ${direct_model}
|
88 |
+
wget https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/lm_model/transformer_lm.pt -O ${lm_model}
|
89 |
+
mkdir -p ${lm_data}
|
90 |
+
wget https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/lm_model/lm_dict/dict.txt -O ${lm_data}/dict.txt
|
91 |
+
wget https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/big.seed4.pt -O ${ch_model}
|
92 |
+
|
93 |
+
k2=10
|
94 |
+
lenpen=0.21
|
95 |
+
lm_wt=0.50
|
96 |
+
bw_wt=0.30
|
97 |
+
fairseq-generate ${binarized_data} \
|
98 |
+
--user-dir examples/fast_noisy_channel \
|
99 |
+
--beam 5 \
|
100 |
+
--path ${direct_model} \
|
101 |
+
--lm-model ${lm_model} \
|
102 |
+
--lm-data ${lm_data} \
|
103 |
+
--channel-model ${ch_model} \
|
104 |
+
--k2 ${k2} \
|
105 |
+
--combine-method noisy_channel \
|
106 |
+
--task noisy_channel_translation \
|
107 |
+
--lenpen ${lenpen} \
|
108 |
+
--lm-wt ${lm_wt} \
|
109 |
+
--ch-wt ${bw_wt} \
|
110 |
+
--gen-subset test \
|
111 |
+
--remove-bpe \
|
112 |
+
--fp16 \
|
113 |
+
--batch-size 1
|
114 |
+
```
|
115 |
+
## Fast Noisy Channel Modeling
|
116 |
+
|
117 |
+
[Bhosale et al. (2020)](http://www.statmt.org/wmt20/pdf/2020.wmt-1.68.pdf) introduces 3 approximations that speed up online noisy channel decoding -
|
118 |
+
- Smaller channel models (`Tranformer Base` with 1 encoder and decoder layer each vs. `Transformer Big`)
|
119 |
+
- This involves training a channel model that is possibly smaller and less accurate in terms of BLEU than a channel model of the same size as the direct model.
|
120 |
+
- Since the role of the channel model is mainly to assign low scores to generations from the language model if they don't translate back to the source, we may not need the most accurate channel model for this purpose.
|
121 |
+
- Smaller output vocabulary size for the channel model (~30,000 -> ~1000)
|
122 |
+
- The channel model doesn't need to score the full output vocabulary, it just needs to score the source tokens, which are completely known.
|
123 |
+
- This is specified using the arguments `--channel-scoring-type src_vocab --top-k-vocab 500`
|
124 |
+
- This means that the output vocabulary for the channel model will be the source tokens for all examples in the batch and the top-K most frequent tokens in the vocabulary
|
125 |
+
- This reduces the memory consumption needed to store channel model scores significantly
|
126 |
+
- Smaller number of candidates (`k2`) scored per beam
|
127 |
+
- This is specified by reducing the argument `--k2`
|
128 |
+
|
129 |
+
|
130 |
+
### Fast Noisy Channel Generation for German-English translation with fairseq
|
131 |
+
|
132 |
+
Here are instructions for **fast** noisy channel generation with a direct model, channel model and language model as explained in section [Fast Noisy Channel Modeling](#fast-noisy-channel-modeling). The main differences are that we use a smaller channel model, reduce `--k2`, set `--channel-scoring-type src_vocab --top-k-vocab 500` and increase the `--batch-size`.
|
133 |
+
|
134 |
+
Note:
|
135 |
+
- Download and install fairseq as per instructions [here](https://github.com/pytorch/fairseq)
|
136 |
+
- Preprocess and binarize the dataset as per instructions in section [Test Data Preprocessing](#test-data-preprocessing)
|
137 |
+
|
138 |
+
```sh
|
139 |
+
binarized_data=data_dir/binarized
|
140 |
+
direct_model=de_en_seed4.pt
|
141 |
+
lm_model=en_lm.pt
|
142 |
+
lm_data=lm_data
|
143 |
+
small_ch_model=en_de.base_1_1.seed4.pt
|
144 |
+
wget https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/direct_models/seed4.pt -O ${direct_model}
|
145 |
+
wget https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/lm_model/transformer_lm.pt -O ${lm_model}
|
146 |
+
mkdir -p ${lm_data}
|
147 |
+
wget https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/lm_model/lm_dict/dict.txt -O ${lm_data}/dict.txt
|
148 |
+
wget https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/base_1_1.seed4.pt -O ${small_ch_model}
|
149 |
+
|
150 |
+
k2=3
|
151 |
+
lenpen=0.23
|
152 |
+
lm_wt=0.58
|
153 |
+
bw_wt=0.26
|
154 |
+
fairseq-generate ${binarized_data} \
|
155 |
+
--user-dir examples/fast_noisy_channel \
|
156 |
+
--beam 5 \
|
157 |
+
--path ${direct_model} \
|
158 |
+
--lm-model ${lm_model} \
|
159 |
+
--lm-data ${lm_data} \
|
160 |
+
--channel-model ${small_ch_model} \
|
161 |
+
--k2 ${k2} \
|
162 |
+
--combine-method noisy_channel \
|
163 |
+
--task noisy_channel_translation \
|
164 |
+
--lenpen ${lenpen} \
|
165 |
+
--lm-wt ${lm_wt} \
|
166 |
+
--ch-wt ${bw_wt} \
|
167 |
+
--gen-subset test \
|
168 |
+
--remove-bpe \
|
169 |
+
--fp16 \
|
170 |
+
--batch-size 50 \
|
171 |
+
--channel-scoring-type src_vocab --top-k-vocab 500
|
172 |
+
```
|
173 |
+
|
174 |
+
## Test Data Preprocessing
|
175 |
+
|
176 |
+
For preprocessing and binarizing the test sets for Romanian-English and German-English translation, we use the following script -
|
177 |
+
|
178 |
+
```sh
|
179 |
+
FAIRSEQ=/path/to/fairseq
|
180 |
+
cd $FAIRSEQ
|
181 |
+
SCRIPTS=$FAIRSEQ/mosesdecoder/scripts
|
182 |
+
if [ ! -d "${SCRIPTS}" ]; then
|
183 |
+
echo 'Cloning Moses github repository (for tokenization scripts)...'
|
184 |
+
git clone https://github.com/moses-smt/mosesdecoder.git
|
185 |
+
fi
|
186 |
+
TOKENIZER=$SCRIPTS/tokenizer/tokenizer.perl
|
187 |
+
NORMALIZE=$SCRIPTS/tokenizer/normalize-punctuation.perl
|
188 |
+
|
189 |
+
s=de
|
190 |
+
t=en
|
191 |
+
test=wmt18
|
192 |
+
|
193 |
+
mkdir -p data_dir
|
194 |
+
|
195 |
+
# Tokenization
|
196 |
+
if [ $s == "ro" ] ; then
|
197 |
+
# Note: Get normalise-romanian.py and remove-diacritics.py from
|
198 |
+
# https://github.com/rsennrich/wmt16-scripts/tree/master/preprocess
|
199 |
+
sacrebleu -t $test -l $s-$t --echo src | \
|
200 |
+
$NORMALIZE -l $s | \
|
201 |
+
python normalise-romanian.py | \
|
202 |
+
python remove-diacritics.py | \
|
203 |
+
$TOKENIZER -l $s -a -q > data_dir/$test.$s-$t.$s
|
204 |
+
else
|
205 |
+
sacrebleu -t $test -l $s-$t --echo src | perl $NORMALIZE -l $s | perl $TOKENIZER -threads 8 -a -l $s > data_dir/$test.$s-$t.$s
|
206 |
+
fi
|
207 |
+
|
208 |
+
sacrebleu -t $test -l $s-$t --echo ref | perl $NORMALIZE -l $t | perl $TOKENIZER -threads 8 -a -l $t > data_dir/$test.$s-$t.$t
|
209 |
+
|
210 |
+
|
211 |
+
# Applying BPE
|
212 |
+
src_bpe_code=/path/to/source/language/bpe/code
|
213 |
+
tgt_bpe_code=/path/to/target/language/bpe/code
|
214 |
+
src_dict=/path/to/source/language/dict
|
215 |
+
tgt_dict=/path/to/target/language/dict
|
216 |
+
|
217 |
+
FASTBPE=$FAIRSEQ/fastBPE
|
218 |
+
if [ ! -d "${FASTBPE}" ] ; then
|
219 |
+
git clone https://github.com/glample/fastBPE.git
|
220 |
+
# Follow compilation instructions at https://github.com/glample/fastBPE
|
221 |
+
g++ -std=c++11 -pthread -O3 fastBPE/main.cc -IfastBPE -o fast
|
222 |
+
fi
|
223 |
+
|
224 |
+
${FASTBPE}/fast applybpe data_dir/bpe.$test.$s-$t.$s data_dir/$test.$s-$t.$s ${src_bpe_code}
|
225 |
+
${FASTBPE}/fast applybpe data_dir/bpe.$test.$s-$t.$s data_dir/$test.$s-$t.$s ${tgt_bpe_code}
|
226 |
+
|
227 |
+
fairseq-preprocess -s $s -t $t \
|
228 |
+
--testpref data_dir/bpe.$test.$s-$t \
|
229 |
+
--destdir data_dir/binarized \
|
230 |
+
--srcdict ${src_dict} \
|
231 |
+
--tgtdict ${tgt_dict}
|
232 |
+
```
|
233 |
+
|
234 |
+
## Calculating BLEU
|
235 |
+
|
236 |
+
```sh
|
237 |
+
DETOKENIZER=$SCRIPTS/tokenizer/detokenizer.perl
|
238 |
+
cat ${generation_output} | grep -P "^H" | sort -V | cut -f 3- | $DETOKENIZER -l $t -q -a | sacrebleu -t $test -l $s-$t
|
239 |
+
```
|
240 |
+
|
241 |
+
|
242 |
+
## Romanian-English Translation
|
243 |
+
|
244 |
+
The direct and channel models are trained using bitext data (WMT16) combined with backtranslated data (The monolingual data used for backtranslation comes from http://data.statmt.org/rsennrich/wmt16_backtranslations/ (Sennrich et al., 2016c))
|
245 |
+
|
246 |
+
The backtranslated data is generated using an ensemble of 3 English-Romanian models trained on bitext training data (WMT16) with unrestricted sampling.
|
247 |
+
|
248 |
+
### BPE Codes and Dictionary
|
249 |
+
|
250 |
+
We learn a joint BPE vocabulary of 18K types on the bitext training data which is used for both the source and target.
|
251 |
+
||Path|
|
252 |
+
|----------|------|
|
253 |
+
| BPE Code | [joint_bpe_18k](https://dl.fbaipublicfiles.com/fast_noisy_channel/ro_en/bpe_18k) |
|
254 |
+
| Dictionary | [dict](https://dl.fbaipublicfiles.com/fast_noisy_channel/ro_en/dict) |
|
255 |
+
|
256 |
+
### Direct Models
|
257 |
+
For Ro-En with backtranslation, the direct and channel models use a Transformer-Big architecture.
|
258 |
+
|
259 |
+
| Seed | Model |
|
260 |
+
|----|----|
|
261 |
+
| 2 | [ro_en_seed2.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/ro_en/direct_models/seed2.pt)
|
262 |
+
| 4 | [ro_en_seed4.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/ro_en/direct_models/seed4.pt)
|
263 |
+
| 6 | [ro_en_seed6.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/ro_en/direct_models/seed6.pt)
|
264 |
+
|
265 |
+
### Channel Models
|
266 |
+
For channel models, we follow the same steps as for the direct models. But backtranslated data is generated in the opposite direction using [this Romanian monolingual data](http://data.statmt.org/rsennrich/wmt16_backtranslations/).
|
267 |
+
The best lenpen, LM weight and CH weight are obtained by sweeping over the validation set (wmt16/dev) using beam 5.
|
268 |
+
| Model Size | Lenpen | LM Weight | CH Weight | Seed 2 | Seed 4 | Seed 6 |
|
269 |
+
|----|----|----|----|----|----|----|
|
270 |
+
| `big` | 0.84 | 0.64 | 0.56 | [big.seed2.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/ro_en/channel_models/big.seed2.pt) | [big.seed2.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/ro_en/channel_models/big.seed2.pt) | [big.seed2.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/ro_en/channel_models/big.seed2.pt) |
|
271 |
+
| `base_1_1` | 0.63 | 0.40 | 0.37 | [base_1_1.seed2.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/ro_en/channel_models/base_1_1.seed2.pt) | [base_1_1.seed4.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/ro_en/channel_models/base_1_1.seed4.pt) | [base_1_1.seed6.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/ro_en/channel_models/base_1_1.seed6.pt) |
|
272 |
+
|
273 |
+
### Language Model
|
274 |
+
The model is trained on de-duplicated English Newscrawl data from 2007-2018 comprising 186 million sentences or 4.5B words after normalization and tokenization.
|
275 |
+
| | Path |
|
276 |
+
|----|----|
|
277 |
+
| `--lm-model` | [transformer_en_lm](https://dl.fbaipublicfiles.com/fast_noisy_channel/ro_en/lm_model/transformer_lm.pt) |
|
278 |
+
| `--lm-data` | [lm_data](https://dl.fbaipublicfiles.com/fast_noisy_channel/ro_en/lm_model/lm_dict)
|
279 |
+
|
280 |
+
## German-English Translation
|
281 |
+
|
282 |
+
### BPE Codes and Dictionaries
|
283 |
+
|
284 |
+
| | Path|
|
285 |
+
|----------|------|
|
286 |
+
| Source BPE Code | [de_bpe_code_24K](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/de_bpe_code_24K) |
|
287 |
+
| Target BPE Code | [en_bpe_code_24K](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/en_bpe_code_24K)
|
288 |
+
| Source Dictionary | [de_dict](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/de_dict) |
|
289 |
+
| Target Dictionary | [en_dict](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/en_dict) |
|
290 |
+
|
291 |
+
### Direct Models
|
292 |
+
We train on WMT’19 training data. Following [Ng et al., 2019](http://statmt.org/wmt19/pdf/53/WMT33.pdf), we apply language identification filtering and remove sentences longer than 250 tokens as well as sentence pairs with a source/target length ratio exceeding 1.5. This results in 26.8M sentence pairs.
|
293 |
+
We use the Transformer-Big architecture for the direct model.
|
294 |
+
|
295 |
+
| Seed | Model |
|
296 |
+
|:----:|----|
|
297 |
+
| 4 | [de_en_seed4.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/direct_models/seed4.pt)
|
298 |
+
| 5 | [de_en_seed5.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/direct_models/seed5.pt)
|
299 |
+
| 6 | [de_en_seed6.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/direct_models/seed6.pt)
|
300 |
+
|
301 |
+
### Channel Models
|
302 |
+
|
303 |
+
We train on WMT’19 training data. Following [Ng et al., 2019](http://statmt.org/wmt19/pdf/53/WMT33.pdf), we apply language identification filtering and remove sentences longer than 250 tokens as well as sentence pairs with a source/target length ratio exceeding 1.5. This results in 26.8M sentence pairs.
|
304 |
+
|
305 |
+
| Model Size | Seed 4 | Seed 5 | Seed 6 |
|
306 |
+
|----|----|----|----|
|
307 |
+
| `big` | [big.seed4.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/big.seed4.pt) | [big.seed5.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/big.seed5.pt) | [big.seed6.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/big.seed6.pt) |
|
308 |
+
| `big_1_1` | [big_1_1.seed4.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/big_1_1.seed4.pt) | [big_1_1.seed5.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/big_1_1.seed5.pt) | [big_1_1.seed6.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/big_1_1.seed6.pt) |
|
309 |
+
| `base` | [base.seed4.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/base.seed4.pt) | [base.seed5.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/base.seed5.pt) | [base.seed6.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/base.seed6.pt) |
|
310 |
+
| `base_1_1` | [base_1_1.seed4.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/base_1_1.seed4.pt) | [base_1_1.seed5.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/base_1_1.seed5.pt) | [base_1_1.seed6.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/base_1_1.seed6.pt) |
|
311 |
+
| `half` | [half.seed4.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/half.seed4.pt) | [half.seed5.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/half.seed5.pt) | [half.seed6.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/half.seed6.pt) |
|
312 |
+
| `half_1_1` | [half_1_1.seed4.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/half_1_1.seed4.pt) | [half_1_1.seed5.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/half_1_1.seed5.pt) | [half_1_1.seed6.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/half_1_1.seed6.pt) |
|
313 |
+
| `quarter` | [quarter.seed4.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/quarter.seed4.pt) | [quarter.seed5.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/quarter.seed5.pt) | [quarter.seed6.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/quarter.seed6.pt) |
|
314 |
+
| `quarter_1_1` | [quarter_1_1.seed4.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/quarter_1_1.seed4.pt) | [quarter_1_1.seed5.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/quarter_1_1.seed5.pt) | [quarter_1_1.seed6.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/quarter_1_1.seed6.pt) |
|
315 |
+
| `8th` | [8th.seed4.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/8th.seed4.pt) | [8th.seed5.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/8th.seed5.pt) | [8th.seed6.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/8th.seed6.pt) |
|
316 |
+
| `8th_1_1` | [8th_1_1.seed4.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/8th_1_1.seed4.pt) | [8th_1_1.seed5.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/8th_1_1.seed5.pt) | [8th_1_1.seed6.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/8th_1_1.seed6.pt) |
|
317 |
+
| `16th` | [16th.seed4.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/16th.seed4.pt) | [16th.seed5.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/16th.seed5.pt) | [16th.seed6.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/16th.seed6.pt) |
|
318 |
+
| `16th_1_1` | [16th_1_1.seed4.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/16th_1_1.seed4.pt) | [16th_1_1.seed5.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/16th_1_1.seed5.pt) | [16th_1_1.seed6.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/16th_1_1.seed6.pt) |
|
319 |
+
|
320 |
+
### Language Model
|
321 |
+
The model is trained on de-duplicated English Newscrawl data from 2007-2018 comprising 186 million sentences or 4.5B words after normalization and tokenization.
|
322 |
+
| | Path |
|
323 |
+
|----|----|
|
324 |
+
| `--lm-model` | [transformer_en_lm](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/lm_model/transformer_lm.pt) |
|
325 |
+
| `--lm-data` | [lm_data](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/lm_model/lm_dict/)
|
326 |
+
|
327 |
+
|
328 |
+
## Citation
|
329 |
+
|
330 |
+
```bibtex
|
331 |
+
@inproceedings{bhosale2020language,
|
332 |
+
title={Language Models not just for Pre-training: Fast Online Neural Noisy Channel Modeling},
|
333 |
+
author={Shruti Bhosale and Kyra Yee and Sergey Edunov and Michael Auli},
|
334 |
+
booktitle={Proceedings of the Fifth Conference on Machine Translation (WMT)},
|
335 |
+
year={2020},
|
336 |
+
}
|
337 |
+
|
338 |
+
@inproceedings{yee2019simple,
|
339 |
+
title={Simple and Effective Noisy Channel Modeling for Neural Machine Translation},
|
340 |
+
author={Yee, Kyra and Dauphin, Yann and Auli, Michael},
|
341 |
+
booktitle={Proceedings of the 2019 Conference on Empirical Methods in Natural Language Processing and the 9th International Joint Conference on Natural Language Processing (EMNLP-IJCNLP)},
|
342 |
+
pages={5700--5705},
|
343 |
+
year={2019}
|
344 |
+
}
|
345 |
+
```
|
fairseq/examples/fast_noisy_channel/__init__.py
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
from . import noisy_channel_translation # noqa
|
7 |
+
from . import noisy_channel_sequence_generator # noqa
|
8 |
+
from . import noisy_channel_beam_search # noqa
|
fairseq/examples/fast_noisy_channel/noisy_channel_beam_search.py
ADDED
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import torch
|
7 |
+
from fairseq.search import Search
|
8 |
+
|
9 |
+
|
10 |
+
class NoisyChannelBeamSearch(Search):
|
11 |
+
|
12 |
+
def __init__(self, tgt_dict):
|
13 |
+
super().__init__(tgt_dict)
|
14 |
+
self.fw_scores_buf = None
|
15 |
+
self.lm_scores_buf = None
|
16 |
+
|
17 |
+
def _init_buffers(self, t):
|
18 |
+
# super()._init_buffers(t)
|
19 |
+
if self.fw_scores_buf is None:
|
20 |
+
self.scores_buf = t.new()
|
21 |
+
self.indices_buf = torch.LongTensor().to(device=t.device)
|
22 |
+
self.beams_buf = torch.LongTensor().to(device=t.device)
|
23 |
+
self.fw_scores_buf = t.new()
|
24 |
+
self.lm_scores_buf = t.new()
|
25 |
+
|
26 |
+
def combine_fw_bw(self, combine_method, fw_cum, bw, step):
|
27 |
+
if combine_method == "noisy_channel":
|
28 |
+
fw_norm = fw_cum.div(step + 1)
|
29 |
+
lprobs = bw + fw_norm
|
30 |
+
elif combine_method == "lm_only":
|
31 |
+
lprobs = bw + fw_cum
|
32 |
+
|
33 |
+
return lprobs
|
34 |
+
|
35 |
+
def step(self, step, fw_lprobs, scores, bw_lprobs, lm_lprobs, combine_method):
|
36 |
+
self._init_buffers(fw_lprobs)
|
37 |
+
bsz, beam_size, vocab_size = fw_lprobs.size()
|
38 |
+
|
39 |
+
if step == 0:
|
40 |
+
# at the first step all hypotheses are equally likely, so use
|
41 |
+
# only the first beam
|
42 |
+
fw_lprobs = fw_lprobs[:, ::beam_size, :].contiguous()
|
43 |
+
bw_lprobs = bw_lprobs[:, ::beam_size, :].contiguous()
|
44 |
+
# nothing to add since we are at the first step
|
45 |
+
fw_lprobs_cum = fw_lprobs
|
46 |
+
|
47 |
+
else:
|
48 |
+
# make probs contain cumulative scores for each hypothesis
|
49 |
+
raw_scores = (scores[:, :, step - 1].unsqueeze(-1))
|
50 |
+
fw_lprobs_cum = (fw_lprobs.add(raw_scores))
|
51 |
+
|
52 |
+
combined_lprobs = self.combine_fw_bw(combine_method, fw_lprobs_cum, bw_lprobs, step)
|
53 |
+
|
54 |
+
# choose the top k according to the combined noisy channel model score
|
55 |
+
torch.topk(
|
56 |
+
combined_lprobs.view(bsz, -1),
|
57 |
+
k=min(
|
58 |
+
# Take the best 2 x beam_size predictions. We'll choose the first
|
59 |
+
# beam_size of these which don't predict eos to continue with.
|
60 |
+
beam_size * 2,
|
61 |
+
combined_lprobs.view(bsz, -1).size(1) - 1, # -1 so we never select pad
|
62 |
+
),
|
63 |
+
out=(self.scores_buf, self.indices_buf),
|
64 |
+
)
|
65 |
+
# save corresponding fw and lm scores
|
66 |
+
self.fw_scores_buf = torch.gather(fw_lprobs_cum.view(bsz, -1), 1, self.indices_buf)
|
67 |
+
self.lm_scores_buf = torch.gather(lm_lprobs.view(bsz, -1), 1, self.indices_buf)
|
68 |
+
# Project back into relative indices and beams
|
69 |
+
self.beams_buf = self.indices_buf // vocab_size
|
70 |
+
self.indices_buf.fmod_(vocab_size)
|
71 |
+
return self.scores_buf, self.fw_scores_buf, self.lm_scores_buf, self.indices_buf, self.beams_buf
|
fairseq/examples/fast_noisy_channel/noisy_channel_sequence_generator.py
ADDED
@@ -0,0 +1,842 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
from typing import Dict, List, Optional
|
7 |
+
|
8 |
+
import math
|
9 |
+
import numpy as np
|
10 |
+
|
11 |
+
import torch
|
12 |
+
import torch.nn.functional as F
|
13 |
+
from torch import Tensor
|
14 |
+
|
15 |
+
from .noisy_channel_beam_search import NoisyChannelBeamSearch
|
16 |
+
from fairseq.sequence_generator import EnsembleModel
|
17 |
+
|
18 |
+
|
19 |
+
class NoisyChannelSequenceGenerator(object):
|
20 |
+
def __init__(
|
21 |
+
self,
|
22 |
+
combine_method,
|
23 |
+
tgt_dict,
|
24 |
+
src_dict=None,
|
25 |
+
beam_size=1,
|
26 |
+
max_len_a=0,
|
27 |
+
max_len_b=200,
|
28 |
+
min_len=1,
|
29 |
+
len_penalty=1.0,
|
30 |
+
unk_penalty=0.0,
|
31 |
+
retain_dropout=False,
|
32 |
+
temperature=1.0,
|
33 |
+
match_source_len=False,
|
34 |
+
no_repeat_ngram_size=0,
|
35 |
+
normalize_scores=True,
|
36 |
+
channel_models=None,
|
37 |
+
k2=10,
|
38 |
+
ch_weight=1.0,
|
39 |
+
channel_scoring_type='log_norm',
|
40 |
+
top_k_vocab=0,
|
41 |
+
lm_models=None,
|
42 |
+
lm_dict=None,
|
43 |
+
lm_weight=1.0,
|
44 |
+
normalize_lm_scores_by_tgt_len=False,
|
45 |
+
):
|
46 |
+
"""Generates translations of a given source sentence,
|
47 |
+
using beam search with noisy channel decoding.
|
48 |
+
|
49 |
+
Args:
|
50 |
+
combine_method (string, optional): Method to combine direct, LM and
|
51 |
+
channel model scores (default: None)
|
52 |
+
tgt_dict (~fairseq.data.Dictionary): target dictionary
|
53 |
+
src_dict (~fairseq.data.Dictionary): source dictionary
|
54 |
+
beam_size (int, optional): beam width (default: 1)
|
55 |
+
max_len_a/b (int, optional): generate sequences of maximum length
|
56 |
+
ax + b, where x is the source length
|
57 |
+
min_len (int, optional): the minimum length of the generated output
|
58 |
+
(not including end-of-sentence)
|
59 |
+
len_penalty (float, optional): length penalty, where <1.0 favors
|
60 |
+
shorter, >1.0 favors longer sentences (default: 1.0)
|
61 |
+
unk_penalty (float, optional): unknown word penalty, where <0
|
62 |
+
produces more unks, >0 produces fewer (default: 0.0)
|
63 |
+
retain_dropout (bool, optional): use dropout when generating
|
64 |
+
(default: False)
|
65 |
+
temperature (float, optional): temperature, where values
|
66 |
+
>1.0 produce more uniform samples and values <1.0 produce
|
67 |
+
sharper samples (default: 1.0)
|
68 |
+
match_source_len (bool, optional): outputs should match the source
|
69 |
+
length (default: False)
|
70 |
+
no_repeat_ngram_size (int, optional): Size of n-grams that we avoid
|
71 |
+
repeating in the generation (default: 0)
|
72 |
+
normalize_scores (bool, optional): normalize scores by the length
|
73 |
+
of the output (default: True)
|
74 |
+
channel_models (List[~fairseq.models.FairseqModel]): ensemble of models
|
75 |
+
translating from the target to the source
|
76 |
+
k2 (int, optional): Top K2 candidates to score per beam at each step (default:10)
|
77 |
+
ch_weight (int, optional): Weight associated with the channel model score
|
78 |
+
assuming that the direct model score has weight 1.0 (default: 1.0)
|
79 |
+
channel_scoring_type (str, optional): String specifying how to score
|
80 |
+
the channel model (default: 'log_norm')
|
81 |
+
top_k_vocab (int, optional): If `channel_scoring_type` is `'src_vocab'` or
|
82 |
+
`'src_vocab_batched'`, then this parameter specifies the number of
|
83 |
+
most frequent tokens to include in the channel model output vocabulary,
|
84 |
+
in addition to the source tokens in the input batch (default: 0)
|
85 |
+
lm_models (List[~fairseq.models.FairseqModel]): ensemble of models
|
86 |
+
generating text in the target language
|
87 |
+
lm_dict (~fairseq.data.Dictionary): LM Model dictionary
|
88 |
+
lm_weight (int, optional): Weight associated with the LM model score
|
89 |
+
assuming that the direct model score has weight 1.0 (default: 1.0)
|
90 |
+
normalize_lm_scores_by_tgt_len (bool, optional): Should we normalize LM scores
|
91 |
+
by the target length? By default, we normalize the combination of
|
92 |
+
LM and channel model scores by the source length
|
93 |
+
"""
|
94 |
+
self.pad = tgt_dict.pad()
|
95 |
+
self.unk = tgt_dict.unk()
|
96 |
+
self.eos = tgt_dict.eos()
|
97 |
+
self.vocab_size = len(tgt_dict)
|
98 |
+
self.beam_size = beam_size
|
99 |
+
# the max beam size is the dictionary size - 1, since we never select pad
|
100 |
+
self.beam_size = min(beam_size, self.vocab_size - 1)
|
101 |
+
self.max_len_a = max_len_a
|
102 |
+
self.max_len_b = max_len_b
|
103 |
+
self.min_len = min_len
|
104 |
+
self.normalize_scores = normalize_scores
|
105 |
+
self.len_penalty = len_penalty
|
106 |
+
self.unk_penalty = unk_penalty
|
107 |
+
self.retain_dropout = retain_dropout
|
108 |
+
self.temperature = temperature
|
109 |
+
self.match_source_len = match_source_len
|
110 |
+
self.no_repeat_ngram_size = no_repeat_ngram_size
|
111 |
+
self.channel_models = channel_models
|
112 |
+
self.src_dict = src_dict
|
113 |
+
self.tgt_dict = tgt_dict
|
114 |
+
self.combine_method = combine_method
|
115 |
+
self.k2 = k2
|
116 |
+
self.ch_weight = ch_weight
|
117 |
+
self.channel_scoring_type = channel_scoring_type
|
118 |
+
self.top_k_vocab = top_k_vocab
|
119 |
+
self.lm_models = lm_models
|
120 |
+
self.lm_dict = lm_dict
|
121 |
+
self.lm_weight = lm_weight
|
122 |
+
self.log_softmax_fn = torch.nn.LogSoftmax(dim=1)
|
123 |
+
self.normalize_lm_scores_by_tgt_len = normalize_lm_scores_by_tgt_len
|
124 |
+
|
125 |
+
self.share_tgt_dict = (self.lm_dict == self.tgt_dict)
|
126 |
+
self.tgt_to_lm = make_dict2dict(tgt_dict, lm_dict)
|
127 |
+
|
128 |
+
self.ch_scoring_bsz = 3072
|
129 |
+
|
130 |
+
assert temperature > 0, '--temperature must be greater than 0'
|
131 |
+
|
132 |
+
self.search = NoisyChannelBeamSearch(tgt_dict)
|
133 |
+
|
134 |
+
@torch.no_grad()
|
135 |
+
def generate(
|
136 |
+
self,
|
137 |
+
models,
|
138 |
+
sample,
|
139 |
+
prefix_tokens=None,
|
140 |
+
bos_token=None,
|
141 |
+
**kwargs
|
142 |
+
):
|
143 |
+
"""Generate a batch of translations.
|
144 |
+
Args:
|
145 |
+
models (List[~fairseq.models.FairseqModel]): ensemble of models
|
146 |
+
sample (dict): batch
|
147 |
+
prefix_tokens (torch.LongTensor, optional): force decoder to begin
|
148 |
+
with these tokens
|
149 |
+
"""
|
150 |
+
model = EnsembleModel(models)
|
151 |
+
incremental_states = torch.jit.annotate(
|
152 |
+
List[Dict[str, Dict[str, Optional[Tensor]]]],
|
153 |
+
[
|
154 |
+
torch.jit.annotate(Dict[str, Dict[str, Optional[Tensor]]], {})
|
155 |
+
for i in range(model.models_size)
|
156 |
+
],
|
157 |
+
)
|
158 |
+
if not self.retain_dropout:
|
159 |
+
model.eval()
|
160 |
+
|
161 |
+
# model.forward normally channels prev_output_tokens into the decoder
|
162 |
+
# separately, but SequenceGenerator directly calls model.encoder
|
163 |
+
encoder_input = {
|
164 |
+
k: v for k, v in sample['net_input'].items()
|
165 |
+
if k != 'prev_output_tokens'
|
166 |
+
}
|
167 |
+
src_tokens = encoder_input['src_tokens']
|
168 |
+
src_lengths_no_eos = (src_tokens.ne(self.eos) & src_tokens.ne(self.pad)).long().sum(dim=1)
|
169 |
+
input_size = src_tokens.size()
|
170 |
+
# batch dimension goes first followed by source lengths
|
171 |
+
bsz = input_size[0]
|
172 |
+
src_len = input_size[1]
|
173 |
+
beam_size = self.beam_size
|
174 |
+
|
175 |
+
if self.match_source_len:
|
176 |
+
max_len = src_lengths_no_eos.max().item()
|
177 |
+
else:
|
178 |
+
max_len = min(
|
179 |
+
int(self.max_len_a * src_len + self.max_len_b),
|
180 |
+
# exclude the EOS marker
|
181 |
+
model.max_decoder_positions() - 1,
|
182 |
+
)
|
183 |
+
|
184 |
+
# compute the encoder output for each beam
|
185 |
+
encoder_outs = model.forward_encoder(encoder_input)
|
186 |
+
new_order = torch.arange(bsz).view(-1, 1).repeat(1, beam_size).view(-1)
|
187 |
+
new_order = new_order.to(src_tokens.device).long()
|
188 |
+
encoder_outs = model.reorder_encoder_out(encoder_outs, new_order)
|
189 |
+
|
190 |
+
src_lengths = encoder_input['src_lengths']
|
191 |
+
# initialize buffers
|
192 |
+
scores = src_tokens.new(bsz * beam_size, max_len + 1).float().fill_(0)
|
193 |
+
lm_prefix_scores = src_tokens.new(bsz * beam_size).float().fill_(0)
|
194 |
+
|
195 |
+
scores_buf = scores.clone()
|
196 |
+
tokens = src_tokens.new(bsz * beam_size, max_len + 2).long().fill_(self.pad)
|
197 |
+
tokens_buf = tokens.clone()
|
198 |
+
tokens[:, 0] = self.eos if bos_token is None else bos_token
|
199 |
+
|
200 |
+
# reorder source tokens so they may be used as a reference in generating P(S|T)
|
201 |
+
src_tokens = reorder_all_tokens(src_tokens, src_lengths, self.src_dict.eos_index)
|
202 |
+
|
203 |
+
src_tokens = src_tokens.repeat(1, beam_size).view(-1, src_len)
|
204 |
+
src_lengths = src_lengths.view(bsz, -1).repeat(1, beam_size).view(bsz*beam_size, -1)
|
205 |
+
|
206 |
+
attn, attn_buf = None, None
|
207 |
+
nonpad_idxs = None
|
208 |
+
|
209 |
+
# The cands_to_ignore indicates candidates that should be ignored.
|
210 |
+
# For example, suppose we're sampling and have already finalized 2/5
|
211 |
+
# samples. Then the cands_to_ignore would mark 2 positions as being ignored,
|
212 |
+
# so that we only finalize the remaining 3 samples.
|
213 |
+
cands_to_ignore = src_tokens.new_zeros(bsz, beam_size).eq(-1) # forward and backward-compatible False mask
|
214 |
+
|
215 |
+
# list of completed sentences
|
216 |
+
finalized = [[] for i in range(bsz)]
|
217 |
+
finished = [False for i in range(bsz)]
|
218 |
+
num_remaining_sent = bsz
|
219 |
+
|
220 |
+
# number of candidate hypos per step
|
221 |
+
cand_size = 2 * beam_size # 2 x beam size in case half are EOS
|
222 |
+
|
223 |
+
# offset arrays for converting between different indexing schemes
|
224 |
+
bbsz_offsets = (torch.arange(0, bsz) * beam_size).unsqueeze(1).type_as(tokens)
|
225 |
+
cand_offsets = torch.arange(0, cand_size).type_as(tokens)
|
226 |
+
|
227 |
+
# helper function for allocating buffers on the fly
|
228 |
+
buffers = {}
|
229 |
+
|
230 |
+
def buffer(name, type_of=tokens): # noqa
|
231 |
+
if name not in buffers:
|
232 |
+
buffers[name] = type_of.new()
|
233 |
+
return buffers[name]
|
234 |
+
|
235 |
+
def is_finished(sent, step, unfin_idx):
|
236 |
+
"""
|
237 |
+
Check whether we've finished generation for a given sentence, by
|
238 |
+
comparing the worst score among finalized hypotheses to the best
|
239 |
+
possible score among unfinalized hypotheses.
|
240 |
+
"""
|
241 |
+
assert len(finalized[sent]) <= beam_size
|
242 |
+
if len(finalized[sent]) == beam_size:
|
243 |
+
return True
|
244 |
+
return False
|
245 |
+
|
246 |
+
def finalize_hypos(step, bbsz_idx, eos_scores, combined_noisy_channel_eos_scores):
|
247 |
+
"""
|
248 |
+
Finalize the given hypotheses at this step, while keeping the total
|
249 |
+
number of finalized hypotheses per sentence <= beam_size.
|
250 |
+
|
251 |
+
Note: the input must be in the desired finalization order, so that
|
252 |
+
hypotheses that appear earlier in the input are preferred to those
|
253 |
+
that appear later.
|
254 |
+
|
255 |
+
Args:
|
256 |
+
step: current time step
|
257 |
+
bbsz_idx: A vector of indices in the range [0, bsz*beam_size),
|
258 |
+
indicating which hypotheses to finalize
|
259 |
+
eos_scores: A vector of the same size as bbsz_idx containing
|
260 |
+
fw scores for each hypothesis
|
261 |
+
combined_noisy_channel_eos_scores: A vector of the same size as bbsz_idx containing
|
262 |
+
combined noisy channel scores for each hypothesis
|
263 |
+
"""
|
264 |
+
assert bbsz_idx.numel() == eos_scores.numel()
|
265 |
+
|
266 |
+
# clone relevant token and attention tensors
|
267 |
+
tokens_clone = tokens.index_select(0, bbsz_idx)
|
268 |
+
tokens_clone = tokens_clone[:, 1:step + 2] # skip the first index, which is EOS
|
269 |
+
assert not tokens_clone.eq(self.eos).any()
|
270 |
+
tokens_clone[:, step] = self.eos
|
271 |
+
attn_clone = attn.index_select(0, bbsz_idx)[:, :, 1:step+2] if attn is not None else None
|
272 |
+
|
273 |
+
# compute scores per token position
|
274 |
+
pos_scores = scores.index_select(0, bbsz_idx)[:, :step+1]
|
275 |
+
pos_scores[:, step] = eos_scores
|
276 |
+
# convert from cumulative to per-position scores
|
277 |
+
pos_scores[:, 1:] = pos_scores[:, 1:] - pos_scores[:, :-1]
|
278 |
+
|
279 |
+
# normalize sentence-level scores
|
280 |
+
if self.normalize_scores:
|
281 |
+
combined_noisy_channel_eos_scores /= (step + 1) ** self.len_penalty
|
282 |
+
|
283 |
+
cum_unfin = []
|
284 |
+
prev = 0
|
285 |
+
for f in finished:
|
286 |
+
if f:
|
287 |
+
prev += 1
|
288 |
+
else:
|
289 |
+
cum_unfin.append(prev)
|
290 |
+
|
291 |
+
sents_seen = set()
|
292 |
+
for i, (idx, score) in enumerate(zip(bbsz_idx.tolist(), combined_noisy_channel_eos_scores.tolist())):
|
293 |
+
unfin_idx = idx // beam_size
|
294 |
+
sent = unfin_idx + cum_unfin[unfin_idx]
|
295 |
+
|
296 |
+
sents_seen.add((sent, unfin_idx))
|
297 |
+
|
298 |
+
if self.match_source_len and step > src_lengths_no_eos[unfin_idx]:
|
299 |
+
score = -math.inf
|
300 |
+
|
301 |
+
def get_hypo():
|
302 |
+
|
303 |
+
if attn_clone is not None:
|
304 |
+
# remove padding tokens from attn scores
|
305 |
+
hypo_attn = attn_clone[i][nonpad_idxs[sent]]
|
306 |
+
_, alignment = hypo_attn.max(dim=0)
|
307 |
+
else:
|
308 |
+
hypo_attn = None
|
309 |
+
alignment = None
|
310 |
+
|
311 |
+
return {
|
312 |
+
'tokens': tokens_clone[i],
|
313 |
+
'score': score,
|
314 |
+
'attention': hypo_attn, # src_len x tgt_len
|
315 |
+
'alignment': alignment,
|
316 |
+
'positional_scores': pos_scores[i],
|
317 |
+
}
|
318 |
+
|
319 |
+
if len(finalized[sent]) < beam_size:
|
320 |
+
finalized[sent].append(get_hypo())
|
321 |
+
|
322 |
+
newly_finished = []
|
323 |
+
for sent, unfin_idx in sents_seen:
|
324 |
+
# check termination conditions for this sentence
|
325 |
+
if not finished[sent] and is_finished(sent, step, unfin_idx):
|
326 |
+
finished[sent] = True
|
327 |
+
newly_finished.append(unfin_idx)
|
328 |
+
return newly_finished
|
329 |
+
|
330 |
+
def noisy_channel_rescoring(lprobs, beam_size, bsz, src_tokens, tokens, k):
|
331 |
+
"""Rescore the top k hypothesis from each beam using noisy channel modeling
|
332 |
+
Returns:
|
333 |
+
new_fw_lprobs: the direct model probabilities after pruning the top k
|
334 |
+
new_ch_lm_lprobs: the combined channel and language model probabilities
|
335 |
+
new_lm_lprobs: the language model probabilities after pruning the top k
|
336 |
+
"""
|
337 |
+
with torch.no_grad():
|
338 |
+
lprobs_size = lprobs.size()
|
339 |
+
if prefix_tokens is not None and step < prefix_tokens.size(1):
|
340 |
+
probs_slice = lprobs.view(bsz, -1, lprobs.size(-1))[:, 0, :]
|
341 |
+
cand_scores = torch.gather(
|
342 |
+
probs_slice, dim=1,
|
343 |
+
index=prefix_tokens[:, step].view(-1, 1).data
|
344 |
+
).expand(-1, beam_size).contiguous().view(bsz*beam_size, 1)
|
345 |
+
cand_indices = prefix_tokens[:, step].view(-1, 1).expand(bsz, beam_size).data.contiguous().view(bsz*beam_size, 1)
|
346 |
+
|
347 |
+
# need to calculate and save fw and lm probs for prefix tokens
|
348 |
+
fw_top_k = cand_scores
|
349 |
+
fw_top_k_idx = cand_indices
|
350 |
+
k = 1
|
351 |
+
else:
|
352 |
+
# take the top k best words for every sentence in batch*beam
|
353 |
+
fw_top_k, fw_top_k_idx = torch.topk(lprobs.view(beam_size*bsz, -1), k=k)
|
354 |
+
eos_idx = torch.nonzero(fw_top_k_idx.view(bsz*beam_size*k, -1) == self.eos)[:, 0]
|
355 |
+
ch_scores = fw_top_k.new_full((beam_size*bsz*k, ), 0)
|
356 |
+
src_size = torch.sum(src_tokens[:, :] != self.src_dict.pad_index, dim=1, keepdim=True, dtype=fw_top_k.dtype)
|
357 |
+
|
358 |
+
if self.combine_method != "lm_only":
|
359 |
+
temp_src_tokens_full = src_tokens[:, :].repeat(1, k).view(bsz*beam_size*k, -1)
|
360 |
+
not_padding = temp_src_tokens_full[:, 1:] != self.src_dict.pad_index
|
361 |
+
cur_tgt_size = step+2
|
362 |
+
|
363 |
+
# add eos to all candidate sentences except those that already end in eos
|
364 |
+
eos_tokens = tokens[:, 0].repeat(1, k).view(-1, 1)
|
365 |
+
eos_tokens[eos_idx] = self.tgt_dict.pad_index
|
366 |
+
|
367 |
+
if step == 0:
|
368 |
+
channel_input = torch.cat((fw_top_k_idx.view(-1, 1), eos_tokens), 1)
|
369 |
+
else:
|
370 |
+
# move eos from beginning to end of target sentence
|
371 |
+
channel_input = torch.cat((tokens[:, 1:step + 1].repeat(1, k).view(-1, step), fw_top_k_idx.view(-1, 1), eos_tokens), 1)
|
372 |
+
|
373 |
+
ch_input_lengths = torch.tensor(np.full(channel_input.size(0), cur_tgt_size))
|
374 |
+
ch_input_lengths[eos_idx] = cur_tgt_size-1
|
375 |
+
if self.channel_scoring_type == "unnormalized":
|
376 |
+
ch_encoder_output = channel_model.encoder(channel_input, src_lengths=ch_input_lengths)
|
377 |
+
ch_decoder_output, _ = channel_model.decoder(temp_src_tokens_full, encoder_out=ch_encoder_output, features_only=True)
|
378 |
+
del ch_encoder_output
|
379 |
+
ch_intermed_scores = channel_model.decoder.unnormalized_scores_given_target(ch_decoder_output, target_ids=temp_src_tokens_full[:, 1:])
|
380 |
+
ch_intermed_scores = ch_intermed_scores.float()
|
381 |
+
ch_intermed_scores *= not_padding.float()
|
382 |
+
ch_scores = torch.sum(ch_intermed_scores, dim=1)
|
383 |
+
elif self.channel_scoring_type == "k2_separate":
|
384 |
+
for k_idx in range(k):
|
385 |
+
k_eos_tokens = eos_tokens[k_idx::k, :]
|
386 |
+
if step == 0:
|
387 |
+
k_ch_input = torch.cat((fw_top_k_idx[:, k_idx:k_idx+1], k_eos_tokens), 1)
|
388 |
+
else:
|
389 |
+
# move eos from beginning to end of target sentence
|
390 |
+
k_ch_input = torch.cat((tokens[:, 1:step + 1], fw_top_k_idx[:, k_idx:k_idx+1], k_eos_tokens), 1)
|
391 |
+
k_ch_input_lengths = ch_input_lengths[k_idx::k]
|
392 |
+
k_ch_output = channel_model(k_ch_input, k_ch_input_lengths, src_tokens)
|
393 |
+
k_ch_lprobs = channel_model.get_normalized_probs(k_ch_output, log_probs=True)
|
394 |
+
k_ch_intermed_scores = torch.gather(k_ch_lprobs[:, :-1, :], 2, src_tokens[:, 1:].unsqueeze(2)).squeeze(2)
|
395 |
+
k_ch_intermed_scores *= not_padding.float()
|
396 |
+
ch_scores[k_idx::k] = torch.sum(k_ch_intermed_scores, dim=1)
|
397 |
+
elif self.channel_scoring_type == "src_vocab":
|
398 |
+
ch_encoder_output = channel_model.encoder(channel_input, src_lengths=ch_input_lengths)
|
399 |
+
ch_decoder_output, _ = channel_model.decoder(temp_src_tokens_full, encoder_out=ch_encoder_output, features_only=True)
|
400 |
+
|
401 |
+
del ch_encoder_output
|
402 |
+
ch_lprobs = normalized_scores_with_batch_vocab(
|
403 |
+
channel_model.decoder,
|
404 |
+
ch_decoder_output, src_tokens, k, bsz, beam_size,
|
405 |
+
self.src_dict.pad_index, top_k=self.top_k_vocab)
|
406 |
+
ch_scores = torch.sum(ch_lprobs, dim=1)
|
407 |
+
elif self.channel_scoring_type == "src_vocab_batched":
|
408 |
+
ch_bsz_size = temp_src_tokens_full.shape[0]
|
409 |
+
ch_lprobs_list = [None] * len(range(0, ch_bsz_size, self.ch_scoring_bsz))
|
410 |
+
for i, start_idx in enumerate(range(0, ch_bsz_size, self.ch_scoring_bsz)):
|
411 |
+
end_idx = min(start_idx + self.ch_scoring_bsz, ch_bsz_size)
|
412 |
+
temp_src_tokens_full_batch = temp_src_tokens_full[start_idx:end_idx, :]
|
413 |
+
channel_input_batch = channel_input[start_idx:end_idx, :]
|
414 |
+
ch_input_lengths_batch = ch_input_lengths[start_idx:end_idx]
|
415 |
+
ch_encoder_output_batch = channel_model.encoder(channel_input_batch, src_lengths=ch_input_lengths_batch)
|
416 |
+
ch_decoder_output_batch, _ = channel_model.decoder(temp_src_tokens_full_batch, encoder_out=ch_encoder_output_batch, features_only=True)
|
417 |
+
ch_lprobs_list[i] = normalized_scores_with_batch_vocab(
|
418 |
+
channel_model.decoder,
|
419 |
+
ch_decoder_output_batch, src_tokens, k, bsz, beam_size,
|
420 |
+
self.src_dict.pad_index, top_k=self.top_k_vocab,
|
421 |
+
start_idx=start_idx, end_idx=end_idx)
|
422 |
+
ch_lprobs = torch.cat(ch_lprobs_list, dim=0)
|
423 |
+
ch_scores = torch.sum(ch_lprobs, dim=1)
|
424 |
+
else:
|
425 |
+
ch_output = channel_model(channel_input, ch_input_lengths, temp_src_tokens_full)
|
426 |
+
ch_lprobs = channel_model.get_normalized_probs(ch_output, log_probs=True)
|
427 |
+
ch_intermed_scores = torch.gather(ch_lprobs[:, :-1, :], 2, temp_src_tokens_full[:, 1:].unsqueeze(2)).squeeze().view(bsz*beam_size*k, -1)
|
428 |
+
ch_intermed_scores *= not_padding.float()
|
429 |
+
ch_scores = torch.sum(ch_intermed_scores, dim=1)
|
430 |
+
|
431 |
+
else:
|
432 |
+
cur_tgt_size = 0
|
433 |
+
ch_scores = ch_scores.view(bsz*beam_size, k)
|
434 |
+
expanded_lm_prefix_scores = lm_prefix_scores.unsqueeze(1).expand(-1, k).flatten()
|
435 |
+
|
436 |
+
if self.share_tgt_dict:
|
437 |
+
lm_scores = get_lm_scores(lm, tokens[:, :step + 1].view(-1, step+1), lm_incremental_states, fw_top_k_idx.view(-1, 1), torch.tensor(np.full(tokens.size(0), step+1)), k)
|
438 |
+
else:
|
439 |
+
new_lm_input = dict2dict(tokens[:, :step + 1].view(-1, step+1), self.tgt_to_lm)
|
440 |
+
new_cands = dict2dict(fw_top_k_idx.view(-1, 1), self.tgt_to_lm)
|
441 |
+
lm_scores = get_lm_scores(lm, new_lm_input, lm_incremental_states, new_cands, torch.tensor(np.full(tokens.size(0), step+1)), k)
|
442 |
+
|
443 |
+
lm_scores.add_(expanded_lm_prefix_scores)
|
444 |
+
ch_lm_scores = combine_ch_lm(self.combine_method, ch_scores, lm_scores, src_size, cur_tgt_size)
|
445 |
+
# initialize all as min value
|
446 |
+
new_fw_lprobs = ch_scores.new(lprobs_size).fill_(-1e17).view(bsz*beam_size, -1)
|
447 |
+
new_ch_lm_lprobs = ch_scores.new(lprobs_size).fill_(-1e17).view(bsz*beam_size, -1)
|
448 |
+
new_lm_lprobs = ch_scores.new(lprobs_size).fill_(-1e17).view(bsz*beam_size, -1)
|
449 |
+
new_fw_lprobs[:, self.pad] = -math.inf
|
450 |
+
new_ch_lm_lprobs[:, self.pad] = -math.inf
|
451 |
+
new_lm_lprobs[:, self.pad] = -math.inf
|
452 |
+
|
453 |
+
new_fw_lprobs.scatter_(1, fw_top_k_idx, fw_top_k)
|
454 |
+
new_ch_lm_lprobs.scatter_(1, fw_top_k_idx, ch_lm_scores)
|
455 |
+
new_lm_lprobs.scatter_(1, fw_top_k_idx, lm_scores.view(-1, k))
|
456 |
+
return new_fw_lprobs, new_ch_lm_lprobs, new_lm_lprobs
|
457 |
+
|
458 |
+
def combine_ch_lm(combine_type, ch_scores, lm_scores1, src_size, tgt_size):
|
459 |
+
if self.channel_scoring_type == "unnormalized":
|
460 |
+
ch_scores = self.log_softmax_fn(
|
461 |
+
ch_scores.view(-1, self.beam_size * self.k2)
|
462 |
+
).view(ch_scores.shape)
|
463 |
+
ch_scores = ch_scores * self.ch_weight
|
464 |
+
lm_scores1 = lm_scores1 * self.lm_weight
|
465 |
+
|
466 |
+
if combine_type == "lm_only":
|
467 |
+
# log P(T|S) + log P(T)
|
468 |
+
ch_scores = lm_scores1.view(ch_scores.size())
|
469 |
+
elif combine_type == "noisy_channel":
|
470 |
+
# 1/t log P(T|S) + 1/s log P(S|T) + 1/t log P(T)
|
471 |
+
if self.normalize_lm_scores_by_tgt_len:
|
472 |
+
ch_scores.div_(src_size)
|
473 |
+
lm_scores_norm = lm_scores1.view(ch_scores.size()).div(tgt_size)
|
474 |
+
ch_scores.add_(lm_scores_norm)
|
475 |
+
# 1/t log P(T|S) + 1/s log P(S|T) + 1/s log P(T)
|
476 |
+
else:
|
477 |
+
ch_scores.add_(lm_scores1.view(ch_scores.size()))
|
478 |
+
ch_scores.div_(src_size)
|
479 |
+
|
480 |
+
return ch_scores
|
481 |
+
|
482 |
+
if self.channel_models is not None:
|
483 |
+
channel_model = self.channel_models[0] # assume only one channel_model model
|
484 |
+
else:
|
485 |
+
channel_model = None
|
486 |
+
|
487 |
+
lm = EnsembleModel(self.lm_models)
|
488 |
+
lm_incremental_states = torch.jit.annotate(
|
489 |
+
List[Dict[str, Dict[str, Optional[Tensor]]]],
|
490 |
+
[
|
491 |
+
torch.jit.annotate(Dict[str, Dict[str, Optional[Tensor]]], {})
|
492 |
+
for i in range(lm.models_size)
|
493 |
+
],
|
494 |
+
)
|
495 |
+
|
496 |
+
reorder_state = None
|
497 |
+
batch_idxs = None
|
498 |
+
for step in range(max_len + 1): # one extra step for EOS marker
|
499 |
+
# reorder decoder internal states based on the prev choice of beams
|
500 |
+
if reorder_state is not None:
|
501 |
+
if batch_idxs is not None:
|
502 |
+
# update beam indices to take into account removed sentences
|
503 |
+
corr = batch_idxs - torch.arange(batch_idxs.numel()).type_as(batch_idxs)
|
504 |
+
reorder_state.view(-1, beam_size).add_(corr.unsqueeze(-1) * beam_size)
|
505 |
+
model.reorder_incremental_state(incremental_states, reorder_state)
|
506 |
+
encoder_outs = model.reorder_encoder_out(encoder_outs, reorder_state)
|
507 |
+
|
508 |
+
lm.reorder_incremental_state(lm_incremental_states, reorder_state)
|
509 |
+
|
510 |
+
fw_lprobs, avg_attn_scores = model.forward_decoder(
|
511 |
+
tokens[:, :step + 1], encoder_outs, incremental_states, temperature=self.temperature,
|
512 |
+
)
|
513 |
+
|
514 |
+
fw_lprobs[:, self.pad] = -math.inf # never select pad
|
515 |
+
fw_lprobs[:, self.unk] -= self.unk_penalty # apply unk penalty
|
516 |
+
fw_lprobs, ch_lm_lprobs, lm_lprobs = noisy_channel_rescoring(fw_lprobs, beam_size, bsz, src_tokens, tokens, self.k2)
|
517 |
+
|
518 |
+
# handle min and max length constraints
|
519 |
+
if step >= max_len:
|
520 |
+
fw_lprobs[:, :self.eos] = -math.inf
|
521 |
+
fw_lprobs[:, self.eos + 1:] = -math.inf
|
522 |
+
elif step < self.min_len:
|
523 |
+
fw_lprobs[:, self.eos] = -math.inf
|
524 |
+
|
525 |
+
# handle prefix tokens (possibly with different lengths)
|
526 |
+
if prefix_tokens is not None and step < prefix_tokens.size(1):
|
527 |
+
prefix_toks = prefix_tokens[:, step].unsqueeze(-1).repeat(1, beam_size).view(-1)
|
528 |
+
prefix_mask = prefix_toks.ne(self.pad)
|
529 |
+
|
530 |
+
prefix_fw_lprobs = fw_lprobs.gather(-1, prefix_toks.unsqueeze(-1))
|
531 |
+
fw_lprobs[prefix_mask] = -math.inf
|
532 |
+
fw_lprobs[prefix_mask] = fw_lprobs[prefix_mask].scatter_(
|
533 |
+
-1, prefix_toks[prefix_mask].unsqueeze(-1), prefix_fw_lprobs
|
534 |
+
)
|
535 |
+
|
536 |
+
prefix_ch_lm_lprobs = ch_lm_lprobs.gather(-1, prefix_toks.unsqueeze(-1))
|
537 |
+
ch_lm_lprobs[prefix_mask] = -math.inf
|
538 |
+
ch_lm_lprobs[prefix_mask] = ch_lm_lprobs[prefix_mask].scatter_(
|
539 |
+
-1, prefix_toks[prefix_mask].unsqueeze(-1), prefix_ch_lm_lprobs
|
540 |
+
)
|
541 |
+
|
542 |
+
prefix_lm_lprobs = lm_lprobs.gather(-1, prefix_toks.unsqueeze(-1))
|
543 |
+
lm_lprobs[prefix_mask] = -math.inf
|
544 |
+
lm_lprobs[prefix_mask] = lm_lprobs[prefix_mask].scatter_(
|
545 |
+
-1, prefix_toks[prefix_mask].unsqueeze(-1), prefix_lm_lprobs
|
546 |
+
)
|
547 |
+
|
548 |
+
# if prefix includes eos, then we should make sure tokens and
|
549 |
+
# scores are the same across all beams
|
550 |
+
eos_mask = prefix_toks.eq(self.eos)
|
551 |
+
if eos_mask.any():
|
552 |
+
# validate that the first beam matches the prefix
|
553 |
+
first_beam = tokens[eos_mask].view(-1, beam_size, tokens.size(-1))[:, 0, 1:step + 1]
|
554 |
+
eos_mask_batch_dim = eos_mask.view(-1, beam_size)[:, 0]
|
555 |
+
target_prefix = prefix_tokens[eos_mask_batch_dim][:, :step]
|
556 |
+
assert (first_beam == target_prefix).all()
|
557 |
+
|
558 |
+
def replicate_first_beam(tensor, mask):
|
559 |
+
tensor = tensor.view(-1, beam_size, tensor.size(-1))
|
560 |
+
tensor[mask] = tensor[mask][:, :1, :]
|
561 |
+
return tensor.view(-1, tensor.size(-1))
|
562 |
+
|
563 |
+
# copy tokens, scores and lprobs from the first beam to all beams
|
564 |
+
tokens = replicate_first_beam(tokens, eos_mask_batch_dim)
|
565 |
+
scores = replicate_first_beam(scores, eos_mask_batch_dim)
|
566 |
+
|
567 |
+
fw_lprobs = replicate_first_beam(fw_lprobs, eos_mask_batch_dim)
|
568 |
+
ch_lm_lprobs = replicate_first_beam(ch_lm_lprobs, eos_mask_batch_dim)
|
569 |
+
lm_lprobs = replicate_first_beam(lm_lprobs, eos_mask_batch_dim)
|
570 |
+
|
571 |
+
if self.no_repeat_ngram_size > 0:
|
572 |
+
# for each beam and batch sentence, generate a list of previous ngrams
|
573 |
+
gen_ngrams = [{} for bbsz_idx in range(bsz * beam_size)]
|
574 |
+
for bbsz_idx in range(bsz * beam_size):
|
575 |
+
gen_tokens = tokens[bbsz_idx].tolist()
|
576 |
+
for ngram in zip(*[gen_tokens[i:] for i in range(self.no_repeat_ngram_size)]):
|
577 |
+
gen_ngrams[bbsz_idx][tuple(ngram[:-1])] = \
|
578 |
+
gen_ngrams[bbsz_idx].get(tuple(ngram[:-1]), []) + [ngram[-1]]
|
579 |
+
|
580 |
+
# Record attention scores
|
581 |
+
if avg_attn_scores is not None:
|
582 |
+
if attn is None:
|
583 |
+
attn = scores.new(bsz * beam_size, src_tokens.size(1), max_len + 2)
|
584 |
+
attn_buf = attn.clone()
|
585 |
+
nonpad_idxs = src_tokens.ne(self.pad)
|
586 |
+
attn[:, :, step + 1].copy_(avg_attn_scores)
|
587 |
+
|
588 |
+
scores = scores.type_as(fw_lprobs)
|
589 |
+
scores_buf = scores_buf.type_as(fw_lprobs)
|
590 |
+
|
591 |
+
self.search.set_src_lengths(src_lengths_no_eos)
|
592 |
+
|
593 |
+
if self.no_repeat_ngram_size > 0:
|
594 |
+
def calculate_banned_tokens(bbsz_idx):
|
595 |
+
# before decoding the next token, prevent decoding of ngrams that have already appeared
|
596 |
+
ngram_index = tuple(tokens[bbsz_idx, step + 2 - self.no_repeat_ngram_size:step + 1].tolist())
|
597 |
+
return gen_ngrams[bbsz_idx].get(ngram_index, [])
|
598 |
+
|
599 |
+
if step + 2 - self.no_repeat_ngram_size >= 0:
|
600 |
+
# no banned tokens if we haven't generated no_repeat_ngram_size tokens yet
|
601 |
+
banned_tokens = [calculate_banned_tokens(bbsz_idx) for bbsz_idx in range(bsz * beam_size)]
|
602 |
+
else:
|
603 |
+
banned_tokens = [[] for bbsz_idx in range(bsz * beam_size)]
|
604 |
+
|
605 |
+
for bbsz_idx in range(bsz * beam_size):
|
606 |
+
fw_lprobs[bbsz_idx, banned_tokens[bbsz_idx]] = -math.inf
|
607 |
+
|
608 |
+
combined_noisy_channel_scores, fw_lprobs_top_k, lm_lprobs_top_k, cand_indices, cand_beams = self.search.step(
|
609 |
+
step,
|
610 |
+
fw_lprobs.view(bsz, -1, self.vocab_size),
|
611 |
+
scores.view(bsz, beam_size, -1)[:, :, :step], ch_lm_lprobs.view(bsz, -1, self.vocab_size),
|
612 |
+
lm_lprobs.view(bsz, -1, self.vocab_size), self.combine_method
|
613 |
+
)
|
614 |
+
|
615 |
+
# cand_bbsz_idx contains beam indices for the top candidate
|
616 |
+
# hypotheses, with a range of values: [0, bsz*beam_size),
|
617 |
+
# and dimensions: [bsz, cand_size]
|
618 |
+
cand_bbsz_idx = cand_beams.add(bbsz_offsets)
|
619 |
+
|
620 |
+
# finalize hypotheses that end in eos (except for candidates to be ignored)
|
621 |
+
eos_mask = cand_indices.eq(self.eos)
|
622 |
+
eos_mask[:, :beam_size] &= ~cands_to_ignore
|
623 |
+
|
624 |
+
# only consider eos when it's among the top beam_size indices
|
625 |
+
eos_bbsz_idx = torch.masked_select(
|
626 |
+
cand_bbsz_idx[:, :beam_size], mask=eos_mask[:, :beam_size]
|
627 |
+
)
|
628 |
+
|
629 |
+
finalized_sents = set()
|
630 |
+
if eos_bbsz_idx.numel() > 0:
|
631 |
+
eos_scores = torch.masked_select(
|
632 |
+
fw_lprobs_top_k[:, :beam_size], mask=eos_mask[:, :beam_size]
|
633 |
+
)
|
634 |
+
combined_noisy_channel_eos_scores = torch.masked_select(
|
635 |
+
combined_noisy_channel_scores[:, :beam_size],
|
636 |
+
mask=eos_mask[:, :beam_size],
|
637 |
+
)
|
638 |
+
|
639 |
+
# finalize hypo using channel model score
|
640 |
+
finalized_sents = finalize_hypos(
|
641 |
+
step, eos_bbsz_idx, eos_scores, combined_noisy_channel_eos_scores)
|
642 |
+
|
643 |
+
num_remaining_sent -= len(finalized_sents)
|
644 |
+
|
645 |
+
assert num_remaining_sent >= 0
|
646 |
+
if num_remaining_sent == 0:
|
647 |
+
break
|
648 |
+
|
649 |
+
if len(finalized_sents) > 0:
|
650 |
+
new_bsz = bsz - len(finalized_sents)
|
651 |
+
|
652 |
+
# construct batch_idxs which holds indices of batches to keep for the next pass
|
653 |
+
batch_mask = cand_indices.new_ones(bsz)
|
654 |
+
batch_mask[cand_indices.new(finalized_sents)] = 0
|
655 |
+
batch_idxs = torch.nonzero(batch_mask).squeeze(-1)
|
656 |
+
|
657 |
+
eos_mask = eos_mask[batch_idxs]
|
658 |
+
cand_beams = cand_beams[batch_idxs]
|
659 |
+
bbsz_offsets.resize_(new_bsz, 1)
|
660 |
+
cand_bbsz_idx = cand_beams.add(bbsz_offsets)
|
661 |
+
|
662 |
+
lm_lprobs_top_k = lm_lprobs_top_k[batch_idxs]
|
663 |
+
|
664 |
+
fw_lprobs_top_k = fw_lprobs_top_k[batch_idxs]
|
665 |
+
cand_indices = cand_indices[batch_idxs]
|
666 |
+
if prefix_tokens is not None:
|
667 |
+
prefix_tokens = prefix_tokens[batch_idxs]
|
668 |
+
src_lengths_no_eos = src_lengths_no_eos[batch_idxs]
|
669 |
+
cands_to_ignore = cands_to_ignore[batch_idxs]
|
670 |
+
|
671 |
+
scores = scores.view(bsz, -1)[batch_idxs].view(new_bsz * beam_size, -1)
|
672 |
+
scores_buf.resize_as_(scores)
|
673 |
+
tokens = tokens.view(bsz, -1)[batch_idxs].view(new_bsz * beam_size, -1)
|
674 |
+
tokens_buf.resize_as_(tokens)
|
675 |
+
src_tokens = src_tokens.view(bsz, -1)[batch_idxs].view(new_bsz * beam_size, -1)
|
676 |
+
src_lengths = src_lengths.view(bsz, -1)[batch_idxs].view(new_bsz * beam_size, -1)
|
677 |
+
lm_prefix_scores = lm_prefix_scores.view(bsz, -1)[batch_idxs].view(new_bsz * beam_size, -1).squeeze()
|
678 |
+
|
679 |
+
if attn is not None:
|
680 |
+
attn = attn.view(bsz, -1)[batch_idxs].view(new_bsz * beam_size, attn.size(1), -1)
|
681 |
+
attn_buf.resize_as_(attn)
|
682 |
+
bsz = new_bsz
|
683 |
+
else:
|
684 |
+
batch_idxs = None
|
685 |
+
|
686 |
+
# Set active_mask so that values > cand_size indicate eos or
|
687 |
+
# ignored hypos and values < cand_size indicate candidate
|
688 |
+
# active hypos. After this, the min values per row are the top
|
689 |
+
# candidate active hypos.
|
690 |
+
eos_mask[:, :beam_size] |= cands_to_ignore
|
691 |
+
active_mask = torch.add(
|
692 |
+
eos_mask.type_as(cand_offsets) * cand_size,
|
693 |
+
cand_offsets[: eos_mask.size(1)],
|
694 |
+
)
|
695 |
+
|
696 |
+
# get the top beam_size active hypotheses, which are just the hypos
|
697 |
+
# with the smallest values in active_mask
|
698 |
+
active_hypos, new_cands_to_ignore = buffer('active_hypos'), buffer('new_cands_to_ignore')
|
699 |
+
torch.topk(
|
700 |
+
active_mask, k=beam_size, dim=1, largest=False,
|
701 |
+
out=(new_cands_to_ignore, active_hypos)
|
702 |
+
)
|
703 |
+
|
704 |
+
# update cands_to_ignore to ignore any finalized hypos
|
705 |
+
cands_to_ignore = new_cands_to_ignore.ge(cand_size)[:, :beam_size]
|
706 |
+
assert (~cands_to_ignore).any(dim=1).all()
|
707 |
+
|
708 |
+
active_bbsz_idx = buffer('active_bbsz_idx')
|
709 |
+
torch.gather(
|
710 |
+
cand_bbsz_idx, dim=1, index=active_hypos,
|
711 |
+
out=active_bbsz_idx,
|
712 |
+
)
|
713 |
+
active_scores = torch.gather(
|
714 |
+
fw_lprobs_top_k, dim=1, index=active_hypos,
|
715 |
+
out=scores[:, step].view(bsz, beam_size),
|
716 |
+
)
|
717 |
+
|
718 |
+
active_bbsz_idx = active_bbsz_idx.view(-1)
|
719 |
+
active_scores = active_scores.view(-1)
|
720 |
+
|
721 |
+
# copy tokens and scores for active hypotheses
|
722 |
+
torch.index_select(
|
723 |
+
tokens[:, :step + 1], dim=0, index=active_bbsz_idx,
|
724 |
+
out=tokens_buf[:, :step + 1],
|
725 |
+
)
|
726 |
+
torch.gather(
|
727 |
+
cand_indices, dim=1, index=active_hypos,
|
728 |
+
out=tokens_buf.view(bsz, beam_size, -1)[:, :, step + 1],
|
729 |
+
)
|
730 |
+
if step > 0:
|
731 |
+
torch.index_select(
|
732 |
+
scores[:, :step], dim=0, index=active_bbsz_idx,
|
733 |
+
out=scores_buf[:, :step],
|
734 |
+
)
|
735 |
+
torch.gather(
|
736 |
+
fw_lprobs_top_k, dim=1, index=active_hypos,
|
737 |
+
out=scores_buf.view(bsz, beam_size, -1)[:, :, step],
|
738 |
+
)
|
739 |
+
torch.gather(
|
740 |
+
lm_lprobs_top_k, dim=1, index=active_hypos,
|
741 |
+
out=lm_prefix_scores.view(bsz, beam_size)
|
742 |
+
)
|
743 |
+
|
744 |
+
# copy attention for active hypotheses
|
745 |
+
if attn is not None:
|
746 |
+
torch.index_select(
|
747 |
+
attn[:, :, :step + 2], dim=0, index=active_bbsz_idx,
|
748 |
+
out=attn_buf[:, :, :step + 2],
|
749 |
+
)
|
750 |
+
|
751 |
+
# swap buffers
|
752 |
+
tokens, tokens_buf = tokens_buf, tokens
|
753 |
+
scores, scores_buf = scores_buf, scores
|
754 |
+
if attn is not None:
|
755 |
+
attn, attn_buf = attn_buf, attn
|
756 |
+
|
757 |
+
# reorder incremental state in decoder
|
758 |
+
reorder_state = active_bbsz_idx
|
759 |
+
|
760 |
+
# sort by score descending
|
761 |
+
for sent in range(len(finalized)):
|
762 |
+
finalized[sent] = sorted(finalized[sent], key=lambda r: r['score'], reverse=True)
|
763 |
+
|
764 |
+
return finalized
|
765 |
+
|
766 |
+
|
767 |
+
def get_lm_scores(model, input_tokens, incremental_states, cand_tokens, input_len, k):
|
768 |
+
with torch.no_grad():
|
769 |
+
lm_lprobs, avg_attn_scores = model.forward_decoder(
|
770 |
+
input_tokens, encoder_outs=None, incremental_states=incremental_states,
|
771 |
+
)
|
772 |
+
|
773 |
+
lm_lprobs_size = lm_lprobs.size(0)
|
774 |
+
probs_next_wrd = torch.gather(lm_lprobs.repeat(1, k).view(lm_lprobs_size*k, -1), 1, cand_tokens).squeeze().view(-1)
|
775 |
+
|
776 |
+
return probs_next_wrd
|
777 |
+
|
778 |
+
|
779 |
+
def make_dict2dict(old_dict, new_dict):
|
780 |
+
dict2dict_map = {}
|
781 |
+
for sym in old_dict.symbols:
|
782 |
+
dict2dict_map[old_dict.index(sym)] = new_dict.index(sym)
|
783 |
+
return dict2dict_map
|
784 |
+
|
785 |
+
|
786 |
+
def dict2dict(tokens, dict2dict_map):
|
787 |
+
if tokens.device == torch.device('cpu'):
|
788 |
+
tokens_tmp = tokens
|
789 |
+
else:
|
790 |
+
tokens_tmp = tokens.cpu()
|
791 |
+
return tokens_tmp.map_(
|
792 |
+
tokens_tmp,
|
793 |
+
lambda _, val, dict2dict_map=dict2dict_map : dict2dict_map[float(val)]
|
794 |
+
).to(tokens.device)
|
795 |
+
|
796 |
+
|
797 |
+
def reorder_tokens(tokens, lengths, eos):
|
798 |
+
# reorder source tokens so they may be used as reference for P(S|T)
|
799 |
+
return torch.cat((tokens.new([eos]), tokens[-lengths:-1], tokens[:-lengths]), 0)
|
800 |
+
|
801 |
+
|
802 |
+
def reorder_all_tokens(tokens, lengths, eos):
|
803 |
+
# used to reorder src tokens from [<pad> <w1> <w2> .. <eos>] to [<eos> <w1> <w2>...<pad>]
|
804 |
+
# so source tokens can be used to predict P(S|T)
|
805 |
+
return torch.stack([reorder_tokens(token, length, eos) for token, length in zip(tokens, lengths)])
|
806 |
+
|
807 |
+
|
808 |
+
def normalized_scores_with_batch_vocab(
|
809 |
+
model_decoder, features, target_ids, k, bsz, beam_size,
|
810 |
+
pad_idx, top_k=0, vocab_size_meter=None, start_idx=None,
|
811 |
+
end_idx=None, **kwargs):
|
812 |
+
"""
|
813 |
+
Get normalized probabilities (or log probs) from a net's output
|
814 |
+
w.r.t. vocab consisting of target IDs in the batch
|
815 |
+
"""
|
816 |
+
if model_decoder.adaptive_softmax is None:
|
817 |
+
weight = model_decoder.output_projection.weight
|
818 |
+
vocab_ids = torch.unique(
|
819 |
+
torch.cat(
|
820 |
+
(torch.unique(target_ids), torch.arange(top_k, device=target_ids.device))
|
821 |
+
)
|
822 |
+
)
|
823 |
+
id_map = dict(zip(vocab_ids.tolist(), range(len(vocab_ids))))
|
824 |
+
mapped_target_ids = target_ids.cpu().apply_(
|
825 |
+
lambda x, id_map=id_map: id_map[x]
|
826 |
+
).to(target_ids.device)
|
827 |
+
expanded_target_ids = mapped_target_ids[:, :].repeat(1, k).view(bsz*beam_size*k, -1)
|
828 |
+
if start_idx is not None and end_idx is not None:
|
829 |
+
expanded_target_ids = expanded_target_ids[start_idx:end_idx, :]
|
830 |
+
logits = F.linear(features, weight[vocab_ids, :])
|
831 |
+
log_softmax = F.log_softmax(logits, dim=-1, dtype=torch.float32)
|
832 |
+
intermed_scores = torch.gather(
|
833 |
+
log_softmax[:, :-1, :],
|
834 |
+
2,
|
835 |
+
expanded_target_ids[:, 1:].unsqueeze(2),
|
836 |
+
).squeeze()
|
837 |
+
not_padding = expanded_target_ids[:, 1:] != pad_idx
|
838 |
+
intermed_scores *= not_padding.float()
|
839 |
+
return intermed_scores
|
840 |
+
else:
|
841 |
+
raise ValueError("adaptive softmax doesn't work with " +
|
842 |
+
"`normalized_scores_with_batch_vocab()`")
|
fairseq/examples/fast_noisy_channel/noisy_channel_translation.py
ADDED
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
from fairseq.tasks.translation import TranslationTask
|
7 |
+
from fairseq.tasks.language_modeling import LanguageModelingTask
|
8 |
+
from fairseq import checkpoint_utils
|
9 |
+
import argparse
|
10 |
+
from fairseq.tasks import register_task
|
11 |
+
import torch
|
12 |
+
|
13 |
+
|
14 |
+
@register_task("noisy_channel_translation")
|
15 |
+
class NoisyChannelTranslation(TranslationTask):
|
16 |
+
"""
|
17 |
+
Rescore the top k candidates from each beam using noisy channel modeling
|
18 |
+
"""
|
19 |
+
|
20 |
+
@staticmethod
|
21 |
+
def add_args(parser):
|
22 |
+
"""Add task-specific arguments to the parser."""
|
23 |
+
TranslationTask.add_args(parser)
|
24 |
+
# fmt: off
|
25 |
+
parser.add_argument('--channel-model', metavar='FILE',
|
26 |
+
help='path to P(S|T) model. P(S|T) and P(T|S) must share source and target dictionaries.')
|
27 |
+
parser.add_argument('--combine-method', default='lm_only',
|
28 |
+
choices=['lm_only', 'noisy_channel'],
|
29 |
+
help="""method for combining direct and channel model scores.
|
30 |
+
lm_only: decode with P(T|S)P(T)
|
31 |
+
noisy_channel: decode with 1/t P(T|S) + 1/s(P(S|T)P(T))""")
|
32 |
+
parser.add_argument('--normalize-lm-scores-by-tgt-len', action='store_true', default=False,
|
33 |
+
help='normalize lm score by target length instead of source length')
|
34 |
+
parser.add_argument('--channel-scoring-type', default='log_norm', choices=['unnormalized', 'log_norm', 'k2_separate', 'src_vocab', 'src_vocab_batched'],
|
35 |
+
help="Normalize bw scores with log softmax or return bw scores without log softmax")
|
36 |
+
parser.add_argument('--top-k-vocab', default=0, type=int,
|
37 |
+
help='top k vocab IDs to use with `src_vocab` in channel model scoring')
|
38 |
+
parser.add_argument('--k2', default=50, type=int,
|
39 |
+
help='the top k2 candidates to rescore with the noisy channel model for each beam')
|
40 |
+
parser.add_argument('--ch-wt', default=1, type=float,
|
41 |
+
help='weight for the channel model')
|
42 |
+
parser.add_argument('--lm-model', metavar='FILE',
|
43 |
+
help='path to lm model file, to model P(T). P(T) must share the same vocab as the direct model on the target side')
|
44 |
+
parser.add_argument('--lm-data', metavar='FILE',
|
45 |
+
help='path to lm model training data for target language, used to properly load LM with correct dictionary')
|
46 |
+
parser.add_argument('--lm-wt', default=1, type=float,
|
47 |
+
help='the weight of the lm in joint decoding')
|
48 |
+
# fmt: on
|
49 |
+
|
50 |
+
def build_generator(
|
51 |
+
self, models, args, seq_gen_cls=None, extra_gen_cls_kwargs=None
|
52 |
+
):
|
53 |
+
if getattr(args, "score_reference", False):
|
54 |
+
raise NotImplementedError()
|
55 |
+
else:
|
56 |
+
from .noisy_channel_sequence_generator import NoisyChannelSequenceGenerator
|
57 |
+
use_cuda = torch.cuda.is_available() and not self.args.cpu
|
58 |
+
assert self.args.lm_model is not None, '--lm-model required for noisy channel generation!'
|
59 |
+
assert self.args.lm_data is not None, '--lm-data required for noisy channel generation to map between LM and bitext vocabs'
|
60 |
+
if self.args.channel_model is not None:
|
61 |
+
import copy
|
62 |
+
ch_args_task = copy.deepcopy(self.args)
|
63 |
+
tmp = ch_args_task.source_lang
|
64 |
+
ch_args_task.source_lang = ch_args_task.target_lang
|
65 |
+
ch_args_task.target_lang = tmp
|
66 |
+
ch_args_task._name = 'translation'
|
67 |
+
channel_task = TranslationTask.setup_task(ch_args_task)
|
68 |
+
|
69 |
+
arg_dict = {}
|
70 |
+
arg_dict['task'] = 'language_modeling'
|
71 |
+
arg_dict['sample_break_mode'] = 'eos'
|
72 |
+
arg_dict['data'] = self.args.lm_data
|
73 |
+
arg_dict['output_dictionary_size'] = -1
|
74 |
+
lm_args = argparse.Namespace(**arg_dict)
|
75 |
+
lm_task = LanguageModelingTask.setup_task(lm_args)
|
76 |
+
lm_dict = lm_task.output_dictionary
|
77 |
+
|
78 |
+
if self.args.channel_model is not None:
|
79 |
+
channel_models, _ = checkpoint_utils.load_model_ensemble(self.args.channel_model.split(':'), task=channel_task)
|
80 |
+
|
81 |
+
for model in channel_models:
|
82 |
+
model.make_generation_fast_(
|
83 |
+
beamable_mm_beam_size=None if args.no_beamable_mm else args.beam,
|
84 |
+
need_attn=args.print_alignment,
|
85 |
+
)
|
86 |
+
if self.args.fp16:
|
87 |
+
model.half()
|
88 |
+
if use_cuda:
|
89 |
+
model.cuda()
|
90 |
+
else:
|
91 |
+
channel_models = None
|
92 |
+
|
93 |
+
lm_models, _ = checkpoint_utils.load_model_ensemble(self.args.lm_model.split(':'), task=lm_task)
|
94 |
+
|
95 |
+
for model in lm_models:
|
96 |
+
model.make_generation_fast_(
|
97 |
+
beamable_mm_beam_size=None if args.no_beamable_mm else args.beam,
|
98 |
+
need_attn=args.print_alignment,
|
99 |
+
)
|
100 |
+
if self.args.fp16:
|
101 |
+
model.half()
|
102 |
+
if use_cuda:
|
103 |
+
model.cuda()
|
104 |
+
return NoisyChannelSequenceGenerator(
|
105 |
+
combine_method=self.args.combine_method,
|
106 |
+
tgt_dict=self.target_dictionary,
|
107 |
+
src_dict=self.source_dictionary,
|
108 |
+
beam_size=getattr(args, 'beam', 5),
|
109 |
+
max_len_a=getattr(args, 'max_len_a', 0),
|
110 |
+
max_len_b=getattr(args, 'max_len_b', 200),
|
111 |
+
min_len=getattr(args, 'min_len', 1),
|
112 |
+
len_penalty=getattr(args, 'lenpen', 1),
|
113 |
+
unk_penalty=getattr(args, 'unkpen', 0),
|
114 |
+
temperature=getattr(args, 'temperature', 1.),
|
115 |
+
match_source_len=getattr(args, 'match_source_len', False),
|
116 |
+
no_repeat_ngram_size=getattr(args, 'no_repeat_ngram_size', 0),
|
117 |
+
normalize_scores=(not getattr(args, 'unnormalized', False)),
|
118 |
+
channel_models=channel_models,
|
119 |
+
k2=getattr(self.args, 'k2', 50),
|
120 |
+
ch_weight=getattr(self.args, 'ch_wt', 1),
|
121 |
+
channel_scoring_type=self.args.channel_scoring_type,
|
122 |
+
top_k_vocab=self.args.top_k_vocab,
|
123 |
+
lm_models=lm_models,
|
124 |
+
lm_dict=lm_dict,
|
125 |
+
lm_weight=getattr(self.args, 'lm_wt', 1),
|
126 |
+
normalize_lm_scores_by_tgt_len=getattr(self.args, 'normalize_lm_scores_by_tgt_len', False),
|
127 |
+
)
|
fairseq/examples/flores101/README.md
ADDED
@@ -0,0 +1,223 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<p align="center">
|
2 |
+
<img src="flores_logo.png" width="500">
|
3 |
+
</p>
|
4 |
+
|
5 |
+
# Flores101: Large-Scale Multilingual Machine Translation
|
6 |
+
|
7 |
+
## Introduction
|
8 |
+
|
9 |
+
Baseline pretrained models for small and large tracks of WMT 21 Large-Scale Multilingual Machine Translation competition.
|
10 |
+
|
11 |
+
Flores Task at WMT 21: http://www.statmt.org/wmt21/large-scale-multilingual-translation-task.html
|
12 |
+
|
13 |
+
Flores announement blog post: https://ai.facebook.com/blog/flores-researchers-kick-off-multilingual-translation-challenge-at-wmt-and-call-for-compute-grants/
|
14 |
+
|
15 |
+
|
16 |
+
|
17 |
+
## Pretrained models
|
18 |
+
|
19 |
+
Model | Num layers | Embed dimension | FFN dimension| Vocab Size | #params | Download
|
20 |
+
---|---|---|---|---|---|---
|
21 |
+
`flores101_mm100_615M` | 12 | 1024 | 4096 | 256,000 | 615M | https://dl.fbaipublicfiles.com/flores101/pretrained_models/flores101_mm100_615M.tar.gz
|
22 |
+
`flores101_mm100_175M` | 6 | 512 | 2048 | 256,000 | 175M | https://dl.fbaipublicfiles.com/flores101/pretrained_models/flores101_mm100_175M.tar.gz
|
23 |
+
|
24 |
+
|
25 |
+
These models are trained similar to [M2M-100](https://arxiv.org/abs/2010.11125) with additional support for the languages that are part of the WMT Large-Scale Multilingual Machine Translation track. Full list of languages can be found at the bottom.
|
26 |
+
|
27 |
+
|
28 |
+
## Example Generation code
|
29 |
+
|
30 |
+
### Download model, sentencepiece vocab
|
31 |
+
|
32 |
+
```bash
|
33 |
+
fairseq=/path/to/fairseq
|
34 |
+
cd $fairseq
|
35 |
+
|
36 |
+
# Download 615M param model.
|
37 |
+
wget https://dl.fbaipublicfiles.com/flores101/pretrained_models/flores101_mm100_615M.tar.gz
|
38 |
+
|
39 |
+
# Extract
|
40 |
+
tar -xvzf flores101_mm100_615M.tar.gz
|
41 |
+
```
|
42 |
+
|
43 |
+
### Encode using our SentencePiece Model
|
44 |
+
Note: Install SentencePiece from [here](https://github.com/google/sentencepiece)
|
45 |
+
|
46 |
+
|
47 |
+
```bash
|
48 |
+
fairseq=/path/to/fairseq
|
49 |
+
cd $fairseq
|
50 |
+
|
51 |
+
# Download example dataset From German to French
|
52 |
+
sacrebleu --echo src -l de-fr -t wmt19 | head -n 20 > raw_input.de-fr.de
|
53 |
+
sacrebleu --echo ref -l de-fr -t wmt19 | head -n 20 > raw_input.de-fr.fr
|
54 |
+
|
55 |
+
for lang in de fr ; do
|
56 |
+
python scripts/spm_encode.py \
|
57 |
+
--model flores101_mm100_615M/sentencepiece.bpe.model \
|
58 |
+
--output_format=piece \
|
59 |
+
--inputs=raw_input.de-fr.${lang} \
|
60 |
+
--outputs=spm.de-fr.${lang}
|
61 |
+
done
|
62 |
+
```
|
63 |
+
|
64 |
+
### Binarization
|
65 |
+
|
66 |
+
```bash
|
67 |
+
fairseq-preprocess \
|
68 |
+
--source-lang de --target-lang fr \
|
69 |
+
--testpref spm.de-fr \
|
70 |
+
--thresholdsrc 0 --thresholdtgt 0 \
|
71 |
+
--destdir data_bin \
|
72 |
+
--srcdict flores101_mm100_615M/dict.txt --tgtdict flores101_mm100_615M/dict.txt
|
73 |
+
```
|
74 |
+
|
75 |
+
### Generation
|
76 |
+
|
77 |
+
|
78 |
+
```bash
|
79 |
+
fairseq-generate \
|
80 |
+
data_bin \
|
81 |
+
--batch-size 1 \
|
82 |
+
--path flores101_mm100_615M/model.pt \
|
83 |
+
--fixed-dictionary flores101_mm100_615M/dict.txt \
|
84 |
+
-s de -t fr \
|
85 |
+
--remove-bpe 'sentencepiece' \
|
86 |
+
--beam 5 \
|
87 |
+
--task translation_multi_simple_epoch \
|
88 |
+
--lang-pairs flores101_mm100_615M/language_pairs.txt \
|
89 |
+
--decoder-langtok --encoder-langtok src \
|
90 |
+
--gen-subset test \
|
91 |
+
--fp16 \
|
92 |
+
--dataset-impl mmap \
|
93 |
+
--distributed-world-size 1 --distributed-no-spawn
|
94 |
+
```
|
95 |
+
|
96 |
+
### Supported Languages and lang code
|
97 |
+
|
98 |
+
Language | lang code
|
99 |
+
---|---
|
100 |
+
Akrikaans | af
|
101 |
+
Amharic | am
|
102 |
+
Arabic | ar
|
103 |
+
Assamese | as
|
104 |
+
Asturian | ast
|
105 |
+
Aymara | ay
|
106 |
+
Azerbaijani | az
|
107 |
+
Bashkir | ba
|
108 |
+
Belarusian | be
|
109 |
+
Bulgarian | bg
|
110 |
+
Bengali | bn
|
111 |
+
Breton | br
|
112 |
+
Bosnian | bs
|
113 |
+
Catalan | ca
|
114 |
+
Cebuano | ceb
|
115 |
+
Chokwe | cjk
|
116 |
+
Czech | cs
|
117 |
+
Welsh | cy
|
118 |
+
Danish | da
|
119 |
+
German | de
|
120 |
+
Dyula| dyu
|
121 |
+
Greek | el
|
122 |
+
English | en
|
123 |
+
Spanish | es
|
124 |
+
Estonian | et
|
125 |
+
Persian | fa
|
126 |
+
Fulah | ff
|
127 |
+
Finnish | fi
|
128 |
+
French | fr
|
129 |
+
Western Frisian | fy
|
130 |
+
Irish | ga
|
131 |
+
Scottish Gaelic | gd
|
132 |
+
Galician | gl
|
133 |
+
Gujarati | gu
|
134 |
+
Hausa | ha
|
135 |
+
Hebrew | he
|
136 |
+
Hindi | hi
|
137 |
+
Croatian | hr
|
138 |
+
Haitian Creole | ht
|
139 |
+
Hungarian | hu
|
140 |
+
Armenian | hy
|
141 |
+
Indonesian | id
|
142 |
+
Igbo | ig
|
143 |
+
Iloko | ilo
|
144 |
+
Icelandic | is
|
145 |
+
Italian | it
|
146 |
+
Japanese | ja
|
147 |
+
Javanese | jv
|
148 |
+
Georgian | ka
|
149 |
+
Kachin | kac
|
150 |
+
Kamba | kam
|
151 |
+
Kabuverdianu | kea
|
152 |
+
Kongo | kg
|
153 |
+
Kazakh | kk
|
154 |
+
Central Khmer | km
|
155 |
+
Kimbundu | kmb
|
156 |
+
Northern Kurdish | kmr
|
157 |
+
Kannada | kn
|
158 |
+
Korean | ko
|
159 |
+
Kurdish | ku
|
160 |
+
Kyrgyz | ky
|
161 |
+
Luxembourgish | lb
|
162 |
+
Ganda | lg
|
163 |
+
Lingala | ln
|
164 |
+
Lao | lo
|
165 |
+
Lithuanian | lt
|
166 |
+
Luo | luo
|
167 |
+
Latvian | lv
|
168 |
+
Malagasy | mg
|
169 |
+
Maori | mi
|
170 |
+
Macedonian | mk
|
171 |
+
Malayalam | ml
|
172 |
+
Mongolian | mn
|
173 |
+
Marathi | mr
|
174 |
+
Malay | ms
|
175 |
+
Maltese | mt
|
176 |
+
Burmese | my
|
177 |
+
Nepali | ne
|
178 |
+
Dutch | nl
|
179 |
+
Norwegian | no
|
180 |
+
Northern Sotho | ns
|
181 |
+
Nyanja | ny
|
182 |
+
Occitan | oc
|
183 |
+
Oromo | om
|
184 |
+
Oriya | or
|
185 |
+
Punjabi | pa
|
186 |
+
Polish | pl
|
187 |
+
Pashto | ps
|
188 |
+
Portuguese | pt
|
189 |
+
Quechua | qu
|
190 |
+
Romanian | ro
|
191 |
+
Russian | ru
|
192 |
+
Sindhi | sd
|
193 |
+
Shan | shn
|
194 |
+
Sinhala | si
|
195 |
+
Slovak | sk
|
196 |
+
Slovenian | sl
|
197 |
+
Shona | sn
|
198 |
+
Somali | so
|
199 |
+
Albanian | sq
|
200 |
+
Serbian | sr
|
201 |
+
Swati | ss
|
202 |
+
Sundanese | su
|
203 |
+
Swedish | sv
|
204 |
+
Swahili | sw
|
205 |
+
Tamil | ta
|
206 |
+
Telugu | te
|
207 |
+
Tajik | tg
|
208 |
+
Thai | th
|
209 |
+
Tigrinya | ti
|
210 |
+
Tagalog | tl
|
211 |
+
Tswana | tn
|
212 |
+
Turkish | tr
|
213 |
+
Ukrainian | uk
|
214 |
+
Umbundu | umb
|
215 |
+
Urdu | ur
|
216 |
+
Uzbek | uz
|
217 |
+
Vietnamese | vi
|
218 |
+
Wolof | wo
|
219 |
+
Xhosa | xh
|
220 |
+
Yiddish | yi
|
221 |
+
Yoruba | yo
|
222 |
+
Chinese| zh
|
223 |
+
Zulu | zu
|
fairseq/examples/flores101/flores_logo.png
ADDED
![]() |
fairseq/examples/fully_sharded_data_parallel/README.md
ADDED
@@ -0,0 +1,177 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Fully Sharded Data Parallel (FSDP)
|
2 |
+
|
3 |
+
## Overview
|
4 |
+
Recent work by [Microsoft](https://arxiv.org/abs/1910.02054) and
|
5 |
+
[Google](https://arxiv.org/abs/2004.13336) has shown that data parallel
|
6 |
+
training can be made significantly more efficient by sharding the model
|
7 |
+
parameters and optimizer state across data parallel workers. These ideas are
|
8 |
+
encapsulated in the new **`FullyShardedDataParallel` (FSDP)** wrapper provided
|
9 |
+
by [fairscale](https://github.com/facebookresearch/fairscale/).
|
10 |
+
|
11 |
+
Compared to PyTorch DDP:
|
12 |
+
* FSDP produces identical results as PyTorch DDP (it's still synchronous data parallel training)
|
13 |
+
* FSDP shards parameters (FP16 + FP32) and optimizer state across data parallel GPUs
|
14 |
+
* FSDP is faster than PyTorch DDP because the optimizer step is sharded, and the communication can be overlapped with the forward pass
|
15 |
+
* FSDP enables training 13B parameter models on 8 GPUs and 175B parameter models on 128 GPUs
|
16 |
+
|
17 |
+
FSDP is fully supported in fairseq via the following new arguments:
|
18 |
+
* `--ddp-backend=fully_sharded`: enables full sharding via FSDP
|
19 |
+
* `--cpu-offload`: offloads the optimizer state and FP32 model copy to CPU (combine with `--optimizer=cpu_adam`)
|
20 |
+
* `--no-reshard-after-forward`: increases training speed for large models (1B+ params) and is similar to ZeRO stage 2
|
21 |
+
* other popular options (`--fp16`, `--update-freq`, `--checkpoint-activations`, `--offload-activations`, etc.) continue to work as normal
|
22 |
+
|
23 |
+
<details><summary>Limitations</summary><p>
|
24 |
+
|
25 |
+
FSDP currently has several limitations compared to fairseq's default DDP backend (PyTorch DDP):
|
26 |
+
* while FSDP is full compatible with pointwise Optimizers (e.g., Adam, AdamW, Adadelta, Adamax, SGD, etc.), it is not currently compatible with non-pointwise Optimizers (e.g., Adagrad, Adafactor, LAMB, etc.)
|
27 |
+
* FSDP depends on flattening the parameters, so models that currently require `--fp16-no-flatten-grads` may not be supported
|
28 |
+
|
29 |
+
See the [fairscale docs](https://fairscale.readthedocs.io/en/latest/api/nn/fsdp_tips.html) for a more detailed
|
30 |
+
explanation of these and other limitations.
|
31 |
+
|
32 |
+
</p></details>
|
33 |
+
|
34 |
+
<details><summary>How it works</summary><p>
|
35 |
+
|
36 |
+
<img width="800" alt="Fully Sharded Data Parallel" src="https://user-images.githubusercontent.com/231798/110406775-c2de0000-8050-11eb-9718-fbfc4510a76a.png">
|
37 |
+
|
38 |
+
See the [fairscale docs](https://fairscale.readthedocs.io/en/latest/api/nn/fsdp_tips.html) for a more detailed
|
39 |
+
explanation of how FSDP works.
|
40 |
+
|
41 |
+
</p></details>
|
42 |
+
|
43 |
+
## Example usage
|
44 |
+
|
45 |
+
The following examples illustrate how to train a very large language model with
|
46 |
+
13 billion parameters on 1 GPU by offloading parameters and optimizer states to
|
47 |
+
CPU, or on 8 GPUs by fully sharding the params and optimizer states across GPUs.
|
48 |
+
|
49 |
+
These examples use the WikiText-103 dataset for demonstration purposes, but
|
50 |
+
in practice a much larger dataset will be needed to achieve good results.
|
51 |
+
Follow the [instructions here](https://github.com/pytorch/fairseq/blob/main/examples/roberta/README.pretraining.md#1-preprocess-the-data)
|
52 |
+
to preprocess the WikiText-103 dataset using the GPT-2/RoBERTa vocabulary.
|
53 |
+
|
54 |
+
### 13B params on 1 V100 GPU (with CPU offloading)
|
55 |
+
|
56 |
+
The following command trains a 13B parameter GPT-3 model on a single V100 GPU
|
57 |
+
using the `--cpu-offload` feature to offload parameters and optimizer states to
|
58 |
+
CPU. In this setting, the optimizer step (Adam) happens on CPU. We also use the
|
59 |
+
`--checkpoint-activations` feature (sometimes called [gradient checkpointing](https://pytorch.org/docs/stable/checkpoint.html)),
|
60 |
+
which further saves memory in exchange for a small increase in computation.
|
61 |
+
|
62 |
+
**Requirements:**
|
63 |
+
- Install the latest master version of fairscale: `pip install git+https://github.com/facebookresearch/fairscale.git@master`
|
64 |
+
- You'll need 32GB of GPU memory and ~256GB of system memory to train the 13B param model.
|
65 |
+
- If you have less system memory, the 6.7B param model can be trained with ~128GB of system memory, just set `--arch transformer_lm_gpt3_6_7`
|
66 |
+
- We use the CPU Adam optimizer from [DeepSpeed](https://github.com/microsoft/DeepSpeed), so you'll need to `pip install deepspeed` before running the command.
|
67 |
+
|
68 |
+
**Notes:**
|
69 |
+
- The command will take ~5 minutes to start training, during which time it will appear to be hung, since randomly initializing 13B weights can be slow.
|
70 |
+
- The `--cpu-offload` feature requires training in mixed precision (`--fp16`).
|
71 |
+
- Tune the `OMP_NUM_THREADS` env variable for best performance with CPU offloading.
|
72 |
+
- The example command below stops training after 10 steps (`--max-update 10`) and does not save checkpoints (`--no-save`).
|
73 |
+
|
74 |
+
```bash
|
75 |
+
OMP_NUM_THREADS=20 CUDA_VISIBLE_DEVICES=0 \
|
76 |
+
fairseq-train data-bin/wikitext-103-roberta-bpe-bin \
|
77 |
+
--ddp-backend fully_sharded --fp16 --fp16-init-scale 4 \
|
78 |
+
--cpu-offload --checkpoint-activations \
|
79 |
+
--task language_modeling --tokens-per-sample 2048 --batch-size 8 \
|
80 |
+
--arch transformer_lm_gpt3_13 \
|
81 |
+
--optimizer cpu_adam --adam-betas "(0.9,0.98)" \
|
82 |
+
--lr 0.0001 --lr-scheduler polynomial_decay --warmup-updates 5 --total-num-update 10 \
|
83 |
+
--max-update 10 --no-save --log-format json --log-interval 1
|
84 |
+
```
|
85 |
+
|
86 |
+
<details><summary>Example output</summary><p>
|
87 |
+
|
88 |
+
```
|
89 |
+
(...)
|
90 |
+
2021-03-08 12:29:51 | INFO | fairseq_cli.train | num. model params: 13,110,865,920 (num. trained: 13,110,865,920)
|
91 |
+
(...)
|
92 |
+
2021-03-08 12:29:51 | INFO | fairseq_cli.train | training on 1 devices (GPUs/TPUs)
|
93 |
+
2021-03-08 12:29:51 | INFO | fairseq_cli.train | max tokens per GPU = None and batch size per GPU = 8
|
94 |
+
(...)
|
95 |
+
Adam Optimizer #0 is created with AVX2 arithmetic capability.
|
96 |
+
Config: alpha=0.000100, betas=(0.900000, 0.980000), weight_decay=0.000000, adam_w=1
|
97 |
+
(...)
|
98 |
+
2021-03-08 12:31:36 | INFO | train_inner | {"epoch": 1, "update": 0.0, "loss": "16.475", "ppl": "91120.8", "wps": "0", "ups": "0", "wpb": "16384", "bsz": "8", "num_updates": "1", "lr": "2e-05", "gnorm": "20.751", "loss_scale": "4", "train_wall": "99", "gb_free": "9.3", "wall": "105"}
|
99 |
+
2021-03-08 12:32:33 | INFO | train_inner | {"epoch": 1, "update": 0.0, "loss": "16.446", "ppl": "89281.6", "wps": "288.7", "ups": "0.02", "wpb": "16384", "bsz": "8", "num_updates": "2", "lr": "4e-05", "gnorm": "19.777", "loss_scale": "4", "train_wall": "57", "gb_free": "9.3", "wall": "161"}
|
100 |
+
2021-03-08 12:33:12 | INFO | fairseq.trainer | NOTE: gradient overflow detected, ignoring gradient, setting loss scale to: 2.0
|
101 |
+
2021-03-08 12:33:51 | INFO | fairseq.trainer | NOTE: gradient overflow detected, ignoring gradient, setting loss scale to: 1.0
|
102 |
+
2021-03-08 12:34:45 | INFO | train_inner | {"epoch": 1, "update": 0.001, "loss": "25.22", "ppl": "3.90691e+07", "wps": "123.4", "ups": "0.01", "wpb": "16384", "bsz": "8", "num_updates": "3", "lr": "6e-05", "gnorm": "131.281", "loss_scale": "1", "train_wall": "133", "gb_free": "9.3", "wall": "294"}
|
103 |
+
2021-03-08 12:35:43 | INFO | train_inner | {"epoch": 1, "update": 0.001, "loss": "18.079", "ppl": "276809", "wps": "285.5", "ups": "0.02", "wpb": "16384", "bsz": "8", "num_updates": "4", "lr": "8e-05", "gnorm": "13.776", "loss_scale": "1", "train_wall": "57", "gb_free": "9.3", "wall": "351"}
|
104 |
+
2021-03-08 12:36:35 | INFO | train_inner | {"epoch": 1, "update": 0.001, "loss": "23.729", "ppl": "1.39088e+07", "wps": "316.7", "ups": "0.02", "wpb": "16384", "bsz": "8", "num_updates": "5", "lr": "0.0001", "gnorm": "72.774", "loss_scale": "1", "train_wall": "52", "gb_free": "9.3", "wall": "403"}
|
105 |
+
2021-03-08 12:37:28 | INFO | train_inner | {"epoch": 1, "update": 0.001, "loss": "20.429", "ppl": "1.41203e+06", "wps": "307.6", "ups": "0.02", "wpb": "16384", "bsz": "8", "num_updates": "6", "lr": "8e-05", "gnorm": "60.846", "loss_scale": "1", "train_wall": "53", "gb_free": "9.3", "wall": "456"}
|
106 |
+
2021-03-08 12:38:27 | INFO | train_inner | {"epoch": 1, "update": 0.001, "loss": "18.965", "ppl": "511684", "wps": "279.4", "ups": "0.02", "wpb": "16384", "bsz": "8", "num_updates": "7", "lr": "6e-05", "gnorm": "22.687", "loss_scale": "1", "train_wall": "59", "gb_free": "9.3", "wall": "515"}
|
107 |
+
2021-03-08 12:39:18 | INFO | train_inner | {"epoch": 1, "update": 0.001, "loss": "18.345", "ppl": "332887", "wps": "319.1", "ups": "0.02", "wpb": "16384", "bsz": "8", "num_updates": "8", "lr": "4e-05", "gnorm": "8.451", "loss_scale": "1", "train_wall": "51", "gb_free": "9.3", "wall": "566"}
|
108 |
+
2021-03-08 12:40:11 | INFO | train_inner | {"epoch": 1, "update": 0.002, "loss": "18.262", "ppl": "314336", "wps": "305.9", "ups": "0.02", "wpb": "16384", "bsz": "8", "num_updates": "9", "lr": "2e-05", "gnorm": "6.457", "loss_scale": "1", "train_wall": "54", "gb_free": "9.3", "wall": "620"}
|
109 |
+
2021-03-08 12:41:04 | INFO | train_inner | {"epoch": 1, "update": 0.002, "loss": "17.556", "ppl": "192686", "wps": "311.8", "ups": "0.02", "wpb": "16384", "bsz": "8", "num_updates": "10", "lr": "0", "gnorm": "5.796", "loss_scale": "1", "train_wall": "53", "gb_free": "9.3", "wall": "673"}
|
110 |
+
2021-03-08 12:41:04 | INFO | fairseq_cli.train | Stopping training due to num_updates: 10 >= max_update: 10
|
111 |
+
2021-03-08 12:41:04 | INFO | fairseq_cli.train | begin validation on "valid" subset
|
112 |
+
2021-03-08 12:43:15 | INFO | valid | {"epoch": 1, "valid_loss": "17.953", "valid_ppl": "253807", "valid_wps": "1868.4", "valid_wpb": "15400.2", "valid_bsz": "7.6", "valid_num_updates": "10"}
|
113 |
+
2021-03-08 12:43:15 | INFO | fairseq_cli.train | end of epoch 1 (average epoch stats below)
|
114 |
+
2021-03-08 12:43:15 | INFO | train | {"epoch": 1, "train_loss": "19.351", "train_ppl": "668509", "train_wps": "210.9", "train_ups": "0.01", "train_wpb": "16384", "train_bsz": "8", "train_num_updates": "10", "train_lr": "0", "train_gnorm": "36.26", "train_loss_scale": "1", "train_train_wall": "667", "train_gb_free": "9.3", "train_wall": "804"}
|
115 |
+
2021-03-08 12:43:15 | INFO | fairseq_cli.train | done training in 798.6 seconds
|
116 |
+
```
|
117 |
+
|
118 |
+
</p></details>
|
119 |
+
|
120 |
+
### 13B params on 8 V100 GPUs (with full parameter + optimizer state sharding)
|
121 |
+
|
122 |
+
FSDP can also shard the parameters and optimizer states across multiple GPUs,
|
123 |
+
reducing memory requirements significantly. On 8 x 32GB GPUs, sharding enables
|
124 |
+
training the same 13B parameter model *without offloading the parameters to
|
125 |
+
CPU*. However, without CPU offloading we'd only be able to fit a batch size of
|
126 |
+
1 per GPU, which would cause training speed to suffer.
|
127 |
+
|
128 |
+
We obtain the best performance on 8 GPUs by combining full sharding and CPU
|
129 |
+
offloading. The following command trains the same 13B parameter GPT-3 model as
|
130 |
+
before on 8 x 32GB V100 GPUs; training speed increases superlinearly from ~310
|
131 |
+
words per second to ~3200 words per second.
|
132 |
+
|
133 |
+
```bash
|
134 |
+
OMP_NUM_THREADS=20 CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \
|
135 |
+
fairseq-train data-bin/wikitext-103-roberta-bpe-bin \
|
136 |
+
--ddp-backend fully_sharded --fp16 --fp16-init-scale 4 \
|
137 |
+
--cpu-offload --checkpoint-activations \
|
138 |
+
--task language_modeling --tokens-per-sample 2048 --batch-size 8 \
|
139 |
+
--arch transformer_lm_gpt3_13 \
|
140 |
+
--optimizer cpu_adam --adam-betas "(0.9,0.98)" \
|
141 |
+
--lr 0.0001 --lr-scheduler polynomial_decay --warmup-updates 5 --total-num-update 10 \
|
142 |
+
--max-update 10 --no-save --log-format json --log-interval 1
|
143 |
+
```
|
144 |
+
|
145 |
+
<details><summary>Example output</summary><p>
|
146 |
+
|
147 |
+
```
|
148 |
+
(...)
|
149 |
+
2021-03-08 18:04:09 | INFO | fairseq_cli.train | num. model params: 13,110,865,920 (num. trained: 13,110,865,920)
|
150 |
+
(...)
|
151 |
+
2021-03-08 18:04:09 | INFO | fairseq_cli.train | training on 8 devices (GPUs/TPUs)
|
152 |
+
2021-03-08 18:04:09 | INFO | fairseq_cli.train | max tokens per GPU = None and batch size per GPU = 8
|
153 |
+
(...)
|
154 |
+
Adam Optimizer #0 is created with AVX2 arithmetic capability.
|
155 |
+
Config: alpha=0.000100, betas=(0.900000, 0.980000), weight_decay=0.000000, adam_w=1
|
156 |
+
(...)
|
157 |
+
2021-03-08 18:05:06 | INFO | train_inner | {"epoch": 1, "update": 0.001, "loss": "16.408", "ppl": "86945.6", "wps": "0", "ups": "0", "wpb": "131072", "bsz": "64", "num_updates": "1", "lr": "2e-05", "gnorm": "18.27", "loss_scale": "4", "train_wall": "47", "gb_free": "9.3", "wall": "56"}
|
158 |
+
2021-03-08 18:05:45 | INFO | train_inner | {"epoch": 1, "update": 0.002, "loss": "16.352", "ppl": "83644.3", "wps": "3283.4", "ups": "0.03", "wpb": "131072", "bsz": "64", "num_updates": "2", "lr": "4e-05", "gnorm": "18.411", "loss_scale": "4", "train_wall": "40", "gb_free": "9.3", "wall": "96"}
|
159 |
+
2021-03-08 18:06:21 | INFO | fairseq.trainer | NOTE: gradient overflow detected, ignoring gradient, setting loss scale to: 2.0
|
160 |
+
2021-03-08 18:06:56 | INFO | fairseq.trainer | NOTE: gradient overflow detected, ignoring gradient, setting loss scale to: 1.0
|
161 |
+
2021-03-08 18:07:37 | INFO | train_inner | {"epoch": 1, "update": 0.006, "loss": "23.682", "ppl": "1.34537e+07", "wps": "1176.6", "ups": "0.01", "wpb": "131072", "bsz": "64", "num_updates": "3", "lr": "6e-05", "gnorm": "119.682", "loss_scale": "1", "train_wall": "111", "gb_free": "9.3", "wall": "208"}
|
162 |
+
2021-03-08 18:08:18 | INFO | train_inner | {"epoch": 1, "update": 0.007, "loss": "18.988", "ppl": "519921", "wps": "3189.1", "ups": "0.02", "wpb": "131072", "bsz": "64", "num_updates": "4", "lr": "8e-05", "gnorm": "14.934", "loss_scale": "1", "train_wall": "41", "gb_free": "9.3", "wall": "249"}
|
163 |
+
2021-03-08 18:08:59 | INFO | train_inner | {"epoch": 1, "update": 0.008, "loss": "20.08", "ppl": "1.10798e+06", "wps": "3223.1", "ups": "0.02", "wpb": "131072", "bsz": "64", "num_updates": "5", "lr": "0.0001", "gnorm": "59.92", "loss_scale": "1", "train_wall": "41", "gb_free": "9.3", "wall": "289"}
|
164 |
+
2021-03-08 18:09:39 | INFO | train_inner | {"epoch": 1, "update": 0.009, "loss": "18.323", "ppl": "327980", "wps": "3256.6", "ups": "0.02", "wpb": "131072", "bsz": "64", "num_updates": "6", "lr": "8e-05", "gnorm": "37.425", "loss_scale": "1", "train_wall": "40", "gb_free": "9.3", "wall": "330"}
|
165 |
+
2021-03-08 18:10:20 | INFO | train_inner | {"epoch": 1, "update": 0.01, "loss": "17.264", "ppl": "157354", "wps": "3188.7", "ups": "0.02", "wpb": "131072", "bsz": "64", "num_updates": "7", "lr": "6e-05", "gnorm": "10.824", "loss_scale": "1", "train_wall": "41", "gb_free": "9.3", "wall": "371"}
|
166 |
+
2021-03-08 18:11:01 | INFO | train_inner | {"epoch": 1, "update": 0.011, "loss": "16.794", "ppl": "113647", "wps": "3230", "ups": "0.02", "wpb": "131072", "bsz": "64", "num_updates": "8", "lr": "4e-05", "gnorm": "5.616", "loss_scale": "1", "train_wall": "41", "gb_free": "9.3", "wall": "411"}
|
167 |
+
2021-03-08 18:11:39 | INFO | train_inner | {"epoch": 1, "update": 0.012, "loss": "16.706", "ppl": "106938", "wps": "3384", "ups": "0.03", "wpb": "131072", "bsz": "64", "num_updates": "9", "lr": "2e-05", "gnorm": "5.318", "loss_scale": "1", "train_wall": "39", "gb_free": "9.3", "wall": "450"}
|
168 |
+
2021-03-08 18:12:19 | INFO | train_inner | {"epoch": 1, "update": 0.013, "loss": "16.548", "ppl": "95796.2", "wps": "3274.4", "ups": "0.02", "wpb": "131072", "bsz": "64", "num_updates": "10", "lr": "0", "gnorm": "5.22", "loss_scale": "1", "train_wall": "40", "gb_free": "9.3", "wall": "490"}
|
169 |
+
2021-03-08 18:12:19 | INFO | fairseq_cli.train | Stopping training due to num_updates: 10 >= max_update: 10
|
170 |
+
2021-03-08 18:12:19 | INFO | fairseq_cli.train | begin validation on "valid" subset
|
171 |
+
2021-03-08 18:12:45 | INFO | valid | {"epoch": 1, "valid_loss": "16.624", "valid_ppl": "101000", "valid_wps": "10855.9", "valid_wpb": "123202", "valid_bsz": "60.5", "valid_num_updates": "10"}
|
172 |
+
2021-03-08 18:12:45 | INFO | fairseq_cli.train | end of epoch 1 (average epoch stats below)
|
173 |
+
2021-03-08 18:12:45 | INFO | train | {"epoch": 1, "train_loss": "18.114", "train_ppl": "283776", "train_wps": "2567.8", "train_ups": "0.02", "train_wpb": "131072", "train_bsz": "64", "train_num_updates": "10", "train_lr": "0", "train_gnorm": "29.562", "train_loss_scale": "1", "train_train_wall": "480", "train_gb_free": "9.3", "train_wall": "516"}
|
174 |
+
2021-03-08 18:12:45 | INFO | fairseq_cli.train | done training in 509.9 seconds
|
175 |
+
```
|
176 |
+
|
177 |
+
</p></details>
|
fairseq/examples/gottbert/README.md
ADDED
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# GottBERT: a pure German language model
|
2 |
+
|
3 |
+
## Introduction
|
4 |
+
|
5 |
+
[GottBERT](http://arxiv.org/abs/2012.02110) is a pretrained language model trained on 145GB of German text based on RoBERTa.
|
6 |
+
|
7 |
+
## Example usage
|
8 |
+
|
9 |
+
### fairseq
|
10 |
+
##### Load GottBERT from torch.hub (PyTorch >= 1.1):
|
11 |
+
```python
|
12 |
+
import torch
|
13 |
+
gottbert = torch.hub.load('pytorch/fairseq', 'gottbert-base')
|
14 |
+
gottbert.eval() # disable dropout (or leave in train mode to finetune)
|
15 |
+
```
|
16 |
+
|
17 |
+
##### Load GottBERT (for PyTorch 1.0 or custom models):
|
18 |
+
```python
|
19 |
+
# Download gottbert model
|
20 |
+
wget https://dl.gottbert.de/fairseq/models/gottbert-base.tar.gz
|
21 |
+
tar -xzvf gottbert.tar.gz
|
22 |
+
|
23 |
+
# Load the model in fairseq
|
24 |
+
from fairseq.models.roberta import GottbertModel
|
25 |
+
gottbert = GottbertModel.from_pretrained('/path/to/gottbert')
|
26 |
+
gottbert.eval() # disable dropout (or leave in train mode to finetune)
|
27 |
+
```
|
28 |
+
|
29 |
+
##### Filling masks:
|
30 |
+
```python
|
31 |
+
masked_line = 'Gott ist <mask> ! :)'
|
32 |
+
gottbert.fill_mask(masked_line, topk=3)
|
33 |
+
# [('Gott ist gut ! :)', 0.3642110526561737, ' gut'),
|
34 |
+
# ('Gott ist überall ! :)', 0.06009674072265625, ' überall'),
|
35 |
+
# ('Gott ist großartig ! :)', 0.0370681993663311, ' großartig')]
|
36 |
+
```
|
37 |
+
|
38 |
+
##### Extract features from GottBERT
|
39 |
+
|
40 |
+
```python
|
41 |
+
# Extract the last layer's features
|
42 |
+
line = "Der erste Schluck aus dem Becher der Naturwissenschaft macht atheistisch , aber auf dem Grunde des Bechers wartet Gott !"
|
43 |
+
tokens = gottbert.encode(line)
|
44 |
+
last_layer_features = gottbert.extract_features(tokens)
|
45 |
+
assert last_layer_features.size() == torch.Size([1, 27, 768])
|
46 |
+
|
47 |
+
# Extract all layer's features (layer 0 is the embedding layer)
|
48 |
+
all_layers = gottbert.extract_features(tokens, return_all_hiddens=True)
|
49 |
+
assert len(all_layers) == 13
|
50 |
+
assert torch.all(all_layers[-1] == last_layer_features)
|
51 |
+
```
|
52 |
+
## Citation
|
53 |
+
If you use our work, please cite:
|
54 |
+
|
55 |
+
```bibtex
|
56 |
+
@misc{scheible2020gottbert,
|
57 |
+
title={GottBERT: a pure German Language Model},
|
58 |
+
author={Raphael Scheible and Fabian Thomczyk and Patric Tippmann and Victor Jaravine and Martin Boeker},
|
59 |
+
year={2020},
|
60 |
+
eprint={2012.02110},
|
61 |
+
archivePrefix={arXiv},
|
62 |
+
primaryClass={cs.CL}
|
63 |
+
}
|
64 |
+
```
|
fairseq/examples/hubert/README.md
ADDED
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# HuBERT
|
2 |
+
|
3 |
+
## Pre-trained and fine-tuned (ASR) models
|
4 |
+
Model | Pretraining Data | Finetuning Dataset | Model | Quantizer
|
5 |
+
|---|---|---|---|---
|
6 |
+
HuBERT Base (~95M params) | [Librispeech](http://www.openslr.org/12) 960 hr | No finetuning (Pretrained Model) | [download](https://dl.fbaipublicfiles.com/hubert/hubert_base_ls960.pt) | [L9 km500](https://dl.fbaipublicfiles.com/hubert/hubert_base_ls960_L9_km500.bin)
|
7 |
+
HuBERT Large (~316M params) | [Libri-Light](https://github.com/facebookresearch/libri-light) 60k hr | No finetuning (Pretrained Model) | [download](https://dl.fbaipublicfiles.com/hubert/hubert_large_ll60k.pt)
|
8 |
+
HuBERT Extra Large (~1B params) | [Libri-Light](https://github.com/facebookresearch/libri-light) 60k hr | No finetuning (Pretrained Model) | [download](https://dl.fbaipublicfiles.com/hubert/hubert_xtralarge_ll60k.pt)
|
9 |
+
HuBERT Large | [Libri-Light](https://github.com/facebookresearch/libri-light) 60k hr | [Librispeech](http://www.openslr.org/12) 960 hr | [download](https://dl.fbaipublicfiles.com/hubert/hubert_large_ll60k_finetune_ls960.pt)
|
10 |
+
HuBERT Extra Large | [Libri-Light](https://github.com/facebookresearch/libri-light) 60k hr | [Librispeech](http://www.openslr.org/12) 960 hr | [download](https://dl.fbaipublicfiles.com/hubert/hubert_xtralarge_ll60k_finetune_ls960.pt)
|
11 |
+
|
12 |
+
## Load a model
|
13 |
+
```
|
14 |
+
ckpt_path = "/path/to/the/checkpoint.pt"
|
15 |
+
models, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task([ckpt_path])
|
16 |
+
model = models[0]
|
17 |
+
```
|
18 |
+
|
19 |
+
## Train a new model
|
20 |
+
|
21 |
+
### Data preparation
|
22 |
+
|
23 |
+
Follow the steps in `./simple_kmeans` to create:
|
24 |
+
- `{train,valid}.tsv` waveform list files
|
25 |
+
- `{train,valid}.km` frame-aligned pseudo label files.
|
26 |
+
- `dict.km.txt` a dummy dictionary
|
27 |
+
The `label_rate` is the same as the feature frame rate used for clustering,
|
28 |
+
which is 100Hz for MFCC features and 50Hz for HuBERT features by default.
|
29 |
+
|
30 |
+
### Pre-train a HuBERT model
|
31 |
+
|
32 |
+
Suppose `{train,valid}.tsv` are saved at `/path/to/data`, `{train,valid}.km`
|
33 |
+
are saved at `/path/to/labels`, and the label rate is 100Hz.
|
34 |
+
|
35 |
+
To train a base model (12 layer transformer), run:
|
36 |
+
```sh
|
37 |
+
$ python fairseq_cli/hydra_train.py \
|
38 |
+
--config-dir /path/to/fairseq-py/examples/hubert/config/pretrain \
|
39 |
+
--config-name hubert_base_librispeech \
|
40 |
+
task.data=/path/to/data task.label_dir=/path/to/labels task.labels='["km"]' model.label_rate=100
|
41 |
+
```
|
42 |
+
|
43 |
+
### Fine-tune a HuBERT model with a CTC loss
|
44 |
+
|
45 |
+
Suppose `{train,valid}.tsv` are saved at `/path/to/data`, and their
|
46 |
+
corresponding character transcripts `{train,valid}.ltr` are saved at
|
47 |
+
`/path/to/trans`.
|
48 |
+
|
49 |
+
To fine-tune a pre-trained HuBERT model at `/path/to/checkpoint`, run
|
50 |
+
```sh
|
51 |
+
$ python fairseq_cli/hydra_train.py \
|
52 |
+
--config-dir /path/to/fairseq-py/examples/hubert/config/finetune \
|
53 |
+
--config-name base_10h \
|
54 |
+
task.data=/path/to/data task.label_dir=/path/to/trans \
|
55 |
+
model.w2v_path=/path/to/checkpoint
|
56 |
+
```
|
57 |
+
|
58 |
+
### Decode a HuBERT model
|
59 |
+
|
60 |
+
Suppose the `test.tsv` and `test.ltr` are the waveform list and transcripts of
|
61 |
+
the split to be decoded, saved at `/path/to/data`, and the fine-tuned model is
|
62 |
+
saved at `/path/to/checkpoint`. We support three decoding modes:
|
63 |
+
- Viterbi decoding: greedy decoding without a language model
|
64 |
+
- KenLM decoding: decoding with an arpa-format KenLM n-gram language model
|
65 |
+
- Fairseq-LM deocding: decoding with a Fairseq neural language model
|
66 |
+
|
67 |
+
|
68 |
+
#### Viterbi decoding
|
69 |
+
|
70 |
+
`task.normalize` needs to be consistent with the value used during fine-tuning.
|
71 |
+
Decoding results will be saved at
|
72 |
+
`/path/to/experiment/directory/decode/viterbi/test`.
|
73 |
+
|
74 |
+
```sh
|
75 |
+
$ python examples/speech_recognition/new/infer.py \
|
76 |
+
--config-dir /path/to/fairseq-py/examples/hubert/config/decode \
|
77 |
+
--config-name infer_viterbi \
|
78 |
+
task.data=/path/to/data \
|
79 |
+
task.normalize=[true|false] \
|
80 |
+
decoding.exp_dir=/path/to/experiment/directory \
|
81 |
+
common_eval.path=/path/to/checkpoint
|
82 |
+
dataset.gen_subset=test \
|
83 |
+
```
|
84 |
+
|
85 |
+
#### KenLM / Fairseq-LM decoding
|
86 |
+
|
87 |
+
Suppose the pronunciation lexicon and the n-gram LM are saved at
|
88 |
+
`/path/to/lexicon` and `/path/to/arpa`, respectively. Decoding results will be
|
89 |
+
saved at `/path/to/experiment/directory/decode/kenlm/test`.
|
90 |
+
|
91 |
+
```sh
|
92 |
+
$ python examples/speech_recognition/new/infer.py \
|
93 |
+
--config-dir /path/to/fairseq-py/examples/hubert/config/decode \
|
94 |
+
--config-name infer_kenlm \
|
95 |
+
task.data=/path/to/data \
|
96 |
+
task.normalize=[true|false] \
|
97 |
+
decoding.exp_dir=/path/to/experiment/directory \
|
98 |
+
common_eval.path=/path/to/checkpoint
|
99 |
+
dataset.gen_subset=test \
|
100 |
+
decoding.decoder.lexicon=/path/to/lexicon \
|
101 |
+
decoding.decoder.lmpath=/path/to/arpa
|
102 |
+
```
|
103 |
+
|
104 |
+
The command above uses the default decoding hyperparameter, which can be found
|
105 |
+
in `examples/speech_recognition/hydra/decoder.py`. These parameters can be
|
106 |
+
configured from the command line. For example, to search with a beam size of
|
107 |
+
500, we can append the command above with `decoding.decoder.beam=500`.
|
108 |
+
Important parameters include:
|
109 |
+
- decoding.decoder.beam
|
110 |
+
- decoding.decoder.beamthreshold
|
111 |
+
- decoding.decoder.lmweight
|
112 |
+
- decoding.decoder.wordscore
|
113 |
+
- decoding.decoder.silweight
|
114 |
+
|
115 |
+
To decode with a Fairseq LM, use `--config-name infer_fsqlm` instead, and
|
116 |
+
change the path of lexicon and LM accordingly.
|
fairseq/examples/hubert/config/decode/ax_sweep/ngram.yaml
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _global_
|
2 |
+
|
3 |
+
common_eval:
|
4 |
+
results_path: ${decoding.exp_dir}/decode/${decoding.decoder.name}_ax/${dataset.gen_subset}
|
5 |
+
|
6 |
+
hydra:
|
7 |
+
sweeper:
|
8 |
+
ax_config:
|
9 |
+
max_trials: 60
|
10 |
+
early_stop:
|
11 |
+
minimize: true
|
12 |
+
max_epochs_without_improvement: 10
|
13 |
+
epsilon: 0.025
|
14 |
+
experiment:
|
15 |
+
name: ${dataset.gen_subset}
|
16 |
+
objective_name: wer
|
17 |
+
minimize: true
|
18 |
+
parameter_constraints: null
|
19 |
+
outcome_constraints: null
|
20 |
+
status_quo: null
|
21 |
+
client:
|
22 |
+
verbose_logging: false
|
23 |
+
random_seed: null
|
24 |
+
params:
|
25 |
+
decoding.decoder.lmweight:
|
26 |
+
type: range
|
27 |
+
bounds: [0.0, 8.0]
|
28 |
+
decoding.decoder.wordscore:
|
29 |
+
type: range
|
30 |
+
bounds: [-5.0, 5.0]
|
31 |
+
decoding.decoder.silweight:
|
32 |
+
type: range
|
33 |
+
bounds: [-10.0, 0.0]
|
fairseq/examples/hubert/config/decode/ax_sweep/transformer.yaml
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _global_
|
2 |
+
|
3 |
+
common_eval:
|
4 |
+
results_path: ${decoding.exp_dir}/decode/${decoding.decoder.name}_ax/${dataset.gen_subset}
|
5 |
+
|
6 |
+
hydra:
|
7 |
+
sweeper:
|
8 |
+
ax_config:
|
9 |
+
max_trials: 60
|
10 |
+
early_stop:
|
11 |
+
minimize: true
|
12 |
+
max_epochs_without_improvement: 10
|
13 |
+
epsilon: 0.025
|
14 |
+
experiment:
|
15 |
+
name: ${dataset.gen_subset}
|
16 |
+
objective_name: wer
|
17 |
+
minimize: true
|
18 |
+
parameter_constraints: null
|
19 |
+
outcome_constraints: null
|
20 |
+
status_quo: null
|
21 |
+
client:
|
22 |
+
verbose_logging: false
|
23 |
+
random_seed: null
|
24 |
+
params:
|
25 |
+
decoding.decoder.lmweight:
|
26 |
+
type: range
|
27 |
+
bounds: [0.0, 4.0]
|
28 |
+
decoding.decoder.wordscore:
|
29 |
+
type: range
|
30 |
+
bounds: [-5.0, 5.0]
|
31 |
+
decoding.decoder.silweight:
|
32 |
+
type: range
|
33 |
+
bounds: [-8.0, 0.0]
|
fairseq/examples/hubert/config/decode/infer_fsqlm.yaml
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _group_
|
2 |
+
|
3 |
+
defaults:
|
4 |
+
- model: null
|
5 |
+
|
6 |
+
hydra:
|
7 |
+
run:
|
8 |
+
dir: ${common_eval.results_path}/beam${decoding.beam}_th${decoding.beamthreshold}_lmw${decoding.lmweight}_wrd${decoding.wordscore}_sil${decoding.silweight}
|
9 |
+
sweep:
|
10 |
+
dir: ${common_eval.results_path}
|
11 |
+
subdir: beam${decoding.beam}_th${decoding.beamthreshold}_lmw${decoding.lmweight}_wrd${decoding.wordscore}_sil${decoding.silweight}
|
12 |
+
|
13 |
+
task:
|
14 |
+
_name: hubert_pretraining
|
15 |
+
single_target: true
|
16 |
+
fine_tuning: true
|
17 |
+
data: ???
|
18 |
+
normalize: ???
|
19 |
+
|
20 |
+
decoding:
|
21 |
+
type: fairseqlm
|
22 |
+
lexicon: ???
|
23 |
+
lmpath: ???
|
24 |
+
beamthreshold: 25
|
25 |
+
beam: 500
|
26 |
+
lmweight: 2
|
27 |
+
wordscore: -1
|
28 |
+
silweight: 0
|
29 |
+
unique_wer_file: true
|
30 |
+
common_eval:
|
31 |
+
results_path: ???
|
32 |
+
path: ???
|
33 |
+
post_process: letter
|
34 |
+
dataset:
|
35 |
+
max_tokens: 1100000
|
36 |
+
gen_subset: ???
|
fairseq/examples/hubert/config/decode/infer_kenlm.yaml
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _group_
|
2 |
+
|
3 |
+
defaults:
|
4 |
+
- model: null
|
5 |
+
|
6 |
+
hydra:
|
7 |
+
run:
|
8 |
+
dir: ${common_eval.results_path}/beam${decoding.beam}_th${decoding.beamthreshold}_lmw${decoding.lmweight}_wrd${decoding.wordscore}_sil${decoding.silweight}
|
9 |
+
sweep:
|
10 |
+
dir: ${common_eval.results_path}
|
11 |
+
subdir: beam${decoding.beam}_th${decoding.beamthreshold}_lmw${decoding.lmweight}_wrd${decoding.wordscore}_sil${decoding.silweight}
|
12 |
+
|
13 |
+
task:
|
14 |
+
_name: hubert_pretraining
|
15 |
+
single_target: true
|
16 |
+
fine_tuning: true
|
17 |
+
data: ???
|
18 |
+
normalize: ???
|
19 |
+
|
20 |
+
decoding:
|
21 |
+
type: kenlm
|
22 |
+
lexicon: ???
|
23 |
+
lmpath: ???
|
24 |
+
beamthreshold: 100
|
25 |
+
beam: 500
|
26 |
+
lmweight: 2
|
27 |
+
wordscore: -1
|
28 |
+
silweight: 0
|
29 |
+
unique_wer_file: true
|
30 |
+
common_eval:
|
31 |
+
results_path: ???
|
32 |
+
path: ???
|
33 |
+
post_process: letter
|
34 |
+
dataset:
|
35 |
+
max_tokens: 1100000
|
36 |
+
gen_subset: ???
|
fairseq/examples/hubert/config/decode/infer_viterbi.yaml
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _group_
|
2 |
+
|
3 |
+
defaults:
|
4 |
+
- model: null
|
5 |
+
|
6 |
+
hydra:
|
7 |
+
run:
|
8 |
+
dir: ${common_eval.results_path}/viterbi
|
9 |
+
sweep:
|
10 |
+
dir: ${common_eval.results_path}
|
11 |
+
subdir: viterbi
|
12 |
+
|
13 |
+
task:
|
14 |
+
_name: hubert_pretraining
|
15 |
+
single_target: true
|
16 |
+
fine_tuning: true
|
17 |
+
data: ???
|
18 |
+
normalize: ???
|
19 |
+
|
20 |
+
decoding:
|
21 |
+
type: viterbi
|
22 |
+
unique_wer_file: true
|
23 |
+
common_eval:
|
24 |
+
results_path: ???
|
25 |
+
path: ???
|
26 |
+
post_process: letter
|
27 |
+
dataset:
|
28 |
+
max_tokens: 1100000
|
29 |
+
gen_subset: ???
|
fairseq/examples/hubert/config/decode/run/submitit_slurm.yaml
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _global_
|
2 |
+
hydra:
|
3 |
+
launcher:
|
4 |
+
cpus_per_task: ${distributed_training.distributed_world_size}
|
5 |
+
gpus_per_node: ${distributed_training.distributed_world_size}
|
6 |
+
tasks_per_node: ${hydra.launcher.gpus_per_node}
|
7 |
+
nodes: 1
|
8 |
+
mem_gb: 200
|
9 |
+
timeout_min: 4320
|
10 |
+
max_num_timeout: 50
|
11 |
+
name: ${hydra.job.config_name}
|
12 |
+
submitit_folder: ${hydra.sweep.dir}/submitit
|
13 |
+
|
14 |
+
distributed_training:
|
15 |
+
distributed_world_size: 1
|
16 |
+
distributed_no_spawn: true
|
17 |
+
distributed_port: 29761
|
fairseq/examples/hubert/config/decode/run/submitit_slurm_8gpu.yaml
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _global_
|
2 |
+
hydra:
|
3 |
+
launcher:
|
4 |
+
cpus_per_task: ${distributed_training.distributed_world_size}
|
5 |
+
gpus_per_node: ${distributed_training.distributed_world_size}
|
6 |
+
tasks_per_node: ${hydra.launcher.gpus_per_node}
|
7 |
+
nodes: 1
|
8 |
+
mem_gb: 200
|
9 |
+
timeout_min: 4320
|
10 |
+
max_num_timeout: 50
|
11 |
+
name: ${hydra.job.config_name}
|
12 |
+
submitit_folder: ${hydra.sweep.dir}/submitit
|
13 |
+
|
14 |
+
distributed_training:
|
15 |
+
distributed_world_size: 8
|
16 |
+
distributed_no_spawn: true
|
17 |
+
distributed_port: 29761
|
fairseq/examples/hubert/config/finetune/base_10h.yaml
ADDED
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _group_
|
2 |
+
|
3 |
+
common:
|
4 |
+
fp16: true
|
5 |
+
log_format: json
|
6 |
+
log_interval: 200
|
7 |
+
tensorboard_logdir: tblog
|
8 |
+
seed: 1337
|
9 |
+
|
10 |
+
checkpoint:
|
11 |
+
save_interval: 5
|
12 |
+
keep_interval_updates: 1
|
13 |
+
no_epoch_checkpoints: true
|
14 |
+
best_checkpoint_metric: wer
|
15 |
+
|
16 |
+
distributed_training:
|
17 |
+
ddp_backend: c10d
|
18 |
+
find_unused_parameters: true
|
19 |
+
distributed_world_size: 1
|
20 |
+
distributed_port: 29671
|
21 |
+
nprocs_per_node: 8
|
22 |
+
|
23 |
+
task:
|
24 |
+
_name: hubert_pretraining
|
25 |
+
data: ???
|
26 |
+
fine_tuning: true
|
27 |
+
label_dir: ???
|
28 |
+
normalize: false # must be consistent with pre-training
|
29 |
+
labels: ["ltr"]
|
30 |
+
single_target: true
|
31 |
+
|
32 |
+
dataset:
|
33 |
+
num_workers: 0
|
34 |
+
max_tokens: 3200000
|
35 |
+
validate_after_updates: ${model.freeze_finetune_updates}
|
36 |
+
validate_interval: 5
|
37 |
+
train_subset: train
|
38 |
+
valid_subset: valid
|
39 |
+
|
40 |
+
criterion:
|
41 |
+
_name: ctc
|
42 |
+
zero_infinity: true
|
43 |
+
|
44 |
+
optimization:
|
45 |
+
max_update: 25000
|
46 |
+
lr: [2e-5]
|
47 |
+
sentence_avg: true
|
48 |
+
update_freq: [1]
|
49 |
+
|
50 |
+
optimizer:
|
51 |
+
_name: adam
|
52 |
+
adam_betas: (0.9,0.98)
|
53 |
+
adam_eps: 1e-08
|
54 |
+
|
55 |
+
lr_scheduler:
|
56 |
+
_name: tri_stage
|
57 |
+
warmup_steps: 8000
|
58 |
+
hold_steps: 0
|
59 |
+
decay_steps: 72000
|
60 |
+
final_lr_scale: 0.05
|
61 |
+
|
62 |
+
model:
|
63 |
+
_name: hubert_ctc
|
64 |
+
w2v_path: ???
|
65 |
+
apply_mask: true
|
66 |
+
mask_selection: static
|
67 |
+
mask_length: 10
|
68 |
+
mask_other: 0
|
69 |
+
mask_prob: 0.75
|
70 |
+
mask_channel_selection: static
|
71 |
+
mask_channel_length: 64
|
72 |
+
mask_channel_other: 0
|
73 |
+
mask_channel_prob: 0.5
|
74 |
+
layerdrop: 0.1
|
75 |
+
dropout: 0.0
|
76 |
+
activation_dropout: 0.1
|
77 |
+
attention_dropout: 0.0
|
78 |
+
feature_grad_mult: 0.0
|
79 |
+
freeze_finetune_updates: 10000
|
80 |
+
|
81 |
+
hydra:
|
82 |
+
job:
|
83 |
+
config:
|
84 |
+
override_dirname:
|
85 |
+
kv_sep: '-'
|
86 |
+
item_sep: '__'
|
87 |
+
exclude_keys:
|
88 |
+
- run
|
89 |
+
- task.data
|
90 |
+
- task.label_dir
|
91 |
+
- model.w2v_path
|
92 |
+
- dataset.train_subset
|
93 |
+
- dataset.valid_subset
|
94 |
+
- criterion.wer_kenlm_model
|
95 |
+
- criterion.wer_lexicon
|
96 |
+
run:
|
97 |
+
dir: ???
|
98 |
+
sweep:
|
99 |
+
dir: ???
|
100 |
+
subdir: ${hydra.job.config_name}__${hydra.job.override_dirname}
|
fairseq/examples/hubert/config/finetune/ckpt/it1.yaml
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _global_
|
2 |
+
|
3 |
+
task:
|
4 |
+
normalize: false
|
5 |
+
|
6 |
+
model:
|
7 |
+
w2v_path: /checkpoint/wnhsu/w2v/hubert_final/iter1/hubert.km.randcrop.pmw1_0.puw0_0.grpnorm.ml10.mp0_8.untie.mxsz250000.ufreq1.maxtok1400000.MU400k.s1337.ngpu32/checkpoint_last.pt
|
fairseq/examples/hubert/config/finetune/lm/ls_4gram.yaml
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _global_
|
2 |
+
|
3 |
+
criterion:
|
4 |
+
wer_kenlm_model: /checkpoint/abdo/old_checkpoint02/datasets/librispeech/4-gram.bin
|
5 |
+
wer_lexicon: /checkpoint/abdo/old_checkpoint02/datasets/librispeech/10h/raw/lexicon_ltr.lst
|
6 |
+
wer_lm_weight: 2.0
|
7 |
+
wer_word_score: -1.0
|
fairseq/examples/hubert/config/finetune/run/submitit_reg.yaml
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _global_
|
2 |
+
|
3 |
+
hydra:
|
4 |
+
launcher:
|
5 |
+
cpus_per_task: 8
|
6 |
+
gpus_per_node: 8
|
7 |
+
tasks_per_node: ${hydra.launcher.gpus_per_node}
|
8 |
+
nodes: 1
|
9 |
+
comment: null
|
10 |
+
mem_gb: 384
|
11 |
+
timeout_min: 4320
|
12 |
+
max_num_timeout: 100
|
13 |
+
constraint: volta32gb
|
14 |
+
name: ${hydra.job.config_name}/${hydra.job.override_dirname}
|
15 |
+
submitit_folder: ${hydra.sweep.dir}/submitit/%j
|
16 |
+
|
17 |
+
distributed_training:
|
18 |
+
distributed_world_size: 8
|
19 |
+
distributed_port: 29671
|
20 |
+
nprocs_per_node: 8
|
fairseq/examples/hubert/config/pretrain/data/iter1.yaml
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _global_
|
2 |
+
|
3 |
+
task:
|
4 |
+
label_dir: ???
|
5 |
+
labels: ["km"]
|
6 |
+
|
7 |
+
model:
|
8 |
+
label_rate: 100
|
fairseq/examples/hubert/config/pretrain/data/iter2.yaml
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _global_
|
2 |
+
|
3 |
+
task:
|
4 |
+
label_dir: ???
|
5 |
+
labels: ["km"]
|
6 |
+
|
7 |
+
model:
|
8 |
+
label_rate: 50
|