Spaces:
Running
Running
update
Browse files- examples/conv_tasnet/run.sh +170 -0
- examples/conv_tasnet/step_1_prepare_data.py +201 -0
- examples/conv_tasnet/step_2_train_model.py +413 -0
- examples/conv_tasnet/yaml/config.yaml +42 -0
- examples/data_preprocess/dns_challenge_to_8k/process_dns_challenge_clean_emotional_speech.py +90 -0
- examples/data_preprocess/dns_challenge_to_8k/process_dns_challenge_clean_read_speech.py +123 -0
- examples/data_preprocess/dns_challenge_to_8k/process_dns_challenge_demand.py +71 -0
- examples/data_preprocess/dns_challenge_to_8k/process_dns_challenge_impulse_responses.py +93 -0
- examples/data_preprocess/dns_challenge_to_8k/process_dns_challenge_noise.py +77 -0
- examples/data_preprocess/dns_challenge_to_8k/process_musan.py +8 -0
- examples/mpnet/run.sh +2 -2
- examples/nx_mpnet/yaml/config.yaml +5 -5
- main.py +8 -1
- requirements.txt +1 -0
- toolbox/torchaudio/losses/__init__.py +6 -0
- toolbox/torchaudio/losses/perceptual.py +75 -0
- toolbox/torchaudio/losses/snr.py +101 -0
- toolbox/torchaudio/losses/spectral.py +351 -0
- toolbox/torchaudio/metrics/__init__.py +6 -0
- toolbox/torchaudio/metrics/pesq.py +80 -0
- toolbox/torchaudio/models/conv_tasnet/configuration_conv_tasnet.py +52 -0
- toolbox/torchaudio/models/conv_tasnet/modeling_conv_tasnet.py +477 -2
- toolbox/torchaudio/models/conv_tasnet/utils.py +55 -0
- toolbox/torchaudio/models/conv_tasnet/yaml/config.yaml +17 -0
- toolbox/torchaudio/models/demucs/__init__.py +6 -0
- toolbox/torchaudio/models/demucs/configuration_demucs.py +51 -0
- toolbox/torchaudio/models/demucs/modeling_demucs.py +299 -0
- toolbox/torchaudio/models/demucs/resample.py +81 -0
- toolbox/torchaudio/models/nx_dfnet/configuration_nx_dfnet.py +102 -0
- toolbox/torchaudio/models/nx_dfnet/modeling_nx_dfnet.py +989 -0
- toolbox/torchaudio/models/nx_dfnet/utils.py +55 -0
examples/conv_tasnet/run.sh
ADDED
@@ -0,0 +1,170 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env bash
|
2 |
+
|
3 |
+
: <<'END'
|
4 |
+
|
5 |
+
|
6 |
+
sh run.sh --stage 2 --stop_stage 2 --system_version windows --file_folder_name file_dir --final_model_name mpnet-aishell-20250224 \
|
7 |
+
--noise_dir "E:/Users/tianx/HuggingDatasets/nx_noise/data/noise" \
|
8 |
+
--speech_dir "E:/programmer/asr_datasets/aishell/data_aishell/wav/train"
|
9 |
+
|
10 |
+
|
11 |
+
sh run.sh --stage 3 --stop_stage 3 --system_version centos --file_folder_name file_dir --final_model_name mpnet-aishell-20250224 \
|
12 |
+
--noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise" \
|
13 |
+
--speech_dir "/data/tianxing/HuggingDatasets/aishell/data_aishell/wav/train"
|
14 |
+
|
15 |
+
sh run.sh --stage 2 --stop_stage 2 --system_version centos --file_folder_name file_dir --final_model_name nx-clean-unet-aishell-20250228 \
|
16 |
+
--noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise" \
|
17 |
+
--speech_dir "/data/tianxing/HuggingDatasets/aishell/data_aishell/wav/train" \
|
18 |
+
--max_epochs 100
|
19 |
+
|
20 |
+
|
21 |
+
sh run.sh --stage 1 --stop_stage 2 --system_version centos --file_folder_name file_dir --final_model_name mpnet-nx-speech-20250224 \
|
22 |
+
--noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise" \
|
23 |
+
--speech_dir "/data/tianxing/HuggingDatasets/nx_noise/data/speech" \
|
24 |
+
--max_epochs 100 --max_count 10000
|
25 |
+
|
26 |
+
|
27 |
+
END
|
28 |
+
|
29 |
+
|
30 |
+
# params
|
31 |
+
system_version="windows";
|
32 |
+
verbose=true;
|
33 |
+
stage=0 # start from 0 if you need to start from data preparation
|
34 |
+
stop_stage=9
|
35 |
+
|
36 |
+
work_dir="$(pwd)"
|
37 |
+
file_folder_name=file_folder_name
|
38 |
+
final_model_name=final_model_name
|
39 |
+
config_file="yaml/config.yaml"
|
40 |
+
limit=10
|
41 |
+
|
42 |
+
noise_dir=/data/tianxing/HuggingDatasets/nx_noise/data/noise
|
43 |
+
speech_dir=/data/tianxing/HuggingDatasets/aishell/data_aishell/wav/train
|
44 |
+
|
45 |
+
max_count=10000000
|
46 |
+
|
47 |
+
nohup_name=nohup.out
|
48 |
+
|
49 |
+
# model params
|
50 |
+
batch_size=64
|
51 |
+
max_epochs=200
|
52 |
+
save_top_k=10
|
53 |
+
patience=5
|
54 |
+
|
55 |
+
|
56 |
+
# parse options
|
57 |
+
while true; do
|
58 |
+
[ -z "${1:-}" ] && break; # break if there are no arguments
|
59 |
+
case "$1" in
|
60 |
+
--*) name=$(echo "$1" | sed s/^--// | sed s/-/_/g);
|
61 |
+
eval '[ -z "${'"$name"'+xxx}" ]' && echo "$0: invalid option $1" 1>&2 && exit 1;
|
62 |
+
old_value="(eval echo \\$$name)";
|
63 |
+
if [ "${old_value}" == "true" ] || [ "${old_value}" == "false" ]; then
|
64 |
+
was_bool=true;
|
65 |
+
else
|
66 |
+
was_bool=false;
|
67 |
+
fi
|
68 |
+
|
69 |
+
# Set the variable to the right value-- the escaped quotes make it work if
|
70 |
+
# the option had spaces, like --cmd "queue.pl -sync y"
|
71 |
+
eval "${name}=\"$2\"";
|
72 |
+
|
73 |
+
# Check that Boolean-valued arguments are really Boolean.
|
74 |
+
if $was_bool && [[ "$2" != "true" && "$2" != "false" ]]; then
|
75 |
+
echo "$0: expected \"true\" or \"false\": $1 $2" 1>&2
|
76 |
+
exit 1;
|
77 |
+
fi
|
78 |
+
shift 2;
|
79 |
+
;;
|
80 |
+
|
81 |
+
*) break;
|
82 |
+
esac
|
83 |
+
done
|
84 |
+
|
85 |
+
file_dir="${work_dir}/${file_folder_name}"
|
86 |
+
final_model_dir="${work_dir}/../../trained_models/${final_model_name}";
|
87 |
+
evaluation_audio_dir="${file_dir}/evaluation_audio"
|
88 |
+
|
89 |
+
dataset="${file_dir}/dataset.xlsx"
|
90 |
+
train_dataset="${file_dir}/train.xlsx"
|
91 |
+
valid_dataset="${file_dir}/valid.xlsx"
|
92 |
+
|
93 |
+
$verbose && echo "system_version: ${system_version}"
|
94 |
+
$verbose && echo "file_folder_name: ${file_folder_name}"
|
95 |
+
|
96 |
+
if [ $system_version == "windows" ]; then
|
97 |
+
alias python3='D:/Users/tianx/PycharmProjects/virtualenv/nx_denoise/Scripts/python.exe'
|
98 |
+
elif [ $system_version == "centos" ] || [ $system_version == "ubuntu" ]; then
|
99 |
+
#source /data/local/bin/nx_denoise/bin/activate
|
100 |
+
alias python3='/data/local/bin/nx_denoise/bin/python3'
|
101 |
+
fi
|
102 |
+
|
103 |
+
|
104 |
+
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
|
105 |
+
$verbose && echo "stage 1: prepare data"
|
106 |
+
cd "${work_dir}" || exit 1
|
107 |
+
python3 step_1_prepare_data.py \
|
108 |
+
--file_dir "${file_dir}" \
|
109 |
+
--noise_dir "${noise_dir}" \
|
110 |
+
--speech_dir "${speech_dir}" \
|
111 |
+
--train_dataset "${train_dataset}" \
|
112 |
+
--valid_dataset "${valid_dataset}" \
|
113 |
+
--max_count "${max_count}" \
|
114 |
+
|
115 |
+
fi
|
116 |
+
|
117 |
+
|
118 |
+
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
|
119 |
+
$verbose && echo "stage 2: train model"
|
120 |
+
cd "${work_dir}" || exit 1
|
121 |
+
python3 step_2_train_model.py \
|
122 |
+
--train_dataset "${train_dataset}" \
|
123 |
+
--valid_dataset "${valid_dataset}" \
|
124 |
+
--serialization_dir "${file_dir}" \
|
125 |
+
--config_file "${config_file}" \
|
126 |
+
|
127 |
+
fi
|
128 |
+
|
129 |
+
|
130 |
+
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
|
131 |
+
$verbose && echo "stage 3: test model"
|
132 |
+
cd "${work_dir}" || exit 1
|
133 |
+
python3 step_3_evaluation.py \
|
134 |
+
--valid_dataset "${valid_dataset}" \
|
135 |
+
--model_dir "${file_dir}/best" \
|
136 |
+
--evaluation_audio_dir "${evaluation_audio_dir}" \
|
137 |
+
--limit "${limit}" \
|
138 |
+
|
139 |
+
fi
|
140 |
+
|
141 |
+
|
142 |
+
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
|
143 |
+
$verbose && echo "stage 4: collect files"
|
144 |
+
cd "${work_dir}" || exit 1
|
145 |
+
|
146 |
+
mkdir -p ${final_model_dir}
|
147 |
+
|
148 |
+
cp "${file_dir}/best"/* "${final_model_dir}"
|
149 |
+
cp -r "${file_dir}/evaluation_audio" "${final_model_dir}"
|
150 |
+
|
151 |
+
cd "${final_model_dir}/.." || exit 1;
|
152 |
+
|
153 |
+
if [ -e "${final_model_name}.zip" ]; then
|
154 |
+
rm -rf "${final_model_name}_backup.zip"
|
155 |
+
mv "${final_model_name}.zip" "${final_model_name}_backup.zip"
|
156 |
+
fi
|
157 |
+
|
158 |
+
zip -r "${final_model_name}.zip" "${final_model_name}"
|
159 |
+
rm -rf "${final_model_name}"
|
160 |
+
|
161 |
+
fi
|
162 |
+
|
163 |
+
|
164 |
+
if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
|
165 |
+
$verbose && echo "stage 5: clear file_dir"
|
166 |
+
cd "${work_dir}" || exit 1
|
167 |
+
|
168 |
+
rm -rf "${file_dir}";
|
169 |
+
|
170 |
+
fi
|
examples/conv_tasnet/step_1_prepare_data.py
ADDED
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
import argparse
|
4 |
+
import os
|
5 |
+
from pathlib import Path
|
6 |
+
import random
|
7 |
+
import sys
|
8 |
+
import shutil
|
9 |
+
|
10 |
+
pwd = os.path.abspath(os.path.dirname(__file__))
|
11 |
+
sys.path.append(os.path.join(pwd, "../../"))
|
12 |
+
|
13 |
+
import pandas as pd
|
14 |
+
from scipy.io import wavfile
|
15 |
+
from tqdm import tqdm
|
16 |
+
import librosa
|
17 |
+
|
18 |
+
from project_settings import project_path
|
19 |
+
|
20 |
+
|
21 |
+
def get_args():
|
22 |
+
parser = argparse.ArgumentParser()
|
23 |
+
parser.add_argument("--file_dir", default="./", type=str)
|
24 |
+
|
25 |
+
parser.add_argument(
|
26 |
+
"--noise_dir",
|
27 |
+
default=r"E:\Users\tianx\HuggingDatasets\nx_noise\data\noise",
|
28 |
+
type=str
|
29 |
+
)
|
30 |
+
parser.add_argument(
|
31 |
+
"--speech_dir",
|
32 |
+
default=r"E:\programmer\asr_datasets\aishell\data_aishell\wav\train",
|
33 |
+
type=str
|
34 |
+
)
|
35 |
+
|
36 |
+
parser.add_argument("--train_dataset", default="train.xlsx", type=str)
|
37 |
+
parser.add_argument("--valid_dataset", default="valid.xlsx", type=str)
|
38 |
+
|
39 |
+
parser.add_argument("--duration", default=2.0, type=float)
|
40 |
+
parser.add_argument("--min_snr_db", default=-10, type=float)
|
41 |
+
parser.add_argument("--max_snr_db", default=20, type=float)
|
42 |
+
|
43 |
+
parser.add_argument("--target_sample_rate", default=8000, type=int)
|
44 |
+
|
45 |
+
parser.add_argument("--max_count", default=10000, type=int)
|
46 |
+
|
47 |
+
args = parser.parse_args()
|
48 |
+
return args
|
49 |
+
|
50 |
+
|
51 |
+
def filename_generator(data_dir: str):
|
52 |
+
data_dir = Path(data_dir)
|
53 |
+
for filename in data_dir.glob("**/*.wav"):
|
54 |
+
yield filename.as_posix()
|
55 |
+
|
56 |
+
|
57 |
+
def target_second_signal_generator(data_dir: str, duration: int = 2, sample_rate: int = 8000):
|
58 |
+
data_dir = Path(data_dir)
|
59 |
+
for filename in data_dir.glob("**/*.wav"):
|
60 |
+
signal, _ = librosa.load(filename.as_posix(), sr=sample_rate)
|
61 |
+
raw_duration = librosa.get_duration(y=signal, sr=sample_rate)
|
62 |
+
|
63 |
+
if raw_duration < duration:
|
64 |
+
# print(f"duration less than {duration} s. skip filename: {filename.as_posix()}")
|
65 |
+
continue
|
66 |
+
if signal.ndim != 1:
|
67 |
+
raise AssertionError(f"expected ndim 1, instead of {signal.ndim}")
|
68 |
+
|
69 |
+
signal_length = len(signal)
|
70 |
+
win_size = int(duration * sample_rate)
|
71 |
+
for begin in range(0, signal_length - win_size, win_size):
|
72 |
+
row = {
|
73 |
+
"filename": filename.as_posix(),
|
74 |
+
"raw_duration": round(raw_duration, 4),
|
75 |
+
"offset": round(begin / sample_rate, 4),
|
76 |
+
"duration": round(duration, 4),
|
77 |
+
}
|
78 |
+
yield row
|
79 |
+
|
80 |
+
|
81 |
+
def get_dataset(args):
|
82 |
+
file_dir = Path(args.file_dir)
|
83 |
+
file_dir.mkdir(exist_ok=True)
|
84 |
+
|
85 |
+
noise_dir = Path(args.noise_dir)
|
86 |
+
speech_dir = Path(args.speech_dir)
|
87 |
+
|
88 |
+
noise_generator = target_second_signal_generator(
|
89 |
+
noise_dir.as_posix(),
|
90 |
+
duration=args.duration,
|
91 |
+
sample_rate=args.target_sample_rate
|
92 |
+
)
|
93 |
+
speech_generator = target_second_signal_generator(
|
94 |
+
speech_dir.as_posix(),
|
95 |
+
duration=args.duration,
|
96 |
+
sample_rate=args.target_sample_rate
|
97 |
+
)
|
98 |
+
|
99 |
+
dataset = list()
|
100 |
+
|
101 |
+
count = 0
|
102 |
+
process_bar = tqdm(desc="build dataset excel")
|
103 |
+
for noise, speech in zip(noise_generator, speech_generator):
|
104 |
+
if count >= args.max_count:
|
105 |
+
break
|
106 |
+
|
107 |
+
noise_filename = noise["filename"]
|
108 |
+
noise_raw_duration = noise["raw_duration"]
|
109 |
+
noise_offset = noise["offset"]
|
110 |
+
noise_duration = noise["duration"]
|
111 |
+
|
112 |
+
speech_filename = speech["filename"]
|
113 |
+
speech_raw_duration = speech["raw_duration"]
|
114 |
+
speech_offset = speech["offset"]
|
115 |
+
speech_duration = speech["duration"]
|
116 |
+
|
117 |
+
random1 = random.random()
|
118 |
+
random2 = random.random()
|
119 |
+
|
120 |
+
row = {
|
121 |
+
"noise_filename": noise_filename,
|
122 |
+
"noise_raw_duration": noise_raw_duration,
|
123 |
+
"noise_offset": noise_offset,
|
124 |
+
"noise_duration": noise_duration,
|
125 |
+
|
126 |
+
"speech_filename": speech_filename,
|
127 |
+
"speech_raw_duration": speech_raw_duration,
|
128 |
+
"speech_offset": speech_offset,
|
129 |
+
"speech_duration": speech_duration,
|
130 |
+
|
131 |
+
"snr_db": random.uniform(args.min_snr_db, args.max_snr_db),
|
132 |
+
|
133 |
+
"random1": random1,
|
134 |
+
"random2": random2,
|
135 |
+
"flag": "TRAIN" if random2 < 0.8 else "TEST",
|
136 |
+
}
|
137 |
+
dataset.append(row)
|
138 |
+
count += 1
|
139 |
+
duration_seconds = count * args.duration
|
140 |
+
duration_hours = duration_seconds / 3600
|
141 |
+
|
142 |
+
process_bar.update(n=1)
|
143 |
+
process_bar.set_postfix({
|
144 |
+
# "duration_seconds": round(duration_seconds, 4),
|
145 |
+
"duration_hours": round(duration_hours, 4),
|
146 |
+
|
147 |
+
})
|
148 |
+
|
149 |
+
dataset = pd.DataFrame(dataset)
|
150 |
+
dataset = dataset.sort_values(by=["random1"], ascending=False)
|
151 |
+
dataset.to_excel(
|
152 |
+
file_dir / "dataset.xlsx",
|
153 |
+
index=False,
|
154 |
+
)
|
155 |
+
return
|
156 |
+
|
157 |
+
|
158 |
+
|
159 |
+
def split_dataset(args):
|
160 |
+
"""分割训练集, 测试集"""
|
161 |
+
file_dir = Path(args.file_dir)
|
162 |
+
file_dir.mkdir(exist_ok=True)
|
163 |
+
|
164 |
+
df = pd.read_excel(file_dir / "dataset.xlsx")
|
165 |
+
|
166 |
+
train = list()
|
167 |
+
test = list()
|
168 |
+
|
169 |
+
for i, row in df.iterrows():
|
170 |
+
flag = row["flag"]
|
171 |
+
if flag == "TRAIN":
|
172 |
+
train.append(row)
|
173 |
+
else:
|
174 |
+
test.append(row)
|
175 |
+
|
176 |
+
train = pd.DataFrame(train)
|
177 |
+
train.to_excel(
|
178 |
+
args.train_dataset,
|
179 |
+
index=False,
|
180 |
+
# encoding="utf_8_sig"
|
181 |
+
)
|
182 |
+
test = pd.DataFrame(test)
|
183 |
+
test.to_excel(
|
184 |
+
args.valid_dataset,
|
185 |
+
index=False,
|
186 |
+
# encoding="utf_8_sig"
|
187 |
+
)
|
188 |
+
|
189 |
+
return
|
190 |
+
|
191 |
+
|
192 |
+
def main():
|
193 |
+
args = get_args()
|
194 |
+
|
195 |
+
get_dataset(args)
|
196 |
+
split_dataset(args)
|
197 |
+
return
|
198 |
+
|
199 |
+
|
200 |
+
if __name__ == "__main__":
|
201 |
+
main()
|
examples/conv_tasnet/step_2_train_model.py
ADDED
@@ -0,0 +1,413 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
"""
|
4 |
+
https://github.com/yxlu-0102/MP-SENet/blob/main/train.py
|
5 |
+
"""
|
6 |
+
import argparse
|
7 |
+
import json
|
8 |
+
import logging
|
9 |
+
from logging.handlers import TimedRotatingFileHandler
|
10 |
+
import os
|
11 |
+
import platform
|
12 |
+
from pathlib import Path
|
13 |
+
import random
|
14 |
+
import sys
|
15 |
+
import shutil
|
16 |
+
from typing import List
|
17 |
+
|
18 |
+
pwd = os.path.abspath(os.path.dirname(__file__))
|
19 |
+
sys.path.append(os.path.join(pwd, "../../"))
|
20 |
+
|
21 |
+
import numpy as np
|
22 |
+
import torch
|
23 |
+
import torch.nn as nn
|
24 |
+
from torch.nn import functional as F
|
25 |
+
from torch.utils.data.dataloader import DataLoader
|
26 |
+
from tqdm import tqdm
|
27 |
+
|
28 |
+
from toolbox.torch.utils.data.dataset.denoise_excel_dataset import DenoiseExcelDataset
|
29 |
+
from toolbox.torchaudio.models.conv_tasnet.configuration_conv_tasnet import ConvTasNetConfig
|
30 |
+
from toolbox.torchaudio.models.conv_tasnet.modeling_conv_tasnet import ConvTasNet, ConvTasNetPretrainedModel
|
31 |
+
from toolbox.torchaudio.losses.snr import NegativeSISNRLoss
|
32 |
+
from toolbox.torchaudio.losses.spectral import LSDLoss
|
33 |
+
from toolbox.torchaudio.losses.perceptual import NegSTOILoss
|
34 |
+
from toolbox.torchaudio.metrics.pesq import run_pesq_score
|
35 |
+
|
36 |
+
|
37 |
+
def get_args():
|
38 |
+
parser = argparse.ArgumentParser()
|
39 |
+
parser.add_argument("--train_dataset", default="train.xlsx", type=str)
|
40 |
+
parser.add_argument("--valid_dataset", default="valid.xlsx", type=str)
|
41 |
+
|
42 |
+
parser.add_argument("--max_epochs", default=100, type=int)
|
43 |
+
|
44 |
+
parser.add_argument("--num_serialized_models_to_keep", default=10, type=int)
|
45 |
+
parser.add_argument("--patience", default=5, type=int)
|
46 |
+
parser.add_argument("--serialization_dir", default="serialization_dir", type=str)
|
47 |
+
|
48 |
+
parser.add_argument("--config_file", default="config.yaml", type=str)
|
49 |
+
|
50 |
+
args = parser.parse_args()
|
51 |
+
return args
|
52 |
+
|
53 |
+
|
54 |
+
def logging_config(file_dir: str):
|
55 |
+
fmt = "%(asctime)s - %(name)s - %(levelname)s %(filename)s:%(lineno)d > %(message)s"
|
56 |
+
|
57 |
+
logging.basicConfig(format=fmt,
|
58 |
+
datefmt="%m/%d/%Y %H:%M:%S",
|
59 |
+
level=logging.INFO)
|
60 |
+
file_handler = TimedRotatingFileHandler(
|
61 |
+
filename=os.path.join(file_dir, "main.log"),
|
62 |
+
encoding="utf-8",
|
63 |
+
when="D",
|
64 |
+
interval=1,
|
65 |
+
backupCount=7
|
66 |
+
)
|
67 |
+
file_handler.setLevel(logging.INFO)
|
68 |
+
file_handler.setFormatter(logging.Formatter(fmt))
|
69 |
+
logger = logging.getLogger(__name__)
|
70 |
+
logger.addHandler(file_handler)
|
71 |
+
|
72 |
+
return logger
|
73 |
+
|
74 |
+
|
75 |
+
class CollateFunction(object):
|
76 |
+
def __init__(self):
|
77 |
+
pass
|
78 |
+
|
79 |
+
def __call__(self, batch: List[dict]):
|
80 |
+
clean_audios = list()
|
81 |
+
noisy_audios = list()
|
82 |
+
|
83 |
+
for sample in batch:
|
84 |
+
# noise_wave: torch.Tensor = sample["noise_wave"]
|
85 |
+
clean_audio: torch.Tensor = sample["speech_wave"]
|
86 |
+
noisy_audio: torch.Tensor = sample["mix_wave"]
|
87 |
+
# snr_db: float = sample["snr_db"]
|
88 |
+
|
89 |
+
clean_audios.append(clean_audio)
|
90 |
+
noisy_audios.append(noisy_audio)
|
91 |
+
|
92 |
+
clean_audios = torch.stack(clean_audios)
|
93 |
+
noisy_audios = torch.stack(noisy_audios)
|
94 |
+
|
95 |
+
# assert
|
96 |
+
if torch.any(torch.isnan(clean_audios)) or torch.any(torch.isinf(clean_audios)):
|
97 |
+
raise AssertionError("nan or inf in clean_audios")
|
98 |
+
if torch.any(torch.isnan(noisy_audios)) or torch.any(torch.isinf(noisy_audios)):
|
99 |
+
raise AssertionError("nan or inf in noisy_audios")
|
100 |
+
return clean_audios, noisy_audios
|
101 |
+
|
102 |
+
|
103 |
+
collate_fn = CollateFunction()
|
104 |
+
|
105 |
+
|
106 |
+
def main():
|
107 |
+
args = get_args()
|
108 |
+
|
109 |
+
config = ConvTasNetConfig.from_pretrained(
|
110 |
+
pretrained_model_name_or_path=args.config_file,
|
111 |
+
)
|
112 |
+
|
113 |
+
serialization_dir = Path(args.serialization_dir)
|
114 |
+
serialization_dir.mkdir(parents=True, exist_ok=True)
|
115 |
+
|
116 |
+
logger = logging_config(serialization_dir)
|
117 |
+
|
118 |
+
random.seed(config.seed)
|
119 |
+
np.random.seed(config.seed)
|
120 |
+
torch.manual_seed(config.seed)
|
121 |
+
logger.info(f"set seed: {config.seed}")
|
122 |
+
|
123 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
124 |
+
n_gpu = torch.cuda.device_count()
|
125 |
+
logger.info(f"GPU available count: {n_gpu}; device: {device}")
|
126 |
+
|
127 |
+
# datasets
|
128 |
+
train_dataset = DenoiseExcelDataset(
|
129 |
+
excel_file=args.train_dataset,
|
130 |
+
expected_sample_rate=8000,
|
131 |
+
max_wave_value=32768.0,
|
132 |
+
)
|
133 |
+
valid_dataset = DenoiseExcelDataset(
|
134 |
+
excel_file=args.valid_dataset,
|
135 |
+
expected_sample_rate=8000,
|
136 |
+
max_wave_value=32768.0,
|
137 |
+
)
|
138 |
+
train_data_loader = DataLoader(
|
139 |
+
dataset=train_dataset,
|
140 |
+
batch_size=config.batch_size,
|
141 |
+
shuffle=True,
|
142 |
+
sampler=None,
|
143 |
+
# Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能.
|
144 |
+
num_workers=0 if platform.system() == "Windows" else os.cpu_count() // 2,
|
145 |
+
collate_fn=collate_fn,
|
146 |
+
pin_memory=False,
|
147 |
+
prefetch_factor=16,
|
148 |
+
)
|
149 |
+
valid_data_loader = DataLoader(
|
150 |
+
dataset=valid_dataset,
|
151 |
+
batch_size=config.batch_size,
|
152 |
+
shuffle=True,
|
153 |
+
sampler=None,
|
154 |
+
# Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能.
|
155 |
+
num_workers=0 if platform.system() == "Windows" else os.cpu_count() // 2,
|
156 |
+
collate_fn=collate_fn,
|
157 |
+
pin_memory=False,
|
158 |
+
prefetch_factor=16,
|
159 |
+
)
|
160 |
+
|
161 |
+
# models
|
162 |
+
logger.info(f"prepare models. config_file: {args.config_file}")
|
163 |
+
model = ConvTasNetPretrainedModel(config).to(device)
|
164 |
+
model.to(device)
|
165 |
+
model.train()
|
166 |
+
|
167 |
+
# optimizer
|
168 |
+
logger.info("prepare optimizer, lr_scheduler, loss_fn, categorical_accuracy")
|
169 |
+
optimizer = torch.optim.AdamW(model.parameters(), args.learning_rate)
|
170 |
+
|
171 |
+
# resume training
|
172 |
+
last_epoch = -1
|
173 |
+
for epoch_i in serialization_dir.glob("epoch-*"):
|
174 |
+
epoch_i = Path(epoch_i)
|
175 |
+
epoch_idx = epoch_i.stem.split("-")[1]
|
176 |
+
epoch_idx = int(epoch_idx)
|
177 |
+
if epoch_idx > last_epoch:
|
178 |
+
last_epoch = epoch_idx
|
179 |
+
|
180 |
+
if last_epoch != -1:
|
181 |
+
logger.info(f"resume from epoch-{last_epoch}.")
|
182 |
+
model_pt = serialization_dir / f"epoch-{last_epoch}/model.pt"
|
183 |
+
optimizer_pth = serialization_dir / f"epoch-{last_epoch}/optimizer.pth"
|
184 |
+
|
185 |
+
logger.info(f"load state dict for model.")
|
186 |
+
with open(model_pt.as_posix(), "rb") as f:
|
187 |
+
state_dict = torch.load(f, map_location="cpu", weights_only=True)
|
188 |
+
model.load_state_dict(state_dict, strict=True)
|
189 |
+
|
190 |
+
logger.info(f"load state dict for optimizer.")
|
191 |
+
with open(optimizer_pth.as_posix(), "rb") as f:
|
192 |
+
state_dict = torch.load(f, map_location="cpu", weights_only=True)
|
193 |
+
optimizer.load_state_dict(state_dict)
|
194 |
+
|
195 |
+
lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
|
196 |
+
optimizer,
|
197 |
+
milestones=[10000, 20000, 30000, 40000, 50000], gamma=0.5
|
198 |
+
)
|
199 |
+
|
200 |
+
ae_loss_fn = nn.L1Loss(reduction="mean").to(device)
|
201 |
+
neg_si_snr_loss_fn = NegativeSISNRLoss(reduction="mean").to(device)
|
202 |
+
neg_stoi_loss_fn = NegSTOILoss(sample_rate=8000, reduction="mean").to(device)
|
203 |
+
lds_loss_fn = LSDLoss(reduction="mean").to(device)
|
204 |
+
|
205 |
+
# training loop
|
206 |
+
|
207 |
+
# state
|
208 |
+
average_pesq_score = 1000000000
|
209 |
+
average_loss = 1000000000
|
210 |
+
average_ae_loss = 1000000000
|
211 |
+
average_neg_si_snr_loss = 1000000000
|
212 |
+
average_neg_stoi_loss = 1000000000
|
213 |
+
average_lds_loss = 1000000000
|
214 |
+
|
215 |
+
model_list = list()
|
216 |
+
best_idx_epoch = None
|
217 |
+
best_metric = None
|
218 |
+
patience_count = 0
|
219 |
+
|
220 |
+
logger.info("training")
|
221 |
+
for idx_epoch in range(max(0, last_epoch+1), args.max_epochs):
|
222 |
+
# train
|
223 |
+
model.train()
|
224 |
+
|
225 |
+
total_pesq_score = 0.
|
226 |
+
total_loss = 0.
|
227 |
+
total_ae_loss = 0.
|
228 |
+
total_neg_si_snr_loss = 0.
|
229 |
+
total_neg_stoi_loss = 0.
|
230 |
+
total_lds_loss = 0.
|
231 |
+
total_batches = 0.
|
232 |
+
progress_bar = tqdm(
|
233 |
+
total=len(train_data_loader),
|
234 |
+
desc="Training; epoch: {}".format(idx_epoch),
|
235 |
+
)
|
236 |
+
for batch in train_data_loader:
|
237 |
+
clean_audios, noisy_audios = batch
|
238 |
+
clean_audios = clean_audios.to(device)
|
239 |
+
noisy_audios = noisy_audios.to(device)
|
240 |
+
|
241 |
+
denoise_audios = model.forward(noisy_audios)
|
242 |
+
denoise_audios = torch.squeeze(denoise_audios, dim=1)
|
243 |
+
|
244 |
+
ae_loss = ae_loss_fn.forward(denoise_audios, clean_audios)
|
245 |
+
neg_si_snr_loss = neg_si_snr_loss_fn.forward(denoise_audios, clean_audios)
|
246 |
+
neg_stoi_loss = neg_stoi_loss_fn.forward(denoise_audios, clean_audios)
|
247 |
+
lds_loss = lds_loss_fn.forward(denoise_audios, clean_audios)
|
248 |
+
|
249 |
+
loss = 0.25 * ae_loss + 0.25 * neg_si_snr_loss + 0.25 * neg_stoi_loss + 0.25 * lds_loss
|
250 |
+
|
251 |
+
denoise_audios_list_r = list(denoise_audios.detach().cpu().numpy())
|
252 |
+
clean_audios_list_r = list(clean_audios.detach().cpu().numpy())
|
253 |
+
pesq_score = run_pesq_score(clean_audios_list_r, denoise_audios_list_r, sample_rate=8000, mode="nb")
|
254 |
+
|
255 |
+
optimizer.zero_grad()
|
256 |
+
loss.backward()
|
257 |
+
optimizer.step()
|
258 |
+
lr_scheduler.step()
|
259 |
+
|
260 |
+
total_pesq_score += pesq_score
|
261 |
+
total_loss += loss.item()
|
262 |
+
total_ae_loss += ae_loss.item()
|
263 |
+
total_neg_si_snr_loss += neg_si_snr_loss.item()
|
264 |
+
total_neg_stoi_loss += neg_stoi_loss.item()
|
265 |
+
total_lds_loss += lds_loss.item()
|
266 |
+
total_batches += 1
|
267 |
+
|
268 |
+
average_pesq_score = round(total_pesq_score / total_batches, 4)
|
269 |
+
average_loss = round(total_loss / total_batches, 4)
|
270 |
+
average_ae_loss = round(total_ae_loss / total_batches, 4)
|
271 |
+
average_neg_si_snr_loss = round(total_neg_si_snr_loss / total_batches, 4)
|
272 |
+
average_neg_stoi_loss = round(total_neg_stoi_loss / total_batches, 4)
|
273 |
+
average_lds_loss = round(total_lds_loss / total_batches, 4)
|
274 |
+
|
275 |
+
progress_bar.update(1)
|
276 |
+
progress_bar.set_postfix({
|
277 |
+
"pesq_score": average_pesq_score,
|
278 |
+
"loss": average_loss,
|
279 |
+
"ae_loss": average_ae_loss,
|
280 |
+
"neg_si_snr_loss": average_neg_si_snr_loss,
|
281 |
+
"neg_stoi_loss": average_neg_stoi_loss,
|
282 |
+
"lds_loss": average_lds_loss,
|
283 |
+
})
|
284 |
+
|
285 |
+
# evaluation
|
286 |
+
model.eval()
|
287 |
+
torch.cuda.empty_cache()
|
288 |
+
|
289 |
+
total_pesq_score = 0.
|
290 |
+
total_loss = 0.
|
291 |
+
total_ae_loss = 0.
|
292 |
+
total_neg_si_snr_loss = 0.
|
293 |
+
total_neg_stoi_loss = 0.
|
294 |
+
total_lds_loss = 0.
|
295 |
+
total_batches = 0.
|
296 |
+
|
297 |
+
progress_bar = tqdm(
|
298 |
+
total=len(valid_data_loader),
|
299 |
+
desc="Evaluation; epoch: {}".format(idx_epoch),
|
300 |
+
)
|
301 |
+
with torch.no_grad():
|
302 |
+
for batch in valid_data_loader:
|
303 |
+
clean_audios, noisy_audios = batch
|
304 |
+
clean_audios = clean_audios.to(device)
|
305 |
+
noisy_audios = noisy_audios.to(device)
|
306 |
+
|
307 |
+
denoise_audios = model.forward(noisy_audios)
|
308 |
+
denoise_audios = torch.squeeze(denoise_audios, dim=1)
|
309 |
+
|
310 |
+
ae_loss = ae_loss_fn.forward(denoise_audios, clean_audios)
|
311 |
+
neg_si_snr_loss = neg_si_snr_loss_fn.forward(denoise_audios, clean_audios)
|
312 |
+
neg_stoi_loss = neg_stoi_loss_fn.forward(denoise_audios, clean_audios)
|
313 |
+
lds_loss = lds_loss_fn.forward(denoise_audios, clean_audios)
|
314 |
+
|
315 |
+
loss = 0.25 * ae_loss + 0.25 * neg_si_snr_loss + 0.25 * neg_stoi_loss + 0.25 * lds_loss
|
316 |
+
|
317 |
+
denoise_audios_list_r = list(denoise_audios.detach().cpu().numpy())
|
318 |
+
clean_audios_list_r = list(clean_audios.detach().cpu().numpy())
|
319 |
+
pesq_score = run_pesq_score(clean_audios_list_r, denoise_audios_list_r, sample_rate=8000, mode="nb")
|
320 |
+
|
321 |
+
total_pesq_score += pesq_score
|
322 |
+
total_loss += loss.item()
|
323 |
+
total_ae_loss += ae_loss.item()
|
324 |
+
total_neg_si_snr_loss += neg_si_snr_loss.item()
|
325 |
+
total_neg_stoi_loss += neg_stoi_loss.item()
|
326 |
+
total_lds_loss += lds_loss.item()
|
327 |
+
total_batches += 1
|
328 |
+
|
329 |
+
average_pesq_score = round(total_pesq_score / total_batches, 4)
|
330 |
+
average_loss = round(total_loss / total_batches, 4)
|
331 |
+
average_ae_loss = round(total_ae_loss / total_batches, 4)
|
332 |
+
average_neg_si_snr_loss = round(total_neg_si_snr_loss / total_batches, 4)
|
333 |
+
average_neg_stoi_loss = round(total_neg_stoi_loss / total_batches, 4)
|
334 |
+
average_lds_loss = round(total_lds_loss / total_batches, 4)
|
335 |
+
|
336 |
+
progress_bar.update(1)
|
337 |
+
progress_bar.set_postfix({
|
338 |
+
"pesq_score": average_pesq_score,
|
339 |
+
"loss": average_loss,
|
340 |
+
"ae_loss": average_ae_loss,
|
341 |
+
"neg_si_snr_loss": average_neg_si_snr_loss,
|
342 |
+
"neg_stoi_loss": average_neg_stoi_loss,
|
343 |
+
"lds_loss": average_lds_loss,
|
344 |
+
})
|
345 |
+
|
346 |
+
# scheduler
|
347 |
+
lr_scheduler.step()
|
348 |
+
|
349 |
+
# save path
|
350 |
+
epoch_dir = serialization_dir / "epoch-{}".format(idx_epoch)
|
351 |
+
epoch_dir.mkdir(parents=True, exist_ok=False)
|
352 |
+
|
353 |
+
# save models
|
354 |
+
model.save_pretrained(epoch_dir.as_posix())
|
355 |
+
|
356 |
+
model_list.append(epoch_dir)
|
357 |
+
if len(model_list) >= args.num_serialized_models_to_keep:
|
358 |
+
model_to_delete: Path = model_list.pop(0)
|
359 |
+
shutil.rmtree(model_to_delete.as_posix())
|
360 |
+
|
361 |
+
# save optim
|
362 |
+
torch.save(optimizer.state_dict(), (epoch_dir / "optimizer.pth").as_posix())
|
363 |
+
|
364 |
+
# save metric
|
365 |
+
if best_metric is None:
|
366 |
+
best_idx_epoch = idx_epoch
|
367 |
+
best_metric = average_loss
|
368 |
+
elif average_loss < best_metric:
|
369 |
+
# great is better.
|
370 |
+
best_idx_epoch = idx_epoch
|
371 |
+
best_metric = average_loss
|
372 |
+
else:
|
373 |
+
pass
|
374 |
+
|
375 |
+
metrics = {
|
376 |
+
"idx_epoch": idx_epoch,
|
377 |
+
"best_idx_epoch": best_idx_epoch,
|
378 |
+
"pesq_score": average_pesq_score,
|
379 |
+
"loss": average_loss,
|
380 |
+
"ae_loss": average_ae_loss,
|
381 |
+
"neg_si_snr_loss": average_neg_si_snr_loss,
|
382 |
+
"neg_stoi_loss": average_neg_stoi_loss,
|
383 |
+
"lds_loss": average_lds_loss,
|
384 |
+
}
|
385 |
+
metrics_filename = epoch_dir / "metrics_epoch.json"
|
386 |
+
with open(metrics_filename, "w", encoding="utf-8") as f:
|
387 |
+
json.dump(metrics, f, indent=4, ensure_ascii=False)
|
388 |
+
|
389 |
+
# save best
|
390 |
+
best_dir = serialization_dir / "best"
|
391 |
+
if best_idx_epoch == idx_epoch:
|
392 |
+
if best_dir.exists():
|
393 |
+
shutil.rmtree(best_dir)
|
394 |
+
shutil.copytree(epoch_dir, best_dir)
|
395 |
+
|
396 |
+
# early stop
|
397 |
+
early_stop_flag = False
|
398 |
+
if best_idx_epoch == idx_epoch:
|
399 |
+
patience_count = 0
|
400 |
+
else:
|
401 |
+
patience_count += 1
|
402 |
+
if patience_count >= args.patience:
|
403 |
+
early_stop_flag = True
|
404 |
+
|
405 |
+
# early stop
|
406 |
+
if early_stop_flag:
|
407 |
+
break
|
408 |
+
|
409 |
+
return
|
410 |
+
|
411 |
+
|
412 |
+
if __name__ == "__main__":
|
413 |
+
main()
|
examples/conv_tasnet/yaml/config.yaml
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model_name: "nx_clean_unet"
|
2 |
+
|
3 |
+
sample_rate: 8000
|
4 |
+
segment_size: 16000
|
5 |
+
n_fft: 512
|
6 |
+
win_size: 200
|
7 |
+
hop_size: 80
|
8 |
+
|
9 |
+
down_sampling_num_layers: 6
|
10 |
+
down_sampling_in_channels: 1
|
11 |
+
down_sampling_hidden_channels: 64
|
12 |
+
down_sampling_kernel_size: 4
|
13 |
+
down_sampling_stride: 2
|
14 |
+
|
15 |
+
causal_in_channels: 1
|
16 |
+
causal_out_channels: 1
|
17 |
+
causal_kernel_size: 3
|
18 |
+
causal_bias: false
|
19 |
+
causal_separable: true
|
20 |
+
causal_f_stride: 1
|
21 |
+
causal_num_layers: 5
|
22 |
+
|
23 |
+
tsfm_hidden_size: 256
|
24 |
+
tsfm_attention_heads: 8
|
25 |
+
tsfm_num_blocks: 6
|
26 |
+
tsfm_dropout_rate: 0.1
|
27 |
+
tsfm_max_length: 512
|
28 |
+
tsfm_chunk_size: 1
|
29 |
+
tsfm_num_left_chunks: 128
|
30 |
+
tsfm_num_right_chunks: 4
|
31 |
+
|
32 |
+
discriminator_dim: 32
|
33 |
+
discriminator_in_channel: 2
|
34 |
+
|
35 |
+
compress_factor: 0.3
|
36 |
+
|
37 |
+
batch_size: 64
|
38 |
+
learning_rate: 0.0005
|
39 |
+
adam_b1: 0.8
|
40 |
+
adam_b2: 0.99
|
41 |
+
lr_decay: 0.99
|
42 |
+
seed: 1234
|
examples/data_preprocess/dns_challenge_to_8k/process_dns_challenge_clean_emotional_speech.py
ADDED
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
"""
|
4 |
+
https://github.com/microsoft/DNS-Challenge/blob/master/download-dns-challenge-3.sh
|
5 |
+
|
6 |
+
1.2G
|
7 |
+
wget https://dns3public.blob.core.windows.net/dns3archive/DEMAND.tar.bz2
|
8 |
+
|
9 |
+
14G
|
10 |
+
wget https://dns3public.blob.core.windows.net/dns3archive/datasets/datasets.noise.tar.bz2
|
11 |
+
|
12 |
+
38G
|
13 |
+
wget https://dns3public.blob.core.windows.net/dns3archive/datasets/datasets.clean.read_speech.tar.bz2
|
14 |
+
|
15 |
+
247M
|
16 |
+
wget https://dns3public.blob.core.windows.net/dns3archive/datasets/datasets.clean.emotional_speech.tar.bz2
|
17 |
+
|
18 |
+
|
19 |
+
"""
|
20 |
+
import argparse
|
21 |
+
import os
|
22 |
+
from pathlib import Path
|
23 |
+
import sys
|
24 |
+
|
25 |
+
import numpy as np
|
26 |
+
from tqdm import tqdm
|
27 |
+
|
28 |
+
pwd = os.path.abspath(os.path.dirname(__file__))
|
29 |
+
sys.path.append(os.path.join(pwd, "../../"))
|
30 |
+
|
31 |
+
import librosa
|
32 |
+
from scipy.io import wavfile
|
33 |
+
|
34 |
+
|
35 |
+
def get_args():
|
36 |
+
parser = argparse.ArgumentParser()
|
37 |
+
|
38 |
+
parser.add_argument(
|
39 |
+
"--data_dir",
|
40 |
+
default=r"E:\programmer\asr_datasets\dns-challenge\datasets.clean.emotional_speech\datasets\clean\emotional_speech",
|
41 |
+
type=str
|
42 |
+
)
|
43 |
+
parser.add_argument(
|
44 |
+
"--output_dir",
|
45 |
+
default=r"E:\programmer\asr_datasets\denoise\dns-clean-emotional-speech-8k",
|
46 |
+
type=str
|
47 |
+
)
|
48 |
+
parser.add_argument("--sample_rate", default=8000, type=int)
|
49 |
+
args = parser.parse_args()
|
50 |
+
return args
|
51 |
+
|
52 |
+
|
53 |
+
def main():
|
54 |
+
args = get_args()
|
55 |
+
|
56 |
+
data_dir = Path(args.data_dir)
|
57 |
+
output_dir = Path(args.output_dir)
|
58 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
59 |
+
|
60 |
+
# finished_set
|
61 |
+
finished_set = set()
|
62 |
+
for filename in tqdm(output_dir.glob("**/*.wav")):
|
63 |
+
name = filename.stem
|
64 |
+
finished_set.add(name)
|
65 |
+
print(f"finished_set count: {len(finished_set)}")
|
66 |
+
|
67 |
+
for filename in tqdm(data_dir.glob("**/*.wav")):
|
68 |
+
label = filename.parts[-2]
|
69 |
+
name = filename.stem
|
70 |
+
# print(f"filename: {filename.as_posix()}")
|
71 |
+
if name in finished_set:
|
72 |
+
continue
|
73 |
+
|
74 |
+
signal, _ = librosa.load(filename.as_posix(), sr=args.sample_rate)
|
75 |
+
|
76 |
+
signal = signal * (1 << 15)
|
77 |
+
signal = np.array(signal, dtype=np.int16)
|
78 |
+
|
79 |
+
to_file = output_dir / f"{label}/{name}.wav"
|
80 |
+
to_file.parent.mkdir(parents=True, exist_ok=True)
|
81 |
+
wavfile.write(
|
82 |
+
to_file.as_posix(),
|
83 |
+
rate=args.sample_rate,
|
84 |
+
data=signal,
|
85 |
+
)
|
86 |
+
return
|
87 |
+
|
88 |
+
|
89 |
+
if __name__ == "__main__":
|
90 |
+
main()
|
examples/data_preprocess/dns_challenge_to_8k/process_dns_challenge_clean_read_speech.py
ADDED
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
"""
|
4 |
+
https://github.com/microsoft/DNS-Challenge/blob/master/download-dns-challenge-3.sh
|
5 |
+
|
6 |
+
1.2G
|
7 |
+
wget https://dns3public.blob.core.windows.net/dns3archive/DEMAND.tar.bz2
|
8 |
+
|
9 |
+
14G
|
10 |
+
wget https://dns3public.blob.core.windows.net/dns3archive/datasets/datasets.noise.tar.bz2
|
11 |
+
|
12 |
+
38G
|
13 |
+
wget https://dns3public.blob.core.windows.net/dns3archive/datasets/datasets.clean.read_speech.tar.bz2
|
14 |
+
|
15 |
+
12G
|
16 |
+
wget https://dns3public.blob.core.windows.net/dns3archive/datasets/datasets.clean.french_data.tar.bz2
|
17 |
+
|
18 |
+
43G
|
19 |
+
wget https://dns3public.blob.core.windows.net/dns3archive/datasets/datasets.clean.german_speech.tar.bz2
|
20 |
+
|
21 |
+
7.9G
|
22 |
+
wget https://dns3public.blob.core.windows.net/dns3archive/datasets/datasets.clean.italian_speech.tar.bz2
|
23 |
+
|
24 |
+
12G
|
25 |
+
wget https://dns3public.blob.core.windows.net/dns3archive/datasets/datasets.clean.mandarin_speech.tar.bz2
|
26 |
+
|
27 |
+
3.1G
|
28 |
+
wget https://dns3public.blob.core.windows.net/dns3archive/datasets/datasets.clean.russian_speech.tar.bz2
|
29 |
+
|
30 |
+
9.7G
|
31 |
+
wget https://dns3public.blob.core.windows.net/dns3archive/datasets/datasets.clean.spanish_speech.tar.bz2
|
32 |
+
|
33 |
+
617M
|
34 |
+
wget https://dns3public.blob.core.windows.net/dns3archive/datasets/datasets.clean.singing_voice.tar.bz2
|
35 |
+
|
36 |
+
"""
|
37 |
+
import argparse
|
38 |
+
import os
|
39 |
+
from pathlib import Path
|
40 |
+
import sys
|
41 |
+
|
42 |
+
import numpy as np
|
43 |
+
from tqdm import tqdm
|
44 |
+
|
45 |
+
pwd = os.path.abspath(os.path.dirname(__file__))
|
46 |
+
sys.path.append(os.path.join(pwd, "../../"))
|
47 |
+
|
48 |
+
import librosa
|
49 |
+
from scipy.io import wavfile
|
50 |
+
|
51 |
+
|
52 |
+
def get_args():
|
53 |
+
parser = argparse.ArgumentParser()
|
54 |
+
|
55 |
+
parser.add_argument(
|
56 |
+
"--data_dir",
|
57 |
+
# default=r"E:\programmer\asr_datasets\dns-challenge\datasets.clean.read_speech\datasets\clean",
|
58 |
+
# default=r"E:\programmer\asr_datasets\dns-challenge\datasets.clean.mandarin_speech\datasets\clean\mandarin_speech",
|
59 |
+
# default=r"E:\programmer\asr_datasets\dns-challenge\datasets.clean.singing_voice\datasets\clean\singing_voice",
|
60 |
+
# default=r"E:\programmer\asr_datasets\dns-challenge\datasets.clean.german_speech\datasets\clean\german_speech",
|
61 |
+
# default=r"E:\programmer\asr_datasets\dns-challenge\datasets.clean.italian_speech\datasets\clean\italian_speech",
|
62 |
+
default=r"E:\programmer\asr_datasets\dns-challenge\datasets.clean.russian_speech\datasets\clean\russian_speech",
|
63 |
+
type=str
|
64 |
+
)
|
65 |
+
parser.add_argument(
|
66 |
+
"--output_dir",
|
67 |
+
# default=r"E:\programmer\asr_datasets\denoise\dns-clean-read-speech-8k",
|
68 |
+
# default=r"E:\programmer\asr_datasets\denoise\dns-clean-mandarin-speech-8k",
|
69 |
+
# default=r"E:\programmer\asr_datasets\denoise\dns-clean-singing-voice-8k",
|
70 |
+
# default=r"E:\programmer\asr_datasets\denoise\dns-clean-german-speech-8k",
|
71 |
+
# default=r"E:\programmer\asr_datasets\denoise\dns-clean-italian-speech-8k",
|
72 |
+
default=r"E:\programmer\asr_datasets\denoise\dns-clean-russian-speech-8k",
|
73 |
+
type=str
|
74 |
+
)
|
75 |
+
parser.add_argument("--sample_rate", default=8000, type=int)
|
76 |
+
args = parser.parse_args()
|
77 |
+
return args
|
78 |
+
|
79 |
+
|
80 |
+
def main():
|
81 |
+
args = get_args()
|
82 |
+
|
83 |
+
data_dir = Path(args.data_dir)
|
84 |
+
output_dir = Path(args.output_dir)
|
85 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
86 |
+
|
87 |
+
# finished_set
|
88 |
+
finished_set = set()
|
89 |
+
for filename in tqdm(output_dir.glob("**/*.wav")):
|
90 |
+
name = filename.stem
|
91 |
+
finished_set.add(name)
|
92 |
+
print(f"finished_set count: {len(finished_set)}")
|
93 |
+
|
94 |
+
for filename in tqdm(data_dir.glob("**/*.wav")):
|
95 |
+
label = filename.parts[-2]
|
96 |
+
name = filename.stem
|
97 |
+
relative_name = filename.relative_to(data_dir)
|
98 |
+
# print(f"filename: {filename.as_posix()}")
|
99 |
+
if name in finished_set:
|
100 |
+
continue
|
101 |
+
finished_set.add(name)
|
102 |
+
|
103 |
+
try:
|
104 |
+
signal, _ = librosa.load(filename.as_posix(), sr=args.sample_rate)
|
105 |
+
except Exception:
|
106 |
+
print(f"skip file: {filename.as_posix()}")
|
107 |
+
continue
|
108 |
+
|
109 |
+
signal = signal * (1 << 15)
|
110 |
+
signal = np.array(signal, dtype=np.int16)
|
111 |
+
|
112 |
+
to_file = output_dir / relative_name.as_posix()
|
113 |
+
to_file.parent.mkdir(parents=True, exist_ok=True)
|
114 |
+
wavfile.write(
|
115 |
+
to_file.as_posix(),
|
116 |
+
rate=args.sample_rate,
|
117 |
+
data=signal,
|
118 |
+
)
|
119 |
+
return
|
120 |
+
|
121 |
+
|
122 |
+
if __name__ == "__main__":
|
123 |
+
main()
|
examples/data_preprocess/dns_challenge_to_8k/process_dns_challenge_demand.py
ADDED
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
"""
|
4 |
+
https://github.com/microsoft/DNS-Challenge/blob/master/download-dns-challenge-3.sh
|
5 |
+
|
6 |
+
1.2G
|
7 |
+
wget https://dns3public.blob.core.windows.net/dns3archive/DEMAND.tar.bz2
|
8 |
+
|
9 |
+
"""
|
10 |
+
import argparse
|
11 |
+
import os
|
12 |
+
from pathlib import Path
|
13 |
+
import random
|
14 |
+
import sys
|
15 |
+
import shutil
|
16 |
+
|
17 |
+
import numpy as np
|
18 |
+
|
19 |
+
pwd = os.path.abspath(os.path.dirname(__file__))
|
20 |
+
sys.path.append(os.path.join(pwd, "../../"))
|
21 |
+
|
22 |
+
import librosa
|
23 |
+
from scipy.io import wavfile
|
24 |
+
|
25 |
+
|
26 |
+
def get_args():
|
27 |
+
parser = argparse.ArgumentParser()
|
28 |
+
|
29 |
+
parser.add_argument(
|
30 |
+
"--data_dir",
|
31 |
+
default=r"E:\programmer\asr_datasets\dns-challenge\DEMAND\demand",
|
32 |
+
type=str
|
33 |
+
)
|
34 |
+
parser.add_argument(
|
35 |
+
"--output_dir",
|
36 |
+
default=r"E:\programmer\asr_datasets\denoise\demand-8k",
|
37 |
+
type=str
|
38 |
+
)
|
39 |
+
parser.add_argument("--sample_rate", default=8000, type=int)
|
40 |
+
args = parser.parse_args()
|
41 |
+
return args
|
42 |
+
|
43 |
+
|
44 |
+
def main():
|
45 |
+
args = get_args()
|
46 |
+
|
47 |
+
data_dir = Path(args.data_dir)
|
48 |
+
output_dir = Path(args.output_dir)
|
49 |
+
output_dir.mkdir(parents=True, exist_ok=False)
|
50 |
+
|
51 |
+
for filename in data_dir.glob("**/ch01.wav"):
|
52 |
+
label = filename.parts[-2]
|
53 |
+
name = filename.stem
|
54 |
+
|
55 |
+
signal, _ = librosa.load(filename.as_posix(), sr=args.sample_rate)
|
56 |
+
|
57 |
+
signal = signal * (1 << 15)
|
58 |
+
signal = np.array(signal, dtype=np.int16)
|
59 |
+
|
60 |
+
to_file = output_dir / f"{label}/{name}.wav"
|
61 |
+
to_file.parent.mkdir(parents=True, exist_ok=True)
|
62 |
+
wavfile.write(
|
63 |
+
to_file.as_posix(),
|
64 |
+
rate=args.sample_rate,
|
65 |
+
data=signal,
|
66 |
+
)
|
67 |
+
return
|
68 |
+
|
69 |
+
|
70 |
+
if __name__ == '__main__':
|
71 |
+
main()
|
examples/data_preprocess/dns_challenge_to_8k/process_dns_challenge_impulse_responses.py
ADDED
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
"""
|
4 |
+
https://github.com/microsoft/DNS-Challenge/blob/master/download-dns-challenge-3.sh
|
5 |
+
|
6 |
+
1.2G
|
7 |
+
wget https://dns3public.blob.core.windows.net/dns3archive/DEMAND.tar.bz2
|
8 |
+
|
9 |
+
14G
|
10 |
+
wget https://dns3public.blob.core.windows.net/dns3archive/datasets/datasets.noise.tar.bz2
|
11 |
+
|
12 |
+
38G
|
13 |
+
wget https://dns3public.blob.core.windows.net/dns3archive/datasets/datasets.clean.read_speech.tar.bz2
|
14 |
+
|
15 |
+
247M
|
16 |
+
wget https://dns3public.blob.core.windows.net/dns3archive/datasets/datasets.clean.emotional_speech.tar.bz2
|
17 |
+
|
18 |
+
240M
|
19 |
+
wget https://dns3public.blob.core.windows.net/dns3archive/datasets/datasets.impulse_responses.tar.bz2
|
20 |
+
|
21 |
+
|
22 |
+
"""
|
23 |
+
import argparse
|
24 |
+
import os
|
25 |
+
from pathlib import Path
|
26 |
+
import sys
|
27 |
+
|
28 |
+
import numpy as np
|
29 |
+
from tqdm import tqdm
|
30 |
+
|
31 |
+
pwd = os.path.abspath(os.path.dirname(__file__))
|
32 |
+
sys.path.append(os.path.join(pwd, "../../"))
|
33 |
+
|
34 |
+
import librosa
|
35 |
+
from scipy.io import wavfile
|
36 |
+
|
37 |
+
|
38 |
+
def get_args():
|
39 |
+
parser = argparse.ArgumentParser()
|
40 |
+
|
41 |
+
parser.add_argument(
|
42 |
+
"--data_dir",
|
43 |
+
default=r"E:\programmer\asr_datasets\dns-challenge\datasets.clean.emotional_speech\datasets\clean\emotional_speech",
|
44 |
+
type=str
|
45 |
+
)
|
46 |
+
parser.add_argument(
|
47 |
+
"--output_dir",
|
48 |
+
default=r"E:\programmer\asr_datasets\denoise\dns-clean-emotional-speech-8k",
|
49 |
+
type=str
|
50 |
+
)
|
51 |
+
parser.add_argument("--sample_rate", default=8000, type=int)
|
52 |
+
args = parser.parse_args()
|
53 |
+
return args
|
54 |
+
|
55 |
+
|
56 |
+
def main():
|
57 |
+
args = get_args()
|
58 |
+
|
59 |
+
data_dir = Path(args.data_dir)
|
60 |
+
output_dir = Path(args.output_dir)
|
61 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
62 |
+
|
63 |
+
# finished_set
|
64 |
+
finished_set = set()
|
65 |
+
for filename in tqdm(output_dir.glob("**/*.wav")):
|
66 |
+
name = filename.stem
|
67 |
+
finished_set.add(name)
|
68 |
+
print(f"finished_set count: {len(finished_set)}")
|
69 |
+
|
70 |
+
for filename in tqdm(data_dir.glob("**/*.wav")):
|
71 |
+
label = filename.parts[-2]
|
72 |
+
name = filename.stem
|
73 |
+
# print(f"filename: {filename.as_posix()}")
|
74 |
+
if name in finished_set:
|
75 |
+
continue
|
76 |
+
|
77 |
+
signal, _ = librosa.load(filename.as_posix(), sr=args.sample_rate)
|
78 |
+
|
79 |
+
signal = signal * (1 << 15)
|
80 |
+
signal = np.array(signal, dtype=np.int16)
|
81 |
+
|
82 |
+
to_file = output_dir / f"{label}/{name}.wav"
|
83 |
+
to_file.parent.mkdir(parents=True, exist_ok=True)
|
84 |
+
wavfile.write(
|
85 |
+
to_file.as_posix(),
|
86 |
+
rate=args.sample_rate,
|
87 |
+
data=signal,
|
88 |
+
)
|
89 |
+
return
|
90 |
+
|
91 |
+
|
92 |
+
if __name__ == "__main__":
|
93 |
+
main()
|
examples/data_preprocess/dns_challenge_to_8k/process_dns_challenge_noise.py
ADDED
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
"""
|
4 |
+
https://github.com/microsoft/DNS-Challenge/blob/master/download-dns-challenge-3.sh
|
5 |
+
|
6 |
+
1.2G
|
7 |
+
wget https://dns3public.blob.core.windows.net/dns3archive/DEMAND.tar.bz2
|
8 |
+
|
9 |
+
14G
|
10 |
+
wget https://dns3public.blob.core.windows.net/dns3archive/datasets/datasets.noise.tar.bz2
|
11 |
+
|
12 |
+
38G
|
13 |
+
wget https://dns3public.blob.core.windows.net/dns3archive/datasets/datasets.clean.read_speech.tar.bz2
|
14 |
+
|
15 |
+
"""
|
16 |
+
import argparse
|
17 |
+
import os
|
18 |
+
from pathlib import Path
|
19 |
+
import sys
|
20 |
+
|
21 |
+
import numpy as np
|
22 |
+
from tqdm import tqdm
|
23 |
+
|
24 |
+
pwd = os.path.abspath(os.path.dirname(__file__))
|
25 |
+
sys.path.append(os.path.join(pwd, "../../"))
|
26 |
+
|
27 |
+
import librosa
|
28 |
+
from scipy.io import wavfile
|
29 |
+
|
30 |
+
|
31 |
+
def get_args():
|
32 |
+
parser = argparse.ArgumentParser()
|
33 |
+
|
34 |
+
parser.add_argument(
|
35 |
+
"--data_dir",
|
36 |
+
default=r"E:\programmer\asr_datasets\dns-challenge\datasets.noise\datasets",
|
37 |
+
type=str
|
38 |
+
)
|
39 |
+
parser.add_argument(
|
40 |
+
"--output_dir",
|
41 |
+
default=r"E:\programmer\asr_datasets\denoise\dns-noise-8k",
|
42 |
+
type=str
|
43 |
+
)
|
44 |
+
parser.add_argument("--sample_rate", default=8000, type=int)
|
45 |
+
args = parser.parse_args()
|
46 |
+
return args
|
47 |
+
|
48 |
+
|
49 |
+
def main():
|
50 |
+
args = get_args()
|
51 |
+
|
52 |
+
data_dir = Path(args.data_dir)
|
53 |
+
output_dir = Path(args.output_dir)
|
54 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
55 |
+
|
56 |
+
for filename in tqdm(data_dir.glob("**/*.wav")):
|
57 |
+
label = filename.parts[-2]
|
58 |
+
name = filename.stem
|
59 |
+
# print(f"filename: {filename.as_posix()}")
|
60 |
+
|
61 |
+
signal, _ = librosa.load(filename.as_posix(), sr=args.sample_rate)
|
62 |
+
|
63 |
+
signal = signal * (1 << 15)
|
64 |
+
signal = np.array(signal, dtype=np.int16)
|
65 |
+
|
66 |
+
to_file = output_dir / f"{label}/{name}.wav"
|
67 |
+
to_file.parent.mkdir(parents=True, exist_ok=True)
|
68 |
+
wavfile.write(
|
69 |
+
to_file.as_posix(),
|
70 |
+
rate=args.sample_rate,
|
71 |
+
data=signal,
|
72 |
+
)
|
73 |
+
return
|
74 |
+
|
75 |
+
|
76 |
+
if __name__ == '__main__':
|
77 |
+
main()
|
examples/data_preprocess/dns_challenge_to_8k/process_musan.py
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
"""
|
4 |
+
https://www.openslr.org/17/
|
5 |
+
"""
|
6 |
+
|
7 |
+
if __name__ == '__main__':
|
8 |
+
pass
|
examples/mpnet/run.sh
CHANGED
@@ -17,10 +17,10 @@ sh run.sh --stage 5 --stop_stage 5 --system_version centos --file_folder_name fi
|
|
17 |
--speech_dir "/data/tianxing/HuggingDatasets/aishell/data_aishell/wav/train"
|
18 |
|
19 |
|
20 |
-
sh run.sh --stage 1 --stop_stage 2 --system_version centos --file_folder_name file_dir --final_model_name mpnet-nx-speech
|
21 |
--noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise" \
|
22 |
--speech_dir "/data/tianxing/HuggingDatasets/nx_noise/data/speech" \
|
23 |
-
--max_epochs
|
24 |
|
25 |
|
26 |
END
|
|
|
17 |
--speech_dir "/data/tianxing/HuggingDatasets/aishell/data_aishell/wav/train"
|
18 |
|
19 |
|
20 |
+
sh run.sh --stage 1 --stop_stage 2 --system_version centos --file_folder_name file_dir --final_model_name mpnet-nx-speech \
|
21 |
--noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise" \
|
22 |
--speech_dir "/data/tianxing/HuggingDatasets/nx_noise/data/speech" \
|
23 |
+
--max_epochs 100
|
24 |
|
25 |
|
26 |
END
|
examples/nx_mpnet/yaml/config.yaml
CHANGED
@@ -15,15 +15,15 @@ mask_hidden_size: 64
|
|
15 |
phase_num_blocks: 4
|
16 |
phase_hidden_size: 64
|
17 |
|
18 |
-
tsfm_hidden_size:
|
19 |
-
tsfm_attention_heads:
|
20 |
-
tsfm_num_blocks:
|
21 |
tsfm_dropout_rate: 0.0
|
22 |
tsfm_max_time_relative_position: 2048
|
23 |
tsfm_max_freq_relative_position: 256
|
24 |
tsfm_chunk_size: 1
|
25 |
-
tsfm_num_left_chunks:
|
26 |
-
tsfm_num_right_chunks:
|
27 |
|
28 |
discriminator_dim: 32
|
29 |
discriminator_in_channel: 2
|
|
|
15 |
phase_num_blocks: 4
|
16 |
phase_hidden_size: 64
|
17 |
|
18 |
+
tsfm_hidden_size: 64
|
19 |
+
tsfm_attention_heads: 4
|
20 |
+
tsfm_num_blocks: 4
|
21 |
tsfm_dropout_rate: 0.0
|
22 |
tsfm_max_time_relative_position: 2048
|
23 |
tsfm_max_freq_relative_position: 256
|
24 |
tsfm_chunk_size: 1
|
25 |
+
tsfm_num_left_chunks: 128
|
26 |
+
tsfm_num_right_chunks: 64
|
27 |
|
28 |
discriminator_dim: 32
|
29 |
discriminator_in_channel: 2
|
main.py
CHANGED
@@ -67,6 +67,13 @@ denoise_engines = {
|
|
67 |
project_path / "trained_models/mpnet-nx-speech-1-epoch.zip").as_posix()
|
68 |
}
|
69 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
70 |
"mpnet-aishell-1-epoch": {
|
71 |
"infer_cls": InferenceMPNet,
|
72 |
"kwargs": {
|
@@ -187,7 +194,7 @@ def main():
|
|
187 |
outputs=[shell_output],
|
188 |
)
|
189 |
|
190 |
-
# http://127.0.0.1:
|
191 |
blocks.queue().launch(
|
192 |
share=False if platform.system() == "Windows" else False,
|
193 |
server_name="127.0.0.1" if platform.system() == "Windows" else "0.0.0.0",
|
|
|
67 |
project_path / "trained_models/mpnet-nx-speech-1-epoch.zip").as_posix()
|
68 |
}
|
69 |
},
|
70 |
+
"mpnet-nx-speech-20-epoch": {
|
71 |
+
"infer_cls": InferenceMPNet,
|
72 |
+
"kwargs": {
|
73 |
+
"pretrained_model_path_or_zip_file": (
|
74 |
+
project_path / "trained_models/mpnet-nx-speech-20-epoch.zip").as_posix()
|
75 |
+
}
|
76 |
+
},
|
77 |
"mpnet-aishell-1-epoch": {
|
78 |
"infer_cls": InferenceMPNet,
|
79 |
"kwargs": {
|
|
|
194 |
outputs=[shell_output],
|
195 |
)
|
196 |
|
197 |
+
# http://127.0.0.1:7865/
|
198 |
blocks.queue().launch(
|
199 |
share=False if platform.system() == "Windows" else False,
|
200 |
server_name="127.0.0.1" if platform.system() == "Windows" else "0.0.0.0",
|
requirements.txt
CHANGED
@@ -12,3 +12,4 @@ torch-pesq==0.1.2
|
|
12 |
torchmetrics==1.6.1
|
13 |
torchmetrics[audio]==1.6.1
|
14 |
einops==0.8.1
|
|
|
|
12 |
torchmetrics==1.6.1
|
13 |
torchmetrics[audio]==1.6.1
|
14 |
einops==0.8.1
|
15 |
+
torch_stoi==0.2.3
|
toolbox/torchaudio/losses/__init__.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
|
4 |
+
|
5 |
+
if __name__ == '__main__':
|
6 |
+
pass
|
toolbox/torchaudio/losses/perceptual.py
ADDED
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
"""
|
4 |
+
https://zhuanlan.zhihu.com/p/627039860
|
5 |
+
"""
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
from torch_stoi import NegSTOILoss as TorchNegSTOILoss
|
9 |
+
|
10 |
+
|
11 |
+
class PMSQELoss(object):
|
12 |
+
"""
|
13 |
+
A Deep Learning Loss Function based on the Perceptual Evaluation of the Speech Quality
|
14 |
+
https://sigmat.ugr.es/PMSQE/
|
15 |
+
|
16 |
+
On Loss Functions for Supervised Monaural Time-Domain Speech Enhancement
|
17 |
+
https://arxiv.org/abs/1909.01019
|
18 |
+
|
19 |
+
https://github.com/asteroid-team/asteroid/blob/master/asteroid/losses/pmsqe.py
|
20 |
+
"""
|
21 |
+
|
22 |
+
|
23 |
+
class NegSTOILoss(nn.Module):
|
24 |
+
"""
|
25 |
+
STOI短时客观可懂度(Short-Time Objective Intelligibility),
|
26 |
+
通过计算语音信号的时域和频域特征之间的相关性来预测语音的可理解度,
|
27 |
+
范围从0到1,分数越高可懂度越高。
|
28 |
+
它适用于评估噪声环境下的语音可懂度改善效果。
|
29 |
+
|
30 |
+
https://github.com/mpariente/pytorch_stoi
|
31 |
+
https://github.com/mpariente/pystoi
|
32 |
+
https://github.com/speechbrain/speechbrain/blob/develop/speechbrain/nnet/loss/stoi_loss.py
|
33 |
+
"""
|
34 |
+
def __init__(self,
|
35 |
+
sample_rate: int,
|
36 |
+
reduction: str = "mean",
|
37 |
+
):
|
38 |
+
super(NegSTOILoss, self).__init__()
|
39 |
+
self.loss_fn = TorchNegSTOILoss(sample_rate=sample_rate)
|
40 |
+
self.reduction = reduction
|
41 |
+
|
42 |
+
if reduction not in ("sum", "mean"):
|
43 |
+
raise AssertionError(f"param reduction must be sum or mean.")
|
44 |
+
|
45 |
+
def forward(self, denoise: torch.Tensor, clean: torch.Tensor):
|
46 |
+
|
47 |
+
batch_loss = self.loss_fn.forward(denoise, clean)
|
48 |
+
|
49 |
+
if self.reduction == "mean":
|
50 |
+
loss = torch.mean(batch_loss)
|
51 |
+
elif self.reduction == "sum":
|
52 |
+
loss = torch.sum(batch_loss)
|
53 |
+
else:
|
54 |
+
raise AssertionError
|
55 |
+
return loss
|
56 |
+
|
57 |
+
|
58 |
+
def main():
|
59 |
+
sample_rate = 16000
|
60 |
+
|
61 |
+
loss_func = NegSTOILoss(
|
62 |
+
sample_rate=sample_rate,
|
63 |
+
reduction="mean",
|
64 |
+
)
|
65 |
+
|
66 |
+
denoise = torch.randn(2, sample_rate)
|
67 |
+
clean = torch.randn(2, sample_rate)
|
68 |
+
|
69 |
+
loss_batch = loss_func.forward(denoise, clean)
|
70 |
+
print(loss_batch)
|
71 |
+
return
|
72 |
+
|
73 |
+
|
74 |
+
if __name__ == "__main__":
|
75 |
+
main()
|
toolbox/torchaudio/losses/snr.py
ADDED
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
"""
|
4 |
+
https://zhuanlan.zhihu.com/p/627039860
|
5 |
+
"""
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
|
9 |
+
|
10 |
+
class NegativeSNRLoss(nn.Module):
|
11 |
+
"""
|
12 |
+
Signal-to-Noise Ratio
|
13 |
+
"""
|
14 |
+
def __init__(self, eps: float = 1e-8):
|
15 |
+
super(NegativeSNRLoss, self).__init__()
|
16 |
+
self.eps = eps
|
17 |
+
|
18 |
+
def forward(self, denoise: torch.Tensor, clean: torch.Tensor):
|
19 |
+
"""
|
20 |
+
Compute the SI-SNR loss between the estimated signal and the target signal.
|
21 |
+
|
22 |
+
:param denoise: The estimated signal (batch_size, signal_length)
|
23 |
+
:param clean: The target signal (batch_size, signal_length)
|
24 |
+
:return: The SI-SNR loss (batch_size,)
|
25 |
+
"""
|
26 |
+
if denoise.shape != clean.shape:
|
27 |
+
raise AssertionError("Input signals must have the same shape")
|
28 |
+
|
29 |
+
denoise = denoise - torch.mean(denoise, dim=-1, keepdim=True)
|
30 |
+
clean = clean - torch.mean(clean, dim=-1, keepdim=True)
|
31 |
+
|
32 |
+
noise = denoise - clean
|
33 |
+
|
34 |
+
clean_power = torch.norm(clean, p=2, dim=-1) ** 2
|
35 |
+
noise_power = torch.norm(noise, p=2, dim=-1) ** 2
|
36 |
+
|
37 |
+
snr = 10 * torch.log10((clean_power + self.eps) / (noise_power + self.eps))
|
38 |
+
|
39 |
+
return -snr.mean()
|
40 |
+
|
41 |
+
|
42 |
+
class NegativeSISNRLoss(nn.Module):
|
43 |
+
"""
|
44 |
+
Scale-Invariant Source-to-Noise Ratio
|
45 |
+
|
46 |
+
https://arxiv.org/abs/2206.07293
|
47 |
+
"""
|
48 |
+
def __init__(self,
|
49 |
+
reduction: str = "mean",
|
50 |
+
eps: float = 1e-8,
|
51 |
+
):
|
52 |
+
super(NegativeSISNRLoss, self).__init__()
|
53 |
+
self.reduction = reduction
|
54 |
+
self.eps = eps
|
55 |
+
|
56 |
+
def forward(self, denoise: torch.Tensor, clean: torch.Tensor):
|
57 |
+
"""
|
58 |
+
Compute the SI-SNR loss between the estimated signal and the target signal.
|
59 |
+
|
60 |
+
:param denoise: The estimated signal (batch_size, signal_length)
|
61 |
+
:param clean: The target signal (batch_size, signal_length)
|
62 |
+
:return: The SI-SNR loss (batch_size,)
|
63 |
+
"""
|
64 |
+
if denoise.shape != clean.shape:
|
65 |
+
raise AssertionError("Input signals must have the same shape")
|
66 |
+
|
67 |
+
denoise = denoise - torch.mean(denoise, dim=-1, keepdim=True)
|
68 |
+
clean = clean - torch.mean(clean, dim=-1, keepdim=True)
|
69 |
+
|
70 |
+
s_target = torch.sum(denoise * clean, dim=-1, keepdim=True) * clean / (torch.norm(clean, p=2, dim=-1, keepdim=True) ** 2 + self.eps)
|
71 |
+
|
72 |
+
e_noise = denoise - s_target
|
73 |
+
|
74 |
+
batch_si_snr = 10 * torch.log10(torch.norm(s_target, p=2, dim=-1) ** 2 / (torch.norm(e_noise, p=2, dim=-1) ** 2 + self.eps))
|
75 |
+
# si_snr shape: [batch_size,]
|
76 |
+
|
77 |
+
if self.reduction == "mean":
|
78 |
+
loss = torch.mean(batch_si_snr)
|
79 |
+
elif self.reduction == "sum":
|
80 |
+
loss = torch.sum(batch_si_snr)
|
81 |
+
else:
|
82 |
+
raise AssertionError
|
83 |
+
return -loss
|
84 |
+
|
85 |
+
|
86 |
+
def main():
|
87 |
+
batch_size = 2
|
88 |
+
signal_length = 16000
|
89 |
+
estimated_signal = torch.randn(batch_size, signal_length)
|
90 |
+
target_signal = torch.randn(batch_size, signal_length)
|
91 |
+
|
92 |
+
si_snr_loss = NegativeSISNRLoss()
|
93 |
+
|
94 |
+
loss = si_snr_loss.forward(estimated_signal, target_signal)
|
95 |
+
print(f"loss: {loss.item()}")
|
96 |
+
|
97 |
+
return
|
98 |
+
|
99 |
+
|
100 |
+
if __name__ == "__main__":
|
101 |
+
main()
|
toolbox/torchaudio/losses/spectral.py
ADDED
@@ -0,0 +1,351 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
"""
|
4 |
+
https://zhuanlan.zhihu.com/p/627039860
|
5 |
+
|
6 |
+
https://github.com/facebookresearch/denoiser/blob/main/denoiser/stft_loss.py
|
7 |
+
"""
|
8 |
+
from typing import List
|
9 |
+
|
10 |
+
import torch
|
11 |
+
import torch.nn as nn
|
12 |
+
from torch.nn import functional as F
|
13 |
+
|
14 |
+
|
15 |
+
class LSDLoss(nn.Module):
|
16 |
+
"""
|
17 |
+
Log Spectral Distance
|
18 |
+
|
19 |
+
Mean square error of power spectrum
|
20 |
+
"""
|
21 |
+
def __init__(self,
|
22 |
+
n_fft: int = 512,
|
23 |
+
win_size: int = 512,
|
24 |
+
hop_size: int = 256,
|
25 |
+
center: bool = True,
|
26 |
+
eps: float = 1e-8,
|
27 |
+
reduction: str = "mean",
|
28 |
+
):
|
29 |
+
super(LSDLoss, self).__init__()
|
30 |
+
self.n_fft = n_fft
|
31 |
+
self.win_size = win_size
|
32 |
+
self.hop_size = hop_size
|
33 |
+
self.center = center
|
34 |
+
self.eps = eps
|
35 |
+
self.reduction = reduction
|
36 |
+
|
37 |
+
if reduction not in ("sum", "mean"):
|
38 |
+
raise AssertionError(f"param reduction must be sum or mean.")
|
39 |
+
|
40 |
+
def forward(self, denoise_power: torch.Tensor, clean_power: torch.Tensor):
|
41 |
+
"""
|
42 |
+
:param denoise_power: The estimated signal (batch_size, signal_length)
|
43 |
+
:param clean_power: The target signal (batch_size, signal_length)
|
44 |
+
:return:
|
45 |
+
"""
|
46 |
+
denoise_power = denoise_power + self.eps
|
47 |
+
clean_power = clean_power + self.eps
|
48 |
+
|
49 |
+
log_denoise_power = torch.log10(denoise_power)
|
50 |
+
log_clean_power = torch.log10(clean_power)
|
51 |
+
|
52 |
+
# mean_square_error shape: [b, f]
|
53 |
+
mean_square_error = torch.mean(torch.square(log_denoise_power - log_clean_power), dim=-1)
|
54 |
+
|
55 |
+
if self.reduction == "mean":
|
56 |
+
lsd_loss = torch.mean(mean_square_error)
|
57 |
+
elif self.reduction == "sum":
|
58 |
+
lsd_loss = torch.sum(mean_square_error)
|
59 |
+
else:
|
60 |
+
raise AssertionError
|
61 |
+
return lsd_loss
|
62 |
+
|
63 |
+
|
64 |
+
class ComplexSpectralLoss(nn.Module):
|
65 |
+
def __init__(self,
|
66 |
+
n_fft: int = 512,
|
67 |
+
win_size: int = 512,
|
68 |
+
hop_size: int = 256,
|
69 |
+
center: bool = True,
|
70 |
+
eps: float = 1e-8,
|
71 |
+
reduction: str = "mean",
|
72 |
+
factor_mag: float = 0.5,
|
73 |
+
factor_pha: float = 0.3,
|
74 |
+
factor_gra: float = 0.2,
|
75 |
+
):
|
76 |
+
super().__init__()
|
77 |
+
self.n_fft = n_fft
|
78 |
+
self.win_size = win_size
|
79 |
+
self.hop_size = hop_size
|
80 |
+
self.center = center
|
81 |
+
self.eps = eps
|
82 |
+
self.reduction = reduction
|
83 |
+
|
84 |
+
self.factor_mag = factor_mag
|
85 |
+
self.factor_pha = factor_pha
|
86 |
+
self.factor_gra = factor_gra
|
87 |
+
|
88 |
+
if reduction not in ("sum", "mean"):
|
89 |
+
raise AssertionError(f"param reduction must be sum or mean.")
|
90 |
+
|
91 |
+
self.window = nn.Parameter(torch.hann_window(win_size), requires_grad=False)
|
92 |
+
|
93 |
+
def forward(self, denoise: torch.Tensor, clean: torch.Tensor):
|
94 |
+
"""
|
95 |
+
:param denoise: The estimated signal (batch_size, signal_length)
|
96 |
+
:param clean: The target signal (batch_size, signal_length)
|
97 |
+
:return:
|
98 |
+
"""
|
99 |
+
if denoise.shape != clean.shape:
|
100 |
+
raise AssertionError("Input signals must have the same shape")
|
101 |
+
|
102 |
+
# denoise_stft, clean_stft shape: [b, f, t]
|
103 |
+
denoise_stft = torch.stft(
|
104 |
+
denoise,
|
105 |
+
n_fft=self.n_fft,
|
106 |
+
win_length=self.win_size,
|
107 |
+
hop_length=self.hop_size,
|
108 |
+
window=self.window,
|
109 |
+
center=self.center,
|
110 |
+
pad_mode="reflect",
|
111 |
+
normalized=False,
|
112 |
+
return_complex=True
|
113 |
+
)
|
114 |
+
clean_stft = torch.stft(
|
115 |
+
clean,
|
116 |
+
n_fft=self.n_fft,
|
117 |
+
win_length=self.win_size,
|
118 |
+
hop_length=self.hop_size,
|
119 |
+
window=self.window,
|
120 |
+
center=self.center,
|
121 |
+
pad_mode="reflect",
|
122 |
+
normalized=False,
|
123 |
+
return_complex=True
|
124 |
+
)
|
125 |
+
|
126 |
+
# complex_diff shape: [b, f, t], dtype: torch.complex64
|
127 |
+
complex_diff = denoise_stft - clean_stft
|
128 |
+
|
129 |
+
# magnitude_diff, phase_diff shape: [b, f, t], dtype: torch.float32
|
130 |
+
magnitude_diff = torch.abs(complex_diff)
|
131 |
+
phase_diff = torch.angle(complex_diff)
|
132 |
+
|
133 |
+
# magnitude_loss, phase_loss shape: [b,]
|
134 |
+
magnitude_loss = torch.norm(magnitude_diff, p=2, dim=(-1, -2))
|
135 |
+
phase_loss = torch.norm(phase_diff, p=1, dim=(-1, -2))
|
136 |
+
|
137 |
+
# phase_grad shape: [b, f, t-1], dtype: torch.float32
|
138 |
+
phase_grad = torch.diff(torch.angle(denoise_stft), dim=-1)
|
139 |
+
grad_loss = torch.mean(torch.abs(phase_grad), dim=(-1, -2))
|
140 |
+
|
141 |
+
# loss, grad_loss shape: [b,]
|
142 |
+
batch_loss = self.factor_mag * magnitude_loss + self.factor_pha * phase_loss + self.factor_gra * grad_loss
|
143 |
+
# print(f"magnitude_loss: {magnitude_loss}")
|
144 |
+
# print(f"phase_loss: {phase_loss}")
|
145 |
+
# print(f"grad_loss: {grad_loss}")
|
146 |
+
|
147 |
+
if self.reduction == "mean":
|
148 |
+
loss = torch.mean(batch_loss)
|
149 |
+
elif self.reduction == "sum":
|
150 |
+
loss = torch.sum(batch_loss)
|
151 |
+
else:
|
152 |
+
raise AssertionError
|
153 |
+
return loss
|
154 |
+
|
155 |
+
|
156 |
+
class SpectralConvergenceLoss(torch.nn.Module):
|
157 |
+
"""Spectral convergence loss module."""
|
158 |
+
|
159 |
+
def __init__(self,
|
160 |
+
reduction: str = "mean",
|
161 |
+
):
|
162 |
+
super(SpectralConvergenceLoss, self).__init__()
|
163 |
+
self.reduction = reduction
|
164 |
+
|
165 |
+
if reduction not in ("sum", "mean"):
|
166 |
+
raise AssertionError(f"param reduction must be sum or mean.")
|
167 |
+
|
168 |
+
def forward(self,
|
169 |
+
denoise_magnitude: torch.Tensor,
|
170 |
+
clean_magnitude: torch.Tensor,
|
171 |
+
):
|
172 |
+
"""
|
173 |
+
:param denoise_magnitude: Tensor, shape: [batch_size, time_steps, freq_bins]
|
174 |
+
:param clean_magnitude: Tensor, shape: [batch_size, time_steps, freq_bins]
|
175 |
+
:return:
|
176 |
+
"""
|
177 |
+
error_norm = torch.norm(denoise_magnitude - clean_magnitude, p="fro", dim=(-1, -2))
|
178 |
+
truth_norm = torch.norm(clean_magnitude, p="fro", dim=(-1, -2))
|
179 |
+
batch_loss = error_norm / truth_norm
|
180 |
+
if self.reduction == "mean":
|
181 |
+
loss = torch.mean(batch_loss)
|
182 |
+
elif self.reduction == "sum":
|
183 |
+
loss = torch.sum(batch_loss)
|
184 |
+
else:
|
185 |
+
raise AssertionError
|
186 |
+
return loss
|
187 |
+
|
188 |
+
|
189 |
+
class LogSTFTMagnitudeLoss(torch.nn.Module):
|
190 |
+
"""Log STFT magnitude loss module."""
|
191 |
+
|
192 |
+
def __init__(self,
|
193 |
+
reduction: str = "mean",
|
194 |
+
):
|
195 |
+
super(LogSTFTMagnitudeLoss, self).__init__()
|
196 |
+
self.reduction = reduction
|
197 |
+
|
198 |
+
if reduction not in ("sum", "mean"):
|
199 |
+
raise AssertionError(f"param reduction must be sum or mean.")
|
200 |
+
|
201 |
+
def forward(self,
|
202 |
+
denoise_magnitude: torch.Tensor,
|
203 |
+
clean_magnitude: torch.Tensor,
|
204 |
+
):
|
205 |
+
"""
|
206 |
+
:param denoise_magnitude: Tensor, shape: [batch_size, time_steps, freq_bins]
|
207 |
+
:param clean_magnitude: Tensor, shape: [batch_size, time_steps, freq_bins]
|
208 |
+
:return:
|
209 |
+
"""
|
210 |
+
return F.l1_loss(torch.log(denoise_magnitude), torch.log(clean_magnitude))
|
211 |
+
|
212 |
+
|
213 |
+
class STFTLoss(torch.nn.Module):
|
214 |
+
"""STFT loss module."""
|
215 |
+
|
216 |
+
def __init__(self,
|
217 |
+
n_fft: int = 1024,
|
218 |
+
win_size: int = 600,
|
219 |
+
hop_size: int = 120,
|
220 |
+
center: bool = True,
|
221 |
+
reduction: str = "mean",
|
222 |
+
):
|
223 |
+
super(STFTLoss, self).__init__()
|
224 |
+
self.n_fft = n_fft
|
225 |
+
self.win_size = win_size
|
226 |
+
self.hop_size = hop_size
|
227 |
+
self.center = center
|
228 |
+
self.reduction = reduction
|
229 |
+
|
230 |
+
self.window = nn.Parameter(torch.hann_window(win_size), requires_grad=False)
|
231 |
+
|
232 |
+
self.spectral_convergence_loss = SpectralConvergenceLoss(reduction=reduction)
|
233 |
+
self.log_stft_magnitude_loss = LogSTFTMagnitudeLoss(reduction=reduction)
|
234 |
+
|
235 |
+
def forward(self, denoise: torch.Tensor, clean: torch.Tensor):
|
236 |
+
"""
|
237 |
+
:param denoise:
|
238 |
+
:param clean:
|
239 |
+
:return:
|
240 |
+
"""
|
241 |
+
if denoise.shape != clean.shape:
|
242 |
+
raise AssertionError("Input signals must have the same shape")
|
243 |
+
|
244 |
+
# denoise_stft, clean_stft shape: [b, f, t]
|
245 |
+
denoise_stft = torch.stft(
|
246 |
+
denoise,
|
247 |
+
n_fft=self.n_fft,
|
248 |
+
win_length=self.win_size,
|
249 |
+
hop_length=self.hop_size,
|
250 |
+
window=self.window,
|
251 |
+
center=self.center,
|
252 |
+
pad_mode="reflect",
|
253 |
+
normalized=False,
|
254 |
+
return_complex=True
|
255 |
+
)
|
256 |
+
clean_stft = torch.stft(
|
257 |
+
clean,
|
258 |
+
n_fft=self.n_fft,
|
259 |
+
win_length=self.win_size,
|
260 |
+
hop_length=self.hop_size,
|
261 |
+
window=self.window,
|
262 |
+
center=self.center,
|
263 |
+
pad_mode="reflect",
|
264 |
+
normalized=False,
|
265 |
+
return_complex=True
|
266 |
+
)
|
267 |
+
|
268 |
+
denoise_magnitude = torch.abs(denoise_stft)
|
269 |
+
clean_magnitude = torch.abs(clean_stft)
|
270 |
+
|
271 |
+
sc_loss = self.spectral_convergence_loss.forward(denoise_magnitude, clean_magnitude)
|
272 |
+
mag_loss = self.log_stft_magnitude_loss.forward(denoise_magnitude, clean_magnitude)
|
273 |
+
|
274 |
+
return sc_loss, mag_loss
|
275 |
+
|
276 |
+
|
277 |
+
class MultiResolutionSTFTLoss(torch.nn.Module):
|
278 |
+
"""Multi resolution STFT loss module."""
|
279 |
+
|
280 |
+
def __init__(self,
|
281 |
+
fft_size_list: List[int] = None,
|
282 |
+
win_size_list: List[int] = None,
|
283 |
+
hop_size_list: List[int] = None,
|
284 |
+
factor_sc=0.1,
|
285 |
+
factor_mag=0.1,
|
286 |
+
):
|
287 |
+
super(MultiResolutionSTFTLoss, self).__init__()
|
288 |
+
fft_size_list = fft_size_list or [1024, 2048, 512]
|
289 |
+
win_size_list = win_size_list or [600, 1200, 240]
|
290 |
+
hop_size_list = hop_size_list or [120, 240, 50]
|
291 |
+
|
292 |
+
if not len(fft_size_list) == len(win_size_list) == len(hop_size_list):
|
293 |
+
raise AssertionError
|
294 |
+
|
295 |
+
loss_fn_list = list()
|
296 |
+
for n_fft, win_size, hop_size in zip(fft_size_list, win_size_list, hop_size_list):
|
297 |
+
loss_fn_list.append(
|
298 |
+
STFTLoss(
|
299 |
+
n_fft=n_fft,
|
300 |
+
win_size=win_size,
|
301 |
+
hop_size=hop_size,
|
302 |
+
)
|
303 |
+
)
|
304 |
+
|
305 |
+
self.loss_fn_list = loss_fn_list
|
306 |
+
self.factor_sc = factor_sc
|
307 |
+
self.factor_mag = factor_mag
|
308 |
+
|
309 |
+
def forward(self, denoise: torch.Tensor, clean: torch.Tensor):
|
310 |
+
"""
|
311 |
+
:param denoise:
|
312 |
+
:param clean:
|
313 |
+
:return:
|
314 |
+
"""
|
315 |
+
if denoise.shape != clean.shape:
|
316 |
+
raise AssertionError("Input signals must have the same shape")
|
317 |
+
|
318 |
+
sc_loss = 0.0
|
319 |
+
mag_loss = 0.0
|
320 |
+
for loss_fn in self.loss_fn_list:
|
321 |
+
sc_l, mag_l = loss_fn.forward(denoise, clean)
|
322 |
+
sc_loss += sc_l
|
323 |
+
mag_loss += mag_l
|
324 |
+
sc_loss = sc_loss / len(self.loss_fn_list)
|
325 |
+
mag_loss = mag_loss / len(self.loss_fn_list)
|
326 |
+
|
327 |
+
sc_loss = self.factor_sc * sc_loss
|
328 |
+
mag_loss = self.factor_mag * mag_loss
|
329 |
+
|
330 |
+
loss = sc_loss + mag_loss
|
331 |
+
return loss
|
332 |
+
|
333 |
+
|
334 |
+
def main():
|
335 |
+
batch_size = 2
|
336 |
+
signal_length = 16000
|
337 |
+
estimated_signal = torch.randn(batch_size, signal_length)
|
338 |
+
target_signal = torch.randn(batch_size, signal_length)
|
339 |
+
|
340 |
+
# loss_fn = LSDLoss()
|
341 |
+
# loss_fn = ComplexSpectralLoss()
|
342 |
+
loss_fn = MultiResolutionSTFTLoss()
|
343 |
+
|
344 |
+
loss = loss_fn.forward(estimated_signal, target_signal)
|
345 |
+
print(f"loss: {loss.item()}")
|
346 |
+
|
347 |
+
return
|
348 |
+
|
349 |
+
|
350 |
+
if __name__ == "__main__":
|
351 |
+
main()
|
toolbox/torchaudio/metrics/__init__.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
|
4 |
+
|
5 |
+
if __name__ == '__main__':
|
6 |
+
pass
|
toolbox/torchaudio/metrics/pesq.py
ADDED
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
from joblib import Parallel, delayed
|
4 |
+
import numpy as np
|
5 |
+
from pesq import pesq
|
6 |
+
from typing import List
|
7 |
+
|
8 |
+
from pesq import cypesq
|
9 |
+
|
10 |
+
|
11 |
+
def run_pesq(clean_audio: np.ndarray,
|
12 |
+
noisy_audio: np.ndarray,
|
13 |
+
sample_rate: int = 16000,
|
14 |
+
mode: str = "wb",
|
15 |
+
) -> float:
|
16 |
+
if sample_rate == 8000 and mode == "wb":
|
17 |
+
raise AssertionError(f"mode should be `nb` when sample_rate is 8000")
|
18 |
+
try:
|
19 |
+
pesq_score = pesq(sample_rate, clean_audio, noisy_audio, mode)
|
20 |
+
except cypesq.NoUtterancesError as e:
|
21 |
+
pesq_score = -1
|
22 |
+
except Exception as e:
|
23 |
+
print(f"pesq failed. error type: {type(e)}, error text: {str(e)}")
|
24 |
+
pesq_score = -1
|
25 |
+
return pesq_score
|
26 |
+
|
27 |
+
|
28 |
+
def run_batch_pesq(clean_audio_list: List[np.ndarray],
|
29 |
+
noisy_audio_list: List[np.ndarray],
|
30 |
+
sample_rate: int = 16000,
|
31 |
+
mode: str = "wb",
|
32 |
+
n_jobs: int = 4,
|
33 |
+
) -> List[float]:
|
34 |
+
parallel = Parallel(n_jobs=n_jobs)
|
35 |
+
|
36 |
+
parallel_tasks = list()
|
37 |
+
for clean_audio, noisy_audio in zip(clean_audio_list, noisy_audio_list):
|
38 |
+
parallel_task = delayed(run_pesq)(clean_audio, noisy_audio, sample_rate, mode)
|
39 |
+
parallel_tasks.append(parallel_task)
|
40 |
+
|
41 |
+
pesq_score_list = parallel.__call__(parallel_tasks)
|
42 |
+
return pesq_score_list
|
43 |
+
|
44 |
+
|
45 |
+
def run_pesq_score(clean_audio_list: List[np.ndarray],
|
46 |
+
noisy_audio_list: List[np.ndarray],
|
47 |
+
sample_rate: int = 16000,
|
48 |
+
mode: str = "wb",
|
49 |
+
n_jobs: int = 4,
|
50 |
+
) -> List[float]:
|
51 |
+
|
52 |
+
pesq_score_list = run_batch_pesq(clean_audio_list=clean_audio_list,
|
53 |
+
noisy_audio_list=noisy_audio_list,
|
54 |
+
sample_rate=sample_rate,
|
55 |
+
mode=mode,
|
56 |
+
n_jobs=n_jobs,
|
57 |
+
)
|
58 |
+
|
59 |
+
pesq_score = np.mean(pesq_score_list)
|
60 |
+
return pesq_score
|
61 |
+
|
62 |
+
|
63 |
+
def main():
|
64 |
+
clean_audio = np.random.uniform(low=0, high=1, size=(2, 160000,))
|
65 |
+
noisy_audio = np.random.uniform(low=0, high=1, size=(2, 160000,))
|
66 |
+
|
67 |
+
clean_audio_list = list(clean_audio)
|
68 |
+
noisy_audio_list = list(noisy_audio)
|
69 |
+
|
70 |
+
pesq_score_list = run_batch_pesq(clean_audio_list, noisy_audio_list)
|
71 |
+
print(pesq_score_list)
|
72 |
+
|
73 |
+
pesq_score = run_pesq_score(clean_audio_list, noisy_audio_list)
|
74 |
+
print(pesq_score)
|
75 |
+
|
76 |
+
return
|
77 |
+
|
78 |
+
|
79 |
+
if __name__ == "__main__":
|
80 |
+
main()
|
toolbox/torchaudio/models/conv_tasnet/configuration_conv_tasnet.py
ADDED
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
from typing import Tuple
|
4 |
+
|
5 |
+
from toolbox.torchaudio.configuration_utils import PretrainedConfig
|
6 |
+
|
7 |
+
|
8 |
+
class ConvTasNetConfig(PretrainedConfig):
|
9 |
+
"""
|
10 |
+
https://github.com/kaituoxu/Conv-TasNet/blob/master/src/train.py
|
11 |
+
"""
|
12 |
+
def __init__(self,
|
13 |
+
sample_rate: int = 8000,
|
14 |
+
segment_size: int = 4,
|
15 |
+
|
16 |
+
win_size: int = 20,
|
17 |
+
|
18 |
+
freq_bins: int = 256,
|
19 |
+
bottleneck_channels: int = 256,
|
20 |
+
num_speakers: int = 2,
|
21 |
+
num_blocks: int = 4,
|
22 |
+
num_sub_blocks: int = 8,
|
23 |
+
sub_blocks_channels: int = 512,
|
24 |
+
sub_blocks_kernel_size: int = 3,
|
25 |
+
|
26 |
+
norm_type: str = "gLN",
|
27 |
+
causal: bool = False,
|
28 |
+
mask_nonlinear: str = "relu",
|
29 |
+
|
30 |
+
**kwargs
|
31 |
+
):
|
32 |
+
super(ConvTasNetConfig, self).__init__(**kwargs)
|
33 |
+
self.sample_rate = sample_rate
|
34 |
+
self.segment_size = segment_size
|
35 |
+
|
36 |
+
self.win_size = win_size
|
37 |
+
|
38 |
+
self.freq_bins = freq_bins
|
39 |
+
self.bottleneck_channels = bottleneck_channels
|
40 |
+
self.num_speakers = num_speakers
|
41 |
+
self.num_blocks = num_blocks
|
42 |
+
self.num_sub_blocks = num_sub_blocks
|
43 |
+
self.sub_blocks_channels = sub_blocks_channels
|
44 |
+
self.sub_blocks_kernel_size = sub_blocks_kernel_size
|
45 |
+
|
46 |
+
self.norm_type = norm_type
|
47 |
+
self.causal = causal
|
48 |
+
self.mask_nonlinear = mask_nonlinear
|
49 |
+
|
50 |
+
|
51 |
+
if __name__ == "__main__":
|
52 |
+
pass
|
toolbox/torchaudio/models/conv_tasnet/modeling_conv_tasnet.py
CHANGED
@@ -2,8 +2,483 @@
|
|
2 |
# -*- coding: utf-8 -*-
|
3 |
"""
|
4 |
https://github.com/kaituoxu/Conv-TasNet/blob/master/src/conv_tasnet.py
|
|
|
|
|
5 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
|
7 |
|
8 |
-
if __name__ ==
|
9 |
-
|
|
|
2 |
# -*- coding: utf-8 -*-
|
3 |
"""
|
4 |
https://github.com/kaituoxu/Conv-TasNet/blob/master/src/conv_tasnet.py
|
5 |
+
|
6 |
+
https://pytorch.org/audio/2.5.0/generated/torchaudio.models.ConvTasNet.html
|
7 |
"""
|
8 |
+
import os
|
9 |
+
from typing import List, Optional, Union
|
10 |
+
|
11 |
+
import torch
|
12 |
+
import torch.nn as nn
|
13 |
+
from torch.nn import functional as F
|
14 |
+
|
15 |
+
from toolbox.torchaudio.configuration_utils import CONFIG_FILE
|
16 |
+
from toolbox.torchaudio.models.conv_tasnet.utils import overlap_and_add
|
17 |
+
from toolbox.torchaudio.models.conv_tasnet.configuration_conv_tasnet import ConvTasNetConfig
|
18 |
+
|
19 |
+
|
20 |
+
class ChannelwiseLayerNorm(nn.Module):
|
21 |
+
"""Channel-wise Layer Normalization (cLN)"""
|
22 |
+
def __init__(self,
|
23 |
+
channels: int,
|
24 |
+
eps: float = 1e-8
|
25 |
+
):
|
26 |
+
super(ChannelwiseLayerNorm, self).__init__()
|
27 |
+
self.gamma = nn.Parameter(torch.Tensor(1, channels, 1))
|
28 |
+
self.beta = nn.Parameter(torch.Tensor(1, channels,1 ))
|
29 |
+
self.reset_parameters()
|
30 |
+
|
31 |
+
self.eps = eps
|
32 |
+
|
33 |
+
def reset_parameters(self):
|
34 |
+
self.gamma.data.fill_(1)
|
35 |
+
self.beta.data.zero_()
|
36 |
+
|
37 |
+
def forward(self, y):
|
38 |
+
"""
|
39 |
+
:param y: Tensor, shape: [batch_size, channels, time_steps]
|
40 |
+
:return: gln_y: Tensor, shape: [batch_size, channels, time_steps]
|
41 |
+
"""
|
42 |
+
# mean, var shape: [batch_size, 1, time_steps]
|
43 |
+
mean = torch.mean(y, dim=1, keepdim=True)
|
44 |
+
var = torch.var(y, dim=1, keepdim=True, unbiased=False)
|
45 |
+
|
46 |
+
cln_y = self.gamma * (y - mean) / torch.pow(var + self.eps, 0.5) + self.beta
|
47 |
+
return cln_y
|
48 |
+
|
49 |
+
|
50 |
+
class GlobalLayerNorm(nn.Module):
|
51 |
+
"""Global Layer Normalization (gLN)"""
|
52 |
+
def __init__(self,
|
53 |
+
channels: int,
|
54 |
+
eps: float = 1e-8
|
55 |
+
):
|
56 |
+
super(GlobalLayerNorm, self).__init__()
|
57 |
+
self.gamma = nn.Parameter(torch.Tensor(1, channels, 1))
|
58 |
+
self.beta = nn.Parameter(torch.Tensor(1, channels,1 ))
|
59 |
+
self.reset_parameters()
|
60 |
+
|
61 |
+
self.eps = eps
|
62 |
+
|
63 |
+
def reset_parameters(self):
|
64 |
+
self.gamma.data.fill_(1)
|
65 |
+
self.beta.data.zero_()
|
66 |
+
|
67 |
+
def forward(self, y):
|
68 |
+
"""
|
69 |
+
:param y: Tensor, shape: [batch_size, channels, time_steps]
|
70 |
+
:return: gln_y: Tensor, shape: [batch_size, channels, time_steps]
|
71 |
+
"""
|
72 |
+
# mean, var shape: [batch_size, 1, 1]
|
73 |
+
mean = y.mean(dim=1, keepdim=True).mean(dim=2, keepdim=True)
|
74 |
+
var = (torch.pow(y-mean, 2)).mean(dim=1, keepdim=True).mean(dim=2, keepdim=True)
|
75 |
+
|
76 |
+
gln_y = self.gamma * (y - mean) / torch.pow(var + self.eps, 0.5) + self.beta
|
77 |
+
return gln_y
|
78 |
+
|
79 |
+
|
80 |
+
def choose_norm(norm_type: str, channels: int):
|
81 |
+
"""
|
82 |
+
The input of normalization will be (M, C, K), where M is batch size,
|
83 |
+
C is channel size and K is sequence length.
|
84 |
+
"""
|
85 |
+
if norm_type == "gLN":
|
86 |
+
return GlobalLayerNorm(channels)
|
87 |
+
elif norm_type == "cLN":
|
88 |
+
return ChannelwiseLayerNorm(channels)
|
89 |
+
else: # norm_type == "BN":
|
90 |
+
# Given input (M, C, K), nn.BatchNorm1d(C) will accumulate statics
|
91 |
+
# along M and K, so this BN usage is right.
|
92 |
+
return nn.BatchNorm1d(channels)
|
93 |
+
|
94 |
+
|
95 |
+
class Chomp1d(nn.Module):
|
96 |
+
"""
|
97 |
+
To ensure the output length is the same as the input.
|
98 |
+
"""
|
99 |
+
def __init__(self, chomp_size: int):
|
100 |
+
super(Chomp1d, self).__init__()
|
101 |
+
self.chomp_size = chomp_size
|
102 |
+
|
103 |
+
def forward(self, x: torch.Tensor):
|
104 |
+
"""
|
105 |
+
:param x: Tensor, shape: [batch_size, hidden_size, k_pad]
|
106 |
+
:return: Tensor, shape: [batch_size, hidden_size, k]
|
107 |
+
"""
|
108 |
+
return x[:, :, :-self.chomp_size].contiguous()
|
109 |
+
|
110 |
+
|
111 |
+
class DepthwiseSeparableConv(nn.Module):
|
112 |
+
def __init__(self,
|
113 |
+
in_channels: int,
|
114 |
+
out_channels: int,
|
115 |
+
kernel_size: int,
|
116 |
+
stride: int,
|
117 |
+
padding: int,
|
118 |
+
dilation: int,
|
119 |
+
norm_type="gLN",
|
120 |
+
causal=False
|
121 |
+
):
|
122 |
+
super(DepthwiseSeparableConv, self).__init__()
|
123 |
+
# Use `groups` option to implement depthwise convolution
|
124 |
+
# [M, H, K] -> [M, H, K]
|
125 |
+
self.depthwise_conv = nn.Conv1d(
|
126 |
+
in_channels=in_channels, out_channels=in_channels,
|
127 |
+
kernel_size=kernel_size, stride=stride,
|
128 |
+
padding=padding, dilation=dilation,
|
129 |
+
groups=in_channels, bias=False,
|
130 |
+
)
|
131 |
+
|
132 |
+
self.chomp = None
|
133 |
+
if causal:
|
134 |
+
self.chomp = Chomp1d(padding)
|
135 |
+
|
136 |
+
self.prelu = nn.PReLU()
|
137 |
+
self.norm = choose_norm(norm_type, in_channels)
|
138 |
+
# [M, H, K] -> [M, B, K]
|
139 |
+
self.pointwise_conv = nn.Conv1d(
|
140 |
+
in_channels=in_channels,
|
141 |
+
out_channels=out_channels,
|
142 |
+
kernel_size=1, bias=False
|
143 |
+
)
|
144 |
+
|
145 |
+
def forward(self, x: torch.Tensor):
|
146 |
+
"""
|
147 |
+
:param x: Tensor, shape: [batch_size, hidden_size, k]
|
148 |
+
:return: Tensor, shape: [batch_size, b, k]
|
149 |
+
"""
|
150 |
+
x = self.depthwise_conv.forward(x)
|
151 |
+
if self.chomp is not None:
|
152 |
+
x = self.chomp.forward(x)
|
153 |
+
x = self.prelu.forward(x)
|
154 |
+
x = self.norm.forward(x)
|
155 |
+
x = self.pointwise_conv.forward(x)
|
156 |
+
|
157 |
+
return x
|
158 |
+
|
159 |
+
|
160 |
+
class Encoder(nn.Module):
|
161 |
+
def __init__(self, win_size: int, freq_bins: int):
|
162 |
+
super(Encoder, self).__init__()
|
163 |
+
self.win_size = win_size
|
164 |
+
self.freq_bins = freq_bins
|
165 |
+
|
166 |
+
self.conv1d_U = nn.Conv1d(
|
167 |
+
in_channels=1,
|
168 |
+
out_channels=freq_bins,
|
169 |
+
kernel_size=win_size,
|
170 |
+
stride=win_size // 2,
|
171 |
+
bias=False
|
172 |
+
)
|
173 |
+
|
174 |
+
def forward(self, mixture):
|
175 |
+
"""
|
176 |
+
:param mixture: Tensor, shape: [batch_size, num_samples]
|
177 |
+
:return: mixture_w, Tensor, shape: [batch_size, freq_bins, time_steps],
|
178 |
+
where time_steps = (num_samples-win_size) / (win_size/2) + 1 = 2num_samples/win_size-1
|
179 |
+
"""
|
180 |
+
mixture = torch.unsqueeze(mixture, 1) # [M, 1, T]
|
181 |
+
mixture_w = F.relu(self.conv1d_U(mixture)) # [M, N, K]
|
182 |
+
return mixture_w
|
183 |
+
|
184 |
+
|
185 |
+
class Decoder(nn.Module):
|
186 |
+
def __init__(self, win_size: int, freq_bins: int):
|
187 |
+
super(Decoder, self).__init__()
|
188 |
+
self.win_size = win_size
|
189 |
+
self.freq_bins = freq_bins
|
190 |
+
|
191 |
+
self.basis_signals = nn.Linear(
|
192 |
+
in_features=freq_bins,
|
193 |
+
out_features=win_size,
|
194 |
+
bias=False
|
195 |
+
)
|
196 |
+
|
197 |
+
def forward(self,
|
198 |
+
mixture_w: torch.Tensor,
|
199 |
+
est_mask: torch.Tensor,
|
200 |
+
):
|
201 |
+
"""
|
202 |
+
:param mixture_w: Tensor, shape: [batch_size, freq_bins, time_steps],
|
203 |
+
where time_steps = (num_samples-win_size) / (win_size/2) + 1 = 2num_samples/win_size-1
|
204 |
+
:param est_mask: Tensor, shape: [batch_size, c, freq_bins, time_steps],
|
205 |
+
:return: Tensor, shape: [batch_size, c, num_samples],
|
206 |
+
"""
|
207 |
+
source_w = torch.unsqueeze(mixture_w, 1) * est_mask
|
208 |
+
source_w = torch.transpose(source_w, 2, 3)
|
209 |
+
est_source = self.basis_signals(source_w)
|
210 |
+
est_source = overlap_and_add(est_source, self.win_size//2)
|
211 |
+
return est_source
|
212 |
+
|
213 |
+
|
214 |
+
class TemporalBlock(nn.Module):
|
215 |
+
def __init__(self,
|
216 |
+
in_channels: int,
|
217 |
+
out_channels: int,
|
218 |
+
kernel_size: int,
|
219 |
+
stride: int,
|
220 |
+
padding: int,
|
221 |
+
dilation: int,
|
222 |
+
norm_type="gLN",
|
223 |
+
causal=False
|
224 |
+
):
|
225 |
+
super(TemporalBlock, self).__init__()
|
226 |
+
self.conv1x1 = nn.Conv1d(in_channels, out_channels, 1, bias=False)
|
227 |
+
self.prelu = nn.PReLU()
|
228 |
+
self.norm = choose_norm(norm_type, out_channels)
|
229 |
+
# [M, H, K] -> [M, B, K]
|
230 |
+
self.dsconv = DepthwiseSeparableConv(
|
231 |
+
out_channels, in_channels,
|
232 |
+
kernel_size, stride,
|
233 |
+
padding, dilation,
|
234 |
+
norm_type, causal,
|
235 |
+
)
|
236 |
+
|
237 |
+
def forward(self, x):
|
238 |
+
residual = x
|
239 |
+
|
240 |
+
x = self.conv1x1.forward(x)
|
241 |
+
x = self.prelu.forward(x)
|
242 |
+
x = self.norm.forward(x)
|
243 |
+
x = self.dsconv.forward(x)
|
244 |
+
|
245 |
+
out = x + residual
|
246 |
+
return out
|
247 |
+
|
248 |
+
|
249 |
+
class TemporalConvNet(nn.Module):
|
250 |
+
def __init__(self,
|
251 |
+
freq_bins: int = 256,
|
252 |
+
bottleneck_channels: int = 256,
|
253 |
+
num_speakers: int = 2,
|
254 |
+
num_blocks: int = 4,
|
255 |
+
num_sub_blocks: int = 8,
|
256 |
+
sub_blocks_channels: int = 512,
|
257 |
+
sub_blocks_kernel_size: int = 3,
|
258 |
+
norm_type: str = "gLN",
|
259 |
+
causal: bool = False,
|
260 |
+
mask_nonlinear: str = "relu",
|
261 |
+
|
262 |
+
):
|
263 |
+
super(TemporalConvNet, self).__init__()
|
264 |
+
self.freq_bins = freq_bins
|
265 |
+
self.bottleneck_channels = bottleneck_channels
|
266 |
+
self.num_speakers = num_speakers
|
267 |
+
|
268 |
+
self.num_blocks = num_blocks
|
269 |
+
self.num_sub_blocks = num_sub_blocks
|
270 |
+
self.sub_blocks_channels = sub_blocks_channels
|
271 |
+
self.sub_blocks_kernel_size = sub_blocks_kernel_size
|
272 |
+
|
273 |
+
self.mask_nonlinear = mask_nonlinear
|
274 |
+
|
275 |
+
self.layer_norm = ChannelwiseLayerNorm(freq_bins)
|
276 |
+
self.bottleneck_conv1x1 = nn.Conv1d(freq_bins, bottleneck_channels, 1, bias=False)
|
277 |
+
|
278 |
+
self.temporal_conv_list = nn.ModuleList([])
|
279 |
+
for num_block_idx in range(num_blocks):
|
280 |
+
sub_blocks = list()
|
281 |
+
for num_sub_block_idx in range(num_sub_blocks):
|
282 |
+
dilation = 2 ** num_sub_block_idx
|
283 |
+
padding = (sub_blocks_kernel_size - 1) * dilation
|
284 |
+
if not causal:
|
285 |
+
padding = padding // 2
|
286 |
+
temporal_block = TemporalBlock(
|
287 |
+
bottleneck_channels, sub_blocks_channels,
|
288 |
+
sub_blocks_kernel_size, stride=1,
|
289 |
+
padding=padding, dilation=dilation,
|
290 |
+
norm_type=norm_type, causal=causal,
|
291 |
+
)
|
292 |
+
sub_blocks.append(temporal_block)
|
293 |
+
self.temporal_conv_list.extend(sub_blocks)
|
294 |
+
|
295 |
+
self.mask_conv1x1 = nn.Conv1d(
|
296 |
+
in_channels=bottleneck_channels,
|
297 |
+
out_channels=num_speakers * freq_bins,
|
298 |
+
kernel_size=1,
|
299 |
+
bias=False,
|
300 |
+
)
|
301 |
+
|
302 |
+
def forward(self, mixture_w: torch.Tensor):
|
303 |
+
"""
|
304 |
+
:param mixture_w: Tensor, shape: [batch_size, freq_bins, time_steps]
|
305 |
+
:return: est_mask: Tensor, shape: [batch_size, freq_bins, time_steps]
|
306 |
+
"""
|
307 |
+
batch_size, freq_bins, time_steps = mixture_w.size()
|
308 |
+
|
309 |
+
x = self.layer_norm.forward(mixture_w)
|
310 |
+
x = self.bottleneck_conv1x1.forward(x)
|
311 |
+
|
312 |
+
for temporal_conv in self.temporal_conv_list:
|
313 |
+
x = temporal_conv.forward(x)
|
314 |
+
|
315 |
+
score = self.mask_conv1x1.forward(x)
|
316 |
+
|
317 |
+
# [M, C*N, K] -> [M, C, N, K]
|
318 |
+
score = score.view(batch_size, self.num_speakers, freq_bins, time_steps)
|
319 |
+
|
320 |
+
if self.mask_nonlinear == "softmax":
|
321 |
+
est_mask = F.softmax(score, dim=1)
|
322 |
+
elif self.mask_nonlinear == "relu":
|
323 |
+
est_mask = F.relu(score)
|
324 |
+
else:
|
325 |
+
raise ValueError("Unsupported mask non-linear function")
|
326 |
+
|
327 |
+
return est_mask
|
328 |
+
|
329 |
+
|
330 |
+
class ConvTasNet(nn.Module):
|
331 |
+
def __init__(self,
|
332 |
+
win_size: int = 20,
|
333 |
+
freq_bins: int = 256,
|
334 |
+
bottleneck_channels: int = 256,
|
335 |
+
num_speakers: int = 2,
|
336 |
+
num_blocks: int = 4,
|
337 |
+
num_sub_blocks: int = 8,
|
338 |
+
sub_blocks_channels: int = 512,
|
339 |
+
sub_blocks_kernel_size: int = 3,
|
340 |
+
norm_type: str = "gLN",
|
341 |
+
causal: bool = False,
|
342 |
+
mask_nonlinear: str = "relu",
|
343 |
+
|
344 |
+
):
|
345 |
+
super(ConvTasNet, self).__init__()
|
346 |
+
self.win_size = win_size
|
347 |
+
|
348 |
+
self.freq_bins = freq_bins
|
349 |
+
self.bottleneck_channels = bottleneck_channels
|
350 |
+
self.num_speakers = num_speakers
|
351 |
+
|
352 |
+
self.num_blocks = num_blocks
|
353 |
+
self.num_sub_blocks = num_sub_blocks
|
354 |
+
self.sub_blocks_channels = sub_blocks_channels
|
355 |
+
self.sub_blocks_kernel_size = sub_blocks_kernel_size
|
356 |
+
|
357 |
+
self.norm_type = norm_type
|
358 |
+
self.causal = causal
|
359 |
+
self.mask_nonlinear = mask_nonlinear
|
360 |
+
|
361 |
+
self.encoder = Encoder(win_size, freq_bins)
|
362 |
+
self.separator = TemporalConvNet(
|
363 |
+
freq_bins=freq_bins,
|
364 |
+
bottleneck_channels=bottleneck_channels,
|
365 |
+
sub_blocks_channels=sub_blocks_channels,
|
366 |
+
sub_blocks_kernel_size=sub_blocks_kernel_size,
|
367 |
+
num_sub_blocks=num_sub_blocks,
|
368 |
+
num_blocks=num_blocks,
|
369 |
+
num_speakers=num_speakers,
|
370 |
+
norm_type=norm_type,
|
371 |
+
causal=causal,
|
372 |
+
mask_nonlinear=mask_nonlinear,
|
373 |
+
)
|
374 |
+
self.decoder = Decoder(win_size=win_size, freq_bins=freq_bins)
|
375 |
+
|
376 |
+
for p in self.parameters():
|
377 |
+
if p.dim() > 1:
|
378 |
+
nn.init.xavier_normal_(p)
|
379 |
+
|
380 |
+
def forward(self, mixture: torch.Tensor):
|
381 |
+
"""
|
382 |
+
:param mixture: Tensor, shape: [batch_size, num_samples]
|
383 |
+
:return: est_source: Tensor, shape: [batch_size, c, num_samples]
|
384 |
+
"""
|
385 |
+
# mixture shape: [batch_size, num_samples]
|
386 |
+
mixture_w = self.encoder.forward(mixture)
|
387 |
+
# mixture_w shape: [batch_size, freq_bins, time_steps]
|
388 |
+
est_mask = self.separator.forward(mixture_w)
|
389 |
+
# est_mask shape: [batch_size, num_speakers, freq_bins, time_steps]
|
390 |
+
est_source = self.decoder.forward(mixture_w, est_mask)
|
391 |
+
|
392 |
+
num_samples1 = mixture.size(-1)
|
393 |
+
num_samples2 = est_source.size(-1)
|
394 |
+
est_source = F.pad(est_source, (0, num_samples1 - num_samples2))
|
395 |
+
return est_source
|
396 |
+
|
397 |
+
|
398 |
+
MODEL_FILE = "model.pt"
|
399 |
+
|
400 |
+
|
401 |
+
class ConvTasNetPretrainedModel(ConvTasNet):
|
402 |
+
def __init__(self,
|
403 |
+
config: ConvTasNetConfig,
|
404 |
+
):
|
405 |
+
super(ConvTasNetPretrainedModel, self).__init__(
|
406 |
+
win_size=config.win_size,
|
407 |
+
freq_bins=config.freq_bins,
|
408 |
+
bottleneck_channels=config.bottleneck_channels,
|
409 |
+
sub_blocks_channels=config.sub_blocks_channels,
|
410 |
+
sub_blocks_kernel_size=config.sub_blocks_kernel_size,
|
411 |
+
num_sub_blocks=config.num_sub_blocks,
|
412 |
+
num_blocks=config.num_blocks,
|
413 |
+
num_speakers=config.num_speakers,
|
414 |
+
norm_type=config.norm_type,
|
415 |
+
causal=config.causal,
|
416 |
+
mask_nonlinear=config.mask_nonlinear,
|
417 |
+
)
|
418 |
+
self.config = config
|
419 |
+
|
420 |
+
@classmethod
|
421 |
+
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
|
422 |
+
config = ConvTasNetConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
423 |
+
|
424 |
+
model = cls(config)
|
425 |
+
|
426 |
+
if os.path.isdir(pretrained_model_name_or_path):
|
427 |
+
ckpt_file = os.path.join(pretrained_model_name_or_path, MODEL_FILE)
|
428 |
+
else:
|
429 |
+
ckpt_file = pretrained_model_name_or_path
|
430 |
+
|
431 |
+
with open(ckpt_file, "rb") as f:
|
432 |
+
state_dict = torch.load(f, map_location="cpu", weights_only=True)
|
433 |
+
model.load_state_dict(state_dict, strict=True)
|
434 |
+
return model
|
435 |
+
|
436 |
+
def save_pretrained(self,
|
437 |
+
save_directory: Union[str, os.PathLike],
|
438 |
+
state_dict: Optional[dict] = None,
|
439 |
+
):
|
440 |
+
|
441 |
+
model = self
|
442 |
+
|
443 |
+
if state_dict is None:
|
444 |
+
state_dict = model.state_dict()
|
445 |
+
|
446 |
+
os.makedirs(save_directory, exist_ok=True)
|
447 |
+
|
448 |
+
# save state dict
|
449 |
+
model_file = os.path.join(save_directory, MODEL_FILE)
|
450 |
+
torch.save(state_dict, model_file)
|
451 |
+
|
452 |
+
# save config
|
453 |
+
config_file = os.path.join(save_directory, CONFIG_FILE)
|
454 |
+
self.config.to_yaml_file(config_file)
|
455 |
+
return save_directory
|
456 |
+
|
457 |
+
|
458 |
+
def main():
|
459 |
+
config = ConvTasNetConfig()
|
460 |
+
tas_net = ConvTasNet(
|
461 |
+
win_size=config.win_size,
|
462 |
+
freq_bins=config.freq_bins,
|
463 |
+
bottleneck_channels=config.bottleneck_channels,
|
464 |
+
sub_blocks_channels=config.sub_blocks_channels,
|
465 |
+
sub_blocks_kernel_size=config.sub_blocks_kernel_size,
|
466 |
+
num_sub_blocks=config.num_sub_blocks,
|
467 |
+
num_blocks=config.num_blocks,
|
468 |
+
num_speakers=config.num_speakers,
|
469 |
+
norm_type=config.norm_type,
|
470 |
+
causal=config.causal,
|
471 |
+
mask_nonlinear=config.mask_nonlinear,
|
472 |
+
)
|
473 |
+
|
474 |
+
print(tas_net)
|
475 |
+
|
476 |
+
mixture = torch.rand(size=(1, 8000*4), dtype=torch.float32)
|
477 |
+
|
478 |
+
outputs = tas_net.forward(mixture)
|
479 |
+
print(outputs.shape)
|
480 |
+
return
|
481 |
|
482 |
|
483 |
+
if __name__ == "__main__":
|
484 |
+
main()
|
toolbox/torchaudio/models/conv_tasnet/utils.py
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
"""
|
4 |
+
https://github.com/kaituoxu/Conv-TasNet/blob/master/src/utils.py
|
5 |
+
"""
|
6 |
+
import math
|
7 |
+
import torch
|
8 |
+
|
9 |
+
|
10 |
+
def overlap_and_add(signal: torch.Tensor, frame_step: int):
|
11 |
+
"""
|
12 |
+
Reconstructs a signal from a framed representation.
|
13 |
+
|
14 |
+
Adds potentially overlapping frames of a signal with shape
|
15 |
+
`[..., frames, frame_length]`, offsetting subsequent frames by `frame_step`.
|
16 |
+
The resulting tensor has shape `[..., output_size]` where
|
17 |
+
|
18 |
+
output_size = (frames - 1) * frame_step + frame_length
|
19 |
+
|
20 |
+
Based on https://github.com/tensorflow/tensorflow/blob/r1.12/tensorflow/contrib/signal/python/ops/reconstruction_ops.py
|
21 |
+
|
22 |
+
:param signal: Tensor, shape: [..., frames, frame_length]. All dimensions may be unknown, and rank must be at least 2.
|
23 |
+
:param frame_step: int, overlap offsets. Must be less than or equal to frame_length.
|
24 |
+
:return: Tensor, shape: [..., output_size].
|
25 |
+
containing the overlap-added frames of signal's inner-most two dimensions.
|
26 |
+
output_size = (frames - 1) * frame_step + frame_length
|
27 |
+
"""
|
28 |
+
outer_dimensions = signal.size()[:-2]
|
29 |
+
frames, frame_length = signal.size()[-2:]
|
30 |
+
|
31 |
+
subframe_length = math.gcd(frame_length, frame_step) # gcd=Greatest Common Divisor
|
32 |
+
subframe_step = frame_step // subframe_length
|
33 |
+
subframes_per_frame = frame_length // subframe_length
|
34 |
+
|
35 |
+
output_size = frame_step * (frames - 1) + frame_length
|
36 |
+
output_subframes = output_size // subframe_length
|
37 |
+
|
38 |
+
subframe_signal = signal.view(*outer_dimensions, -1, subframe_length)
|
39 |
+
|
40 |
+
frame = torch.arange(0, output_subframes).unfold(0, subframes_per_frame, subframe_step)
|
41 |
+
|
42 |
+
frame = frame.clone().detach()
|
43 |
+
frame = frame.to(signal.device)
|
44 |
+
frame = frame.long()
|
45 |
+
|
46 |
+
frame = frame.contiguous().view(-1)
|
47 |
+
|
48 |
+
result = signal.new_zeros(*outer_dimensions, output_subframes, subframe_length)
|
49 |
+
result.index_add_(-2, frame, subframe_signal)
|
50 |
+
result = result.view(*outer_dimensions, -1)
|
51 |
+
return result
|
52 |
+
|
53 |
+
|
54 |
+
if __name__ == "__main__":
|
55 |
+
pass
|
toolbox/torchaudio/models/conv_tasnet/yaml/config.yaml
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model_name: "conv_tasnet"
|
2 |
+
|
3 |
+
sample_rate: 8000
|
4 |
+
segment_size: 4
|
5 |
+
|
6 |
+
win_size: 20
|
7 |
+
freq_bins: 256
|
8 |
+
bottleneck_channels: 256
|
9 |
+
num_speakers: 2
|
10 |
+
num_blocks: 4
|
11 |
+
num_sub_blocks: 8
|
12 |
+
sub_blocks_channels: 512
|
13 |
+
sub_blocks_kernel_size: 3
|
14 |
+
|
15 |
+
norm_type: "gLN"
|
16 |
+
causal: false
|
17 |
+
mask_nonlinear: "relu"
|
toolbox/torchaudio/models/demucs/__init__.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
|
4 |
+
|
5 |
+
if __name__ == '__main__':
|
6 |
+
pass
|
toolbox/torchaudio/models/demucs/configuration_demucs.py
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
from toolbox.torchaudio.configuration_utils import PretrainedConfig
|
4 |
+
|
5 |
+
|
6 |
+
class DemucsConfig(PretrainedConfig):
|
7 |
+
def __init__(self,
|
8 |
+
sample_rate: int = 8000,
|
9 |
+
|
10 |
+
in_channels: int = 1,
|
11 |
+
out_channels: int = 1,
|
12 |
+
hidden_channels: int = 48,
|
13 |
+
|
14 |
+
depth: int = 5,
|
15 |
+
kernel_size: int = 8,
|
16 |
+
stride: int = 4,
|
17 |
+
|
18 |
+
causal: bool = True,
|
19 |
+
resample: int = 4,
|
20 |
+
growth: int = 2,
|
21 |
+
|
22 |
+
max_hidden: int = 10_000,
|
23 |
+
do_normalize: bool = True,
|
24 |
+
rescale: float = 0.1,
|
25 |
+
floor: float = 1e-3,
|
26 |
+
|
27 |
+
**kwargs
|
28 |
+
):
|
29 |
+
super(DemucsConfig, self).__init__(**kwargs)
|
30 |
+
self.sample_rate = sample_rate
|
31 |
+
|
32 |
+
self.in_channels = in_channels
|
33 |
+
self.out_channels = out_channels
|
34 |
+
self.hidden_channels = hidden_channels
|
35 |
+
|
36 |
+
self.depth = depth
|
37 |
+
self.kernel_size = kernel_size
|
38 |
+
self.stride = stride
|
39 |
+
|
40 |
+
self.causal = causal
|
41 |
+
self.resample = resample
|
42 |
+
self.growth = growth
|
43 |
+
|
44 |
+
self.max_hidden = max_hidden
|
45 |
+
self.do_normalize = do_normalize
|
46 |
+
self.rescale = rescale
|
47 |
+
self.floor = floor
|
48 |
+
|
49 |
+
|
50 |
+
if __name__ == "__main__":
|
51 |
+
pass
|
toolbox/torchaudio/models/demucs/modeling_demucs.py
ADDED
@@ -0,0 +1,299 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
"""
|
4 |
+
https://arxiv.org/abs/2006.12847
|
5 |
+
|
6 |
+
https://github.com/facebookresearch/denoiser
|
7 |
+
"""
|
8 |
+
import math
|
9 |
+
import os
|
10 |
+
from typing import List, Optional, Union
|
11 |
+
|
12 |
+
import torch
|
13 |
+
import torch.nn as nn
|
14 |
+
from torch.nn import functional as F
|
15 |
+
|
16 |
+
from toolbox.torchaudio.configuration_utils import CONFIG_FILE
|
17 |
+
from toolbox.torchaudio.models.demucs.configuration_demucs import DemucsConfig
|
18 |
+
from toolbox.torchaudio.models.demucs.resample import upsample2, downsample2
|
19 |
+
|
20 |
+
|
21 |
+
activation_layer_dict = {
|
22 |
+
"glu": nn.GLU,
|
23 |
+
"relu": nn.ReLU,
|
24 |
+
"identity": nn.Identity,
|
25 |
+
"sigmoid": nn.Sigmoid,
|
26 |
+
}
|
27 |
+
|
28 |
+
|
29 |
+
class BLSTM(nn.Module):
|
30 |
+
def __init__(self,
|
31 |
+
hidden_size: int,
|
32 |
+
num_layers: int = 2,
|
33 |
+
bidirectional: bool = True,
|
34 |
+
):
|
35 |
+
super().__init__()
|
36 |
+
self.lstm = nn.LSTM(bidirectional=bidirectional,
|
37 |
+
num_layers=num_layers,
|
38 |
+
hidden_size=hidden_size,
|
39 |
+
input_size=hidden_size
|
40 |
+
)
|
41 |
+
self.linear = None
|
42 |
+
if bidirectional:
|
43 |
+
self.linear = nn.Linear(2 * hidden_size, hidden_size)
|
44 |
+
|
45 |
+
def forward(self,
|
46 |
+
x: torch.Tensor,
|
47 |
+
hx: torch.Tensor = None
|
48 |
+
):
|
49 |
+
x, hx = self.lstm.forward(x, hx)
|
50 |
+
if self.linear:
|
51 |
+
x = self.linear(x)
|
52 |
+
return x, hx
|
53 |
+
|
54 |
+
|
55 |
+
def rescale_conv(conv, reference):
|
56 |
+
std = conv.weight.std().detach()
|
57 |
+
scale = (std / reference)**0.5
|
58 |
+
conv.weight.data /= scale
|
59 |
+
if conv.bias is not None:
|
60 |
+
conv.bias.data /= scale
|
61 |
+
|
62 |
+
|
63 |
+
def rescale_module(module, reference):
|
64 |
+
for sub in module.modules():
|
65 |
+
if isinstance(sub, (nn.Conv1d, nn.ConvTranspose1d)):
|
66 |
+
rescale_conv(sub, reference)
|
67 |
+
|
68 |
+
|
69 |
+
class DemucsModel(nn.Module):
|
70 |
+
def __init__(self,
|
71 |
+
in_channels: int = 1,
|
72 |
+
out_channels: int = 1,
|
73 |
+
hidden_channels: int = 48,
|
74 |
+
depth: int = 5,
|
75 |
+
kernel_size: int = 8,
|
76 |
+
stride: int = 4,
|
77 |
+
causal: bool = True,
|
78 |
+
resample: int = 4,
|
79 |
+
growth: int = 2,
|
80 |
+
max_hidden: int = 10_000,
|
81 |
+
do_normalize: bool = True,
|
82 |
+
rescale: float = 0.1,
|
83 |
+
floor: float = 1e-3,
|
84 |
+
):
|
85 |
+
super(DemucsModel, self).__init__()
|
86 |
+
|
87 |
+
self.in_channels = in_channels
|
88 |
+
self.out_channels = out_channels
|
89 |
+
self.hidden_channels = hidden_channels
|
90 |
+
|
91 |
+
self.depth = depth
|
92 |
+
self.kernel_size = kernel_size
|
93 |
+
self.stride = stride
|
94 |
+
|
95 |
+
self.causal = causal
|
96 |
+
|
97 |
+
self.resample = resample
|
98 |
+
self.growth = growth
|
99 |
+
self.max_hidden = max_hidden
|
100 |
+
self.do_normalize = do_normalize
|
101 |
+
self.rescale = rescale
|
102 |
+
self.floor = floor
|
103 |
+
|
104 |
+
if resample not in [1, 2, 4]:
|
105 |
+
raise ValueError("Resample should be 1, 2 or 4.")
|
106 |
+
|
107 |
+
self.encoder = nn.ModuleList()
|
108 |
+
self.decoder = nn.ModuleList()
|
109 |
+
|
110 |
+
for index in range(depth):
|
111 |
+
encode = []
|
112 |
+
encode += [
|
113 |
+
nn.Conv1d(in_channels, hidden_channels, kernel_size, stride),
|
114 |
+
nn.ReLU(),
|
115 |
+
nn.Conv1d(hidden_channels, hidden_channels * 2, 1),
|
116 |
+
nn.GLU(1),
|
117 |
+
]
|
118 |
+
self.encoder.append(nn.Sequential(*encode))
|
119 |
+
|
120 |
+
decode = []
|
121 |
+
decode += [
|
122 |
+
nn.Conv1d(hidden_channels, 2 * hidden_channels, 1),
|
123 |
+
nn.GLU(1),
|
124 |
+
nn.ConvTranspose1d(hidden_channels, out_channels, kernel_size, stride),
|
125 |
+
]
|
126 |
+
if index > 0:
|
127 |
+
decode.append(nn.ReLU())
|
128 |
+
self.decoder.insert(0, nn.Sequential(*decode))
|
129 |
+
out_channels = hidden_channels
|
130 |
+
in_channels = hidden_channels
|
131 |
+
hidden_channels = min(int(growth * hidden_channels), max_hidden)
|
132 |
+
|
133 |
+
self.lstm = BLSTM(in_channels, bidirectional=not causal)
|
134 |
+
|
135 |
+
if rescale:
|
136 |
+
rescale_module(self, reference=rescale)
|
137 |
+
|
138 |
+
@staticmethod
|
139 |
+
def valid_length(length: int, depth: int, kernel_size: int, stride: int, resample: int):
|
140 |
+
"""
|
141 |
+
Return the nearest valid length to use with the model so that
|
142 |
+
there is no time steps left over in a convolutions, e.g. for all
|
143 |
+
layers, size of the input - kernel_size % stride = 0.
|
144 |
+
|
145 |
+
If the mixture has a valid length, the estimated sources
|
146 |
+
will have exactly the same length.
|
147 |
+
"""
|
148 |
+
length = math.ceil(length * resample)
|
149 |
+
for idx in range(depth):
|
150 |
+
length = math.ceil((length - kernel_size) / stride) + 1
|
151 |
+
length = max(length, 1)
|
152 |
+
for idx in range(depth):
|
153 |
+
length = (length - 1) * stride + kernel_size
|
154 |
+
length = int(math.ceil(length / resample))
|
155 |
+
return int(length)
|
156 |
+
|
157 |
+
def forward(self, noisy: torch.Tensor):
|
158 |
+
"""
|
159 |
+
:param noisy: Tensor, shape: [batch_size, num_samples] or [batch_size, channels, num_samples]
|
160 |
+
:return:
|
161 |
+
"""
|
162 |
+
if noisy.dim() == 2:
|
163 |
+
noisy = noisy.unsqueeze(1)
|
164 |
+
# noisy shape: [batch_size, channels, num_samples]
|
165 |
+
|
166 |
+
if self.do_normalize:
|
167 |
+
mono = noisy.mean(dim=1, keepdim=True)
|
168 |
+
std = mono.std(dim=-1, keepdim=True)
|
169 |
+
noisy = noisy / (self.floor + std)
|
170 |
+
else:
|
171 |
+
std = 1
|
172 |
+
|
173 |
+
_, _, length = noisy.shape
|
174 |
+
x = noisy
|
175 |
+
|
176 |
+
length_ = self.valid_length(length, self.depth, self.kernel_size, self.stride, self.resample)
|
177 |
+
x = F.pad(x, (0, length_ - length))
|
178 |
+
|
179 |
+
if self.resample == 2:
|
180 |
+
x = upsample2(x)
|
181 |
+
elif self.resample == 4:
|
182 |
+
x = upsample2(x)
|
183 |
+
x = upsample2(x)
|
184 |
+
|
185 |
+
skips = []
|
186 |
+
for encode in self.encoder:
|
187 |
+
x = encode(x)
|
188 |
+
skips.append(x)
|
189 |
+
x = x.permute(2, 0, 1)
|
190 |
+
x, _ = self.lstm(x)
|
191 |
+
x = x.permute(1, 2, 0)
|
192 |
+
|
193 |
+
for decode in self.decoder:
|
194 |
+
skip = skips.pop(-1)
|
195 |
+
x = x + skip[..., :x.shape[-1]]
|
196 |
+
x = decode(x)
|
197 |
+
|
198 |
+
if self.resample == 2:
|
199 |
+
x = downsample2(x)
|
200 |
+
elif self.resample == 4:
|
201 |
+
x = downsample2(x)
|
202 |
+
x = downsample2(x)
|
203 |
+
|
204 |
+
x = x[..., :length]
|
205 |
+
return std * x
|
206 |
+
|
207 |
+
|
208 |
+
MODEL_FILE = "model.pt"
|
209 |
+
|
210 |
+
|
211 |
+
class DemucsPretrainedModel(DemucsModel):
|
212 |
+
def __init__(self,
|
213 |
+
config: DemucsConfig,
|
214 |
+
):
|
215 |
+
super(DemucsPretrainedModel, self).__init__(
|
216 |
+
# sample_rate=config.sample_rate,
|
217 |
+
in_channels=config.in_channels,
|
218 |
+
out_channels=config.out_channels,
|
219 |
+
hidden_channels=config.hidden_channels,
|
220 |
+
depth=config.depth,
|
221 |
+
kernel_size=config.kernel_size,
|
222 |
+
stride=config.stride,
|
223 |
+
causal=config.causal,
|
224 |
+
resample=config.resample,
|
225 |
+
growth=config.growth,
|
226 |
+
max_hidden=config.max_hidden,
|
227 |
+
do_normalize=config.do_normalize,
|
228 |
+
rescale=config.rescale,
|
229 |
+
floor=config.floor,
|
230 |
+
)
|
231 |
+
self.config = config
|
232 |
+
|
233 |
+
@classmethod
|
234 |
+
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
|
235 |
+
config = DemucsConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
236 |
+
|
237 |
+
model = cls(config)
|
238 |
+
|
239 |
+
if os.path.isdir(pretrained_model_name_or_path):
|
240 |
+
ckpt_file = os.path.join(pretrained_model_name_or_path, MODEL_FILE)
|
241 |
+
else:
|
242 |
+
ckpt_file = pretrained_model_name_or_path
|
243 |
+
|
244 |
+
with open(ckpt_file, "rb") as f:
|
245 |
+
state_dict = torch.load(f, map_location="cpu", weights_only=True)
|
246 |
+
model.load_state_dict(state_dict, strict=True)
|
247 |
+
return model
|
248 |
+
|
249 |
+
def save_pretrained(self,
|
250 |
+
save_directory: Union[str, os.PathLike],
|
251 |
+
state_dict: Optional[dict] = None,
|
252 |
+
):
|
253 |
+
|
254 |
+
model = self
|
255 |
+
|
256 |
+
if state_dict is None:
|
257 |
+
state_dict = model.state_dict()
|
258 |
+
|
259 |
+
os.makedirs(save_directory, exist_ok=True)
|
260 |
+
|
261 |
+
# save state dict
|
262 |
+
model_file = os.path.join(save_directory, MODEL_FILE)
|
263 |
+
torch.save(state_dict, model_file)
|
264 |
+
|
265 |
+
# save config
|
266 |
+
config_file = os.path.join(save_directory, CONFIG_FILE)
|
267 |
+
self.config.to_yaml_file(config_file)
|
268 |
+
return save_directory
|
269 |
+
|
270 |
+
|
271 |
+
def main():
|
272 |
+
config = DemucsConfig()
|
273 |
+
model = DemucsModel(
|
274 |
+
in_channels=config.in_channels,
|
275 |
+
out_channels=config.out_channels,
|
276 |
+
hidden_channels=config.hidden_channels,
|
277 |
+
depth=config.depth,
|
278 |
+
kernel_size=config.kernel_size,
|
279 |
+
stride=config.stride,
|
280 |
+
causal=config.causal,
|
281 |
+
resample=config.resample,
|
282 |
+
growth=config.growth,
|
283 |
+
max_hidden=config.max_hidden,
|
284 |
+
do_normalize=config.do_normalize,
|
285 |
+
rescale=config.rescale,
|
286 |
+
floor=config.floor,
|
287 |
+
)
|
288 |
+
|
289 |
+
print(model)
|
290 |
+
|
291 |
+
noisy = torch.rand(size=(1, 8000*4), dtype=torch.float32)
|
292 |
+
|
293 |
+
denoise = model.forward(noisy)
|
294 |
+
print(denoise.shape)
|
295 |
+
return
|
296 |
+
|
297 |
+
|
298 |
+
if __name__ == "__main__":
|
299 |
+
main()
|
toolbox/torchaudio/models/demucs/resample.py
ADDED
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
4 |
+
# All rights reserved.
|
5 |
+
#
|
6 |
+
# This source code is licensed under the license found in the
|
7 |
+
# LICENSE file in the root directory of this source tree.
|
8 |
+
# author: adefossez
|
9 |
+
|
10 |
+
import math
|
11 |
+
|
12 |
+
import torch as th
|
13 |
+
from torch.nn import functional as F
|
14 |
+
|
15 |
+
|
16 |
+
def sinc(t):
|
17 |
+
"""sinc.
|
18 |
+
|
19 |
+
:param t: the input tensor
|
20 |
+
"""
|
21 |
+
return th.where(t == 0, th.tensor(1., device=t.device, dtype=t.dtype), th.sin(t) / t)
|
22 |
+
|
23 |
+
|
24 |
+
def kernel_upsample2(zeros=56):
|
25 |
+
"""kernel_upsample2.
|
26 |
+
|
27 |
+
"""
|
28 |
+
win = th.hann_window(4 * zeros + 1, periodic=False)
|
29 |
+
winodd = win[1::2]
|
30 |
+
t = th.linspace(-zeros + 0.5, zeros - 0.5, 2 * zeros)
|
31 |
+
t *= math.pi
|
32 |
+
kernel = (sinc(t) * winodd).view(1, 1, -1)
|
33 |
+
return kernel
|
34 |
+
|
35 |
+
|
36 |
+
def upsample2(x, zeros=56):
|
37 |
+
"""
|
38 |
+
Upsampling the input by 2 using sinc interpolation.
|
39 |
+
Smith, Julius, and Phil Gossett. "A flexible sampling-rate conversion method."
|
40 |
+
ICASSP'84. IEEE International Conference on Acoustics, Speech, and Signal Processing.
|
41 |
+
Vol. 9. IEEE, 1984.
|
42 |
+
"""
|
43 |
+
*other, time = x.shape
|
44 |
+
kernel = kernel_upsample2(zeros).to(x)
|
45 |
+
out = F.conv1d(x.view(-1, 1, time), kernel, padding=zeros)[..., 1:].view(*other, time)
|
46 |
+
y = th.stack([x, out], dim=-1)
|
47 |
+
return y.view(*other, -1)
|
48 |
+
|
49 |
+
|
50 |
+
def kernel_downsample2(zeros=56):
|
51 |
+
"""kernel_downsample2.
|
52 |
+
|
53 |
+
"""
|
54 |
+
win = th.hann_window(4 * zeros + 1, periodic=False)
|
55 |
+
winodd = win[1::2]
|
56 |
+
t = th.linspace(-zeros + 0.5, zeros - 0.5, 2 * zeros)
|
57 |
+
t.mul_(math.pi)
|
58 |
+
kernel = (sinc(t) * winodd).view(1, 1, -1)
|
59 |
+
return kernel
|
60 |
+
|
61 |
+
|
62 |
+
def downsample2(x, zeros=56):
|
63 |
+
"""
|
64 |
+
Downsampling the input by 2 using sinc interpolation.
|
65 |
+
Smith, Julius, and Phil Gossett. "A flexible sampling-rate conversion method."
|
66 |
+
ICASSP'84. IEEE International Conference on Acoustics, Speech, and Signal Processing.
|
67 |
+
Vol. 9. IEEE, 1984.
|
68 |
+
"""
|
69 |
+
if x.shape[-1] % 2 != 0:
|
70 |
+
x = F.pad(x, (0, 1))
|
71 |
+
xeven = x[..., ::2]
|
72 |
+
xodd = x[..., 1::2]
|
73 |
+
*other, time = xodd.shape
|
74 |
+
kernel = kernel_downsample2(zeros).to(x)
|
75 |
+
out = xeven + F.conv1d(xodd.view(-1, 1, time), kernel, padding=zeros)[..., :-1].view(
|
76 |
+
*other, time)
|
77 |
+
return out.view(*other, -1).mul(0.5)
|
78 |
+
|
79 |
+
|
80 |
+
if __name__ == "__main__":
|
81 |
+
pass
|
toolbox/torchaudio/models/nx_dfnet/configuration_nx_dfnet.py
ADDED
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
from typing import Tuple
|
4 |
+
|
5 |
+
from toolbox.torchaudio.configuration_utils import PretrainedConfig
|
6 |
+
|
7 |
+
|
8 |
+
class NXDfNetConfig(PretrainedConfig):
|
9 |
+
def __init__(self,
|
10 |
+
sample_rate: int = 8000,
|
11 |
+
freq_bins: int = 256,
|
12 |
+
win_size: int = 200,
|
13 |
+
hop_size: int = 100,
|
14 |
+
|
15 |
+
conv_channels: int = 64,
|
16 |
+
conv_kernel_size_input: Tuple[int, int] = (3, 3),
|
17 |
+
conv_kernel_size_inner: Tuple[int, int] = (1, 3),
|
18 |
+
conv_lookahead: int = 0,
|
19 |
+
|
20 |
+
convt_kernel_size_inner: Tuple[int, int] = (1, 3),
|
21 |
+
|
22 |
+
embedding_hidden_size: int = 256,
|
23 |
+
encoder_combine_op: str = "concat",
|
24 |
+
|
25 |
+
encoder_emb_skip_op: str = "none",
|
26 |
+
encoder_emb_linear_groups: int = 16,
|
27 |
+
encoder_emb_hidden_size: int = 256,
|
28 |
+
|
29 |
+
encoder_linear_groups: int = 32,
|
30 |
+
|
31 |
+
lsnr_max: int = 30,
|
32 |
+
lsnr_min: int = -15,
|
33 |
+
norm_tau: float = 1.,
|
34 |
+
|
35 |
+
decoder_emb_num_layers: int = 3,
|
36 |
+
decoder_emb_skip_op: str = "none",
|
37 |
+
decoder_emb_linear_groups: int = 16,
|
38 |
+
decoder_emb_hidden_size: int = 256,
|
39 |
+
|
40 |
+
df_decoder_hidden_size: int = 256,
|
41 |
+
df_num_layers: int = 2,
|
42 |
+
df_order: int = 5,
|
43 |
+
df_bins: int = 96,
|
44 |
+
df_gru_skip: str = "grouped_linear",
|
45 |
+
df_decoder_linear_groups: int = 16,
|
46 |
+
df_pathway_kernel_size_t: int = 5,
|
47 |
+
df_lookahead: int = 2,
|
48 |
+
|
49 |
+
use_post_filter: bool = False,
|
50 |
+
**kwargs
|
51 |
+
):
|
52 |
+
super(NXDfNetConfig, self).__init__(**kwargs)
|
53 |
+
# transform
|
54 |
+
self.sample_rate = sample_rate
|
55 |
+
self.freq_bins = freq_bins
|
56 |
+
self.win_size = win_size
|
57 |
+
self.hop_size = hop_size
|
58 |
+
|
59 |
+
# conv
|
60 |
+
self.conv_channels = conv_channels
|
61 |
+
self.conv_kernel_size_input = conv_kernel_size_input
|
62 |
+
self.conv_kernel_size_inner = conv_kernel_size_inner
|
63 |
+
self.conv_lookahead = conv_lookahead
|
64 |
+
|
65 |
+
self.convt_kernel_size_inner = convt_kernel_size_inner
|
66 |
+
|
67 |
+
self.embedding_hidden_size = embedding_hidden_size
|
68 |
+
|
69 |
+
# encoder
|
70 |
+
self.encoder_emb_skip_op = encoder_emb_skip_op
|
71 |
+
self.encoder_emb_linear_groups = encoder_emb_linear_groups
|
72 |
+
self.encoder_emb_hidden_size = encoder_emb_hidden_size
|
73 |
+
|
74 |
+
self.encoder_linear_groups = encoder_linear_groups
|
75 |
+
self.encoder_combine_op = encoder_combine_op
|
76 |
+
|
77 |
+
self.lsnr_max = lsnr_max
|
78 |
+
self.lsnr_min = lsnr_min
|
79 |
+
self.norm_tau = norm_tau
|
80 |
+
|
81 |
+
# decoder
|
82 |
+
self.decoder_emb_num_layers = decoder_emb_num_layers
|
83 |
+
self.decoder_emb_skip_op = decoder_emb_skip_op
|
84 |
+
self.decoder_emb_linear_groups = decoder_emb_linear_groups
|
85 |
+
self.decoder_emb_hidden_size = decoder_emb_hidden_size
|
86 |
+
|
87 |
+
# df decoder
|
88 |
+
self.df_decoder_hidden_size = df_decoder_hidden_size
|
89 |
+
self.df_num_layers = df_num_layers
|
90 |
+
self.df_order = df_order
|
91 |
+
self.df_bins = df_bins
|
92 |
+
self.df_gru_skip = df_gru_skip
|
93 |
+
self.df_decoder_linear_groups = df_decoder_linear_groups
|
94 |
+
self.df_pathway_kernel_size_t = df_pathway_kernel_size_t
|
95 |
+
self.df_lookahead = df_lookahead
|
96 |
+
|
97 |
+
# runtime
|
98 |
+
self.use_post_filter = use_post_filter
|
99 |
+
|
100 |
+
|
101 |
+
if __name__ == "__main__":
|
102 |
+
pass
|
toolbox/torchaudio/models/nx_dfnet/modeling_nx_dfnet.py
ADDED
@@ -0,0 +1,989 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
import os
|
4 |
+
import math
|
5 |
+
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
|
6 |
+
|
7 |
+
import numpy as np
|
8 |
+
import torch
|
9 |
+
import torch.nn as nn
|
10 |
+
from torch.nn import functional as F
|
11 |
+
import torchaudio
|
12 |
+
|
13 |
+
from toolbox.torchaudio.models.nx_dfnet.utils import overlap_and_add
|
14 |
+
from toolbox.torchaudio.models.nx_dfnet.configuration_nx_dfnet import NXDfNetConfig
|
15 |
+
from toolbox.torchaudio.configuration_utils import CONFIG_FILE
|
16 |
+
|
17 |
+
|
18 |
+
MODEL_FILE = "model.pt"
|
19 |
+
|
20 |
+
|
21 |
+
norm_layer_dict = {
|
22 |
+
"batch_norm_2d": torch.nn.BatchNorm2d
|
23 |
+
}
|
24 |
+
|
25 |
+
|
26 |
+
activation_layer_dict = {
|
27 |
+
"relu": torch.nn.ReLU,
|
28 |
+
"identity": torch.nn.Identity,
|
29 |
+
"sigmoid": torch.nn.Sigmoid,
|
30 |
+
}
|
31 |
+
|
32 |
+
|
33 |
+
class CausalConv2d(nn.Sequential):
|
34 |
+
def __init__(self,
|
35 |
+
in_channels: int,
|
36 |
+
out_channels: int,
|
37 |
+
kernel_size: Union[int, Iterable[int]],
|
38 |
+
fstride: int = 1,
|
39 |
+
dilation: int = 1,
|
40 |
+
fpad: bool = True,
|
41 |
+
bias: bool = True,
|
42 |
+
separable: bool = False,
|
43 |
+
norm_layer: str = "batch_norm_2d",
|
44 |
+
activation_layer: str = "relu",
|
45 |
+
lookahead: int = 0
|
46 |
+
):
|
47 |
+
"""
|
48 |
+
Causal Conv2d by delaying the signal for any lookahead.
|
49 |
+
|
50 |
+
Expected input format: [batch_size, channels, time_steps, spec_dim]
|
51 |
+
|
52 |
+
:param in_channels:
|
53 |
+
:param out_channels:
|
54 |
+
:param kernel_size:
|
55 |
+
:param fstride:
|
56 |
+
:param dilation:
|
57 |
+
:param fpad:
|
58 |
+
"""
|
59 |
+
super(CausalConv2d, self).__init__()
|
60 |
+
kernel_size = (kernel_size, kernel_size) if isinstance(kernel_size, int) else tuple(kernel_size)
|
61 |
+
|
62 |
+
if fpad:
|
63 |
+
fpad_ = kernel_size[1] // 2 + dilation - 1
|
64 |
+
else:
|
65 |
+
fpad_ = 0
|
66 |
+
|
67 |
+
# for last 2 dim, pad (left, right, top, bottom).
|
68 |
+
pad = (0, 0, kernel_size[0] - 1 - lookahead, lookahead)
|
69 |
+
|
70 |
+
layers = list()
|
71 |
+
if any(x > 0 for x in pad):
|
72 |
+
layers.append(nn.ConstantPad2d(pad, 0.0))
|
73 |
+
|
74 |
+
groups = math.gcd(in_channels, out_channels) if separable else 1
|
75 |
+
if groups == 1:
|
76 |
+
separable = False
|
77 |
+
if max(kernel_size) == 1:
|
78 |
+
separable = False
|
79 |
+
|
80 |
+
layers.append(
|
81 |
+
nn.Conv2d(
|
82 |
+
in_channels,
|
83 |
+
out_channels,
|
84 |
+
kernel_size=kernel_size,
|
85 |
+
padding=(0, fpad_),
|
86 |
+
stride=(1, fstride), # stride over time is always 1
|
87 |
+
dilation=(1, dilation), # dilation over time is always 1
|
88 |
+
groups=groups,
|
89 |
+
bias=bias,
|
90 |
+
)
|
91 |
+
)
|
92 |
+
|
93 |
+
if separable:
|
94 |
+
layers.append(
|
95 |
+
nn.Conv2d(
|
96 |
+
out_channels,
|
97 |
+
out_channels,
|
98 |
+
kernel_size=1,
|
99 |
+
bias=False,
|
100 |
+
)
|
101 |
+
)
|
102 |
+
|
103 |
+
if norm_layer is not None:
|
104 |
+
norm_layer = norm_layer_dict[norm_layer]
|
105 |
+
layers.append(norm_layer(out_channels))
|
106 |
+
|
107 |
+
if activation_layer is not None:
|
108 |
+
activation_layer = activation_layer_dict[activation_layer]
|
109 |
+
layers.append(activation_layer())
|
110 |
+
|
111 |
+
super().__init__(*layers)
|
112 |
+
|
113 |
+
def forward(self, inputs):
|
114 |
+
for module in self:
|
115 |
+
inputs = module(inputs)
|
116 |
+
return inputs
|
117 |
+
|
118 |
+
|
119 |
+
class CausalConvTranspose2d(nn.Sequential):
|
120 |
+
def __init__(self,
|
121 |
+
in_channels: int,
|
122 |
+
out_channels: int,
|
123 |
+
kernel_size: Union[int, Iterable[int]],
|
124 |
+
fstride: int = 1,
|
125 |
+
dilation: int = 1,
|
126 |
+
fpad: bool = True,
|
127 |
+
bias: bool = True,
|
128 |
+
separable: bool = False,
|
129 |
+
norm_layer: str = "batch_norm_2d",
|
130 |
+
activation_layer: str = "relu",
|
131 |
+
lookahead: int = 0
|
132 |
+
):
|
133 |
+
"""
|
134 |
+
Causal ConvTranspose2d.
|
135 |
+
|
136 |
+
Expected input format: [batch_size, channels, time_steps, spec_dim]
|
137 |
+
"""
|
138 |
+
super(CausalConvTranspose2d, self).__init__()
|
139 |
+
|
140 |
+
kernel_size = (kernel_size, kernel_size) if isinstance(kernel_size, int) else kernel_size
|
141 |
+
|
142 |
+
if fpad:
|
143 |
+
fpad_ = kernel_size[1] // 2
|
144 |
+
else:
|
145 |
+
fpad_ = 0
|
146 |
+
|
147 |
+
# for last 2 dim, pad (left, right, top, bottom).
|
148 |
+
pad = (0, 0, kernel_size[0] - 1 - lookahead, lookahead)
|
149 |
+
|
150 |
+
layers = []
|
151 |
+
if any(x > 0 for x in pad):
|
152 |
+
layers.append(nn.ConstantPad2d(pad, 0.0))
|
153 |
+
|
154 |
+
groups = math.gcd(in_channels, out_channels) if separable else 1
|
155 |
+
if groups == 1:
|
156 |
+
separable = False
|
157 |
+
|
158 |
+
layers.append(
|
159 |
+
nn.ConvTranspose2d(
|
160 |
+
in_channels,
|
161 |
+
out_channels,
|
162 |
+
kernel_size=kernel_size,
|
163 |
+
padding=(kernel_size[0] - 1, fpad_ + dilation - 1),
|
164 |
+
output_padding=(0, fpad_),
|
165 |
+
stride=(1, fstride), # stride over time is always 1
|
166 |
+
dilation=(1, dilation), # dilation over time is always 1
|
167 |
+
groups=groups,
|
168 |
+
bias=bias,
|
169 |
+
)
|
170 |
+
)
|
171 |
+
|
172 |
+
if separable:
|
173 |
+
layers.append(
|
174 |
+
nn.Conv2d(
|
175 |
+
out_channels,
|
176 |
+
out_channels,
|
177 |
+
kernel_size=1,
|
178 |
+
bias=False,
|
179 |
+
)
|
180 |
+
)
|
181 |
+
|
182 |
+
if norm_layer is not None:
|
183 |
+
norm_layer = norm_layer_dict[norm_layer]
|
184 |
+
layers.append(norm_layer(out_channels))
|
185 |
+
|
186 |
+
if activation_layer is not None:
|
187 |
+
activation_layer = activation_layer_dict[activation_layer]
|
188 |
+
layers.append(activation_layer())
|
189 |
+
|
190 |
+
super().__init__(*layers)
|
191 |
+
|
192 |
+
|
193 |
+
class GroupedLinear(nn.Module):
|
194 |
+
|
195 |
+
def __init__(self, input_size: int, hidden_size: int, groups: int = 1):
|
196 |
+
super().__init__()
|
197 |
+
# self.weight: Tensor
|
198 |
+
self.input_size = input_size
|
199 |
+
self.hidden_size = hidden_size
|
200 |
+
self.groups = groups
|
201 |
+
assert input_size % groups == 0, f"Input size {input_size} not divisible by {groups}"
|
202 |
+
assert hidden_size % groups == 0, f"Hidden size {hidden_size} not divisible by {groups}"
|
203 |
+
self.ws = input_size // groups
|
204 |
+
self.register_parameter(
|
205 |
+
"weight",
|
206 |
+
torch.nn.Parameter(
|
207 |
+
torch.zeros(groups, input_size // groups, hidden_size // groups), requires_grad=True
|
208 |
+
),
|
209 |
+
)
|
210 |
+
self.reset_parameters()
|
211 |
+
|
212 |
+
def reset_parameters(self):
|
213 |
+
nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) # type: ignore
|
214 |
+
|
215 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
216 |
+
# x: [..., I]
|
217 |
+
b, t, _ = x.shape
|
218 |
+
# new_shape = list(x.shape)[:-1] + [self.groups, self.ws]
|
219 |
+
new_shape = (b, t, self.groups, self.ws)
|
220 |
+
x = x.view(new_shape)
|
221 |
+
# The better way, but not supported by torchscript
|
222 |
+
# x = x.unflatten(-1, (self.groups, self.ws)) # [..., G, I/G]
|
223 |
+
x = torch.einsum("btgi,gih->btgh", x, self.weight) # [..., G, H/G]
|
224 |
+
x = x.flatten(2, 3) # [B, T, H]
|
225 |
+
return x
|
226 |
+
|
227 |
+
def __repr__(self):
|
228 |
+
cls = self.__class__.__name__
|
229 |
+
return f"{cls}(input_size: {self.input_size}, hidden_size: {self.hidden_size}, groups: {self.groups})"
|
230 |
+
|
231 |
+
|
232 |
+
class SqueezedGRU_S(nn.Module):
|
233 |
+
"""
|
234 |
+
SGE net: Video object detection with squeezed GRU and information entropy map
|
235 |
+
https://arxiv.org/abs/2106.07224
|
236 |
+
"""
|
237 |
+
|
238 |
+
def __init__(
|
239 |
+
self,
|
240 |
+
input_size: int,
|
241 |
+
hidden_size: int,
|
242 |
+
output_size: Optional[int] = None,
|
243 |
+
num_layers: int = 1,
|
244 |
+
linear_groups: int = 8,
|
245 |
+
batch_first: bool = True,
|
246 |
+
skip_op: str = "none",
|
247 |
+
activation_layer: str = "identity",
|
248 |
+
):
|
249 |
+
super().__init__()
|
250 |
+
self.input_size = input_size
|
251 |
+
self.hidden_size = hidden_size
|
252 |
+
|
253 |
+
self.linear_in = nn.Sequential(
|
254 |
+
GroupedLinear(
|
255 |
+
input_size=input_size,
|
256 |
+
hidden_size=hidden_size,
|
257 |
+
groups=linear_groups,
|
258 |
+
),
|
259 |
+
activation_layer_dict[activation_layer](),
|
260 |
+
)
|
261 |
+
|
262 |
+
# gru skip operator
|
263 |
+
self.gru_skip_op = None
|
264 |
+
|
265 |
+
if skip_op == "none":
|
266 |
+
self.gru_skip_op = None
|
267 |
+
elif skip_op == "identity":
|
268 |
+
if not input_size != output_size:
|
269 |
+
raise AssertionError("Dimensions do not match")
|
270 |
+
self.gru_skip_op = nn.Identity()
|
271 |
+
elif skip_op == "grouped_linear":
|
272 |
+
self.gru_skip_op = GroupedLinear(
|
273 |
+
input_size=hidden_size,
|
274 |
+
hidden_size=hidden_size,
|
275 |
+
groups=linear_groups,
|
276 |
+
)
|
277 |
+
else:
|
278 |
+
raise NotImplementedError()
|
279 |
+
|
280 |
+
self.gru = nn.GRU(
|
281 |
+
input_size=hidden_size,
|
282 |
+
hidden_size=hidden_size,
|
283 |
+
num_layers=num_layers,
|
284 |
+
batch_first=batch_first,
|
285 |
+
bidirectional=False,
|
286 |
+
)
|
287 |
+
|
288 |
+
if output_size is not None:
|
289 |
+
self.linear_out = nn.Sequential(
|
290 |
+
GroupedLinear(
|
291 |
+
input_size=hidden_size,
|
292 |
+
hidden_size=output_size,
|
293 |
+
groups=linear_groups,
|
294 |
+
),
|
295 |
+
activation_layer_dict[activation_layer](),
|
296 |
+
)
|
297 |
+
else:
|
298 |
+
self.linear_out = nn.Identity()
|
299 |
+
|
300 |
+
def forward(self, inputs: torch.Tensor, h=None) -> Tuple[torch.Tensor, torch.Tensor]:
|
301 |
+
x = self.linear_in(inputs)
|
302 |
+
|
303 |
+
x, h = self.gru.forward(x, h)
|
304 |
+
|
305 |
+
x = self.linear_out(x)
|
306 |
+
|
307 |
+
if self.gru_skip_op is not None:
|
308 |
+
x = x + self.gru_skip_op(inputs)
|
309 |
+
|
310 |
+
return x, h
|
311 |
+
|
312 |
+
|
313 |
+
class Add(nn.Module):
|
314 |
+
def forward(self, a, b):
|
315 |
+
return a + b
|
316 |
+
|
317 |
+
|
318 |
+
class Concat(nn.Module):
|
319 |
+
def forward(self, a, b):
|
320 |
+
return torch.cat((a, b), dim=-1)
|
321 |
+
|
322 |
+
|
323 |
+
class DeepSTFT(nn.Module):
|
324 |
+
def __init__(self, win_size: int, freq_bins: int):
|
325 |
+
super(DeepSTFT, self).__init__()
|
326 |
+
self.win_size = win_size
|
327 |
+
self.freq_bins = freq_bins
|
328 |
+
|
329 |
+
self.conv1d_U = nn.Conv1d(
|
330 |
+
in_channels=1,
|
331 |
+
out_channels=freq_bins * 2,
|
332 |
+
kernel_size=win_size,
|
333 |
+
stride=win_size // 2,
|
334 |
+
bias=False
|
335 |
+
)
|
336 |
+
|
337 |
+
def forward(self, signal: torch.Tensor):
|
338 |
+
"""
|
339 |
+
:param signal: Tensor, shape: [batch_size, num_samples]
|
340 |
+
:return: v, Tensor, shape: [batch_size, freq_bins, time_steps, 2],
|
341 |
+
where time_steps = (num_samples-win_size) / (win_size/2) + 1 = 2num_samples/win_size-1
|
342 |
+
"""
|
343 |
+
signal = torch.unsqueeze(signal, 1)
|
344 |
+
# signal shape: [batch_size, 1, num_samples]
|
345 |
+
spec = F.relu(self.conv1d_U(signal))
|
346 |
+
# spec shape: [batch_size, freq_bins * 2, time_steps]
|
347 |
+
b, f2, t = spec.shape
|
348 |
+
spec = spec.view(b, f2//2, 2, t).permute(0, 1, 3, 2)
|
349 |
+
# spec shape: [batch_size, freq_bins, time_steps, 2]
|
350 |
+
return spec
|
351 |
+
|
352 |
+
|
353 |
+
class DeepISTFT(nn.Module):
|
354 |
+
def __init__(self, win_size: int, freq_bins: int):
|
355 |
+
super(DeepISTFT, self).__init__()
|
356 |
+
self.win_size = win_size
|
357 |
+
self.freq_bins = freq_bins
|
358 |
+
|
359 |
+
self.basis_signals = nn.Linear(
|
360 |
+
in_features=freq_bins * 2,
|
361 |
+
out_features=win_size,
|
362 |
+
bias=False
|
363 |
+
)
|
364 |
+
|
365 |
+
def forward(self,
|
366 |
+
spec: torch.Tensor,
|
367 |
+
):
|
368 |
+
"""
|
369 |
+
:param spec: Tensor, shape: [batch_size, freq_bins, time_steps, 2],
|
370 |
+
where time_steps = (num_samples-win_size) / (win_size/2) + 1 = 2num_samples/win_size-1
|
371 |
+
:return: Tensor, shape: [batch_size, c, num_samples],
|
372 |
+
"""
|
373 |
+
b, f, t, _ = spec.shape
|
374 |
+
# spec shape: [b, f, t, 2]
|
375 |
+
spec = spec.permute(0, 2, 1, 3)
|
376 |
+
# spec shape: [b, t, f, 2]
|
377 |
+
spec = spec.view(b, 1, t, -1)
|
378 |
+
# spec shape: [b, 1, t, f2]
|
379 |
+
signal = self.basis_signals(spec)
|
380 |
+
# signal shape: [b, 1, t, win_size]
|
381 |
+
signal = overlap_and_add(signal, self.win_size//2)
|
382 |
+
# signal shape: [b, 1, num_samples]
|
383 |
+
return signal
|
384 |
+
|
385 |
+
|
386 |
+
class Encoder(nn.Module):
|
387 |
+
def __init__(self, config: NXDfNetConfig):
|
388 |
+
super(Encoder, self).__init__()
|
389 |
+
self.embedding_input_size = config.conv_channels * config.freq_bins // 4
|
390 |
+
self.embedding_output_size = config.conv_channels * config.freq_bins // 4
|
391 |
+
self.embedding_hidden_size = config.embedding_hidden_size
|
392 |
+
|
393 |
+
self.spec_conv0 = CausalConv2d(
|
394 |
+
in_channels=1,
|
395 |
+
out_channels=config.conv_channels,
|
396 |
+
kernel_size=config.conv_kernel_size_input,
|
397 |
+
bias=False,
|
398 |
+
separable=True,
|
399 |
+
fstride=1,
|
400 |
+
lookahead=config.conv_lookahead,
|
401 |
+
)
|
402 |
+
self.spec_conv1 = CausalConv2d(
|
403 |
+
in_channels=config.conv_channels,
|
404 |
+
out_channels=config.conv_channels,
|
405 |
+
kernel_size=config.conv_kernel_size_inner,
|
406 |
+
bias=False,
|
407 |
+
separable=True,
|
408 |
+
fstride=2,
|
409 |
+
lookahead=config.conv_lookahead,
|
410 |
+
)
|
411 |
+
self.spec_conv2 = CausalConv2d(
|
412 |
+
in_channels=config.conv_channels,
|
413 |
+
out_channels=config.conv_channels,
|
414 |
+
kernel_size=config.conv_kernel_size_inner,
|
415 |
+
bias=False,
|
416 |
+
separable=True,
|
417 |
+
fstride=2,
|
418 |
+
lookahead=config.conv_lookahead,
|
419 |
+
)
|
420 |
+
self.spec_conv3 = CausalConv2d(
|
421 |
+
in_channels=config.conv_channels,
|
422 |
+
out_channels=config.conv_channels,
|
423 |
+
kernel_size=config.conv_kernel_size_inner,
|
424 |
+
bias=False,
|
425 |
+
separable=True,
|
426 |
+
fstride=1,
|
427 |
+
lookahead=config.conv_lookahead,
|
428 |
+
)
|
429 |
+
|
430 |
+
self.df_conv0 = CausalConv2d(
|
431 |
+
in_channels=2,
|
432 |
+
out_channels=config.conv_channels,
|
433 |
+
kernel_size=config.conv_kernel_size_input,
|
434 |
+
bias=False,
|
435 |
+
separable=True,
|
436 |
+
fstride=1,
|
437 |
+
)
|
438 |
+
self.df_conv1 = CausalConv2d(
|
439 |
+
in_channels=config.conv_channels,
|
440 |
+
out_channels=config.conv_channels,
|
441 |
+
kernel_size=config.conv_kernel_size_inner,
|
442 |
+
bias=False,
|
443 |
+
separable=True,
|
444 |
+
fstride=2,
|
445 |
+
)
|
446 |
+
self.df_fc_emb = nn.Sequential(
|
447 |
+
GroupedLinear(
|
448 |
+
config.conv_channels * config.df_bins // 2,
|
449 |
+
self.embedding_input_size,
|
450 |
+
groups=config.encoder_linear_groups
|
451 |
+
),
|
452 |
+
nn.ReLU(inplace=True)
|
453 |
+
)
|
454 |
+
|
455 |
+
if config.encoder_combine_op == "concat":
|
456 |
+
self.embedding_input_size *= 2
|
457 |
+
self.combine = Concat()
|
458 |
+
else:
|
459 |
+
self.combine = Add()
|
460 |
+
|
461 |
+
# emb_gru
|
462 |
+
if config.freq_bins % 8 != 0:
|
463 |
+
raise AssertionError("freq_bins should be divisible by 8")
|
464 |
+
|
465 |
+
self.emb_gru = SqueezedGRU_S(
|
466 |
+
self.embedding_input_size,
|
467 |
+
self.embedding_hidden_size,
|
468 |
+
output_size=self.embedding_output_size,
|
469 |
+
num_layers=1,
|
470 |
+
batch_first=True,
|
471 |
+
skip_op=config.encoder_emb_skip_op,
|
472 |
+
linear_groups=config.encoder_emb_linear_groups,
|
473 |
+
activation_layer="relu",
|
474 |
+
)
|
475 |
+
|
476 |
+
# lsnr
|
477 |
+
self.lsnr_fc = nn.Sequential(
|
478 |
+
nn.Linear(self.embedding_output_size, 1),
|
479 |
+
nn.Sigmoid()
|
480 |
+
)
|
481 |
+
self.lsnr_scale = config.lsnr_max - config.lsnr_min
|
482 |
+
self.lsnr_offset = config.lsnr_min
|
483 |
+
|
484 |
+
def forward(self,
|
485 |
+
power_spec: torch.Tensor,
|
486 |
+
df_spec: torch.Tensor,
|
487 |
+
hidden_state: torch.Tensor = None,
|
488 |
+
):
|
489 |
+
# power_spec shape: (batch_size, 1, time_steps, spec_dim)
|
490 |
+
e0 = self.spec_conv0.forward(power_spec)
|
491 |
+
e1 = self.spec_conv1.forward(e0)
|
492 |
+
e2 = self.spec_conv2.forward(e1)
|
493 |
+
e3 = self.spec_conv3.forward(e2)
|
494 |
+
# e0 shape: [batch_size, channels, time_steps, spec_dim]
|
495 |
+
# e1 shape: [batch_size, channels, time_steps, spec_dim // 2]
|
496 |
+
# e2 shape: [batch_size, channels, time_steps, spec_dim // 4]
|
497 |
+
# e3 shape: [batch_size, channels, time_steps, spec_dim // 4]
|
498 |
+
|
499 |
+
# df_spec, shape: (batch_size, 2, time_steps, df_bins)
|
500 |
+
c0 = self.df_conv0(df_spec)
|
501 |
+
c1 = self.df_conv1(c0)
|
502 |
+
# c0 shape: [batch_size, channels, time_steps, df_bins]
|
503 |
+
# c1 shape: [batch_size, channels, time_steps, df_bins // 2]
|
504 |
+
|
505 |
+
cemb = c1.permute(0, 2, 3, 1)
|
506 |
+
# cemb shape: [batch_size, time_steps, df_bins // 2, channels]
|
507 |
+
cemb = cemb.flatten(2)
|
508 |
+
# cemb shape: [batch_size, time_steps, df_bins // 2 * channels]
|
509 |
+
cemb = self.df_fc_emb(cemb)
|
510 |
+
# cemb shape: [batch_size, time_steps, spec_dim // 4 * channels]
|
511 |
+
|
512 |
+
# e3 shape: [batch_size, channels, time_steps, spec_dim // 4]
|
513 |
+
emb = e3.permute(0, 2, 3, 1)
|
514 |
+
# emb shape: [batch_size, time_steps, spec_dim // 4, channels]
|
515 |
+
emb = emb.flatten(2)
|
516 |
+
# emb shape: [batch_size, time_steps, spec_dim // 4 * channels]
|
517 |
+
|
518 |
+
emb = self.combine(emb, cemb)
|
519 |
+
# if concat; emb shape: [batch_size, time_steps, spec_dim // 4 * channels * 2]
|
520 |
+
# if add; emb shape: [batch_size, time_steps, spec_dim // 4 * channels]
|
521 |
+
|
522 |
+
emb, h = self.emb_gru.forward(emb, hidden_state)
|
523 |
+
# emb shape: [batch_size, time_steps, spec_dim // 4 * channels]
|
524 |
+
# h shape: [batch_size, 1, spec_dim]
|
525 |
+
|
526 |
+
lsnr = self.lsnr_fc(emb) * self.lsnr_scale + self.lsnr_offset
|
527 |
+
# lsnr shape: [batch_size, time_steps, 1]
|
528 |
+
|
529 |
+
return e0, e1, e2, e3, emb, c0, lsnr, h
|
530 |
+
|
531 |
+
|
532 |
+
class Decoder(nn.Module):
|
533 |
+
def __init__(self, config: NXDfNetConfig):
|
534 |
+
super(Decoder, self).__init__()
|
535 |
+
|
536 |
+
if config.freq_bins % 8 != 0:
|
537 |
+
raise AssertionError("freq_bins should be divisible by 8")
|
538 |
+
|
539 |
+
self.emb_in_dim = config.conv_channels * config.freq_bins // 4
|
540 |
+
self.emb_out_dim = config.conv_channels * config.freq_bins // 4
|
541 |
+
self.emb_hidden_dim = config.decoder_emb_hidden_size
|
542 |
+
|
543 |
+
self.emb_gru = SqueezedGRU_S(
|
544 |
+
self.emb_in_dim,
|
545 |
+
self.emb_hidden_dim,
|
546 |
+
output_size=self.emb_out_dim,
|
547 |
+
num_layers=config.decoder_emb_num_layers - 1,
|
548 |
+
batch_first=True,
|
549 |
+
skip_op=config.decoder_emb_skip_op,
|
550 |
+
linear_groups=config.decoder_emb_linear_groups,
|
551 |
+
activation_layer="relu",
|
552 |
+
)
|
553 |
+
self.conv3p = CausalConv2d(
|
554 |
+
in_channels=config.conv_channels,
|
555 |
+
out_channels=config.conv_channels,
|
556 |
+
kernel_size=1,
|
557 |
+
bias=False,
|
558 |
+
separable=True,
|
559 |
+
fstride=1,
|
560 |
+
lookahead=config.conv_lookahead,
|
561 |
+
)
|
562 |
+
self.convt3 = CausalConv2d(
|
563 |
+
in_channels=config.conv_channels,
|
564 |
+
out_channels=config.conv_channels,
|
565 |
+
kernel_size=config.conv_kernel_size_inner,
|
566 |
+
bias=False,
|
567 |
+
separable=True,
|
568 |
+
fstride=1,
|
569 |
+
lookahead=config.conv_lookahead,
|
570 |
+
)
|
571 |
+
self.conv2p = CausalConv2d(
|
572 |
+
in_channels=config.conv_channels,
|
573 |
+
out_channels=config.conv_channels,
|
574 |
+
kernel_size=1,
|
575 |
+
bias=False,
|
576 |
+
separable=True,
|
577 |
+
fstride=1,
|
578 |
+
lookahead=config.conv_lookahead,
|
579 |
+
)
|
580 |
+
self.convt2 = CausalConvTranspose2d(
|
581 |
+
in_channels=config.conv_channels,
|
582 |
+
out_channels=config.conv_channels,
|
583 |
+
kernel_size=config.convt_kernel_size_inner,
|
584 |
+
bias=False,
|
585 |
+
separable=True,
|
586 |
+
fstride=2,
|
587 |
+
lookahead=config.conv_lookahead,
|
588 |
+
)
|
589 |
+
self.conv1p = CausalConv2d(
|
590 |
+
in_channels=config.conv_channels,
|
591 |
+
out_channels=config.conv_channels,
|
592 |
+
kernel_size=1,
|
593 |
+
bias=False,
|
594 |
+
separable=True,
|
595 |
+
fstride=1,
|
596 |
+
lookahead=config.conv_lookahead,
|
597 |
+
)
|
598 |
+
self.convt1 = CausalConvTranspose2d(
|
599 |
+
in_channels=config.conv_channels,
|
600 |
+
out_channels=config.conv_channels,
|
601 |
+
kernel_size=config.convt_kernel_size_inner,
|
602 |
+
bias=False,
|
603 |
+
separable=True,
|
604 |
+
fstride=2,
|
605 |
+
lookahead=config.conv_lookahead,
|
606 |
+
)
|
607 |
+
self.conv0p = CausalConv2d(
|
608 |
+
in_channels=config.conv_channels,
|
609 |
+
out_channels=config.conv_channels,
|
610 |
+
kernel_size=1,
|
611 |
+
bias=False,
|
612 |
+
separable=True,
|
613 |
+
fstride=1,
|
614 |
+
lookahead=config.conv_lookahead,
|
615 |
+
)
|
616 |
+
self.conv0_out = CausalConv2d(
|
617 |
+
in_channels=config.conv_channels,
|
618 |
+
out_channels=1,
|
619 |
+
kernel_size=config.conv_kernel_size_inner,
|
620 |
+
activation_layer="sigmoid",
|
621 |
+
bias=False,
|
622 |
+
separable=True,
|
623 |
+
fstride=1,
|
624 |
+
lookahead=config.conv_lookahead,
|
625 |
+
)
|
626 |
+
|
627 |
+
def forward(self, emb, e3, e2, e1, e0) -> torch.Tensor:
|
628 |
+
# Estimates erb mask
|
629 |
+
b, _, t, f8 = e3.shape
|
630 |
+
|
631 |
+
# emb shape: [batch_size, time_steps, (freq_dim // 4) * conv_channels]
|
632 |
+
emb, _ = self.emb_gru(emb)
|
633 |
+
# emb shape: [batch_size, conv_channels, time_steps, freq_dim // 4]
|
634 |
+
emb = emb.view(b, t, f8, -1).permute(0, 3, 1, 2)
|
635 |
+
e3 = self.convt3(self.conv3p(e3) + emb)
|
636 |
+
# e3 shape: [batch_size, conv_channels, time_steps, freq_dim // 4]
|
637 |
+
e2 = self.convt2(self.conv2p(e2) + e3)
|
638 |
+
# e2 shape: [batch_size, conv_channels, time_steps, freq_dim // 2]
|
639 |
+
e1 = self.convt1(self.conv1p(e1) + e2)
|
640 |
+
# e1 shape: [batch_size, conv_channels, time_steps, freq_dim]
|
641 |
+
mask = self.conv0_out(self.conv0p(e0) + e1)
|
642 |
+
# mask shape: [batch_size, 1, time_steps, freq_dim]
|
643 |
+
return mask
|
644 |
+
|
645 |
+
|
646 |
+
class DfDecoder(nn.Module):
|
647 |
+
def __init__(self, config: NXDfNetConfig):
|
648 |
+
super(DfDecoder, self).__init__()
|
649 |
+
|
650 |
+
self.embedding_input_size = config.conv_channels * config.freq_bins // 4
|
651 |
+
self.df_decoder_hidden_size = config.df_decoder_hidden_size
|
652 |
+
self.df_num_layers = config.df_num_layers
|
653 |
+
|
654 |
+
self.df_order = config.df_order
|
655 |
+
|
656 |
+
self.df_bins = config.df_bins
|
657 |
+
self.df_out_ch = config.df_order * 2
|
658 |
+
|
659 |
+
self.df_convp = CausalConv2d(
|
660 |
+
config.conv_channels,
|
661 |
+
self.df_out_ch,
|
662 |
+
fstride=1,
|
663 |
+
kernel_size=(config.df_pathway_kernel_size_t, 1),
|
664 |
+
separable=True,
|
665 |
+
bias=False,
|
666 |
+
)
|
667 |
+
self.df_gru = SqueezedGRU_S(
|
668 |
+
self.embedding_input_size,
|
669 |
+
self.df_decoder_hidden_size,
|
670 |
+
num_layers=self.df_num_layers,
|
671 |
+
batch_first=True,
|
672 |
+
skip_op="none",
|
673 |
+
activation_layer="relu",
|
674 |
+
)
|
675 |
+
|
676 |
+
if config.df_gru_skip == "none":
|
677 |
+
self.df_skip = None
|
678 |
+
elif config.df_gru_skip == "identity":
|
679 |
+
if config.embedding_hidden_size != config.df_decoder_hidden_size:
|
680 |
+
raise AssertionError("Dimensions do not match")
|
681 |
+
self.df_skip = nn.Identity()
|
682 |
+
elif config.df_gru_skip == "grouped_linear":
|
683 |
+
self.df_skip = GroupedLinear(
|
684 |
+
self.embedding_input_size,
|
685 |
+
self.df_decoder_hidden_size,
|
686 |
+
groups=config.df_decoder_linear_groups
|
687 |
+
)
|
688 |
+
else:
|
689 |
+
raise NotImplementedError()
|
690 |
+
|
691 |
+
self.df_out: nn.Module
|
692 |
+
out_dim = self.df_bins * self.df_out_ch
|
693 |
+
|
694 |
+
self.df_out = nn.Sequential(
|
695 |
+
GroupedLinear(
|
696 |
+
input_size=self.df_decoder_hidden_size,
|
697 |
+
hidden_size=out_dim,
|
698 |
+
groups=config.df_decoder_linear_groups
|
699 |
+
),
|
700 |
+
nn.Tanh()
|
701 |
+
)
|
702 |
+
self.df_fc_a = nn.Sequential(
|
703 |
+
nn.Linear(self.df_decoder_hidden_size, 1),
|
704 |
+
nn.Sigmoid()
|
705 |
+
)
|
706 |
+
|
707 |
+
def forward(self, emb: torch.Tensor, c0: torch.Tensor) -> torch.Tensor:
|
708 |
+
# emb shape: [batch_size, time_steps, df_bins // 4 * channels]
|
709 |
+
b, t, _ = emb.shape
|
710 |
+
df_coefs, _ = self.df_gru(emb)
|
711 |
+
if self.df_skip is not None:
|
712 |
+
df_coefs = df_coefs + self.df_skip(emb)
|
713 |
+
# df_coefs shape: [batch_size, time_steps, df_decoder_hidden_size]
|
714 |
+
|
715 |
+
# c0 shape: [batch_size, channels, time_steps, df_bins]
|
716 |
+
c0 = self.df_convp(c0)
|
717 |
+
# c0 shape: [batch_size, df_order * 2, time_steps, df_bins]
|
718 |
+
c0 = c0.permute(0, 2, 3, 1)
|
719 |
+
# c0 shape: [batch_size, time_steps, df_bins, df_order * 2]
|
720 |
+
|
721 |
+
df_coefs = self.df_out(df_coefs) # [B, T, F*O*2], O: df_order
|
722 |
+
# df_coefs shape: [batch_size, time_steps, df_bins * df_order * 2]
|
723 |
+
df_coefs = df_coefs.view(b, t, self.df_bins, self.df_out_ch)
|
724 |
+
# df_coefs shape: [batch_size, time_steps, df_bins, df_order * 2]
|
725 |
+
df_coefs = df_coefs + c0
|
726 |
+
# df_coefs shape: [batch_size, time_steps, df_bins, df_order * 2]
|
727 |
+
return df_coefs
|
728 |
+
|
729 |
+
|
730 |
+
class DfOutputReshapeMF(nn.Module):
|
731 |
+
"""Coefficients output reshape for multiframe/MultiFrameModule
|
732 |
+
|
733 |
+
Requires input of shape B, C, T, F, 2.
|
734 |
+
"""
|
735 |
+
|
736 |
+
def __init__(self, df_order: int, df_bins: int):
|
737 |
+
super().__init__()
|
738 |
+
self.df_order = df_order
|
739 |
+
self.df_bins = df_bins
|
740 |
+
|
741 |
+
def forward(self, coefs: torch.Tensor) -> torch.Tensor:
|
742 |
+
# [B, T, F, O*2] -> [B, O, T, F, 2]
|
743 |
+
new_shape = list(coefs.shape)
|
744 |
+
new_shape[-1] = -1
|
745 |
+
new_shape.append(2)
|
746 |
+
coefs = coefs.view(new_shape)
|
747 |
+
coefs = coefs.permute(0, 3, 1, 2, 4)
|
748 |
+
return coefs
|
749 |
+
|
750 |
+
|
751 |
+
class Mask(nn.Module):
|
752 |
+
def __init__(self, use_post_filter: bool = False, eps: float = 1e-12):
|
753 |
+
super().__init__()
|
754 |
+
self.use_post_filter = use_post_filter
|
755 |
+
self.eps = eps
|
756 |
+
|
757 |
+
def post_filter(self, mask: torch.Tensor, beta: float = 0.02) -> torch.Tensor:
|
758 |
+
"""
|
759 |
+
Post-Filter
|
760 |
+
|
761 |
+
A Perceptually-Motivated Approach for Low-Complexity, Real-Time Enhancement of Fullband Speech.
|
762 |
+
https://arxiv.org/abs/2008.04259
|
763 |
+
|
764 |
+
:param mask: Real valued mask, typically of shape [B, C, T, F].
|
765 |
+
:param beta: Global gain factor.
|
766 |
+
:return:
|
767 |
+
"""
|
768 |
+
mask_sin = mask * torch.sin(np.pi * mask / 2)
|
769 |
+
mask_pf = (1 + beta) * mask / (1 + beta * mask.div(mask_sin.clamp_min(self.eps)).pow(2))
|
770 |
+
return mask_pf
|
771 |
+
|
772 |
+
def forward(self, spec: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
|
773 |
+
# spec shape: [batch_size, 1, time_steps, freq_bins, 2]
|
774 |
+
|
775 |
+
if not self.training and self.use_post_filter:
|
776 |
+
mask = self.post_filter(mask)
|
777 |
+
|
778 |
+
# mask shape: [batch_size, 1, time_steps, freq_bins]
|
779 |
+
mask = mask.unsqueeze(4)
|
780 |
+
# mask shape: [batch_size, 1, time_steps, freq_bins, 1]
|
781 |
+
return spec * mask
|
782 |
+
|
783 |
+
|
784 |
+
class DeepFiltering(nn.Module):
|
785 |
+
def __init__(self,
|
786 |
+
df_bins: int,
|
787 |
+
df_order: int,
|
788 |
+
lookahead: int = 0,
|
789 |
+
):
|
790 |
+
super(DeepFiltering, self).__init__()
|
791 |
+
self.df_bins = df_bins
|
792 |
+
self.df_order = df_order
|
793 |
+
self.need_unfold = df_order > 1
|
794 |
+
self.lookahead = lookahead
|
795 |
+
|
796 |
+
self.pad = nn.ConstantPad2d((0, 0, df_order - 1 - lookahead, lookahead), 0.0)
|
797 |
+
|
798 |
+
def spec_unfold(self, spec: torch.Tensor):
|
799 |
+
"""
|
800 |
+
Pads and unfolds the spectrogram according to frame_size.
|
801 |
+
:param spec: complex Tensor, Spectrogram of shape [B, C, T, F].
|
802 |
+
:return: Tensor, Unfolded spectrogram of shape [B, C, T, F, N], where N: frame_size.
|
803 |
+
"""
|
804 |
+
if self.need_unfold:
|
805 |
+
# spec shape: [batch_size, freq_bins, time_steps]
|
806 |
+
spec_pad = self.pad(spec)
|
807 |
+
# spec_pad shape: [batch_size, 1, time_steps_pad, freq_bins]
|
808 |
+
spec_unfold = spec_pad.unfold(2, self.df_order, 1)
|
809 |
+
# spec_unfold shape: [batch_size, 1, time_steps, freq_bins, df_order]
|
810 |
+
return spec_unfold
|
811 |
+
else:
|
812 |
+
return spec.unsqueeze(-1)
|
813 |
+
|
814 |
+
def forward(self,
|
815 |
+
spec: torch.Tensor,
|
816 |
+
coefs: torch.Tensor,
|
817 |
+
):
|
818 |
+
# spec shape: [batch_size, 1, time_steps, freq_bins, 2]
|
819 |
+
spec = spec.contiguous()
|
820 |
+
spec_u = self.spec_unfold(torch.view_as_complex(spec))
|
821 |
+
# spec_u shape: [batch_size, 1, time_steps, freq_bins, df_order]
|
822 |
+
|
823 |
+
# coefs shape: [batch_size, df_order, time_steps, df_bins, 2]
|
824 |
+
coefs = torch.view_as_complex(coefs)
|
825 |
+
# coefs shape: [batch_size, df_order, time_steps, df_bins]
|
826 |
+
spec_f = spec_u.narrow(-2, 0, self.df_bins)
|
827 |
+
# spec_f shape: [batch_size, 1, time_steps, df_bins, df_order]
|
828 |
+
|
829 |
+
coefs = coefs.view(coefs.shape[0], -1, self.df_order, *coefs.shape[2:])
|
830 |
+
# coefs shape: [batch_size, 1, df_order, time_steps, df_bins]
|
831 |
+
|
832 |
+
spec_f = self.df(spec_f, coefs)
|
833 |
+
# spec_f shape: [batch_size, 1, time_steps, df_bins]
|
834 |
+
|
835 |
+
if self.training:
|
836 |
+
spec = spec.clone()
|
837 |
+
spec[..., :self.df_bins, :] = torch.view_as_real(spec_f)
|
838 |
+
# spec shape: [batch_size, 1, time_steps, freq_bins, 2]
|
839 |
+
return spec
|
840 |
+
|
841 |
+
@staticmethod
|
842 |
+
def df(spec: torch.Tensor, coefs: torch.Tensor) -> torch.Tensor:
|
843 |
+
"""
|
844 |
+
Deep filter implementation using `torch.einsum`. Requires unfolded spectrogram.
|
845 |
+
:param spec: (complex Tensor). Spectrogram of shape [B, C, T, F, N].
|
846 |
+
:param coefs: (complex Tensor). Coefficients of shape [B, C, N, T, F].
|
847 |
+
:return: (complex Tensor). Spectrogram of shape [B, C, T, F].
|
848 |
+
"""
|
849 |
+
return torch.einsum("...tfn,...ntf->...tf", spec, coefs)
|
850 |
+
|
851 |
+
|
852 |
+
class NXDfNet(nn.Module):
|
853 |
+
def __init__(self, config: NXDfNetConfig):
|
854 |
+
super(NXDfNet, self).__init__()
|
855 |
+
self.config = config
|
856 |
+
|
857 |
+
self.stft = DeepSTFT(win_size=config.win_size, freq_bins=config.freq_bins)
|
858 |
+
self.istft = DeepISTFT(win_size=config.win_size, freq_bins=config.freq_bins)
|
859 |
+
|
860 |
+
self.encoder = Encoder(config)
|
861 |
+
self.decoder = Decoder(config)
|
862 |
+
|
863 |
+
self.df_decoder = DfDecoder(config)
|
864 |
+
self.df_out_transform = DfOutputReshapeMF(config.df_order, config.df_bins)
|
865 |
+
self.df_op = DeepFiltering(
|
866 |
+
df_bins=config.df_bins,
|
867 |
+
df_order=config.df_order,
|
868 |
+
lookahead=config.df_lookahead,
|
869 |
+
)
|
870 |
+
|
871 |
+
self.mask = Mask(use_post_filter=config.use_post_filter)
|
872 |
+
|
873 |
+
def forward(self,
|
874 |
+
noisy: torch.Tensor,
|
875 |
+
):
|
876 |
+
"""
|
877 |
+
:param noisy: Tensor, shape: [batch_size, num_samples]
|
878 |
+
:return:
|
879 |
+
"""
|
880 |
+
spec = self.stft.forward(noisy)
|
881 |
+
# spec shape: [batch_size, freq_bins, time_steps, 2]
|
882 |
+
power_spec = torch.sum(torch.square(spec), dim=-1)
|
883 |
+
power_spec = power_spec.unsqueeze(1).permute(0, 1, 3, 2)
|
884 |
+
# power_spec shape: [batch_size, freq_bins, time_steps]
|
885 |
+
# power_spec shape: [batch_size, 1, freq_bins, time_steps]
|
886 |
+
# power_spec shape: [batch_size, 1, time_steps, freq_bins]
|
887 |
+
|
888 |
+
df_spec = spec.permute(0, 3, 2, 1)
|
889 |
+
# df_spec shape: [batch_size, 2, time_steps, freq_bins]
|
890 |
+
df_spec = df_spec[..., :self.df_decoder.df_bins]
|
891 |
+
# df_spec shape: [batch_size, 2, time_steps, df_bins]
|
892 |
+
|
893 |
+
# spec shape: [batch_size, freq_bins, time_steps, 2]
|
894 |
+
spec = torch.transpose(spec, dim0=1, dim1=2)
|
895 |
+
# spec shape: [batch_size, time_steps, freq_bins, 2]
|
896 |
+
spec = torch.unsqueeze(spec, dim=1)
|
897 |
+
# spec shape: [batch_size, 1, time_steps, freq_bins, 2]
|
898 |
+
|
899 |
+
e0, e1, e2, e3, emb, c0, _, h = self.encoder.forward(power_spec, df_spec)
|
900 |
+
|
901 |
+
mask = self.decoder.forward(emb, e3, e2, e1, e0)
|
902 |
+
# mask shape: [batch_size, 1, time_steps, freq_bins]
|
903 |
+
if torch.any(mask > 1) or torch.any(mask < 0):
|
904 |
+
raise AssertionError
|
905 |
+
|
906 |
+
spec_m = self.mask.forward(spec, mask)
|
907 |
+
|
908 |
+
# lsnr shape: [batch_size, time_steps, 1]
|
909 |
+
# lsnr = torch.transpose(lsnr, dim0=2, dim1=1)
|
910 |
+
# lsnr shape: [batch_size, 1, time_steps]
|
911 |
+
|
912 |
+
df_coefs = self.df_decoder.forward(emb, c0)
|
913 |
+
df_coefs = self.df_out_transform(df_coefs)
|
914 |
+
# df_coefs shape: [batch_size, df_order, time_steps, df_bins, 2]
|
915 |
+
|
916 |
+
spec_e = self.df_op.forward(spec.clone(), df_coefs)
|
917 |
+
# spec_e shape: [batch_size, 1, time_steps, freq_bins, 2]
|
918 |
+
|
919 |
+
spec_e[..., self.df_decoder.df_bins:, :] = spec_m[..., self.df_decoder.df_bins:, :]
|
920 |
+
|
921 |
+
spec_e = torch.squeeze(spec_e, dim=1)
|
922 |
+
spec_e = spec_e.permute(0, 2, 1, 3)
|
923 |
+
# spec_e shape: [batch_size, freq_bins, time_steps, 2]
|
924 |
+
|
925 |
+
denoise = self.istft.forward(spec_e)
|
926 |
+
# spec_e shape: [batch_size, freq_bins, time_steps, 2]
|
927 |
+
return denoise
|
928 |
+
|
929 |
+
|
930 |
+
class NXDfNetPretrainedModel(NXDfNet):
|
931 |
+
def __init__(self,
|
932 |
+
config: NXDfNetConfig,
|
933 |
+
):
|
934 |
+
super(NXDfNetPretrainedModel, self).__init__(
|
935 |
+
config=config,
|
936 |
+
)
|
937 |
+
|
938 |
+
@classmethod
|
939 |
+
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
|
940 |
+
config = NXDfNetConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
941 |
+
|
942 |
+
model = cls(config)
|
943 |
+
|
944 |
+
if os.path.isdir(pretrained_model_name_or_path):
|
945 |
+
ckpt_file = os.path.join(pretrained_model_name_or_path, MODEL_FILE)
|
946 |
+
else:
|
947 |
+
ckpt_file = pretrained_model_name_or_path
|
948 |
+
|
949 |
+
with open(ckpt_file, "rb") as f:
|
950 |
+
state_dict = torch.load(f, map_location="cpu", weights_only=True)
|
951 |
+
model.load_state_dict(state_dict, strict=True)
|
952 |
+
return model
|
953 |
+
|
954 |
+
def save_pretrained(self,
|
955 |
+
save_directory: Union[str, os.PathLike],
|
956 |
+
state_dict: Optional[dict] = None,
|
957 |
+
):
|
958 |
+
|
959 |
+
model = self
|
960 |
+
|
961 |
+
if state_dict is None:
|
962 |
+
state_dict = model.state_dict()
|
963 |
+
|
964 |
+
os.makedirs(save_directory, exist_ok=True)
|
965 |
+
|
966 |
+
# save state dict
|
967 |
+
model_file = os.path.join(save_directory, MODEL_FILE)
|
968 |
+
torch.save(state_dict, model_file)
|
969 |
+
|
970 |
+
# save config
|
971 |
+
config_file = os.path.join(save_directory, CONFIG_FILE)
|
972 |
+
self.config.to_yaml_file(config_file)
|
973 |
+
return save_directory
|
974 |
+
|
975 |
+
|
976 |
+
def main():
|
977 |
+
|
978 |
+
config = NXDfNetConfig()
|
979 |
+
model = NXDfNet(config=config)
|
980 |
+
|
981 |
+
inputs = torch.randn(size=(1, 16000), dtype=torch.float32)
|
982 |
+
|
983 |
+
denoise = model.forward(inputs)
|
984 |
+
print(denoise.shape)
|
985 |
+
return
|
986 |
+
|
987 |
+
|
988 |
+
if __name__ == "__main__":
|
989 |
+
main()
|
toolbox/torchaudio/models/nx_dfnet/utils.py
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
"""
|
4 |
+
https://github.com/kaituoxu/Conv-TasNet/blob/master/src/utils.py
|
5 |
+
"""
|
6 |
+
import math
|
7 |
+
import torch
|
8 |
+
|
9 |
+
|
10 |
+
def overlap_and_add(signal: torch.Tensor, frame_step: int):
|
11 |
+
"""
|
12 |
+
Reconstructs a signal from a framed representation.
|
13 |
+
|
14 |
+
Adds potentially overlapping frames of a signal with shape
|
15 |
+
`[..., frames, frame_length]`, offsetting subsequent frames by `frame_step`.
|
16 |
+
The resulting tensor has shape `[..., output_size]` where
|
17 |
+
|
18 |
+
output_size = (frames - 1) * frame_step + frame_length
|
19 |
+
|
20 |
+
Based on https://github.com/tensorflow/tensorflow/blob/r1.12/tensorflow/contrib/signal/python/ops/reconstruction_ops.py
|
21 |
+
|
22 |
+
:param signal: Tensor, shape: [..., frames, frame_length]. All dimensions may be unknown, and rank must be at least 2.
|
23 |
+
:param frame_step: int, overlap offsets. Must be less than or equal to frame_length.
|
24 |
+
:return: Tensor, shape: [..., output_size].
|
25 |
+
containing the overlap-added frames of signal's inner-most two dimensions.
|
26 |
+
output_size = (frames - 1) * frame_step + frame_length
|
27 |
+
"""
|
28 |
+
outer_dimensions = signal.size()[:-2]
|
29 |
+
frames, frame_length = signal.size()[-2:]
|
30 |
+
|
31 |
+
subframe_length = math.gcd(frame_length, frame_step) # gcd=Greatest Common Divisor
|
32 |
+
subframe_step = frame_step // subframe_length
|
33 |
+
subframes_per_frame = frame_length // subframe_length
|
34 |
+
|
35 |
+
output_size = frame_step * (frames - 1) + frame_length
|
36 |
+
output_subframes = output_size // subframe_length
|
37 |
+
|
38 |
+
subframe_signal = signal.view(*outer_dimensions, -1, subframe_length)
|
39 |
+
|
40 |
+
frame = torch.arange(0, output_subframes).unfold(0, subframes_per_frame, subframe_step)
|
41 |
+
|
42 |
+
frame = frame.clone().detach()
|
43 |
+
frame = frame.to(signal.device)
|
44 |
+
frame = frame.long()
|
45 |
+
|
46 |
+
frame = frame.contiguous().view(-1)
|
47 |
+
|
48 |
+
result = signal.new_zeros(*outer_dimensions, output_subframes, subframe_length)
|
49 |
+
result.index_add_(-2, frame, subframe_signal)
|
50 |
+
result = result.view(*outer_dimensions, -1)
|
51 |
+
return result
|
52 |
+
|
53 |
+
|
54 |
+
if __name__ == "__main__":
|
55 |
+
pass
|