HoneyTian commited on
Commit
e86d760
·
1 Parent(s): b1eb75a
Files changed (31) hide show
  1. examples/conv_tasnet/run.sh +170 -0
  2. examples/conv_tasnet/step_1_prepare_data.py +201 -0
  3. examples/conv_tasnet/step_2_train_model.py +413 -0
  4. examples/conv_tasnet/yaml/config.yaml +42 -0
  5. examples/data_preprocess/dns_challenge_to_8k/process_dns_challenge_clean_emotional_speech.py +90 -0
  6. examples/data_preprocess/dns_challenge_to_8k/process_dns_challenge_clean_read_speech.py +123 -0
  7. examples/data_preprocess/dns_challenge_to_8k/process_dns_challenge_demand.py +71 -0
  8. examples/data_preprocess/dns_challenge_to_8k/process_dns_challenge_impulse_responses.py +93 -0
  9. examples/data_preprocess/dns_challenge_to_8k/process_dns_challenge_noise.py +77 -0
  10. examples/data_preprocess/dns_challenge_to_8k/process_musan.py +8 -0
  11. examples/mpnet/run.sh +2 -2
  12. examples/nx_mpnet/yaml/config.yaml +5 -5
  13. main.py +8 -1
  14. requirements.txt +1 -0
  15. toolbox/torchaudio/losses/__init__.py +6 -0
  16. toolbox/torchaudio/losses/perceptual.py +75 -0
  17. toolbox/torchaudio/losses/snr.py +101 -0
  18. toolbox/torchaudio/losses/spectral.py +351 -0
  19. toolbox/torchaudio/metrics/__init__.py +6 -0
  20. toolbox/torchaudio/metrics/pesq.py +80 -0
  21. toolbox/torchaudio/models/conv_tasnet/configuration_conv_tasnet.py +52 -0
  22. toolbox/torchaudio/models/conv_tasnet/modeling_conv_tasnet.py +477 -2
  23. toolbox/torchaudio/models/conv_tasnet/utils.py +55 -0
  24. toolbox/torchaudio/models/conv_tasnet/yaml/config.yaml +17 -0
  25. toolbox/torchaudio/models/demucs/__init__.py +6 -0
  26. toolbox/torchaudio/models/demucs/configuration_demucs.py +51 -0
  27. toolbox/torchaudio/models/demucs/modeling_demucs.py +299 -0
  28. toolbox/torchaudio/models/demucs/resample.py +81 -0
  29. toolbox/torchaudio/models/nx_dfnet/configuration_nx_dfnet.py +102 -0
  30. toolbox/torchaudio/models/nx_dfnet/modeling_nx_dfnet.py +989 -0
  31. toolbox/torchaudio/models/nx_dfnet/utils.py +55 -0
examples/conv_tasnet/run.sh ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+
3
+ : <<'END'
4
+
5
+
6
+ sh run.sh --stage 2 --stop_stage 2 --system_version windows --file_folder_name file_dir --final_model_name mpnet-aishell-20250224 \
7
+ --noise_dir "E:/Users/tianx/HuggingDatasets/nx_noise/data/noise" \
8
+ --speech_dir "E:/programmer/asr_datasets/aishell/data_aishell/wav/train"
9
+
10
+
11
+ sh run.sh --stage 3 --stop_stage 3 --system_version centos --file_folder_name file_dir --final_model_name mpnet-aishell-20250224 \
12
+ --noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise" \
13
+ --speech_dir "/data/tianxing/HuggingDatasets/aishell/data_aishell/wav/train"
14
+
15
+ sh run.sh --stage 2 --stop_stage 2 --system_version centos --file_folder_name file_dir --final_model_name nx-clean-unet-aishell-20250228 \
16
+ --noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise" \
17
+ --speech_dir "/data/tianxing/HuggingDatasets/aishell/data_aishell/wav/train" \
18
+ --max_epochs 100
19
+
20
+
21
+ sh run.sh --stage 1 --stop_stage 2 --system_version centos --file_folder_name file_dir --final_model_name mpnet-nx-speech-20250224 \
22
+ --noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise" \
23
+ --speech_dir "/data/tianxing/HuggingDatasets/nx_noise/data/speech" \
24
+ --max_epochs 100 --max_count 10000
25
+
26
+
27
+ END
28
+
29
+
30
+ # params
31
+ system_version="windows";
32
+ verbose=true;
33
+ stage=0 # start from 0 if you need to start from data preparation
34
+ stop_stage=9
35
+
36
+ work_dir="$(pwd)"
37
+ file_folder_name=file_folder_name
38
+ final_model_name=final_model_name
39
+ config_file="yaml/config.yaml"
40
+ limit=10
41
+
42
+ noise_dir=/data/tianxing/HuggingDatasets/nx_noise/data/noise
43
+ speech_dir=/data/tianxing/HuggingDatasets/aishell/data_aishell/wav/train
44
+
45
+ max_count=10000000
46
+
47
+ nohup_name=nohup.out
48
+
49
+ # model params
50
+ batch_size=64
51
+ max_epochs=200
52
+ save_top_k=10
53
+ patience=5
54
+
55
+
56
+ # parse options
57
+ while true; do
58
+ [ -z "${1:-}" ] && break; # break if there are no arguments
59
+ case "$1" in
60
+ --*) name=$(echo "$1" | sed s/^--// | sed s/-/_/g);
61
+ eval '[ -z "${'"$name"'+xxx}" ]' && echo "$0: invalid option $1" 1>&2 && exit 1;
62
+ old_value="(eval echo \\$$name)";
63
+ if [ "${old_value}" == "true" ] || [ "${old_value}" == "false" ]; then
64
+ was_bool=true;
65
+ else
66
+ was_bool=false;
67
+ fi
68
+
69
+ # Set the variable to the right value-- the escaped quotes make it work if
70
+ # the option had spaces, like --cmd "queue.pl -sync y"
71
+ eval "${name}=\"$2\"";
72
+
73
+ # Check that Boolean-valued arguments are really Boolean.
74
+ if $was_bool && [[ "$2" != "true" && "$2" != "false" ]]; then
75
+ echo "$0: expected \"true\" or \"false\": $1 $2" 1>&2
76
+ exit 1;
77
+ fi
78
+ shift 2;
79
+ ;;
80
+
81
+ *) break;
82
+ esac
83
+ done
84
+
85
+ file_dir="${work_dir}/${file_folder_name}"
86
+ final_model_dir="${work_dir}/../../trained_models/${final_model_name}";
87
+ evaluation_audio_dir="${file_dir}/evaluation_audio"
88
+
89
+ dataset="${file_dir}/dataset.xlsx"
90
+ train_dataset="${file_dir}/train.xlsx"
91
+ valid_dataset="${file_dir}/valid.xlsx"
92
+
93
+ $verbose && echo "system_version: ${system_version}"
94
+ $verbose && echo "file_folder_name: ${file_folder_name}"
95
+
96
+ if [ $system_version == "windows" ]; then
97
+ alias python3='D:/Users/tianx/PycharmProjects/virtualenv/nx_denoise/Scripts/python.exe'
98
+ elif [ $system_version == "centos" ] || [ $system_version == "ubuntu" ]; then
99
+ #source /data/local/bin/nx_denoise/bin/activate
100
+ alias python3='/data/local/bin/nx_denoise/bin/python3'
101
+ fi
102
+
103
+
104
+ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
105
+ $verbose && echo "stage 1: prepare data"
106
+ cd "${work_dir}" || exit 1
107
+ python3 step_1_prepare_data.py \
108
+ --file_dir "${file_dir}" \
109
+ --noise_dir "${noise_dir}" \
110
+ --speech_dir "${speech_dir}" \
111
+ --train_dataset "${train_dataset}" \
112
+ --valid_dataset "${valid_dataset}" \
113
+ --max_count "${max_count}" \
114
+
115
+ fi
116
+
117
+
118
+ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
119
+ $verbose && echo "stage 2: train model"
120
+ cd "${work_dir}" || exit 1
121
+ python3 step_2_train_model.py \
122
+ --train_dataset "${train_dataset}" \
123
+ --valid_dataset "${valid_dataset}" \
124
+ --serialization_dir "${file_dir}" \
125
+ --config_file "${config_file}" \
126
+
127
+ fi
128
+
129
+
130
+ if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
131
+ $verbose && echo "stage 3: test model"
132
+ cd "${work_dir}" || exit 1
133
+ python3 step_3_evaluation.py \
134
+ --valid_dataset "${valid_dataset}" \
135
+ --model_dir "${file_dir}/best" \
136
+ --evaluation_audio_dir "${evaluation_audio_dir}" \
137
+ --limit "${limit}" \
138
+
139
+ fi
140
+
141
+
142
+ if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
143
+ $verbose && echo "stage 4: collect files"
144
+ cd "${work_dir}" || exit 1
145
+
146
+ mkdir -p ${final_model_dir}
147
+
148
+ cp "${file_dir}/best"/* "${final_model_dir}"
149
+ cp -r "${file_dir}/evaluation_audio" "${final_model_dir}"
150
+
151
+ cd "${final_model_dir}/.." || exit 1;
152
+
153
+ if [ -e "${final_model_name}.zip" ]; then
154
+ rm -rf "${final_model_name}_backup.zip"
155
+ mv "${final_model_name}.zip" "${final_model_name}_backup.zip"
156
+ fi
157
+
158
+ zip -r "${final_model_name}.zip" "${final_model_name}"
159
+ rm -rf "${final_model_name}"
160
+
161
+ fi
162
+
163
+
164
+ if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
165
+ $verbose && echo "stage 5: clear file_dir"
166
+ cd "${work_dir}" || exit 1
167
+
168
+ rm -rf "${file_dir}";
169
+
170
+ fi
examples/conv_tasnet/step_1_prepare_data.py ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ import argparse
4
+ import os
5
+ from pathlib import Path
6
+ import random
7
+ import sys
8
+ import shutil
9
+
10
+ pwd = os.path.abspath(os.path.dirname(__file__))
11
+ sys.path.append(os.path.join(pwd, "../../"))
12
+
13
+ import pandas as pd
14
+ from scipy.io import wavfile
15
+ from tqdm import tqdm
16
+ import librosa
17
+
18
+ from project_settings import project_path
19
+
20
+
21
+ def get_args():
22
+ parser = argparse.ArgumentParser()
23
+ parser.add_argument("--file_dir", default="./", type=str)
24
+
25
+ parser.add_argument(
26
+ "--noise_dir",
27
+ default=r"E:\Users\tianx\HuggingDatasets\nx_noise\data\noise",
28
+ type=str
29
+ )
30
+ parser.add_argument(
31
+ "--speech_dir",
32
+ default=r"E:\programmer\asr_datasets\aishell\data_aishell\wav\train",
33
+ type=str
34
+ )
35
+
36
+ parser.add_argument("--train_dataset", default="train.xlsx", type=str)
37
+ parser.add_argument("--valid_dataset", default="valid.xlsx", type=str)
38
+
39
+ parser.add_argument("--duration", default=2.0, type=float)
40
+ parser.add_argument("--min_snr_db", default=-10, type=float)
41
+ parser.add_argument("--max_snr_db", default=20, type=float)
42
+
43
+ parser.add_argument("--target_sample_rate", default=8000, type=int)
44
+
45
+ parser.add_argument("--max_count", default=10000, type=int)
46
+
47
+ args = parser.parse_args()
48
+ return args
49
+
50
+
51
+ def filename_generator(data_dir: str):
52
+ data_dir = Path(data_dir)
53
+ for filename in data_dir.glob("**/*.wav"):
54
+ yield filename.as_posix()
55
+
56
+
57
+ def target_second_signal_generator(data_dir: str, duration: int = 2, sample_rate: int = 8000):
58
+ data_dir = Path(data_dir)
59
+ for filename in data_dir.glob("**/*.wav"):
60
+ signal, _ = librosa.load(filename.as_posix(), sr=sample_rate)
61
+ raw_duration = librosa.get_duration(y=signal, sr=sample_rate)
62
+
63
+ if raw_duration < duration:
64
+ # print(f"duration less than {duration} s. skip filename: {filename.as_posix()}")
65
+ continue
66
+ if signal.ndim != 1:
67
+ raise AssertionError(f"expected ndim 1, instead of {signal.ndim}")
68
+
69
+ signal_length = len(signal)
70
+ win_size = int(duration * sample_rate)
71
+ for begin in range(0, signal_length - win_size, win_size):
72
+ row = {
73
+ "filename": filename.as_posix(),
74
+ "raw_duration": round(raw_duration, 4),
75
+ "offset": round(begin / sample_rate, 4),
76
+ "duration": round(duration, 4),
77
+ }
78
+ yield row
79
+
80
+
81
+ def get_dataset(args):
82
+ file_dir = Path(args.file_dir)
83
+ file_dir.mkdir(exist_ok=True)
84
+
85
+ noise_dir = Path(args.noise_dir)
86
+ speech_dir = Path(args.speech_dir)
87
+
88
+ noise_generator = target_second_signal_generator(
89
+ noise_dir.as_posix(),
90
+ duration=args.duration,
91
+ sample_rate=args.target_sample_rate
92
+ )
93
+ speech_generator = target_second_signal_generator(
94
+ speech_dir.as_posix(),
95
+ duration=args.duration,
96
+ sample_rate=args.target_sample_rate
97
+ )
98
+
99
+ dataset = list()
100
+
101
+ count = 0
102
+ process_bar = tqdm(desc="build dataset excel")
103
+ for noise, speech in zip(noise_generator, speech_generator):
104
+ if count >= args.max_count:
105
+ break
106
+
107
+ noise_filename = noise["filename"]
108
+ noise_raw_duration = noise["raw_duration"]
109
+ noise_offset = noise["offset"]
110
+ noise_duration = noise["duration"]
111
+
112
+ speech_filename = speech["filename"]
113
+ speech_raw_duration = speech["raw_duration"]
114
+ speech_offset = speech["offset"]
115
+ speech_duration = speech["duration"]
116
+
117
+ random1 = random.random()
118
+ random2 = random.random()
119
+
120
+ row = {
121
+ "noise_filename": noise_filename,
122
+ "noise_raw_duration": noise_raw_duration,
123
+ "noise_offset": noise_offset,
124
+ "noise_duration": noise_duration,
125
+
126
+ "speech_filename": speech_filename,
127
+ "speech_raw_duration": speech_raw_duration,
128
+ "speech_offset": speech_offset,
129
+ "speech_duration": speech_duration,
130
+
131
+ "snr_db": random.uniform(args.min_snr_db, args.max_snr_db),
132
+
133
+ "random1": random1,
134
+ "random2": random2,
135
+ "flag": "TRAIN" if random2 < 0.8 else "TEST",
136
+ }
137
+ dataset.append(row)
138
+ count += 1
139
+ duration_seconds = count * args.duration
140
+ duration_hours = duration_seconds / 3600
141
+
142
+ process_bar.update(n=1)
143
+ process_bar.set_postfix({
144
+ # "duration_seconds": round(duration_seconds, 4),
145
+ "duration_hours": round(duration_hours, 4),
146
+
147
+ })
148
+
149
+ dataset = pd.DataFrame(dataset)
150
+ dataset = dataset.sort_values(by=["random1"], ascending=False)
151
+ dataset.to_excel(
152
+ file_dir / "dataset.xlsx",
153
+ index=False,
154
+ )
155
+ return
156
+
157
+
158
+
159
+ def split_dataset(args):
160
+ """分割训练集, 测试集"""
161
+ file_dir = Path(args.file_dir)
162
+ file_dir.mkdir(exist_ok=True)
163
+
164
+ df = pd.read_excel(file_dir / "dataset.xlsx")
165
+
166
+ train = list()
167
+ test = list()
168
+
169
+ for i, row in df.iterrows():
170
+ flag = row["flag"]
171
+ if flag == "TRAIN":
172
+ train.append(row)
173
+ else:
174
+ test.append(row)
175
+
176
+ train = pd.DataFrame(train)
177
+ train.to_excel(
178
+ args.train_dataset,
179
+ index=False,
180
+ # encoding="utf_8_sig"
181
+ )
182
+ test = pd.DataFrame(test)
183
+ test.to_excel(
184
+ args.valid_dataset,
185
+ index=False,
186
+ # encoding="utf_8_sig"
187
+ )
188
+
189
+ return
190
+
191
+
192
+ def main():
193
+ args = get_args()
194
+
195
+ get_dataset(args)
196
+ split_dataset(args)
197
+ return
198
+
199
+
200
+ if __name__ == "__main__":
201
+ main()
examples/conv_tasnet/step_2_train_model.py ADDED
@@ -0,0 +1,413 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ https://github.com/yxlu-0102/MP-SENet/blob/main/train.py
5
+ """
6
+ import argparse
7
+ import json
8
+ import logging
9
+ from logging.handlers import TimedRotatingFileHandler
10
+ import os
11
+ import platform
12
+ from pathlib import Path
13
+ import random
14
+ import sys
15
+ import shutil
16
+ from typing import List
17
+
18
+ pwd = os.path.abspath(os.path.dirname(__file__))
19
+ sys.path.append(os.path.join(pwd, "../../"))
20
+
21
+ import numpy as np
22
+ import torch
23
+ import torch.nn as nn
24
+ from torch.nn import functional as F
25
+ from torch.utils.data.dataloader import DataLoader
26
+ from tqdm import tqdm
27
+
28
+ from toolbox.torch.utils.data.dataset.denoise_excel_dataset import DenoiseExcelDataset
29
+ from toolbox.torchaudio.models.conv_tasnet.configuration_conv_tasnet import ConvTasNetConfig
30
+ from toolbox.torchaudio.models.conv_tasnet.modeling_conv_tasnet import ConvTasNet, ConvTasNetPretrainedModel
31
+ from toolbox.torchaudio.losses.snr import NegativeSISNRLoss
32
+ from toolbox.torchaudio.losses.spectral import LSDLoss
33
+ from toolbox.torchaudio.losses.perceptual import NegSTOILoss
34
+ from toolbox.torchaudio.metrics.pesq import run_pesq_score
35
+
36
+
37
+ def get_args():
38
+ parser = argparse.ArgumentParser()
39
+ parser.add_argument("--train_dataset", default="train.xlsx", type=str)
40
+ parser.add_argument("--valid_dataset", default="valid.xlsx", type=str)
41
+
42
+ parser.add_argument("--max_epochs", default=100, type=int)
43
+
44
+ parser.add_argument("--num_serialized_models_to_keep", default=10, type=int)
45
+ parser.add_argument("--patience", default=5, type=int)
46
+ parser.add_argument("--serialization_dir", default="serialization_dir", type=str)
47
+
48
+ parser.add_argument("--config_file", default="config.yaml", type=str)
49
+
50
+ args = parser.parse_args()
51
+ return args
52
+
53
+
54
+ def logging_config(file_dir: str):
55
+ fmt = "%(asctime)s - %(name)s - %(levelname)s %(filename)s:%(lineno)d > %(message)s"
56
+
57
+ logging.basicConfig(format=fmt,
58
+ datefmt="%m/%d/%Y %H:%M:%S",
59
+ level=logging.INFO)
60
+ file_handler = TimedRotatingFileHandler(
61
+ filename=os.path.join(file_dir, "main.log"),
62
+ encoding="utf-8",
63
+ when="D",
64
+ interval=1,
65
+ backupCount=7
66
+ )
67
+ file_handler.setLevel(logging.INFO)
68
+ file_handler.setFormatter(logging.Formatter(fmt))
69
+ logger = logging.getLogger(__name__)
70
+ logger.addHandler(file_handler)
71
+
72
+ return logger
73
+
74
+
75
+ class CollateFunction(object):
76
+ def __init__(self):
77
+ pass
78
+
79
+ def __call__(self, batch: List[dict]):
80
+ clean_audios = list()
81
+ noisy_audios = list()
82
+
83
+ for sample in batch:
84
+ # noise_wave: torch.Tensor = sample["noise_wave"]
85
+ clean_audio: torch.Tensor = sample["speech_wave"]
86
+ noisy_audio: torch.Tensor = sample["mix_wave"]
87
+ # snr_db: float = sample["snr_db"]
88
+
89
+ clean_audios.append(clean_audio)
90
+ noisy_audios.append(noisy_audio)
91
+
92
+ clean_audios = torch.stack(clean_audios)
93
+ noisy_audios = torch.stack(noisy_audios)
94
+
95
+ # assert
96
+ if torch.any(torch.isnan(clean_audios)) or torch.any(torch.isinf(clean_audios)):
97
+ raise AssertionError("nan or inf in clean_audios")
98
+ if torch.any(torch.isnan(noisy_audios)) or torch.any(torch.isinf(noisy_audios)):
99
+ raise AssertionError("nan or inf in noisy_audios")
100
+ return clean_audios, noisy_audios
101
+
102
+
103
+ collate_fn = CollateFunction()
104
+
105
+
106
+ def main():
107
+ args = get_args()
108
+
109
+ config = ConvTasNetConfig.from_pretrained(
110
+ pretrained_model_name_or_path=args.config_file,
111
+ )
112
+
113
+ serialization_dir = Path(args.serialization_dir)
114
+ serialization_dir.mkdir(parents=True, exist_ok=True)
115
+
116
+ logger = logging_config(serialization_dir)
117
+
118
+ random.seed(config.seed)
119
+ np.random.seed(config.seed)
120
+ torch.manual_seed(config.seed)
121
+ logger.info(f"set seed: {config.seed}")
122
+
123
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
124
+ n_gpu = torch.cuda.device_count()
125
+ logger.info(f"GPU available count: {n_gpu}; device: {device}")
126
+
127
+ # datasets
128
+ train_dataset = DenoiseExcelDataset(
129
+ excel_file=args.train_dataset,
130
+ expected_sample_rate=8000,
131
+ max_wave_value=32768.0,
132
+ )
133
+ valid_dataset = DenoiseExcelDataset(
134
+ excel_file=args.valid_dataset,
135
+ expected_sample_rate=8000,
136
+ max_wave_value=32768.0,
137
+ )
138
+ train_data_loader = DataLoader(
139
+ dataset=train_dataset,
140
+ batch_size=config.batch_size,
141
+ shuffle=True,
142
+ sampler=None,
143
+ # Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能.
144
+ num_workers=0 if platform.system() == "Windows" else os.cpu_count() // 2,
145
+ collate_fn=collate_fn,
146
+ pin_memory=False,
147
+ prefetch_factor=16,
148
+ )
149
+ valid_data_loader = DataLoader(
150
+ dataset=valid_dataset,
151
+ batch_size=config.batch_size,
152
+ shuffle=True,
153
+ sampler=None,
154
+ # Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能.
155
+ num_workers=0 if platform.system() == "Windows" else os.cpu_count() // 2,
156
+ collate_fn=collate_fn,
157
+ pin_memory=False,
158
+ prefetch_factor=16,
159
+ )
160
+
161
+ # models
162
+ logger.info(f"prepare models. config_file: {args.config_file}")
163
+ model = ConvTasNetPretrainedModel(config).to(device)
164
+ model.to(device)
165
+ model.train()
166
+
167
+ # optimizer
168
+ logger.info("prepare optimizer, lr_scheduler, loss_fn, categorical_accuracy")
169
+ optimizer = torch.optim.AdamW(model.parameters(), args.learning_rate)
170
+
171
+ # resume training
172
+ last_epoch = -1
173
+ for epoch_i in serialization_dir.glob("epoch-*"):
174
+ epoch_i = Path(epoch_i)
175
+ epoch_idx = epoch_i.stem.split("-")[1]
176
+ epoch_idx = int(epoch_idx)
177
+ if epoch_idx > last_epoch:
178
+ last_epoch = epoch_idx
179
+
180
+ if last_epoch != -1:
181
+ logger.info(f"resume from epoch-{last_epoch}.")
182
+ model_pt = serialization_dir / f"epoch-{last_epoch}/model.pt"
183
+ optimizer_pth = serialization_dir / f"epoch-{last_epoch}/optimizer.pth"
184
+
185
+ logger.info(f"load state dict for model.")
186
+ with open(model_pt.as_posix(), "rb") as f:
187
+ state_dict = torch.load(f, map_location="cpu", weights_only=True)
188
+ model.load_state_dict(state_dict, strict=True)
189
+
190
+ logger.info(f"load state dict for optimizer.")
191
+ with open(optimizer_pth.as_posix(), "rb") as f:
192
+ state_dict = torch.load(f, map_location="cpu", weights_only=True)
193
+ optimizer.load_state_dict(state_dict)
194
+
195
+ lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
196
+ optimizer,
197
+ milestones=[10000, 20000, 30000, 40000, 50000], gamma=0.5
198
+ )
199
+
200
+ ae_loss_fn = nn.L1Loss(reduction="mean").to(device)
201
+ neg_si_snr_loss_fn = NegativeSISNRLoss(reduction="mean").to(device)
202
+ neg_stoi_loss_fn = NegSTOILoss(sample_rate=8000, reduction="mean").to(device)
203
+ lds_loss_fn = LSDLoss(reduction="mean").to(device)
204
+
205
+ # training loop
206
+
207
+ # state
208
+ average_pesq_score = 1000000000
209
+ average_loss = 1000000000
210
+ average_ae_loss = 1000000000
211
+ average_neg_si_snr_loss = 1000000000
212
+ average_neg_stoi_loss = 1000000000
213
+ average_lds_loss = 1000000000
214
+
215
+ model_list = list()
216
+ best_idx_epoch = None
217
+ best_metric = None
218
+ patience_count = 0
219
+
220
+ logger.info("training")
221
+ for idx_epoch in range(max(0, last_epoch+1), args.max_epochs):
222
+ # train
223
+ model.train()
224
+
225
+ total_pesq_score = 0.
226
+ total_loss = 0.
227
+ total_ae_loss = 0.
228
+ total_neg_si_snr_loss = 0.
229
+ total_neg_stoi_loss = 0.
230
+ total_lds_loss = 0.
231
+ total_batches = 0.
232
+ progress_bar = tqdm(
233
+ total=len(train_data_loader),
234
+ desc="Training; epoch: {}".format(idx_epoch),
235
+ )
236
+ for batch in train_data_loader:
237
+ clean_audios, noisy_audios = batch
238
+ clean_audios = clean_audios.to(device)
239
+ noisy_audios = noisy_audios.to(device)
240
+
241
+ denoise_audios = model.forward(noisy_audios)
242
+ denoise_audios = torch.squeeze(denoise_audios, dim=1)
243
+
244
+ ae_loss = ae_loss_fn.forward(denoise_audios, clean_audios)
245
+ neg_si_snr_loss = neg_si_snr_loss_fn.forward(denoise_audios, clean_audios)
246
+ neg_stoi_loss = neg_stoi_loss_fn.forward(denoise_audios, clean_audios)
247
+ lds_loss = lds_loss_fn.forward(denoise_audios, clean_audios)
248
+
249
+ loss = 0.25 * ae_loss + 0.25 * neg_si_snr_loss + 0.25 * neg_stoi_loss + 0.25 * lds_loss
250
+
251
+ denoise_audios_list_r = list(denoise_audios.detach().cpu().numpy())
252
+ clean_audios_list_r = list(clean_audios.detach().cpu().numpy())
253
+ pesq_score = run_pesq_score(clean_audios_list_r, denoise_audios_list_r, sample_rate=8000, mode="nb")
254
+
255
+ optimizer.zero_grad()
256
+ loss.backward()
257
+ optimizer.step()
258
+ lr_scheduler.step()
259
+
260
+ total_pesq_score += pesq_score
261
+ total_loss += loss.item()
262
+ total_ae_loss += ae_loss.item()
263
+ total_neg_si_snr_loss += neg_si_snr_loss.item()
264
+ total_neg_stoi_loss += neg_stoi_loss.item()
265
+ total_lds_loss += lds_loss.item()
266
+ total_batches += 1
267
+
268
+ average_pesq_score = round(total_pesq_score / total_batches, 4)
269
+ average_loss = round(total_loss / total_batches, 4)
270
+ average_ae_loss = round(total_ae_loss / total_batches, 4)
271
+ average_neg_si_snr_loss = round(total_neg_si_snr_loss / total_batches, 4)
272
+ average_neg_stoi_loss = round(total_neg_stoi_loss / total_batches, 4)
273
+ average_lds_loss = round(total_lds_loss / total_batches, 4)
274
+
275
+ progress_bar.update(1)
276
+ progress_bar.set_postfix({
277
+ "pesq_score": average_pesq_score,
278
+ "loss": average_loss,
279
+ "ae_loss": average_ae_loss,
280
+ "neg_si_snr_loss": average_neg_si_snr_loss,
281
+ "neg_stoi_loss": average_neg_stoi_loss,
282
+ "lds_loss": average_lds_loss,
283
+ })
284
+
285
+ # evaluation
286
+ model.eval()
287
+ torch.cuda.empty_cache()
288
+
289
+ total_pesq_score = 0.
290
+ total_loss = 0.
291
+ total_ae_loss = 0.
292
+ total_neg_si_snr_loss = 0.
293
+ total_neg_stoi_loss = 0.
294
+ total_lds_loss = 0.
295
+ total_batches = 0.
296
+
297
+ progress_bar = tqdm(
298
+ total=len(valid_data_loader),
299
+ desc="Evaluation; epoch: {}".format(idx_epoch),
300
+ )
301
+ with torch.no_grad():
302
+ for batch in valid_data_loader:
303
+ clean_audios, noisy_audios = batch
304
+ clean_audios = clean_audios.to(device)
305
+ noisy_audios = noisy_audios.to(device)
306
+
307
+ denoise_audios = model.forward(noisy_audios)
308
+ denoise_audios = torch.squeeze(denoise_audios, dim=1)
309
+
310
+ ae_loss = ae_loss_fn.forward(denoise_audios, clean_audios)
311
+ neg_si_snr_loss = neg_si_snr_loss_fn.forward(denoise_audios, clean_audios)
312
+ neg_stoi_loss = neg_stoi_loss_fn.forward(denoise_audios, clean_audios)
313
+ lds_loss = lds_loss_fn.forward(denoise_audios, clean_audios)
314
+
315
+ loss = 0.25 * ae_loss + 0.25 * neg_si_snr_loss + 0.25 * neg_stoi_loss + 0.25 * lds_loss
316
+
317
+ denoise_audios_list_r = list(denoise_audios.detach().cpu().numpy())
318
+ clean_audios_list_r = list(clean_audios.detach().cpu().numpy())
319
+ pesq_score = run_pesq_score(clean_audios_list_r, denoise_audios_list_r, sample_rate=8000, mode="nb")
320
+
321
+ total_pesq_score += pesq_score
322
+ total_loss += loss.item()
323
+ total_ae_loss += ae_loss.item()
324
+ total_neg_si_snr_loss += neg_si_snr_loss.item()
325
+ total_neg_stoi_loss += neg_stoi_loss.item()
326
+ total_lds_loss += lds_loss.item()
327
+ total_batches += 1
328
+
329
+ average_pesq_score = round(total_pesq_score / total_batches, 4)
330
+ average_loss = round(total_loss / total_batches, 4)
331
+ average_ae_loss = round(total_ae_loss / total_batches, 4)
332
+ average_neg_si_snr_loss = round(total_neg_si_snr_loss / total_batches, 4)
333
+ average_neg_stoi_loss = round(total_neg_stoi_loss / total_batches, 4)
334
+ average_lds_loss = round(total_lds_loss / total_batches, 4)
335
+
336
+ progress_bar.update(1)
337
+ progress_bar.set_postfix({
338
+ "pesq_score": average_pesq_score,
339
+ "loss": average_loss,
340
+ "ae_loss": average_ae_loss,
341
+ "neg_si_snr_loss": average_neg_si_snr_loss,
342
+ "neg_stoi_loss": average_neg_stoi_loss,
343
+ "lds_loss": average_lds_loss,
344
+ })
345
+
346
+ # scheduler
347
+ lr_scheduler.step()
348
+
349
+ # save path
350
+ epoch_dir = serialization_dir / "epoch-{}".format(idx_epoch)
351
+ epoch_dir.mkdir(parents=True, exist_ok=False)
352
+
353
+ # save models
354
+ model.save_pretrained(epoch_dir.as_posix())
355
+
356
+ model_list.append(epoch_dir)
357
+ if len(model_list) >= args.num_serialized_models_to_keep:
358
+ model_to_delete: Path = model_list.pop(0)
359
+ shutil.rmtree(model_to_delete.as_posix())
360
+
361
+ # save optim
362
+ torch.save(optimizer.state_dict(), (epoch_dir / "optimizer.pth").as_posix())
363
+
364
+ # save metric
365
+ if best_metric is None:
366
+ best_idx_epoch = idx_epoch
367
+ best_metric = average_loss
368
+ elif average_loss < best_metric:
369
+ # great is better.
370
+ best_idx_epoch = idx_epoch
371
+ best_metric = average_loss
372
+ else:
373
+ pass
374
+
375
+ metrics = {
376
+ "idx_epoch": idx_epoch,
377
+ "best_idx_epoch": best_idx_epoch,
378
+ "pesq_score": average_pesq_score,
379
+ "loss": average_loss,
380
+ "ae_loss": average_ae_loss,
381
+ "neg_si_snr_loss": average_neg_si_snr_loss,
382
+ "neg_stoi_loss": average_neg_stoi_loss,
383
+ "lds_loss": average_lds_loss,
384
+ }
385
+ metrics_filename = epoch_dir / "metrics_epoch.json"
386
+ with open(metrics_filename, "w", encoding="utf-8") as f:
387
+ json.dump(metrics, f, indent=4, ensure_ascii=False)
388
+
389
+ # save best
390
+ best_dir = serialization_dir / "best"
391
+ if best_idx_epoch == idx_epoch:
392
+ if best_dir.exists():
393
+ shutil.rmtree(best_dir)
394
+ shutil.copytree(epoch_dir, best_dir)
395
+
396
+ # early stop
397
+ early_stop_flag = False
398
+ if best_idx_epoch == idx_epoch:
399
+ patience_count = 0
400
+ else:
401
+ patience_count += 1
402
+ if patience_count >= args.patience:
403
+ early_stop_flag = True
404
+
405
+ # early stop
406
+ if early_stop_flag:
407
+ break
408
+
409
+ return
410
+
411
+
412
+ if __name__ == "__main__":
413
+ main()
examples/conv_tasnet/yaml/config.yaml ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model_name: "nx_clean_unet"
2
+
3
+ sample_rate: 8000
4
+ segment_size: 16000
5
+ n_fft: 512
6
+ win_size: 200
7
+ hop_size: 80
8
+
9
+ down_sampling_num_layers: 6
10
+ down_sampling_in_channels: 1
11
+ down_sampling_hidden_channels: 64
12
+ down_sampling_kernel_size: 4
13
+ down_sampling_stride: 2
14
+
15
+ causal_in_channels: 1
16
+ causal_out_channels: 1
17
+ causal_kernel_size: 3
18
+ causal_bias: false
19
+ causal_separable: true
20
+ causal_f_stride: 1
21
+ causal_num_layers: 5
22
+
23
+ tsfm_hidden_size: 256
24
+ tsfm_attention_heads: 8
25
+ tsfm_num_blocks: 6
26
+ tsfm_dropout_rate: 0.1
27
+ tsfm_max_length: 512
28
+ tsfm_chunk_size: 1
29
+ tsfm_num_left_chunks: 128
30
+ tsfm_num_right_chunks: 4
31
+
32
+ discriminator_dim: 32
33
+ discriminator_in_channel: 2
34
+
35
+ compress_factor: 0.3
36
+
37
+ batch_size: 64
38
+ learning_rate: 0.0005
39
+ adam_b1: 0.8
40
+ adam_b2: 0.99
41
+ lr_decay: 0.99
42
+ seed: 1234
examples/data_preprocess/dns_challenge_to_8k/process_dns_challenge_clean_emotional_speech.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ https://github.com/microsoft/DNS-Challenge/blob/master/download-dns-challenge-3.sh
5
+
6
+ 1.2G
7
+ wget https://dns3public.blob.core.windows.net/dns3archive/DEMAND.tar.bz2
8
+
9
+ 14G
10
+ wget https://dns3public.blob.core.windows.net/dns3archive/datasets/datasets.noise.tar.bz2
11
+
12
+ 38G
13
+ wget https://dns3public.blob.core.windows.net/dns3archive/datasets/datasets.clean.read_speech.tar.bz2
14
+
15
+ 247M
16
+ wget https://dns3public.blob.core.windows.net/dns3archive/datasets/datasets.clean.emotional_speech.tar.bz2
17
+
18
+
19
+ """
20
+ import argparse
21
+ import os
22
+ from pathlib import Path
23
+ import sys
24
+
25
+ import numpy as np
26
+ from tqdm import tqdm
27
+
28
+ pwd = os.path.abspath(os.path.dirname(__file__))
29
+ sys.path.append(os.path.join(pwd, "../../"))
30
+
31
+ import librosa
32
+ from scipy.io import wavfile
33
+
34
+
35
+ def get_args():
36
+ parser = argparse.ArgumentParser()
37
+
38
+ parser.add_argument(
39
+ "--data_dir",
40
+ default=r"E:\programmer\asr_datasets\dns-challenge\datasets.clean.emotional_speech\datasets\clean\emotional_speech",
41
+ type=str
42
+ )
43
+ parser.add_argument(
44
+ "--output_dir",
45
+ default=r"E:\programmer\asr_datasets\denoise\dns-clean-emotional-speech-8k",
46
+ type=str
47
+ )
48
+ parser.add_argument("--sample_rate", default=8000, type=int)
49
+ args = parser.parse_args()
50
+ return args
51
+
52
+
53
+ def main():
54
+ args = get_args()
55
+
56
+ data_dir = Path(args.data_dir)
57
+ output_dir = Path(args.output_dir)
58
+ output_dir.mkdir(parents=True, exist_ok=True)
59
+
60
+ # finished_set
61
+ finished_set = set()
62
+ for filename in tqdm(output_dir.glob("**/*.wav")):
63
+ name = filename.stem
64
+ finished_set.add(name)
65
+ print(f"finished_set count: {len(finished_set)}")
66
+
67
+ for filename in tqdm(data_dir.glob("**/*.wav")):
68
+ label = filename.parts[-2]
69
+ name = filename.stem
70
+ # print(f"filename: {filename.as_posix()}")
71
+ if name in finished_set:
72
+ continue
73
+
74
+ signal, _ = librosa.load(filename.as_posix(), sr=args.sample_rate)
75
+
76
+ signal = signal * (1 << 15)
77
+ signal = np.array(signal, dtype=np.int16)
78
+
79
+ to_file = output_dir / f"{label}/{name}.wav"
80
+ to_file.parent.mkdir(parents=True, exist_ok=True)
81
+ wavfile.write(
82
+ to_file.as_posix(),
83
+ rate=args.sample_rate,
84
+ data=signal,
85
+ )
86
+ return
87
+
88
+
89
+ if __name__ == "__main__":
90
+ main()
examples/data_preprocess/dns_challenge_to_8k/process_dns_challenge_clean_read_speech.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ https://github.com/microsoft/DNS-Challenge/blob/master/download-dns-challenge-3.sh
5
+
6
+ 1.2G
7
+ wget https://dns3public.blob.core.windows.net/dns3archive/DEMAND.tar.bz2
8
+
9
+ 14G
10
+ wget https://dns3public.blob.core.windows.net/dns3archive/datasets/datasets.noise.tar.bz2
11
+
12
+ 38G
13
+ wget https://dns3public.blob.core.windows.net/dns3archive/datasets/datasets.clean.read_speech.tar.bz2
14
+
15
+ 12G
16
+ wget https://dns3public.blob.core.windows.net/dns3archive/datasets/datasets.clean.french_data.tar.bz2
17
+
18
+ 43G
19
+ wget https://dns3public.blob.core.windows.net/dns3archive/datasets/datasets.clean.german_speech.tar.bz2
20
+
21
+ 7.9G
22
+ wget https://dns3public.blob.core.windows.net/dns3archive/datasets/datasets.clean.italian_speech.tar.bz2
23
+
24
+ 12G
25
+ wget https://dns3public.blob.core.windows.net/dns3archive/datasets/datasets.clean.mandarin_speech.tar.bz2
26
+
27
+ 3.1G
28
+ wget https://dns3public.blob.core.windows.net/dns3archive/datasets/datasets.clean.russian_speech.tar.bz2
29
+
30
+ 9.7G
31
+ wget https://dns3public.blob.core.windows.net/dns3archive/datasets/datasets.clean.spanish_speech.tar.bz2
32
+
33
+ 617M
34
+ wget https://dns3public.blob.core.windows.net/dns3archive/datasets/datasets.clean.singing_voice.tar.bz2
35
+
36
+ """
37
+ import argparse
38
+ import os
39
+ from pathlib import Path
40
+ import sys
41
+
42
+ import numpy as np
43
+ from tqdm import tqdm
44
+
45
+ pwd = os.path.abspath(os.path.dirname(__file__))
46
+ sys.path.append(os.path.join(pwd, "../../"))
47
+
48
+ import librosa
49
+ from scipy.io import wavfile
50
+
51
+
52
+ def get_args():
53
+ parser = argparse.ArgumentParser()
54
+
55
+ parser.add_argument(
56
+ "--data_dir",
57
+ # default=r"E:\programmer\asr_datasets\dns-challenge\datasets.clean.read_speech\datasets\clean",
58
+ # default=r"E:\programmer\asr_datasets\dns-challenge\datasets.clean.mandarin_speech\datasets\clean\mandarin_speech",
59
+ # default=r"E:\programmer\asr_datasets\dns-challenge\datasets.clean.singing_voice\datasets\clean\singing_voice",
60
+ # default=r"E:\programmer\asr_datasets\dns-challenge\datasets.clean.german_speech\datasets\clean\german_speech",
61
+ # default=r"E:\programmer\asr_datasets\dns-challenge\datasets.clean.italian_speech\datasets\clean\italian_speech",
62
+ default=r"E:\programmer\asr_datasets\dns-challenge\datasets.clean.russian_speech\datasets\clean\russian_speech",
63
+ type=str
64
+ )
65
+ parser.add_argument(
66
+ "--output_dir",
67
+ # default=r"E:\programmer\asr_datasets\denoise\dns-clean-read-speech-8k",
68
+ # default=r"E:\programmer\asr_datasets\denoise\dns-clean-mandarin-speech-8k",
69
+ # default=r"E:\programmer\asr_datasets\denoise\dns-clean-singing-voice-8k",
70
+ # default=r"E:\programmer\asr_datasets\denoise\dns-clean-german-speech-8k",
71
+ # default=r"E:\programmer\asr_datasets\denoise\dns-clean-italian-speech-8k",
72
+ default=r"E:\programmer\asr_datasets\denoise\dns-clean-russian-speech-8k",
73
+ type=str
74
+ )
75
+ parser.add_argument("--sample_rate", default=8000, type=int)
76
+ args = parser.parse_args()
77
+ return args
78
+
79
+
80
+ def main():
81
+ args = get_args()
82
+
83
+ data_dir = Path(args.data_dir)
84
+ output_dir = Path(args.output_dir)
85
+ output_dir.mkdir(parents=True, exist_ok=True)
86
+
87
+ # finished_set
88
+ finished_set = set()
89
+ for filename in tqdm(output_dir.glob("**/*.wav")):
90
+ name = filename.stem
91
+ finished_set.add(name)
92
+ print(f"finished_set count: {len(finished_set)}")
93
+
94
+ for filename in tqdm(data_dir.glob("**/*.wav")):
95
+ label = filename.parts[-2]
96
+ name = filename.stem
97
+ relative_name = filename.relative_to(data_dir)
98
+ # print(f"filename: {filename.as_posix()}")
99
+ if name in finished_set:
100
+ continue
101
+ finished_set.add(name)
102
+
103
+ try:
104
+ signal, _ = librosa.load(filename.as_posix(), sr=args.sample_rate)
105
+ except Exception:
106
+ print(f"skip file: {filename.as_posix()}")
107
+ continue
108
+
109
+ signal = signal * (1 << 15)
110
+ signal = np.array(signal, dtype=np.int16)
111
+
112
+ to_file = output_dir / relative_name.as_posix()
113
+ to_file.parent.mkdir(parents=True, exist_ok=True)
114
+ wavfile.write(
115
+ to_file.as_posix(),
116
+ rate=args.sample_rate,
117
+ data=signal,
118
+ )
119
+ return
120
+
121
+
122
+ if __name__ == "__main__":
123
+ main()
examples/data_preprocess/dns_challenge_to_8k/process_dns_challenge_demand.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ https://github.com/microsoft/DNS-Challenge/blob/master/download-dns-challenge-3.sh
5
+
6
+ 1.2G
7
+ wget https://dns3public.blob.core.windows.net/dns3archive/DEMAND.tar.bz2
8
+
9
+ """
10
+ import argparse
11
+ import os
12
+ from pathlib import Path
13
+ import random
14
+ import sys
15
+ import shutil
16
+
17
+ import numpy as np
18
+
19
+ pwd = os.path.abspath(os.path.dirname(__file__))
20
+ sys.path.append(os.path.join(pwd, "../../"))
21
+
22
+ import librosa
23
+ from scipy.io import wavfile
24
+
25
+
26
+ def get_args():
27
+ parser = argparse.ArgumentParser()
28
+
29
+ parser.add_argument(
30
+ "--data_dir",
31
+ default=r"E:\programmer\asr_datasets\dns-challenge\DEMAND\demand",
32
+ type=str
33
+ )
34
+ parser.add_argument(
35
+ "--output_dir",
36
+ default=r"E:\programmer\asr_datasets\denoise\demand-8k",
37
+ type=str
38
+ )
39
+ parser.add_argument("--sample_rate", default=8000, type=int)
40
+ args = parser.parse_args()
41
+ return args
42
+
43
+
44
+ def main():
45
+ args = get_args()
46
+
47
+ data_dir = Path(args.data_dir)
48
+ output_dir = Path(args.output_dir)
49
+ output_dir.mkdir(parents=True, exist_ok=False)
50
+
51
+ for filename in data_dir.glob("**/ch01.wav"):
52
+ label = filename.parts[-2]
53
+ name = filename.stem
54
+
55
+ signal, _ = librosa.load(filename.as_posix(), sr=args.sample_rate)
56
+
57
+ signal = signal * (1 << 15)
58
+ signal = np.array(signal, dtype=np.int16)
59
+
60
+ to_file = output_dir / f"{label}/{name}.wav"
61
+ to_file.parent.mkdir(parents=True, exist_ok=True)
62
+ wavfile.write(
63
+ to_file.as_posix(),
64
+ rate=args.sample_rate,
65
+ data=signal,
66
+ )
67
+ return
68
+
69
+
70
+ if __name__ == '__main__':
71
+ main()
examples/data_preprocess/dns_challenge_to_8k/process_dns_challenge_impulse_responses.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ https://github.com/microsoft/DNS-Challenge/blob/master/download-dns-challenge-3.sh
5
+
6
+ 1.2G
7
+ wget https://dns3public.blob.core.windows.net/dns3archive/DEMAND.tar.bz2
8
+
9
+ 14G
10
+ wget https://dns3public.blob.core.windows.net/dns3archive/datasets/datasets.noise.tar.bz2
11
+
12
+ 38G
13
+ wget https://dns3public.blob.core.windows.net/dns3archive/datasets/datasets.clean.read_speech.tar.bz2
14
+
15
+ 247M
16
+ wget https://dns3public.blob.core.windows.net/dns3archive/datasets/datasets.clean.emotional_speech.tar.bz2
17
+
18
+ 240M
19
+ wget https://dns3public.blob.core.windows.net/dns3archive/datasets/datasets.impulse_responses.tar.bz2
20
+
21
+
22
+ """
23
+ import argparse
24
+ import os
25
+ from pathlib import Path
26
+ import sys
27
+
28
+ import numpy as np
29
+ from tqdm import tqdm
30
+
31
+ pwd = os.path.abspath(os.path.dirname(__file__))
32
+ sys.path.append(os.path.join(pwd, "../../"))
33
+
34
+ import librosa
35
+ from scipy.io import wavfile
36
+
37
+
38
+ def get_args():
39
+ parser = argparse.ArgumentParser()
40
+
41
+ parser.add_argument(
42
+ "--data_dir",
43
+ default=r"E:\programmer\asr_datasets\dns-challenge\datasets.clean.emotional_speech\datasets\clean\emotional_speech",
44
+ type=str
45
+ )
46
+ parser.add_argument(
47
+ "--output_dir",
48
+ default=r"E:\programmer\asr_datasets\denoise\dns-clean-emotional-speech-8k",
49
+ type=str
50
+ )
51
+ parser.add_argument("--sample_rate", default=8000, type=int)
52
+ args = parser.parse_args()
53
+ return args
54
+
55
+
56
+ def main():
57
+ args = get_args()
58
+
59
+ data_dir = Path(args.data_dir)
60
+ output_dir = Path(args.output_dir)
61
+ output_dir.mkdir(parents=True, exist_ok=True)
62
+
63
+ # finished_set
64
+ finished_set = set()
65
+ for filename in tqdm(output_dir.glob("**/*.wav")):
66
+ name = filename.stem
67
+ finished_set.add(name)
68
+ print(f"finished_set count: {len(finished_set)}")
69
+
70
+ for filename in tqdm(data_dir.glob("**/*.wav")):
71
+ label = filename.parts[-2]
72
+ name = filename.stem
73
+ # print(f"filename: {filename.as_posix()}")
74
+ if name in finished_set:
75
+ continue
76
+
77
+ signal, _ = librosa.load(filename.as_posix(), sr=args.sample_rate)
78
+
79
+ signal = signal * (1 << 15)
80
+ signal = np.array(signal, dtype=np.int16)
81
+
82
+ to_file = output_dir / f"{label}/{name}.wav"
83
+ to_file.parent.mkdir(parents=True, exist_ok=True)
84
+ wavfile.write(
85
+ to_file.as_posix(),
86
+ rate=args.sample_rate,
87
+ data=signal,
88
+ )
89
+ return
90
+
91
+
92
+ if __name__ == "__main__":
93
+ main()
examples/data_preprocess/dns_challenge_to_8k/process_dns_challenge_noise.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ https://github.com/microsoft/DNS-Challenge/blob/master/download-dns-challenge-3.sh
5
+
6
+ 1.2G
7
+ wget https://dns3public.blob.core.windows.net/dns3archive/DEMAND.tar.bz2
8
+
9
+ 14G
10
+ wget https://dns3public.blob.core.windows.net/dns3archive/datasets/datasets.noise.tar.bz2
11
+
12
+ 38G
13
+ wget https://dns3public.blob.core.windows.net/dns3archive/datasets/datasets.clean.read_speech.tar.bz2
14
+
15
+ """
16
+ import argparse
17
+ import os
18
+ from pathlib import Path
19
+ import sys
20
+
21
+ import numpy as np
22
+ from tqdm import tqdm
23
+
24
+ pwd = os.path.abspath(os.path.dirname(__file__))
25
+ sys.path.append(os.path.join(pwd, "../../"))
26
+
27
+ import librosa
28
+ from scipy.io import wavfile
29
+
30
+
31
+ def get_args():
32
+ parser = argparse.ArgumentParser()
33
+
34
+ parser.add_argument(
35
+ "--data_dir",
36
+ default=r"E:\programmer\asr_datasets\dns-challenge\datasets.noise\datasets",
37
+ type=str
38
+ )
39
+ parser.add_argument(
40
+ "--output_dir",
41
+ default=r"E:\programmer\asr_datasets\denoise\dns-noise-8k",
42
+ type=str
43
+ )
44
+ parser.add_argument("--sample_rate", default=8000, type=int)
45
+ args = parser.parse_args()
46
+ return args
47
+
48
+
49
+ def main():
50
+ args = get_args()
51
+
52
+ data_dir = Path(args.data_dir)
53
+ output_dir = Path(args.output_dir)
54
+ output_dir.mkdir(parents=True, exist_ok=True)
55
+
56
+ for filename in tqdm(data_dir.glob("**/*.wav")):
57
+ label = filename.parts[-2]
58
+ name = filename.stem
59
+ # print(f"filename: {filename.as_posix()}")
60
+
61
+ signal, _ = librosa.load(filename.as_posix(), sr=args.sample_rate)
62
+
63
+ signal = signal * (1 << 15)
64
+ signal = np.array(signal, dtype=np.int16)
65
+
66
+ to_file = output_dir / f"{label}/{name}.wav"
67
+ to_file.parent.mkdir(parents=True, exist_ok=True)
68
+ wavfile.write(
69
+ to_file.as_posix(),
70
+ rate=args.sample_rate,
71
+ data=signal,
72
+ )
73
+ return
74
+
75
+
76
+ if __name__ == '__main__':
77
+ main()
examples/data_preprocess/dns_challenge_to_8k/process_musan.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ https://www.openslr.org/17/
5
+ """
6
+
7
+ if __name__ == '__main__':
8
+ pass
examples/mpnet/run.sh CHANGED
@@ -17,10 +17,10 @@ sh run.sh --stage 5 --stop_stage 5 --system_version centos --file_folder_name fi
17
  --speech_dir "/data/tianxing/HuggingDatasets/aishell/data_aishell/wav/train"
18
 
19
 
20
- sh run.sh --stage 1 --stop_stage 2 --system_version centos --file_folder_name file_dir --final_model_name mpnet-nx-speech-20250224 \
21
  --noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise" \
22
  --speech_dir "/data/tianxing/HuggingDatasets/nx_noise/data/speech" \
23
- --max_epochs 1
24
 
25
 
26
  END
 
17
  --speech_dir "/data/tianxing/HuggingDatasets/aishell/data_aishell/wav/train"
18
 
19
 
20
+ sh run.sh --stage 1 --stop_stage 2 --system_version centos --file_folder_name file_dir --final_model_name mpnet-nx-speech \
21
  --noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise" \
22
  --speech_dir "/data/tianxing/HuggingDatasets/nx_noise/data/speech" \
23
+ --max_epochs 100
24
 
25
 
26
  END
examples/nx_mpnet/yaml/config.yaml CHANGED
@@ -15,15 +15,15 @@ mask_hidden_size: 64
15
  phase_num_blocks: 4
16
  phase_hidden_size: 64
17
 
18
- tsfm_hidden_size: 128
19
- tsfm_attention_heads: 8
20
- tsfm_num_blocks: 6
21
  tsfm_dropout_rate: 0.0
22
  tsfm_max_time_relative_position: 2048
23
  tsfm_max_freq_relative_position: 256
24
  tsfm_chunk_size: 1
25
- tsfm_num_left_chunks: 64
26
- tsfm_num_right_chunks: 32
27
 
28
  discriminator_dim: 32
29
  discriminator_in_channel: 2
 
15
  phase_num_blocks: 4
16
  phase_hidden_size: 64
17
 
18
+ tsfm_hidden_size: 64
19
+ tsfm_attention_heads: 4
20
+ tsfm_num_blocks: 4
21
  tsfm_dropout_rate: 0.0
22
  tsfm_max_time_relative_position: 2048
23
  tsfm_max_freq_relative_position: 256
24
  tsfm_chunk_size: 1
25
+ tsfm_num_left_chunks: 128
26
+ tsfm_num_right_chunks: 64
27
 
28
  discriminator_dim: 32
29
  discriminator_in_channel: 2
main.py CHANGED
@@ -67,6 +67,13 @@ denoise_engines = {
67
  project_path / "trained_models/mpnet-nx-speech-1-epoch.zip").as_posix()
68
  }
69
  },
 
 
 
 
 
 
 
70
  "mpnet-aishell-1-epoch": {
71
  "infer_cls": InferenceMPNet,
72
  "kwargs": {
@@ -187,7 +194,7 @@ def main():
187
  outputs=[shell_output],
188
  )
189
 
190
- # http://127.0.0.1:7864/
191
  blocks.queue().launch(
192
  share=False if platform.system() == "Windows" else False,
193
  server_name="127.0.0.1" if platform.system() == "Windows" else "0.0.0.0",
 
67
  project_path / "trained_models/mpnet-nx-speech-1-epoch.zip").as_posix()
68
  }
69
  },
70
+ "mpnet-nx-speech-20-epoch": {
71
+ "infer_cls": InferenceMPNet,
72
+ "kwargs": {
73
+ "pretrained_model_path_or_zip_file": (
74
+ project_path / "trained_models/mpnet-nx-speech-20-epoch.zip").as_posix()
75
+ }
76
+ },
77
  "mpnet-aishell-1-epoch": {
78
  "infer_cls": InferenceMPNet,
79
  "kwargs": {
 
194
  outputs=[shell_output],
195
  )
196
 
197
+ # http://127.0.0.1:7865/
198
  blocks.queue().launch(
199
  share=False if platform.system() == "Windows" else False,
200
  server_name="127.0.0.1" if platform.system() == "Windows" else "0.0.0.0",
requirements.txt CHANGED
@@ -12,3 +12,4 @@ torch-pesq==0.1.2
12
  torchmetrics==1.6.1
13
  torchmetrics[audio]==1.6.1
14
  einops==0.8.1
 
 
12
  torchmetrics==1.6.1
13
  torchmetrics[audio]==1.6.1
14
  einops==0.8.1
15
+ torch_stoi==0.2.3
toolbox/torchaudio/losses/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+
5
+ if __name__ == '__main__':
6
+ pass
toolbox/torchaudio/losses/perceptual.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ https://zhuanlan.zhihu.com/p/627039860
5
+ """
6
+ import torch
7
+ import torch.nn as nn
8
+ from torch_stoi import NegSTOILoss as TorchNegSTOILoss
9
+
10
+
11
+ class PMSQELoss(object):
12
+ """
13
+ A Deep Learning Loss Function based on the Perceptual Evaluation of the Speech Quality
14
+ https://sigmat.ugr.es/PMSQE/
15
+
16
+ On Loss Functions for Supervised Monaural Time-Domain Speech Enhancement
17
+ https://arxiv.org/abs/1909.01019
18
+
19
+ https://github.com/asteroid-team/asteroid/blob/master/asteroid/losses/pmsqe.py
20
+ """
21
+
22
+
23
+ class NegSTOILoss(nn.Module):
24
+ """
25
+ STOI短时客观可懂度(Short-Time Objective Intelligibility),
26
+ 通过计算语音信号的时域和频域特征之间的相关性来预测语音的可理解度,
27
+ 范围从0到1,分数越高可懂度越高。
28
+ 它适用于评估噪声环境下的语音可懂度改善效果。
29
+
30
+ https://github.com/mpariente/pytorch_stoi
31
+ https://github.com/mpariente/pystoi
32
+ https://github.com/speechbrain/speechbrain/blob/develop/speechbrain/nnet/loss/stoi_loss.py
33
+ """
34
+ def __init__(self,
35
+ sample_rate: int,
36
+ reduction: str = "mean",
37
+ ):
38
+ super(NegSTOILoss, self).__init__()
39
+ self.loss_fn = TorchNegSTOILoss(sample_rate=sample_rate)
40
+ self.reduction = reduction
41
+
42
+ if reduction not in ("sum", "mean"):
43
+ raise AssertionError(f"param reduction must be sum or mean.")
44
+
45
+ def forward(self, denoise: torch.Tensor, clean: torch.Tensor):
46
+
47
+ batch_loss = self.loss_fn.forward(denoise, clean)
48
+
49
+ if self.reduction == "mean":
50
+ loss = torch.mean(batch_loss)
51
+ elif self.reduction == "sum":
52
+ loss = torch.sum(batch_loss)
53
+ else:
54
+ raise AssertionError
55
+ return loss
56
+
57
+
58
+ def main():
59
+ sample_rate = 16000
60
+
61
+ loss_func = NegSTOILoss(
62
+ sample_rate=sample_rate,
63
+ reduction="mean",
64
+ )
65
+
66
+ denoise = torch.randn(2, sample_rate)
67
+ clean = torch.randn(2, sample_rate)
68
+
69
+ loss_batch = loss_func.forward(denoise, clean)
70
+ print(loss_batch)
71
+ return
72
+
73
+
74
+ if __name__ == "__main__":
75
+ main()
toolbox/torchaudio/losses/snr.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ https://zhuanlan.zhihu.com/p/627039860
5
+ """
6
+ import torch
7
+ import torch.nn as nn
8
+
9
+
10
+ class NegativeSNRLoss(nn.Module):
11
+ """
12
+ Signal-to-Noise Ratio
13
+ """
14
+ def __init__(self, eps: float = 1e-8):
15
+ super(NegativeSNRLoss, self).__init__()
16
+ self.eps = eps
17
+
18
+ def forward(self, denoise: torch.Tensor, clean: torch.Tensor):
19
+ """
20
+ Compute the SI-SNR loss between the estimated signal and the target signal.
21
+
22
+ :param denoise: The estimated signal (batch_size, signal_length)
23
+ :param clean: The target signal (batch_size, signal_length)
24
+ :return: The SI-SNR loss (batch_size,)
25
+ """
26
+ if denoise.shape != clean.shape:
27
+ raise AssertionError("Input signals must have the same shape")
28
+
29
+ denoise = denoise - torch.mean(denoise, dim=-1, keepdim=True)
30
+ clean = clean - torch.mean(clean, dim=-1, keepdim=True)
31
+
32
+ noise = denoise - clean
33
+
34
+ clean_power = torch.norm(clean, p=2, dim=-1) ** 2
35
+ noise_power = torch.norm(noise, p=2, dim=-1) ** 2
36
+
37
+ snr = 10 * torch.log10((clean_power + self.eps) / (noise_power + self.eps))
38
+
39
+ return -snr.mean()
40
+
41
+
42
+ class NegativeSISNRLoss(nn.Module):
43
+ """
44
+ Scale-Invariant Source-to-Noise Ratio
45
+
46
+ https://arxiv.org/abs/2206.07293
47
+ """
48
+ def __init__(self,
49
+ reduction: str = "mean",
50
+ eps: float = 1e-8,
51
+ ):
52
+ super(NegativeSISNRLoss, self).__init__()
53
+ self.reduction = reduction
54
+ self.eps = eps
55
+
56
+ def forward(self, denoise: torch.Tensor, clean: torch.Tensor):
57
+ """
58
+ Compute the SI-SNR loss between the estimated signal and the target signal.
59
+
60
+ :param denoise: The estimated signal (batch_size, signal_length)
61
+ :param clean: The target signal (batch_size, signal_length)
62
+ :return: The SI-SNR loss (batch_size,)
63
+ """
64
+ if denoise.shape != clean.shape:
65
+ raise AssertionError("Input signals must have the same shape")
66
+
67
+ denoise = denoise - torch.mean(denoise, dim=-1, keepdim=True)
68
+ clean = clean - torch.mean(clean, dim=-1, keepdim=True)
69
+
70
+ s_target = torch.sum(denoise * clean, dim=-1, keepdim=True) * clean / (torch.norm(clean, p=2, dim=-1, keepdim=True) ** 2 + self.eps)
71
+
72
+ e_noise = denoise - s_target
73
+
74
+ batch_si_snr = 10 * torch.log10(torch.norm(s_target, p=2, dim=-1) ** 2 / (torch.norm(e_noise, p=2, dim=-1) ** 2 + self.eps))
75
+ # si_snr shape: [batch_size,]
76
+
77
+ if self.reduction == "mean":
78
+ loss = torch.mean(batch_si_snr)
79
+ elif self.reduction == "sum":
80
+ loss = torch.sum(batch_si_snr)
81
+ else:
82
+ raise AssertionError
83
+ return -loss
84
+
85
+
86
+ def main():
87
+ batch_size = 2
88
+ signal_length = 16000
89
+ estimated_signal = torch.randn(batch_size, signal_length)
90
+ target_signal = torch.randn(batch_size, signal_length)
91
+
92
+ si_snr_loss = NegativeSISNRLoss()
93
+
94
+ loss = si_snr_loss.forward(estimated_signal, target_signal)
95
+ print(f"loss: {loss.item()}")
96
+
97
+ return
98
+
99
+
100
+ if __name__ == "__main__":
101
+ main()
toolbox/torchaudio/losses/spectral.py ADDED
@@ -0,0 +1,351 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ https://zhuanlan.zhihu.com/p/627039860
5
+
6
+ https://github.com/facebookresearch/denoiser/blob/main/denoiser/stft_loss.py
7
+ """
8
+ from typing import List
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ from torch.nn import functional as F
13
+
14
+
15
+ class LSDLoss(nn.Module):
16
+ """
17
+ Log Spectral Distance
18
+
19
+ Mean square error of power spectrum
20
+ """
21
+ def __init__(self,
22
+ n_fft: int = 512,
23
+ win_size: int = 512,
24
+ hop_size: int = 256,
25
+ center: bool = True,
26
+ eps: float = 1e-8,
27
+ reduction: str = "mean",
28
+ ):
29
+ super(LSDLoss, self).__init__()
30
+ self.n_fft = n_fft
31
+ self.win_size = win_size
32
+ self.hop_size = hop_size
33
+ self.center = center
34
+ self.eps = eps
35
+ self.reduction = reduction
36
+
37
+ if reduction not in ("sum", "mean"):
38
+ raise AssertionError(f"param reduction must be sum or mean.")
39
+
40
+ def forward(self, denoise_power: torch.Tensor, clean_power: torch.Tensor):
41
+ """
42
+ :param denoise_power: The estimated signal (batch_size, signal_length)
43
+ :param clean_power: The target signal (batch_size, signal_length)
44
+ :return:
45
+ """
46
+ denoise_power = denoise_power + self.eps
47
+ clean_power = clean_power + self.eps
48
+
49
+ log_denoise_power = torch.log10(denoise_power)
50
+ log_clean_power = torch.log10(clean_power)
51
+
52
+ # mean_square_error shape: [b, f]
53
+ mean_square_error = torch.mean(torch.square(log_denoise_power - log_clean_power), dim=-1)
54
+
55
+ if self.reduction == "mean":
56
+ lsd_loss = torch.mean(mean_square_error)
57
+ elif self.reduction == "sum":
58
+ lsd_loss = torch.sum(mean_square_error)
59
+ else:
60
+ raise AssertionError
61
+ return lsd_loss
62
+
63
+
64
+ class ComplexSpectralLoss(nn.Module):
65
+ def __init__(self,
66
+ n_fft: int = 512,
67
+ win_size: int = 512,
68
+ hop_size: int = 256,
69
+ center: bool = True,
70
+ eps: float = 1e-8,
71
+ reduction: str = "mean",
72
+ factor_mag: float = 0.5,
73
+ factor_pha: float = 0.3,
74
+ factor_gra: float = 0.2,
75
+ ):
76
+ super().__init__()
77
+ self.n_fft = n_fft
78
+ self.win_size = win_size
79
+ self.hop_size = hop_size
80
+ self.center = center
81
+ self.eps = eps
82
+ self.reduction = reduction
83
+
84
+ self.factor_mag = factor_mag
85
+ self.factor_pha = factor_pha
86
+ self.factor_gra = factor_gra
87
+
88
+ if reduction not in ("sum", "mean"):
89
+ raise AssertionError(f"param reduction must be sum or mean.")
90
+
91
+ self.window = nn.Parameter(torch.hann_window(win_size), requires_grad=False)
92
+
93
+ def forward(self, denoise: torch.Tensor, clean: torch.Tensor):
94
+ """
95
+ :param denoise: The estimated signal (batch_size, signal_length)
96
+ :param clean: The target signal (batch_size, signal_length)
97
+ :return:
98
+ """
99
+ if denoise.shape != clean.shape:
100
+ raise AssertionError("Input signals must have the same shape")
101
+
102
+ # denoise_stft, clean_stft shape: [b, f, t]
103
+ denoise_stft = torch.stft(
104
+ denoise,
105
+ n_fft=self.n_fft,
106
+ win_length=self.win_size,
107
+ hop_length=self.hop_size,
108
+ window=self.window,
109
+ center=self.center,
110
+ pad_mode="reflect",
111
+ normalized=False,
112
+ return_complex=True
113
+ )
114
+ clean_stft = torch.stft(
115
+ clean,
116
+ n_fft=self.n_fft,
117
+ win_length=self.win_size,
118
+ hop_length=self.hop_size,
119
+ window=self.window,
120
+ center=self.center,
121
+ pad_mode="reflect",
122
+ normalized=False,
123
+ return_complex=True
124
+ )
125
+
126
+ # complex_diff shape: [b, f, t], dtype: torch.complex64
127
+ complex_diff = denoise_stft - clean_stft
128
+
129
+ # magnitude_diff, phase_diff shape: [b, f, t], dtype: torch.float32
130
+ magnitude_diff = torch.abs(complex_diff)
131
+ phase_diff = torch.angle(complex_diff)
132
+
133
+ # magnitude_loss, phase_loss shape: [b,]
134
+ magnitude_loss = torch.norm(magnitude_diff, p=2, dim=(-1, -2))
135
+ phase_loss = torch.norm(phase_diff, p=1, dim=(-1, -2))
136
+
137
+ # phase_grad shape: [b, f, t-1], dtype: torch.float32
138
+ phase_grad = torch.diff(torch.angle(denoise_stft), dim=-1)
139
+ grad_loss = torch.mean(torch.abs(phase_grad), dim=(-1, -2))
140
+
141
+ # loss, grad_loss shape: [b,]
142
+ batch_loss = self.factor_mag * magnitude_loss + self.factor_pha * phase_loss + self.factor_gra * grad_loss
143
+ # print(f"magnitude_loss: {magnitude_loss}")
144
+ # print(f"phase_loss: {phase_loss}")
145
+ # print(f"grad_loss: {grad_loss}")
146
+
147
+ if self.reduction == "mean":
148
+ loss = torch.mean(batch_loss)
149
+ elif self.reduction == "sum":
150
+ loss = torch.sum(batch_loss)
151
+ else:
152
+ raise AssertionError
153
+ return loss
154
+
155
+
156
+ class SpectralConvergenceLoss(torch.nn.Module):
157
+ """Spectral convergence loss module."""
158
+
159
+ def __init__(self,
160
+ reduction: str = "mean",
161
+ ):
162
+ super(SpectralConvergenceLoss, self).__init__()
163
+ self.reduction = reduction
164
+
165
+ if reduction not in ("sum", "mean"):
166
+ raise AssertionError(f"param reduction must be sum or mean.")
167
+
168
+ def forward(self,
169
+ denoise_magnitude: torch.Tensor,
170
+ clean_magnitude: torch.Tensor,
171
+ ):
172
+ """
173
+ :param denoise_magnitude: Tensor, shape: [batch_size, time_steps, freq_bins]
174
+ :param clean_magnitude: Tensor, shape: [batch_size, time_steps, freq_bins]
175
+ :return:
176
+ """
177
+ error_norm = torch.norm(denoise_magnitude - clean_magnitude, p="fro", dim=(-1, -2))
178
+ truth_norm = torch.norm(clean_magnitude, p="fro", dim=(-1, -2))
179
+ batch_loss = error_norm / truth_norm
180
+ if self.reduction == "mean":
181
+ loss = torch.mean(batch_loss)
182
+ elif self.reduction == "sum":
183
+ loss = torch.sum(batch_loss)
184
+ else:
185
+ raise AssertionError
186
+ return loss
187
+
188
+
189
+ class LogSTFTMagnitudeLoss(torch.nn.Module):
190
+ """Log STFT magnitude loss module."""
191
+
192
+ def __init__(self,
193
+ reduction: str = "mean",
194
+ ):
195
+ super(LogSTFTMagnitudeLoss, self).__init__()
196
+ self.reduction = reduction
197
+
198
+ if reduction not in ("sum", "mean"):
199
+ raise AssertionError(f"param reduction must be sum or mean.")
200
+
201
+ def forward(self,
202
+ denoise_magnitude: torch.Tensor,
203
+ clean_magnitude: torch.Tensor,
204
+ ):
205
+ """
206
+ :param denoise_magnitude: Tensor, shape: [batch_size, time_steps, freq_bins]
207
+ :param clean_magnitude: Tensor, shape: [batch_size, time_steps, freq_bins]
208
+ :return:
209
+ """
210
+ return F.l1_loss(torch.log(denoise_magnitude), torch.log(clean_magnitude))
211
+
212
+
213
+ class STFTLoss(torch.nn.Module):
214
+ """STFT loss module."""
215
+
216
+ def __init__(self,
217
+ n_fft: int = 1024,
218
+ win_size: int = 600,
219
+ hop_size: int = 120,
220
+ center: bool = True,
221
+ reduction: str = "mean",
222
+ ):
223
+ super(STFTLoss, self).__init__()
224
+ self.n_fft = n_fft
225
+ self.win_size = win_size
226
+ self.hop_size = hop_size
227
+ self.center = center
228
+ self.reduction = reduction
229
+
230
+ self.window = nn.Parameter(torch.hann_window(win_size), requires_grad=False)
231
+
232
+ self.spectral_convergence_loss = SpectralConvergenceLoss(reduction=reduction)
233
+ self.log_stft_magnitude_loss = LogSTFTMagnitudeLoss(reduction=reduction)
234
+
235
+ def forward(self, denoise: torch.Tensor, clean: torch.Tensor):
236
+ """
237
+ :param denoise:
238
+ :param clean:
239
+ :return:
240
+ """
241
+ if denoise.shape != clean.shape:
242
+ raise AssertionError("Input signals must have the same shape")
243
+
244
+ # denoise_stft, clean_stft shape: [b, f, t]
245
+ denoise_stft = torch.stft(
246
+ denoise,
247
+ n_fft=self.n_fft,
248
+ win_length=self.win_size,
249
+ hop_length=self.hop_size,
250
+ window=self.window,
251
+ center=self.center,
252
+ pad_mode="reflect",
253
+ normalized=False,
254
+ return_complex=True
255
+ )
256
+ clean_stft = torch.stft(
257
+ clean,
258
+ n_fft=self.n_fft,
259
+ win_length=self.win_size,
260
+ hop_length=self.hop_size,
261
+ window=self.window,
262
+ center=self.center,
263
+ pad_mode="reflect",
264
+ normalized=False,
265
+ return_complex=True
266
+ )
267
+
268
+ denoise_magnitude = torch.abs(denoise_stft)
269
+ clean_magnitude = torch.abs(clean_stft)
270
+
271
+ sc_loss = self.spectral_convergence_loss.forward(denoise_magnitude, clean_magnitude)
272
+ mag_loss = self.log_stft_magnitude_loss.forward(denoise_magnitude, clean_magnitude)
273
+
274
+ return sc_loss, mag_loss
275
+
276
+
277
+ class MultiResolutionSTFTLoss(torch.nn.Module):
278
+ """Multi resolution STFT loss module."""
279
+
280
+ def __init__(self,
281
+ fft_size_list: List[int] = None,
282
+ win_size_list: List[int] = None,
283
+ hop_size_list: List[int] = None,
284
+ factor_sc=0.1,
285
+ factor_mag=0.1,
286
+ ):
287
+ super(MultiResolutionSTFTLoss, self).__init__()
288
+ fft_size_list = fft_size_list or [1024, 2048, 512]
289
+ win_size_list = win_size_list or [600, 1200, 240]
290
+ hop_size_list = hop_size_list or [120, 240, 50]
291
+
292
+ if not len(fft_size_list) == len(win_size_list) == len(hop_size_list):
293
+ raise AssertionError
294
+
295
+ loss_fn_list = list()
296
+ for n_fft, win_size, hop_size in zip(fft_size_list, win_size_list, hop_size_list):
297
+ loss_fn_list.append(
298
+ STFTLoss(
299
+ n_fft=n_fft,
300
+ win_size=win_size,
301
+ hop_size=hop_size,
302
+ )
303
+ )
304
+
305
+ self.loss_fn_list = loss_fn_list
306
+ self.factor_sc = factor_sc
307
+ self.factor_mag = factor_mag
308
+
309
+ def forward(self, denoise: torch.Tensor, clean: torch.Tensor):
310
+ """
311
+ :param denoise:
312
+ :param clean:
313
+ :return:
314
+ """
315
+ if denoise.shape != clean.shape:
316
+ raise AssertionError("Input signals must have the same shape")
317
+
318
+ sc_loss = 0.0
319
+ mag_loss = 0.0
320
+ for loss_fn in self.loss_fn_list:
321
+ sc_l, mag_l = loss_fn.forward(denoise, clean)
322
+ sc_loss += sc_l
323
+ mag_loss += mag_l
324
+ sc_loss = sc_loss / len(self.loss_fn_list)
325
+ mag_loss = mag_loss / len(self.loss_fn_list)
326
+
327
+ sc_loss = self.factor_sc * sc_loss
328
+ mag_loss = self.factor_mag * mag_loss
329
+
330
+ loss = sc_loss + mag_loss
331
+ return loss
332
+
333
+
334
+ def main():
335
+ batch_size = 2
336
+ signal_length = 16000
337
+ estimated_signal = torch.randn(batch_size, signal_length)
338
+ target_signal = torch.randn(batch_size, signal_length)
339
+
340
+ # loss_fn = LSDLoss()
341
+ # loss_fn = ComplexSpectralLoss()
342
+ loss_fn = MultiResolutionSTFTLoss()
343
+
344
+ loss = loss_fn.forward(estimated_signal, target_signal)
345
+ print(f"loss: {loss.item()}")
346
+
347
+ return
348
+
349
+
350
+ if __name__ == "__main__":
351
+ main()
toolbox/torchaudio/metrics/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+
5
+ if __name__ == '__main__':
6
+ pass
toolbox/torchaudio/metrics/pesq.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ from joblib import Parallel, delayed
4
+ import numpy as np
5
+ from pesq import pesq
6
+ from typing import List
7
+
8
+ from pesq import cypesq
9
+
10
+
11
+ def run_pesq(clean_audio: np.ndarray,
12
+ noisy_audio: np.ndarray,
13
+ sample_rate: int = 16000,
14
+ mode: str = "wb",
15
+ ) -> float:
16
+ if sample_rate == 8000 and mode == "wb":
17
+ raise AssertionError(f"mode should be `nb` when sample_rate is 8000")
18
+ try:
19
+ pesq_score = pesq(sample_rate, clean_audio, noisy_audio, mode)
20
+ except cypesq.NoUtterancesError as e:
21
+ pesq_score = -1
22
+ except Exception as e:
23
+ print(f"pesq failed. error type: {type(e)}, error text: {str(e)}")
24
+ pesq_score = -1
25
+ return pesq_score
26
+
27
+
28
+ def run_batch_pesq(clean_audio_list: List[np.ndarray],
29
+ noisy_audio_list: List[np.ndarray],
30
+ sample_rate: int = 16000,
31
+ mode: str = "wb",
32
+ n_jobs: int = 4,
33
+ ) -> List[float]:
34
+ parallel = Parallel(n_jobs=n_jobs)
35
+
36
+ parallel_tasks = list()
37
+ for clean_audio, noisy_audio in zip(clean_audio_list, noisy_audio_list):
38
+ parallel_task = delayed(run_pesq)(clean_audio, noisy_audio, sample_rate, mode)
39
+ parallel_tasks.append(parallel_task)
40
+
41
+ pesq_score_list = parallel.__call__(parallel_tasks)
42
+ return pesq_score_list
43
+
44
+
45
+ def run_pesq_score(clean_audio_list: List[np.ndarray],
46
+ noisy_audio_list: List[np.ndarray],
47
+ sample_rate: int = 16000,
48
+ mode: str = "wb",
49
+ n_jobs: int = 4,
50
+ ) -> List[float]:
51
+
52
+ pesq_score_list = run_batch_pesq(clean_audio_list=clean_audio_list,
53
+ noisy_audio_list=noisy_audio_list,
54
+ sample_rate=sample_rate,
55
+ mode=mode,
56
+ n_jobs=n_jobs,
57
+ )
58
+
59
+ pesq_score = np.mean(pesq_score_list)
60
+ return pesq_score
61
+
62
+
63
+ def main():
64
+ clean_audio = np.random.uniform(low=0, high=1, size=(2, 160000,))
65
+ noisy_audio = np.random.uniform(low=0, high=1, size=(2, 160000,))
66
+
67
+ clean_audio_list = list(clean_audio)
68
+ noisy_audio_list = list(noisy_audio)
69
+
70
+ pesq_score_list = run_batch_pesq(clean_audio_list, noisy_audio_list)
71
+ print(pesq_score_list)
72
+
73
+ pesq_score = run_pesq_score(clean_audio_list, noisy_audio_list)
74
+ print(pesq_score)
75
+
76
+ return
77
+
78
+
79
+ if __name__ == "__main__":
80
+ main()
toolbox/torchaudio/models/conv_tasnet/configuration_conv_tasnet.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ from typing import Tuple
4
+
5
+ from toolbox.torchaudio.configuration_utils import PretrainedConfig
6
+
7
+
8
+ class ConvTasNetConfig(PretrainedConfig):
9
+ """
10
+ https://github.com/kaituoxu/Conv-TasNet/blob/master/src/train.py
11
+ """
12
+ def __init__(self,
13
+ sample_rate: int = 8000,
14
+ segment_size: int = 4,
15
+
16
+ win_size: int = 20,
17
+
18
+ freq_bins: int = 256,
19
+ bottleneck_channels: int = 256,
20
+ num_speakers: int = 2,
21
+ num_blocks: int = 4,
22
+ num_sub_blocks: int = 8,
23
+ sub_blocks_channels: int = 512,
24
+ sub_blocks_kernel_size: int = 3,
25
+
26
+ norm_type: str = "gLN",
27
+ causal: bool = False,
28
+ mask_nonlinear: str = "relu",
29
+
30
+ **kwargs
31
+ ):
32
+ super(ConvTasNetConfig, self).__init__(**kwargs)
33
+ self.sample_rate = sample_rate
34
+ self.segment_size = segment_size
35
+
36
+ self.win_size = win_size
37
+
38
+ self.freq_bins = freq_bins
39
+ self.bottleneck_channels = bottleneck_channels
40
+ self.num_speakers = num_speakers
41
+ self.num_blocks = num_blocks
42
+ self.num_sub_blocks = num_sub_blocks
43
+ self.sub_blocks_channels = sub_blocks_channels
44
+ self.sub_blocks_kernel_size = sub_blocks_kernel_size
45
+
46
+ self.norm_type = norm_type
47
+ self.causal = causal
48
+ self.mask_nonlinear = mask_nonlinear
49
+
50
+
51
+ if __name__ == "__main__":
52
+ pass
toolbox/torchaudio/models/conv_tasnet/modeling_conv_tasnet.py CHANGED
@@ -2,8 +2,483 @@
2
  # -*- coding: utf-8 -*-
3
  """
4
  https://github.com/kaituoxu/Conv-TasNet/blob/master/src/conv_tasnet.py
 
 
5
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
 
8
- if __name__ == '__main__':
9
- pass
 
2
  # -*- coding: utf-8 -*-
3
  """
4
  https://github.com/kaituoxu/Conv-TasNet/blob/master/src/conv_tasnet.py
5
+
6
+ https://pytorch.org/audio/2.5.0/generated/torchaudio.models.ConvTasNet.html
7
  """
8
+ import os
9
+ from typing import List, Optional, Union
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+ from torch.nn import functional as F
14
+
15
+ from toolbox.torchaudio.configuration_utils import CONFIG_FILE
16
+ from toolbox.torchaudio.models.conv_tasnet.utils import overlap_and_add
17
+ from toolbox.torchaudio.models.conv_tasnet.configuration_conv_tasnet import ConvTasNetConfig
18
+
19
+
20
+ class ChannelwiseLayerNorm(nn.Module):
21
+ """Channel-wise Layer Normalization (cLN)"""
22
+ def __init__(self,
23
+ channels: int,
24
+ eps: float = 1e-8
25
+ ):
26
+ super(ChannelwiseLayerNorm, self).__init__()
27
+ self.gamma = nn.Parameter(torch.Tensor(1, channels, 1))
28
+ self.beta = nn.Parameter(torch.Tensor(1, channels,1 ))
29
+ self.reset_parameters()
30
+
31
+ self.eps = eps
32
+
33
+ def reset_parameters(self):
34
+ self.gamma.data.fill_(1)
35
+ self.beta.data.zero_()
36
+
37
+ def forward(self, y):
38
+ """
39
+ :param y: Tensor, shape: [batch_size, channels, time_steps]
40
+ :return: gln_y: Tensor, shape: [batch_size, channels, time_steps]
41
+ """
42
+ # mean, var shape: [batch_size, 1, time_steps]
43
+ mean = torch.mean(y, dim=1, keepdim=True)
44
+ var = torch.var(y, dim=1, keepdim=True, unbiased=False)
45
+
46
+ cln_y = self.gamma * (y - mean) / torch.pow(var + self.eps, 0.5) + self.beta
47
+ return cln_y
48
+
49
+
50
+ class GlobalLayerNorm(nn.Module):
51
+ """Global Layer Normalization (gLN)"""
52
+ def __init__(self,
53
+ channels: int,
54
+ eps: float = 1e-8
55
+ ):
56
+ super(GlobalLayerNorm, self).__init__()
57
+ self.gamma = nn.Parameter(torch.Tensor(1, channels, 1))
58
+ self.beta = nn.Parameter(torch.Tensor(1, channels,1 ))
59
+ self.reset_parameters()
60
+
61
+ self.eps = eps
62
+
63
+ def reset_parameters(self):
64
+ self.gamma.data.fill_(1)
65
+ self.beta.data.zero_()
66
+
67
+ def forward(self, y):
68
+ """
69
+ :param y: Tensor, shape: [batch_size, channels, time_steps]
70
+ :return: gln_y: Tensor, shape: [batch_size, channels, time_steps]
71
+ """
72
+ # mean, var shape: [batch_size, 1, 1]
73
+ mean = y.mean(dim=1, keepdim=True).mean(dim=2, keepdim=True)
74
+ var = (torch.pow(y-mean, 2)).mean(dim=1, keepdim=True).mean(dim=2, keepdim=True)
75
+
76
+ gln_y = self.gamma * (y - mean) / torch.pow(var + self.eps, 0.5) + self.beta
77
+ return gln_y
78
+
79
+
80
+ def choose_norm(norm_type: str, channels: int):
81
+ """
82
+ The input of normalization will be (M, C, K), where M is batch size,
83
+ C is channel size and K is sequence length.
84
+ """
85
+ if norm_type == "gLN":
86
+ return GlobalLayerNorm(channels)
87
+ elif norm_type == "cLN":
88
+ return ChannelwiseLayerNorm(channels)
89
+ else: # norm_type == "BN":
90
+ # Given input (M, C, K), nn.BatchNorm1d(C) will accumulate statics
91
+ # along M and K, so this BN usage is right.
92
+ return nn.BatchNorm1d(channels)
93
+
94
+
95
+ class Chomp1d(nn.Module):
96
+ """
97
+ To ensure the output length is the same as the input.
98
+ """
99
+ def __init__(self, chomp_size: int):
100
+ super(Chomp1d, self).__init__()
101
+ self.chomp_size = chomp_size
102
+
103
+ def forward(self, x: torch.Tensor):
104
+ """
105
+ :param x: Tensor, shape: [batch_size, hidden_size, k_pad]
106
+ :return: Tensor, shape: [batch_size, hidden_size, k]
107
+ """
108
+ return x[:, :, :-self.chomp_size].contiguous()
109
+
110
+
111
+ class DepthwiseSeparableConv(nn.Module):
112
+ def __init__(self,
113
+ in_channels: int,
114
+ out_channels: int,
115
+ kernel_size: int,
116
+ stride: int,
117
+ padding: int,
118
+ dilation: int,
119
+ norm_type="gLN",
120
+ causal=False
121
+ ):
122
+ super(DepthwiseSeparableConv, self).__init__()
123
+ # Use `groups` option to implement depthwise convolution
124
+ # [M, H, K] -> [M, H, K]
125
+ self.depthwise_conv = nn.Conv1d(
126
+ in_channels=in_channels, out_channels=in_channels,
127
+ kernel_size=kernel_size, stride=stride,
128
+ padding=padding, dilation=dilation,
129
+ groups=in_channels, bias=False,
130
+ )
131
+
132
+ self.chomp = None
133
+ if causal:
134
+ self.chomp = Chomp1d(padding)
135
+
136
+ self.prelu = nn.PReLU()
137
+ self.norm = choose_norm(norm_type, in_channels)
138
+ # [M, H, K] -> [M, B, K]
139
+ self.pointwise_conv = nn.Conv1d(
140
+ in_channels=in_channels,
141
+ out_channels=out_channels,
142
+ kernel_size=1, bias=False
143
+ )
144
+
145
+ def forward(self, x: torch.Tensor):
146
+ """
147
+ :param x: Tensor, shape: [batch_size, hidden_size, k]
148
+ :return: Tensor, shape: [batch_size, b, k]
149
+ """
150
+ x = self.depthwise_conv.forward(x)
151
+ if self.chomp is not None:
152
+ x = self.chomp.forward(x)
153
+ x = self.prelu.forward(x)
154
+ x = self.norm.forward(x)
155
+ x = self.pointwise_conv.forward(x)
156
+
157
+ return x
158
+
159
+
160
+ class Encoder(nn.Module):
161
+ def __init__(self, win_size: int, freq_bins: int):
162
+ super(Encoder, self).__init__()
163
+ self.win_size = win_size
164
+ self.freq_bins = freq_bins
165
+
166
+ self.conv1d_U = nn.Conv1d(
167
+ in_channels=1,
168
+ out_channels=freq_bins,
169
+ kernel_size=win_size,
170
+ stride=win_size // 2,
171
+ bias=False
172
+ )
173
+
174
+ def forward(self, mixture):
175
+ """
176
+ :param mixture: Tensor, shape: [batch_size, num_samples]
177
+ :return: mixture_w, Tensor, shape: [batch_size, freq_bins, time_steps],
178
+ where time_steps = (num_samples-win_size) / (win_size/2) + 1 = 2num_samples/win_size-1
179
+ """
180
+ mixture = torch.unsqueeze(mixture, 1) # [M, 1, T]
181
+ mixture_w = F.relu(self.conv1d_U(mixture)) # [M, N, K]
182
+ return mixture_w
183
+
184
+
185
+ class Decoder(nn.Module):
186
+ def __init__(self, win_size: int, freq_bins: int):
187
+ super(Decoder, self).__init__()
188
+ self.win_size = win_size
189
+ self.freq_bins = freq_bins
190
+
191
+ self.basis_signals = nn.Linear(
192
+ in_features=freq_bins,
193
+ out_features=win_size,
194
+ bias=False
195
+ )
196
+
197
+ def forward(self,
198
+ mixture_w: torch.Tensor,
199
+ est_mask: torch.Tensor,
200
+ ):
201
+ """
202
+ :param mixture_w: Tensor, shape: [batch_size, freq_bins, time_steps],
203
+ where time_steps = (num_samples-win_size) / (win_size/2) + 1 = 2num_samples/win_size-1
204
+ :param est_mask: Tensor, shape: [batch_size, c, freq_bins, time_steps],
205
+ :return: Tensor, shape: [batch_size, c, num_samples],
206
+ """
207
+ source_w = torch.unsqueeze(mixture_w, 1) * est_mask
208
+ source_w = torch.transpose(source_w, 2, 3)
209
+ est_source = self.basis_signals(source_w)
210
+ est_source = overlap_and_add(est_source, self.win_size//2)
211
+ return est_source
212
+
213
+
214
+ class TemporalBlock(nn.Module):
215
+ def __init__(self,
216
+ in_channels: int,
217
+ out_channels: int,
218
+ kernel_size: int,
219
+ stride: int,
220
+ padding: int,
221
+ dilation: int,
222
+ norm_type="gLN",
223
+ causal=False
224
+ ):
225
+ super(TemporalBlock, self).__init__()
226
+ self.conv1x1 = nn.Conv1d(in_channels, out_channels, 1, bias=False)
227
+ self.prelu = nn.PReLU()
228
+ self.norm = choose_norm(norm_type, out_channels)
229
+ # [M, H, K] -> [M, B, K]
230
+ self.dsconv = DepthwiseSeparableConv(
231
+ out_channels, in_channels,
232
+ kernel_size, stride,
233
+ padding, dilation,
234
+ norm_type, causal,
235
+ )
236
+
237
+ def forward(self, x):
238
+ residual = x
239
+
240
+ x = self.conv1x1.forward(x)
241
+ x = self.prelu.forward(x)
242
+ x = self.norm.forward(x)
243
+ x = self.dsconv.forward(x)
244
+
245
+ out = x + residual
246
+ return out
247
+
248
+
249
+ class TemporalConvNet(nn.Module):
250
+ def __init__(self,
251
+ freq_bins: int = 256,
252
+ bottleneck_channels: int = 256,
253
+ num_speakers: int = 2,
254
+ num_blocks: int = 4,
255
+ num_sub_blocks: int = 8,
256
+ sub_blocks_channels: int = 512,
257
+ sub_blocks_kernel_size: int = 3,
258
+ norm_type: str = "gLN",
259
+ causal: bool = False,
260
+ mask_nonlinear: str = "relu",
261
+
262
+ ):
263
+ super(TemporalConvNet, self).__init__()
264
+ self.freq_bins = freq_bins
265
+ self.bottleneck_channels = bottleneck_channels
266
+ self.num_speakers = num_speakers
267
+
268
+ self.num_blocks = num_blocks
269
+ self.num_sub_blocks = num_sub_blocks
270
+ self.sub_blocks_channels = sub_blocks_channels
271
+ self.sub_blocks_kernel_size = sub_blocks_kernel_size
272
+
273
+ self.mask_nonlinear = mask_nonlinear
274
+
275
+ self.layer_norm = ChannelwiseLayerNorm(freq_bins)
276
+ self.bottleneck_conv1x1 = nn.Conv1d(freq_bins, bottleneck_channels, 1, bias=False)
277
+
278
+ self.temporal_conv_list = nn.ModuleList([])
279
+ for num_block_idx in range(num_blocks):
280
+ sub_blocks = list()
281
+ for num_sub_block_idx in range(num_sub_blocks):
282
+ dilation = 2 ** num_sub_block_idx
283
+ padding = (sub_blocks_kernel_size - 1) * dilation
284
+ if not causal:
285
+ padding = padding // 2
286
+ temporal_block = TemporalBlock(
287
+ bottleneck_channels, sub_blocks_channels,
288
+ sub_blocks_kernel_size, stride=1,
289
+ padding=padding, dilation=dilation,
290
+ norm_type=norm_type, causal=causal,
291
+ )
292
+ sub_blocks.append(temporal_block)
293
+ self.temporal_conv_list.extend(sub_blocks)
294
+
295
+ self.mask_conv1x1 = nn.Conv1d(
296
+ in_channels=bottleneck_channels,
297
+ out_channels=num_speakers * freq_bins,
298
+ kernel_size=1,
299
+ bias=False,
300
+ )
301
+
302
+ def forward(self, mixture_w: torch.Tensor):
303
+ """
304
+ :param mixture_w: Tensor, shape: [batch_size, freq_bins, time_steps]
305
+ :return: est_mask: Tensor, shape: [batch_size, freq_bins, time_steps]
306
+ """
307
+ batch_size, freq_bins, time_steps = mixture_w.size()
308
+
309
+ x = self.layer_norm.forward(mixture_w)
310
+ x = self.bottleneck_conv1x1.forward(x)
311
+
312
+ for temporal_conv in self.temporal_conv_list:
313
+ x = temporal_conv.forward(x)
314
+
315
+ score = self.mask_conv1x1.forward(x)
316
+
317
+ # [M, C*N, K] -> [M, C, N, K]
318
+ score = score.view(batch_size, self.num_speakers, freq_bins, time_steps)
319
+
320
+ if self.mask_nonlinear == "softmax":
321
+ est_mask = F.softmax(score, dim=1)
322
+ elif self.mask_nonlinear == "relu":
323
+ est_mask = F.relu(score)
324
+ else:
325
+ raise ValueError("Unsupported mask non-linear function")
326
+
327
+ return est_mask
328
+
329
+
330
+ class ConvTasNet(nn.Module):
331
+ def __init__(self,
332
+ win_size: int = 20,
333
+ freq_bins: int = 256,
334
+ bottleneck_channels: int = 256,
335
+ num_speakers: int = 2,
336
+ num_blocks: int = 4,
337
+ num_sub_blocks: int = 8,
338
+ sub_blocks_channels: int = 512,
339
+ sub_blocks_kernel_size: int = 3,
340
+ norm_type: str = "gLN",
341
+ causal: bool = False,
342
+ mask_nonlinear: str = "relu",
343
+
344
+ ):
345
+ super(ConvTasNet, self).__init__()
346
+ self.win_size = win_size
347
+
348
+ self.freq_bins = freq_bins
349
+ self.bottleneck_channels = bottleneck_channels
350
+ self.num_speakers = num_speakers
351
+
352
+ self.num_blocks = num_blocks
353
+ self.num_sub_blocks = num_sub_blocks
354
+ self.sub_blocks_channels = sub_blocks_channels
355
+ self.sub_blocks_kernel_size = sub_blocks_kernel_size
356
+
357
+ self.norm_type = norm_type
358
+ self.causal = causal
359
+ self.mask_nonlinear = mask_nonlinear
360
+
361
+ self.encoder = Encoder(win_size, freq_bins)
362
+ self.separator = TemporalConvNet(
363
+ freq_bins=freq_bins,
364
+ bottleneck_channels=bottleneck_channels,
365
+ sub_blocks_channels=sub_blocks_channels,
366
+ sub_blocks_kernel_size=sub_blocks_kernel_size,
367
+ num_sub_blocks=num_sub_blocks,
368
+ num_blocks=num_blocks,
369
+ num_speakers=num_speakers,
370
+ norm_type=norm_type,
371
+ causal=causal,
372
+ mask_nonlinear=mask_nonlinear,
373
+ )
374
+ self.decoder = Decoder(win_size=win_size, freq_bins=freq_bins)
375
+
376
+ for p in self.parameters():
377
+ if p.dim() > 1:
378
+ nn.init.xavier_normal_(p)
379
+
380
+ def forward(self, mixture: torch.Tensor):
381
+ """
382
+ :param mixture: Tensor, shape: [batch_size, num_samples]
383
+ :return: est_source: Tensor, shape: [batch_size, c, num_samples]
384
+ """
385
+ # mixture shape: [batch_size, num_samples]
386
+ mixture_w = self.encoder.forward(mixture)
387
+ # mixture_w shape: [batch_size, freq_bins, time_steps]
388
+ est_mask = self.separator.forward(mixture_w)
389
+ # est_mask shape: [batch_size, num_speakers, freq_bins, time_steps]
390
+ est_source = self.decoder.forward(mixture_w, est_mask)
391
+
392
+ num_samples1 = mixture.size(-1)
393
+ num_samples2 = est_source.size(-1)
394
+ est_source = F.pad(est_source, (0, num_samples1 - num_samples2))
395
+ return est_source
396
+
397
+
398
+ MODEL_FILE = "model.pt"
399
+
400
+
401
+ class ConvTasNetPretrainedModel(ConvTasNet):
402
+ def __init__(self,
403
+ config: ConvTasNetConfig,
404
+ ):
405
+ super(ConvTasNetPretrainedModel, self).__init__(
406
+ win_size=config.win_size,
407
+ freq_bins=config.freq_bins,
408
+ bottleneck_channels=config.bottleneck_channels,
409
+ sub_blocks_channels=config.sub_blocks_channels,
410
+ sub_blocks_kernel_size=config.sub_blocks_kernel_size,
411
+ num_sub_blocks=config.num_sub_blocks,
412
+ num_blocks=config.num_blocks,
413
+ num_speakers=config.num_speakers,
414
+ norm_type=config.norm_type,
415
+ causal=config.causal,
416
+ mask_nonlinear=config.mask_nonlinear,
417
+ )
418
+ self.config = config
419
+
420
+ @classmethod
421
+ def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
422
+ config = ConvTasNetConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
423
+
424
+ model = cls(config)
425
+
426
+ if os.path.isdir(pretrained_model_name_or_path):
427
+ ckpt_file = os.path.join(pretrained_model_name_or_path, MODEL_FILE)
428
+ else:
429
+ ckpt_file = pretrained_model_name_or_path
430
+
431
+ with open(ckpt_file, "rb") as f:
432
+ state_dict = torch.load(f, map_location="cpu", weights_only=True)
433
+ model.load_state_dict(state_dict, strict=True)
434
+ return model
435
+
436
+ def save_pretrained(self,
437
+ save_directory: Union[str, os.PathLike],
438
+ state_dict: Optional[dict] = None,
439
+ ):
440
+
441
+ model = self
442
+
443
+ if state_dict is None:
444
+ state_dict = model.state_dict()
445
+
446
+ os.makedirs(save_directory, exist_ok=True)
447
+
448
+ # save state dict
449
+ model_file = os.path.join(save_directory, MODEL_FILE)
450
+ torch.save(state_dict, model_file)
451
+
452
+ # save config
453
+ config_file = os.path.join(save_directory, CONFIG_FILE)
454
+ self.config.to_yaml_file(config_file)
455
+ return save_directory
456
+
457
+
458
+ def main():
459
+ config = ConvTasNetConfig()
460
+ tas_net = ConvTasNet(
461
+ win_size=config.win_size,
462
+ freq_bins=config.freq_bins,
463
+ bottleneck_channels=config.bottleneck_channels,
464
+ sub_blocks_channels=config.sub_blocks_channels,
465
+ sub_blocks_kernel_size=config.sub_blocks_kernel_size,
466
+ num_sub_blocks=config.num_sub_blocks,
467
+ num_blocks=config.num_blocks,
468
+ num_speakers=config.num_speakers,
469
+ norm_type=config.norm_type,
470
+ causal=config.causal,
471
+ mask_nonlinear=config.mask_nonlinear,
472
+ )
473
+
474
+ print(tas_net)
475
+
476
+ mixture = torch.rand(size=(1, 8000*4), dtype=torch.float32)
477
+
478
+ outputs = tas_net.forward(mixture)
479
+ print(outputs.shape)
480
+ return
481
 
482
 
483
+ if __name__ == "__main__":
484
+ main()
toolbox/torchaudio/models/conv_tasnet/utils.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ https://github.com/kaituoxu/Conv-TasNet/blob/master/src/utils.py
5
+ """
6
+ import math
7
+ import torch
8
+
9
+
10
+ def overlap_and_add(signal: torch.Tensor, frame_step: int):
11
+ """
12
+ Reconstructs a signal from a framed representation.
13
+
14
+ Adds potentially overlapping frames of a signal with shape
15
+ `[..., frames, frame_length]`, offsetting subsequent frames by `frame_step`.
16
+ The resulting tensor has shape `[..., output_size]` where
17
+
18
+ output_size = (frames - 1) * frame_step + frame_length
19
+
20
+ Based on https://github.com/tensorflow/tensorflow/blob/r1.12/tensorflow/contrib/signal/python/ops/reconstruction_ops.py
21
+
22
+ :param signal: Tensor, shape: [..., frames, frame_length]. All dimensions may be unknown, and rank must be at least 2.
23
+ :param frame_step: int, overlap offsets. Must be less than or equal to frame_length.
24
+ :return: Tensor, shape: [..., output_size].
25
+ containing the overlap-added frames of signal's inner-most two dimensions.
26
+ output_size = (frames - 1) * frame_step + frame_length
27
+ """
28
+ outer_dimensions = signal.size()[:-2]
29
+ frames, frame_length = signal.size()[-2:]
30
+
31
+ subframe_length = math.gcd(frame_length, frame_step) # gcd=Greatest Common Divisor
32
+ subframe_step = frame_step // subframe_length
33
+ subframes_per_frame = frame_length // subframe_length
34
+
35
+ output_size = frame_step * (frames - 1) + frame_length
36
+ output_subframes = output_size // subframe_length
37
+
38
+ subframe_signal = signal.view(*outer_dimensions, -1, subframe_length)
39
+
40
+ frame = torch.arange(0, output_subframes).unfold(0, subframes_per_frame, subframe_step)
41
+
42
+ frame = frame.clone().detach()
43
+ frame = frame.to(signal.device)
44
+ frame = frame.long()
45
+
46
+ frame = frame.contiguous().view(-1)
47
+
48
+ result = signal.new_zeros(*outer_dimensions, output_subframes, subframe_length)
49
+ result.index_add_(-2, frame, subframe_signal)
50
+ result = result.view(*outer_dimensions, -1)
51
+ return result
52
+
53
+
54
+ if __name__ == "__main__":
55
+ pass
toolbox/torchaudio/models/conv_tasnet/yaml/config.yaml ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model_name: "conv_tasnet"
2
+
3
+ sample_rate: 8000
4
+ segment_size: 4
5
+
6
+ win_size: 20
7
+ freq_bins: 256
8
+ bottleneck_channels: 256
9
+ num_speakers: 2
10
+ num_blocks: 4
11
+ num_sub_blocks: 8
12
+ sub_blocks_channels: 512
13
+ sub_blocks_kernel_size: 3
14
+
15
+ norm_type: "gLN"
16
+ causal: false
17
+ mask_nonlinear: "relu"
toolbox/torchaudio/models/demucs/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+
5
+ if __name__ == '__main__':
6
+ pass
toolbox/torchaudio/models/demucs/configuration_demucs.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ from toolbox.torchaudio.configuration_utils import PretrainedConfig
4
+
5
+
6
+ class DemucsConfig(PretrainedConfig):
7
+ def __init__(self,
8
+ sample_rate: int = 8000,
9
+
10
+ in_channels: int = 1,
11
+ out_channels: int = 1,
12
+ hidden_channels: int = 48,
13
+
14
+ depth: int = 5,
15
+ kernel_size: int = 8,
16
+ stride: int = 4,
17
+
18
+ causal: bool = True,
19
+ resample: int = 4,
20
+ growth: int = 2,
21
+
22
+ max_hidden: int = 10_000,
23
+ do_normalize: bool = True,
24
+ rescale: float = 0.1,
25
+ floor: float = 1e-3,
26
+
27
+ **kwargs
28
+ ):
29
+ super(DemucsConfig, self).__init__(**kwargs)
30
+ self.sample_rate = sample_rate
31
+
32
+ self.in_channels = in_channels
33
+ self.out_channels = out_channels
34
+ self.hidden_channels = hidden_channels
35
+
36
+ self.depth = depth
37
+ self.kernel_size = kernel_size
38
+ self.stride = stride
39
+
40
+ self.causal = causal
41
+ self.resample = resample
42
+ self.growth = growth
43
+
44
+ self.max_hidden = max_hidden
45
+ self.do_normalize = do_normalize
46
+ self.rescale = rescale
47
+ self.floor = floor
48
+
49
+
50
+ if __name__ == "__main__":
51
+ pass
toolbox/torchaudio/models/demucs/modeling_demucs.py ADDED
@@ -0,0 +1,299 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ https://arxiv.org/abs/2006.12847
5
+
6
+ https://github.com/facebookresearch/denoiser
7
+ """
8
+ import math
9
+ import os
10
+ from typing import List, Optional, Union
11
+
12
+ import torch
13
+ import torch.nn as nn
14
+ from torch.nn import functional as F
15
+
16
+ from toolbox.torchaudio.configuration_utils import CONFIG_FILE
17
+ from toolbox.torchaudio.models.demucs.configuration_demucs import DemucsConfig
18
+ from toolbox.torchaudio.models.demucs.resample import upsample2, downsample2
19
+
20
+
21
+ activation_layer_dict = {
22
+ "glu": nn.GLU,
23
+ "relu": nn.ReLU,
24
+ "identity": nn.Identity,
25
+ "sigmoid": nn.Sigmoid,
26
+ }
27
+
28
+
29
+ class BLSTM(nn.Module):
30
+ def __init__(self,
31
+ hidden_size: int,
32
+ num_layers: int = 2,
33
+ bidirectional: bool = True,
34
+ ):
35
+ super().__init__()
36
+ self.lstm = nn.LSTM(bidirectional=bidirectional,
37
+ num_layers=num_layers,
38
+ hidden_size=hidden_size,
39
+ input_size=hidden_size
40
+ )
41
+ self.linear = None
42
+ if bidirectional:
43
+ self.linear = nn.Linear(2 * hidden_size, hidden_size)
44
+
45
+ def forward(self,
46
+ x: torch.Tensor,
47
+ hx: torch.Tensor = None
48
+ ):
49
+ x, hx = self.lstm.forward(x, hx)
50
+ if self.linear:
51
+ x = self.linear(x)
52
+ return x, hx
53
+
54
+
55
+ def rescale_conv(conv, reference):
56
+ std = conv.weight.std().detach()
57
+ scale = (std / reference)**0.5
58
+ conv.weight.data /= scale
59
+ if conv.bias is not None:
60
+ conv.bias.data /= scale
61
+
62
+
63
+ def rescale_module(module, reference):
64
+ for sub in module.modules():
65
+ if isinstance(sub, (nn.Conv1d, nn.ConvTranspose1d)):
66
+ rescale_conv(sub, reference)
67
+
68
+
69
+ class DemucsModel(nn.Module):
70
+ def __init__(self,
71
+ in_channels: int = 1,
72
+ out_channels: int = 1,
73
+ hidden_channels: int = 48,
74
+ depth: int = 5,
75
+ kernel_size: int = 8,
76
+ stride: int = 4,
77
+ causal: bool = True,
78
+ resample: int = 4,
79
+ growth: int = 2,
80
+ max_hidden: int = 10_000,
81
+ do_normalize: bool = True,
82
+ rescale: float = 0.1,
83
+ floor: float = 1e-3,
84
+ ):
85
+ super(DemucsModel, self).__init__()
86
+
87
+ self.in_channels = in_channels
88
+ self.out_channels = out_channels
89
+ self.hidden_channels = hidden_channels
90
+
91
+ self.depth = depth
92
+ self.kernel_size = kernel_size
93
+ self.stride = stride
94
+
95
+ self.causal = causal
96
+
97
+ self.resample = resample
98
+ self.growth = growth
99
+ self.max_hidden = max_hidden
100
+ self.do_normalize = do_normalize
101
+ self.rescale = rescale
102
+ self.floor = floor
103
+
104
+ if resample not in [1, 2, 4]:
105
+ raise ValueError("Resample should be 1, 2 or 4.")
106
+
107
+ self.encoder = nn.ModuleList()
108
+ self.decoder = nn.ModuleList()
109
+
110
+ for index in range(depth):
111
+ encode = []
112
+ encode += [
113
+ nn.Conv1d(in_channels, hidden_channels, kernel_size, stride),
114
+ nn.ReLU(),
115
+ nn.Conv1d(hidden_channels, hidden_channels * 2, 1),
116
+ nn.GLU(1),
117
+ ]
118
+ self.encoder.append(nn.Sequential(*encode))
119
+
120
+ decode = []
121
+ decode += [
122
+ nn.Conv1d(hidden_channels, 2 * hidden_channels, 1),
123
+ nn.GLU(1),
124
+ nn.ConvTranspose1d(hidden_channels, out_channels, kernel_size, stride),
125
+ ]
126
+ if index > 0:
127
+ decode.append(nn.ReLU())
128
+ self.decoder.insert(0, nn.Sequential(*decode))
129
+ out_channels = hidden_channels
130
+ in_channels = hidden_channels
131
+ hidden_channels = min(int(growth * hidden_channels), max_hidden)
132
+
133
+ self.lstm = BLSTM(in_channels, bidirectional=not causal)
134
+
135
+ if rescale:
136
+ rescale_module(self, reference=rescale)
137
+
138
+ @staticmethod
139
+ def valid_length(length: int, depth: int, kernel_size: int, stride: int, resample: int):
140
+ """
141
+ Return the nearest valid length to use with the model so that
142
+ there is no time steps left over in a convolutions, e.g. for all
143
+ layers, size of the input - kernel_size % stride = 0.
144
+
145
+ If the mixture has a valid length, the estimated sources
146
+ will have exactly the same length.
147
+ """
148
+ length = math.ceil(length * resample)
149
+ for idx in range(depth):
150
+ length = math.ceil((length - kernel_size) / stride) + 1
151
+ length = max(length, 1)
152
+ for idx in range(depth):
153
+ length = (length - 1) * stride + kernel_size
154
+ length = int(math.ceil(length / resample))
155
+ return int(length)
156
+
157
+ def forward(self, noisy: torch.Tensor):
158
+ """
159
+ :param noisy: Tensor, shape: [batch_size, num_samples] or [batch_size, channels, num_samples]
160
+ :return:
161
+ """
162
+ if noisy.dim() == 2:
163
+ noisy = noisy.unsqueeze(1)
164
+ # noisy shape: [batch_size, channels, num_samples]
165
+
166
+ if self.do_normalize:
167
+ mono = noisy.mean(dim=1, keepdim=True)
168
+ std = mono.std(dim=-1, keepdim=True)
169
+ noisy = noisy / (self.floor + std)
170
+ else:
171
+ std = 1
172
+
173
+ _, _, length = noisy.shape
174
+ x = noisy
175
+
176
+ length_ = self.valid_length(length, self.depth, self.kernel_size, self.stride, self.resample)
177
+ x = F.pad(x, (0, length_ - length))
178
+
179
+ if self.resample == 2:
180
+ x = upsample2(x)
181
+ elif self.resample == 4:
182
+ x = upsample2(x)
183
+ x = upsample2(x)
184
+
185
+ skips = []
186
+ for encode in self.encoder:
187
+ x = encode(x)
188
+ skips.append(x)
189
+ x = x.permute(2, 0, 1)
190
+ x, _ = self.lstm(x)
191
+ x = x.permute(1, 2, 0)
192
+
193
+ for decode in self.decoder:
194
+ skip = skips.pop(-1)
195
+ x = x + skip[..., :x.shape[-1]]
196
+ x = decode(x)
197
+
198
+ if self.resample == 2:
199
+ x = downsample2(x)
200
+ elif self.resample == 4:
201
+ x = downsample2(x)
202
+ x = downsample2(x)
203
+
204
+ x = x[..., :length]
205
+ return std * x
206
+
207
+
208
+ MODEL_FILE = "model.pt"
209
+
210
+
211
+ class DemucsPretrainedModel(DemucsModel):
212
+ def __init__(self,
213
+ config: DemucsConfig,
214
+ ):
215
+ super(DemucsPretrainedModel, self).__init__(
216
+ # sample_rate=config.sample_rate,
217
+ in_channels=config.in_channels,
218
+ out_channels=config.out_channels,
219
+ hidden_channels=config.hidden_channels,
220
+ depth=config.depth,
221
+ kernel_size=config.kernel_size,
222
+ stride=config.stride,
223
+ causal=config.causal,
224
+ resample=config.resample,
225
+ growth=config.growth,
226
+ max_hidden=config.max_hidden,
227
+ do_normalize=config.do_normalize,
228
+ rescale=config.rescale,
229
+ floor=config.floor,
230
+ )
231
+ self.config = config
232
+
233
+ @classmethod
234
+ def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
235
+ config = DemucsConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
236
+
237
+ model = cls(config)
238
+
239
+ if os.path.isdir(pretrained_model_name_or_path):
240
+ ckpt_file = os.path.join(pretrained_model_name_or_path, MODEL_FILE)
241
+ else:
242
+ ckpt_file = pretrained_model_name_or_path
243
+
244
+ with open(ckpt_file, "rb") as f:
245
+ state_dict = torch.load(f, map_location="cpu", weights_only=True)
246
+ model.load_state_dict(state_dict, strict=True)
247
+ return model
248
+
249
+ def save_pretrained(self,
250
+ save_directory: Union[str, os.PathLike],
251
+ state_dict: Optional[dict] = None,
252
+ ):
253
+
254
+ model = self
255
+
256
+ if state_dict is None:
257
+ state_dict = model.state_dict()
258
+
259
+ os.makedirs(save_directory, exist_ok=True)
260
+
261
+ # save state dict
262
+ model_file = os.path.join(save_directory, MODEL_FILE)
263
+ torch.save(state_dict, model_file)
264
+
265
+ # save config
266
+ config_file = os.path.join(save_directory, CONFIG_FILE)
267
+ self.config.to_yaml_file(config_file)
268
+ return save_directory
269
+
270
+
271
+ def main():
272
+ config = DemucsConfig()
273
+ model = DemucsModel(
274
+ in_channels=config.in_channels,
275
+ out_channels=config.out_channels,
276
+ hidden_channels=config.hidden_channels,
277
+ depth=config.depth,
278
+ kernel_size=config.kernel_size,
279
+ stride=config.stride,
280
+ causal=config.causal,
281
+ resample=config.resample,
282
+ growth=config.growth,
283
+ max_hidden=config.max_hidden,
284
+ do_normalize=config.do_normalize,
285
+ rescale=config.rescale,
286
+ floor=config.floor,
287
+ )
288
+
289
+ print(model)
290
+
291
+ noisy = torch.rand(size=(1, 8000*4), dtype=torch.float32)
292
+
293
+ denoise = model.forward(noisy)
294
+ print(denoise.shape)
295
+ return
296
+
297
+
298
+ if __name__ == "__main__":
299
+ main()
toolbox/torchaudio/models/demucs/resample.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ # Copyright (c) Facebook, Inc. and its affiliates.
4
+ # All rights reserved.
5
+ #
6
+ # This source code is licensed under the license found in the
7
+ # LICENSE file in the root directory of this source tree.
8
+ # author: adefossez
9
+
10
+ import math
11
+
12
+ import torch as th
13
+ from torch.nn import functional as F
14
+
15
+
16
+ def sinc(t):
17
+ """sinc.
18
+
19
+ :param t: the input tensor
20
+ """
21
+ return th.where(t == 0, th.tensor(1., device=t.device, dtype=t.dtype), th.sin(t) / t)
22
+
23
+
24
+ def kernel_upsample2(zeros=56):
25
+ """kernel_upsample2.
26
+
27
+ """
28
+ win = th.hann_window(4 * zeros + 1, periodic=False)
29
+ winodd = win[1::2]
30
+ t = th.linspace(-zeros + 0.5, zeros - 0.5, 2 * zeros)
31
+ t *= math.pi
32
+ kernel = (sinc(t) * winodd).view(1, 1, -1)
33
+ return kernel
34
+
35
+
36
+ def upsample2(x, zeros=56):
37
+ """
38
+ Upsampling the input by 2 using sinc interpolation.
39
+ Smith, Julius, and Phil Gossett. "A flexible sampling-rate conversion method."
40
+ ICASSP'84. IEEE International Conference on Acoustics, Speech, and Signal Processing.
41
+ Vol. 9. IEEE, 1984.
42
+ """
43
+ *other, time = x.shape
44
+ kernel = kernel_upsample2(zeros).to(x)
45
+ out = F.conv1d(x.view(-1, 1, time), kernel, padding=zeros)[..., 1:].view(*other, time)
46
+ y = th.stack([x, out], dim=-1)
47
+ return y.view(*other, -1)
48
+
49
+
50
+ def kernel_downsample2(zeros=56):
51
+ """kernel_downsample2.
52
+
53
+ """
54
+ win = th.hann_window(4 * zeros + 1, periodic=False)
55
+ winodd = win[1::2]
56
+ t = th.linspace(-zeros + 0.5, zeros - 0.5, 2 * zeros)
57
+ t.mul_(math.pi)
58
+ kernel = (sinc(t) * winodd).view(1, 1, -1)
59
+ return kernel
60
+
61
+
62
+ def downsample2(x, zeros=56):
63
+ """
64
+ Downsampling the input by 2 using sinc interpolation.
65
+ Smith, Julius, and Phil Gossett. "A flexible sampling-rate conversion method."
66
+ ICASSP'84. IEEE International Conference on Acoustics, Speech, and Signal Processing.
67
+ Vol. 9. IEEE, 1984.
68
+ """
69
+ if x.shape[-1] % 2 != 0:
70
+ x = F.pad(x, (0, 1))
71
+ xeven = x[..., ::2]
72
+ xodd = x[..., 1::2]
73
+ *other, time = xodd.shape
74
+ kernel = kernel_downsample2(zeros).to(x)
75
+ out = xeven + F.conv1d(xodd.view(-1, 1, time), kernel, padding=zeros)[..., :-1].view(
76
+ *other, time)
77
+ return out.view(*other, -1).mul(0.5)
78
+
79
+
80
+ if __name__ == "__main__":
81
+ pass
toolbox/torchaudio/models/nx_dfnet/configuration_nx_dfnet.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ from typing import Tuple
4
+
5
+ from toolbox.torchaudio.configuration_utils import PretrainedConfig
6
+
7
+
8
+ class NXDfNetConfig(PretrainedConfig):
9
+ def __init__(self,
10
+ sample_rate: int = 8000,
11
+ freq_bins: int = 256,
12
+ win_size: int = 200,
13
+ hop_size: int = 100,
14
+
15
+ conv_channels: int = 64,
16
+ conv_kernel_size_input: Tuple[int, int] = (3, 3),
17
+ conv_kernel_size_inner: Tuple[int, int] = (1, 3),
18
+ conv_lookahead: int = 0,
19
+
20
+ convt_kernel_size_inner: Tuple[int, int] = (1, 3),
21
+
22
+ embedding_hidden_size: int = 256,
23
+ encoder_combine_op: str = "concat",
24
+
25
+ encoder_emb_skip_op: str = "none",
26
+ encoder_emb_linear_groups: int = 16,
27
+ encoder_emb_hidden_size: int = 256,
28
+
29
+ encoder_linear_groups: int = 32,
30
+
31
+ lsnr_max: int = 30,
32
+ lsnr_min: int = -15,
33
+ norm_tau: float = 1.,
34
+
35
+ decoder_emb_num_layers: int = 3,
36
+ decoder_emb_skip_op: str = "none",
37
+ decoder_emb_linear_groups: int = 16,
38
+ decoder_emb_hidden_size: int = 256,
39
+
40
+ df_decoder_hidden_size: int = 256,
41
+ df_num_layers: int = 2,
42
+ df_order: int = 5,
43
+ df_bins: int = 96,
44
+ df_gru_skip: str = "grouped_linear",
45
+ df_decoder_linear_groups: int = 16,
46
+ df_pathway_kernel_size_t: int = 5,
47
+ df_lookahead: int = 2,
48
+
49
+ use_post_filter: bool = False,
50
+ **kwargs
51
+ ):
52
+ super(NXDfNetConfig, self).__init__(**kwargs)
53
+ # transform
54
+ self.sample_rate = sample_rate
55
+ self.freq_bins = freq_bins
56
+ self.win_size = win_size
57
+ self.hop_size = hop_size
58
+
59
+ # conv
60
+ self.conv_channels = conv_channels
61
+ self.conv_kernel_size_input = conv_kernel_size_input
62
+ self.conv_kernel_size_inner = conv_kernel_size_inner
63
+ self.conv_lookahead = conv_lookahead
64
+
65
+ self.convt_kernel_size_inner = convt_kernel_size_inner
66
+
67
+ self.embedding_hidden_size = embedding_hidden_size
68
+
69
+ # encoder
70
+ self.encoder_emb_skip_op = encoder_emb_skip_op
71
+ self.encoder_emb_linear_groups = encoder_emb_linear_groups
72
+ self.encoder_emb_hidden_size = encoder_emb_hidden_size
73
+
74
+ self.encoder_linear_groups = encoder_linear_groups
75
+ self.encoder_combine_op = encoder_combine_op
76
+
77
+ self.lsnr_max = lsnr_max
78
+ self.lsnr_min = lsnr_min
79
+ self.norm_tau = norm_tau
80
+
81
+ # decoder
82
+ self.decoder_emb_num_layers = decoder_emb_num_layers
83
+ self.decoder_emb_skip_op = decoder_emb_skip_op
84
+ self.decoder_emb_linear_groups = decoder_emb_linear_groups
85
+ self.decoder_emb_hidden_size = decoder_emb_hidden_size
86
+
87
+ # df decoder
88
+ self.df_decoder_hidden_size = df_decoder_hidden_size
89
+ self.df_num_layers = df_num_layers
90
+ self.df_order = df_order
91
+ self.df_bins = df_bins
92
+ self.df_gru_skip = df_gru_skip
93
+ self.df_decoder_linear_groups = df_decoder_linear_groups
94
+ self.df_pathway_kernel_size_t = df_pathway_kernel_size_t
95
+ self.df_lookahead = df_lookahead
96
+
97
+ # runtime
98
+ self.use_post_filter = use_post_filter
99
+
100
+
101
+ if __name__ == "__main__":
102
+ pass
toolbox/torchaudio/models/nx_dfnet/modeling_nx_dfnet.py ADDED
@@ -0,0 +1,989 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ import os
4
+ import math
5
+ from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
6
+
7
+ import numpy as np
8
+ import torch
9
+ import torch.nn as nn
10
+ from torch.nn import functional as F
11
+ import torchaudio
12
+
13
+ from toolbox.torchaudio.models.nx_dfnet.utils import overlap_and_add
14
+ from toolbox.torchaudio.models.nx_dfnet.configuration_nx_dfnet import NXDfNetConfig
15
+ from toolbox.torchaudio.configuration_utils import CONFIG_FILE
16
+
17
+
18
+ MODEL_FILE = "model.pt"
19
+
20
+
21
+ norm_layer_dict = {
22
+ "batch_norm_2d": torch.nn.BatchNorm2d
23
+ }
24
+
25
+
26
+ activation_layer_dict = {
27
+ "relu": torch.nn.ReLU,
28
+ "identity": torch.nn.Identity,
29
+ "sigmoid": torch.nn.Sigmoid,
30
+ }
31
+
32
+
33
+ class CausalConv2d(nn.Sequential):
34
+ def __init__(self,
35
+ in_channels: int,
36
+ out_channels: int,
37
+ kernel_size: Union[int, Iterable[int]],
38
+ fstride: int = 1,
39
+ dilation: int = 1,
40
+ fpad: bool = True,
41
+ bias: bool = True,
42
+ separable: bool = False,
43
+ norm_layer: str = "batch_norm_2d",
44
+ activation_layer: str = "relu",
45
+ lookahead: int = 0
46
+ ):
47
+ """
48
+ Causal Conv2d by delaying the signal for any lookahead.
49
+
50
+ Expected input format: [batch_size, channels, time_steps, spec_dim]
51
+
52
+ :param in_channels:
53
+ :param out_channels:
54
+ :param kernel_size:
55
+ :param fstride:
56
+ :param dilation:
57
+ :param fpad:
58
+ """
59
+ super(CausalConv2d, self).__init__()
60
+ kernel_size = (kernel_size, kernel_size) if isinstance(kernel_size, int) else tuple(kernel_size)
61
+
62
+ if fpad:
63
+ fpad_ = kernel_size[1] // 2 + dilation - 1
64
+ else:
65
+ fpad_ = 0
66
+
67
+ # for last 2 dim, pad (left, right, top, bottom).
68
+ pad = (0, 0, kernel_size[0] - 1 - lookahead, lookahead)
69
+
70
+ layers = list()
71
+ if any(x > 0 for x in pad):
72
+ layers.append(nn.ConstantPad2d(pad, 0.0))
73
+
74
+ groups = math.gcd(in_channels, out_channels) if separable else 1
75
+ if groups == 1:
76
+ separable = False
77
+ if max(kernel_size) == 1:
78
+ separable = False
79
+
80
+ layers.append(
81
+ nn.Conv2d(
82
+ in_channels,
83
+ out_channels,
84
+ kernel_size=kernel_size,
85
+ padding=(0, fpad_),
86
+ stride=(1, fstride), # stride over time is always 1
87
+ dilation=(1, dilation), # dilation over time is always 1
88
+ groups=groups,
89
+ bias=bias,
90
+ )
91
+ )
92
+
93
+ if separable:
94
+ layers.append(
95
+ nn.Conv2d(
96
+ out_channels,
97
+ out_channels,
98
+ kernel_size=1,
99
+ bias=False,
100
+ )
101
+ )
102
+
103
+ if norm_layer is not None:
104
+ norm_layer = norm_layer_dict[norm_layer]
105
+ layers.append(norm_layer(out_channels))
106
+
107
+ if activation_layer is not None:
108
+ activation_layer = activation_layer_dict[activation_layer]
109
+ layers.append(activation_layer())
110
+
111
+ super().__init__(*layers)
112
+
113
+ def forward(self, inputs):
114
+ for module in self:
115
+ inputs = module(inputs)
116
+ return inputs
117
+
118
+
119
+ class CausalConvTranspose2d(nn.Sequential):
120
+ def __init__(self,
121
+ in_channels: int,
122
+ out_channels: int,
123
+ kernel_size: Union[int, Iterable[int]],
124
+ fstride: int = 1,
125
+ dilation: int = 1,
126
+ fpad: bool = True,
127
+ bias: bool = True,
128
+ separable: bool = False,
129
+ norm_layer: str = "batch_norm_2d",
130
+ activation_layer: str = "relu",
131
+ lookahead: int = 0
132
+ ):
133
+ """
134
+ Causal ConvTranspose2d.
135
+
136
+ Expected input format: [batch_size, channels, time_steps, spec_dim]
137
+ """
138
+ super(CausalConvTranspose2d, self).__init__()
139
+
140
+ kernel_size = (kernel_size, kernel_size) if isinstance(kernel_size, int) else kernel_size
141
+
142
+ if fpad:
143
+ fpad_ = kernel_size[1] // 2
144
+ else:
145
+ fpad_ = 0
146
+
147
+ # for last 2 dim, pad (left, right, top, bottom).
148
+ pad = (0, 0, kernel_size[0] - 1 - lookahead, lookahead)
149
+
150
+ layers = []
151
+ if any(x > 0 for x in pad):
152
+ layers.append(nn.ConstantPad2d(pad, 0.0))
153
+
154
+ groups = math.gcd(in_channels, out_channels) if separable else 1
155
+ if groups == 1:
156
+ separable = False
157
+
158
+ layers.append(
159
+ nn.ConvTranspose2d(
160
+ in_channels,
161
+ out_channels,
162
+ kernel_size=kernel_size,
163
+ padding=(kernel_size[0] - 1, fpad_ + dilation - 1),
164
+ output_padding=(0, fpad_),
165
+ stride=(1, fstride), # stride over time is always 1
166
+ dilation=(1, dilation), # dilation over time is always 1
167
+ groups=groups,
168
+ bias=bias,
169
+ )
170
+ )
171
+
172
+ if separable:
173
+ layers.append(
174
+ nn.Conv2d(
175
+ out_channels,
176
+ out_channels,
177
+ kernel_size=1,
178
+ bias=False,
179
+ )
180
+ )
181
+
182
+ if norm_layer is not None:
183
+ norm_layer = norm_layer_dict[norm_layer]
184
+ layers.append(norm_layer(out_channels))
185
+
186
+ if activation_layer is not None:
187
+ activation_layer = activation_layer_dict[activation_layer]
188
+ layers.append(activation_layer())
189
+
190
+ super().__init__(*layers)
191
+
192
+
193
+ class GroupedLinear(nn.Module):
194
+
195
+ def __init__(self, input_size: int, hidden_size: int, groups: int = 1):
196
+ super().__init__()
197
+ # self.weight: Tensor
198
+ self.input_size = input_size
199
+ self.hidden_size = hidden_size
200
+ self.groups = groups
201
+ assert input_size % groups == 0, f"Input size {input_size} not divisible by {groups}"
202
+ assert hidden_size % groups == 0, f"Hidden size {hidden_size} not divisible by {groups}"
203
+ self.ws = input_size // groups
204
+ self.register_parameter(
205
+ "weight",
206
+ torch.nn.Parameter(
207
+ torch.zeros(groups, input_size // groups, hidden_size // groups), requires_grad=True
208
+ ),
209
+ )
210
+ self.reset_parameters()
211
+
212
+ def reset_parameters(self):
213
+ nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) # type: ignore
214
+
215
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
216
+ # x: [..., I]
217
+ b, t, _ = x.shape
218
+ # new_shape = list(x.shape)[:-1] + [self.groups, self.ws]
219
+ new_shape = (b, t, self.groups, self.ws)
220
+ x = x.view(new_shape)
221
+ # The better way, but not supported by torchscript
222
+ # x = x.unflatten(-1, (self.groups, self.ws)) # [..., G, I/G]
223
+ x = torch.einsum("btgi,gih->btgh", x, self.weight) # [..., G, H/G]
224
+ x = x.flatten(2, 3) # [B, T, H]
225
+ return x
226
+
227
+ def __repr__(self):
228
+ cls = self.__class__.__name__
229
+ return f"{cls}(input_size: {self.input_size}, hidden_size: {self.hidden_size}, groups: {self.groups})"
230
+
231
+
232
+ class SqueezedGRU_S(nn.Module):
233
+ """
234
+ SGE net: Video object detection with squeezed GRU and information entropy map
235
+ https://arxiv.org/abs/2106.07224
236
+ """
237
+
238
+ def __init__(
239
+ self,
240
+ input_size: int,
241
+ hidden_size: int,
242
+ output_size: Optional[int] = None,
243
+ num_layers: int = 1,
244
+ linear_groups: int = 8,
245
+ batch_first: bool = True,
246
+ skip_op: str = "none",
247
+ activation_layer: str = "identity",
248
+ ):
249
+ super().__init__()
250
+ self.input_size = input_size
251
+ self.hidden_size = hidden_size
252
+
253
+ self.linear_in = nn.Sequential(
254
+ GroupedLinear(
255
+ input_size=input_size,
256
+ hidden_size=hidden_size,
257
+ groups=linear_groups,
258
+ ),
259
+ activation_layer_dict[activation_layer](),
260
+ )
261
+
262
+ # gru skip operator
263
+ self.gru_skip_op = None
264
+
265
+ if skip_op == "none":
266
+ self.gru_skip_op = None
267
+ elif skip_op == "identity":
268
+ if not input_size != output_size:
269
+ raise AssertionError("Dimensions do not match")
270
+ self.gru_skip_op = nn.Identity()
271
+ elif skip_op == "grouped_linear":
272
+ self.gru_skip_op = GroupedLinear(
273
+ input_size=hidden_size,
274
+ hidden_size=hidden_size,
275
+ groups=linear_groups,
276
+ )
277
+ else:
278
+ raise NotImplementedError()
279
+
280
+ self.gru = nn.GRU(
281
+ input_size=hidden_size,
282
+ hidden_size=hidden_size,
283
+ num_layers=num_layers,
284
+ batch_first=batch_first,
285
+ bidirectional=False,
286
+ )
287
+
288
+ if output_size is not None:
289
+ self.linear_out = nn.Sequential(
290
+ GroupedLinear(
291
+ input_size=hidden_size,
292
+ hidden_size=output_size,
293
+ groups=linear_groups,
294
+ ),
295
+ activation_layer_dict[activation_layer](),
296
+ )
297
+ else:
298
+ self.linear_out = nn.Identity()
299
+
300
+ def forward(self, inputs: torch.Tensor, h=None) -> Tuple[torch.Tensor, torch.Tensor]:
301
+ x = self.linear_in(inputs)
302
+
303
+ x, h = self.gru.forward(x, h)
304
+
305
+ x = self.linear_out(x)
306
+
307
+ if self.gru_skip_op is not None:
308
+ x = x + self.gru_skip_op(inputs)
309
+
310
+ return x, h
311
+
312
+
313
+ class Add(nn.Module):
314
+ def forward(self, a, b):
315
+ return a + b
316
+
317
+
318
+ class Concat(nn.Module):
319
+ def forward(self, a, b):
320
+ return torch.cat((a, b), dim=-1)
321
+
322
+
323
+ class DeepSTFT(nn.Module):
324
+ def __init__(self, win_size: int, freq_bins: int):
325
+ super(DeepSTFT, self).__init__()
326
+ self.win_size = win_size
327
+ self.freq_bins = freq_bins
328
+
329
+ self.conv1d_U = nn.Conv1d(
330
+ in_channels=1,
331
+ out_channels=freq_bins * 2,
332
+ kernel_size=win_size,
333
+ stride=win_size // 2,
334
+ bias=False
335
+ )
336
+
337
+ def forward(self, signal: torch.Tensor):
338
+ """
339
+ :param signal: Tensor, shape: [batch_size, num_samples]
340
+ :return: v, Tensor, shape: [batch_size, freq_bins, time_steps, 2],
341
+ where time_steps = (num_samples-win_size) / (win_size/2) + 1 = 2num_samples/win_size-1
342
+ """
343
+ signal = torch.unsqueeze(signal, 1)
344
+ # signal shape: [batch_size, 1, num_samples]
345
+ spec = F.relu(self.conv1d_U(signal))
346
+ # spec shape: [batch_size, freq_bins * 2, time_steps]
347
+ b, f2, t = spec.shape
348
+ spec = spec.view(b, f2//2, 2, t).permute(0, 1, 3, 2)
349
+ # spec shape: [batch_size, freq_bins, time_steps, 2]
350
+ return spec
351
+
352
+
353
+ class DeepISTFT(nn.Module):
354
+ def __init__(self, win_size: int, freq_bins: int):
355
+ super(DeepISTFT, self).__init__()
356
+ self.win_size = win_size
357
+ self.freq_bins = freq_bins
358
+
359
+ self.basis_signals = nn.Linear(
360
+ in_features=freq_bins * 2,
361
+ out_features=win_size,
362
+ bias=False
363
+ )
364
+
365
+ def forward(self,
366
+ spec: torch.Tensor,
367
+ ):
368
+ """
369
+ :param spec: Tensor, shape: [batch_size, freq_bins, time_steps, 2],
370
+ where time_steps = (num_samples-win_size) / (win_size/2) + 1 = 2num_samples/win_size-1
371
+ :return: Tensor, shape: [batch_size, c, num_samples],
372
+ """
373
+ b, f, t, _ = spec.shape
374
+ # spec shape: [b, f, t, 2]
375
+ spec = spec.permute(0, 2, 1, 3)
376
+ # spec shape: [b, t, f, 2]
377
+ spec = spec.view(b, 1, t, -1)
378
+ # spec shape: [b, 1, t, f2]
379
+ signal = self.basis_signals(spec)
380
+ # signal shape: [b, 1, t, win_size]
381
+ signal = overlap_and_add(signal, self.win_size//2)
382
+ # signal shape: [b, 1, num_samples]
383
+ return signal
384
+
385
+
386
+ class Encoder(nn.Module):
387
+ def __init__(self, config: NXDfNetConfig):
388
+ super(Encoder, self).__init__()
389
+ self.embedding_input_size = config.conv_channels * config.freq_bins // 4
390
+ self.embedding_output_size = config.conv_channels * config.freq_bins // 4
391
+ self.embedding_hidden_size = config.embedding_hidden_size
392
+
393
+ self.spec_conv0 = CausalConv2d(
394
+ in_channels=1,
395
+ out_channels=config.conv_channels,
396
+ kernel_size=config.conv_kernel_size_input,
397
+ bias=False,
398
+ separable=True,
399
+ fstride=1,
400
+ lookahead=config.conv_lookahead,
401
+ )
402
+ self.spec_conv1 = CausalConv2d(
403
+ in_channels=config.conv_channels,
404
+ out_channels=config.conv_channels,
405
+ kernel_size=config.conv_kernel_size_inner,
406
+ bias=False,
407
+ separable=True,
408
+ fstride=2,
409
+ lookahead=config.conv_lookahead,
410
+ )
411
+ self.spec_conv2 = CausalConv2d(
412
+ in_channels=config.conv_channels,
413
+ out_channels=config.conv_channels,
414
+ kernel_size=config.conv_kernel_size_inner,
415
+ bias=False,
416
+ separable=True,
417
+ fstride=2,
418
+ lookahead=config.conv_lookahead,
419
+ )
420
+ self.spec_conv3 = CausalConv2d(
421
+ in_channels=config.conv_channels,
422
+ out_channels=config.conv_channels,
423
+ kernel_size=config.conv_kernel_size_inner,
424
+ bias=False,
425
+ separable=True,
426
+ fstride=1,
427
+ lookahead=config.conv_lookahead,
428
+ )
429
+
430
+ self.df_conv0 = CausalConv2d(
431
+ in_channels=2,
432
+ out_channels=config.conv_channels,
433
+ kernel_size=config.conv_kernel_size_input,
434
+ bias=False,
435
+ separable=True,
436
+ fstride=1,
437
+ )
438
+ self.df_conv1 = CausalConv2d(
439
+ in_channels=config.conv_channels,
440
+ out_channels=config.conv_channels,
441
+ kernel_size=config.conv_kernel_size_inner,
442
+ bias=False,
443
+ separable=True,
444
+ fstride=2,
445
+ )
446
+ self.df_fc_emb = nn.Sequential(
447
+ GroupedLinear(
448
+ config.conv_channels * config.df_bins // 2,
449
+ self.embedding_input_size,
450
+ groups=config.encoder_linear_groups
451
+ ),
452
+ nn.ReLU(inplace=True)
453
+ )
454
+
455
+ if config.encoder_combine_op == "concat":
456
+ self.embedding_input_size *= 2
457
+ self.combine = Concat()
458
+ else:
459
+ self.combine = Add()
460
+
461
+ # emb_gru
462
+ if config.freq_bins % 8 != 0:
463
+ raise AssertionError("freq_bins should be divisible by 8")
464
+
465
+ self.emb_gru = SqueezedGRU_S(
466
+ self.embedding_input_size,
467
+ self.embedding_hidden_size,
468
+ output_size=self.embedding_output_size,
469
+ num_layers=1,
470
+ batch_first=True,
471
+ skip_op=config.encoder_emb_skip_op,
472
+ linear_groups=config.encoder_emb_linear_groups,
473
+ activation_layer="relu",
474
+ )
475
+
476
+ # lsnr
477
+ self.lsnr_fc = nn.Sequential(
478
+ nn.Linear(self.embedding_output_size, 1),
479
+ nn.Sigmoid()
480
+ )
481
+ self.lsnr_scale = config.lsnr_max - config.lsnr_min
482
+ self.lsnr_offset = config.lsnr_min
483
+
484
+ def forward(self,
485
+ power_spec: torch.Tensor,
486
+ df_spec: torch.Tensor,
487
+ hidden_state: torch.Tensor = None,
488
+ ):
489
+ # power_spec shape: (batch_size, 1, time_steps, spec_dim)
490
+ e0 = self.spec_conv0.forward(power_spec)
491
+ e1 = self.spec_conv1.forward(e0)
492
+ e2 = self.spec_conv2.forward(e1)
493
+ e3 = self.spec_conv3.forward(e2)
494
+ # e0 shape: [batch_size, channels, time_steps, spec_dim]
495
+ # e1 shape: [batch_size, channels, time_steps, spec_dim // 2]
496
+ # e2 shape: [batch_size, channels, time_steps, spec_dim // 4]
497
+ # e3 shape: [batch_size, channels, time_steps, spec_dim // 4]
498
+
499
+ # df_spec, shape: (batch_size, 2, time_steps, df_bins)
500
+ c0 = self.df_conv0(df_spec)
501
+ c1 = self.df_conv1(c0)
502
+ # c0 shape: [batch_size, channels, time_steps, df_bins]
503
+ # c1 shape: [batch_size, channels, time_steps, df_bins // 2]
504
+
505
+ cemb = c1.permute(0, 2, 3, 1)
506
+ # cemb shape: [batch_size, time_steps, df_bins // 2, channels]
507
+ cemb = cemb.flatten(2)
508
+ # cemb shape: [batch_size, time_steps, df_bins // 2 * channels]
509
+ cemb = self.df_fc_emb(cemb)
510
+ # cemb shape: [batch_size, time_steps, spec_dim // 4 * channels]
511
+
512
+ # e3 shape: [batch_size, channels, time_steps, spec_dim // 4]
513
+ emb = e3.permute(0, 2, 3, 1)
514
+ # emb shape: [batch_size, time_steps, spec_dim // 4, channels]
515
+ emb = emb.flatten(2)
516
+ # emb shape: [batch_size, time_steps, spec_dim // 4 * channels]
517
+
518
+ emb = self.combine(emb, cemb)
519
+ # if concat; emb shape: [batch_size, time_steps, spec_dim // 4 * channels * 2]
520
+ # if add; emb shape: [batch_size, time_steps, spec_dim // 4 * channels]
521
+
522
+ emb, h = self.emb_gru.forward(emb, hidden_state)
523
+ # emb shape: [batch_size, time_steps, spec_dim // 4 * channels]
524
+ # h shape: [batch_size, 1, spec_dim]
525
+
526
+ lsnr = self.lsnr_fc(emb) * self.lsnr_scale + self.lsnr_offset
527
+ # lsnr shape: [batch_size, time_steps, 1]
528
+
529
+ return e0, e1, e2, e3, emb, c0, lsnr, h
530
+
531
+
532
+ class Decoder(nn.Module):
533
+ def __init__(self, config: NXDfNetConfig):
534
+ super(Decoder, self).__init__()
535
+
536
+ if config.freq_bins % 8 != 0:
537
+ raise AssertionError("freq_bins should be divisible by 8")
538
+
539
+ self.emb_in_dim = config.conv_channels * config.freq_bins // 4
540
+ self.emb_out_dim = config.conv_channels * config.freq_bins // 4
541
+ self.emb_hidden_dim = config.decoder_emb_hidden_size
542
+
543
+ self.emb_gru = SqueezedGRU_S(
544
+ self.emb_in_dim,
545
+ self.emb_hidden_dim,
546
+ output_size=self.emb_out_dim,
547
+ num_layers=config.decoder_emb_num_layers - 1,
548
+ batch_first=True,
549
+ skip_op=config.decoder_emb_skip_op,
550
+ linear_groups=config.decoder_emb_linear_groups,
551
+ activation_layer="relu",
552
+ )
553
+ self.conv3p = CausalConv2d(
554
+ in_channels=config.conv_channels,
555
+ out_channels=config.conv_channels,
556
+ kernel_size=1,
557
+ bias=False,
558
+ separable=True,
559
+ fstride=1,
560
+ lookahead=config.conv_lookahead,
561
+ )
562
+ self.convt3 = CausalConv2d(
563
+ in_channels=config.conv_channels,
564
+ out_channels=config.conv_channels,
565
+ kernel_size=config.conv_kernel_size_inner,
566
+ bias=False,
567
+ separable=True,
568
+ fstride=1,
569
+ lookahead=config.conv_lookahead,
570
+ )
571
+ self.conv2p = CausalConv2d(
572
+ in_channels=config.conv_channels,
573
+ out_channels=config.conv_channels,
574
+ kernel_size=1,
575
+ bias=False,
576
+ separable=True,
577
+ fstride=1,
578
+ lookahead=config.conv_lookahead,
579
+ )
580
+ self.convt2 = CausalConvTranspose2d(
581
+ in_channels=config.conv_channels,
582
+ out_channels=config.conv_channels,
583
+ kernel_size=config.convt_kernel_size_inner,
584
+ bias=False,
585
+ separable=True,
586
+ fstride=2,
587
+ lookahead=config.conv_lookahead,
588
+ )
589
+ self.conv1p = CausalConv2d(
590
+ in_channels=config.conv_channels,
591
+ out_channels=config.conv_channels,
592
+ kernel_size=1,
593
+ bias=False,
594
+ separable=True,
595
+ fstride=1,
596
+ lookahead=config.conv_lookahead,
597
+ )
598
+ self.convt1 = CausalConvTranspose2d(
599
+ in_channels=config.conv_channels,
600
+ out_channels=config.conv_channels,
601
+ kernel_size=config.convt_kernel_size_inner,
602
+ bias=False,
603
+ separable=True,
604
+ fstride=2,
605
+ lookahead=config.conv_lookahead,
606
+ )
607
+ self.conv0p = CausalConv2d(
608
+ in_channels=config.conv_channels,
609
+ out_channels=config.conv_channels,
610
+ kernel_size=1,
611
+ bias=False,
612
+ separable=True,
613
+ fstride=1,
614
+ lookahead=config.conv_lookahead,
615
+ )
616
+ self.conv0_out = CausalConv2d(
617
+ in_channels=config.conv_channels,
618
+ out_channels=1,
619
+ kernel_size=config.conv_kernel_size_inner,
620
+ activation_layer="sigmoid",
621
+ bias=False,
622
+ separable=True,
623
+ fstride=1,
624
+ lookahead=config.conv_lookahead,
625
+ )
626
+
627
+ def forward(self, emb, e3, e2, e1, e0) -> torch.Tensor:
628
+ # Estimates erb mask
629
+ b, _, t, f8 = e3.shape
630
+
631
+ # emb shape: [batch_size, time_steps, (freq_dim // 4) * conv_channels]
632
+ emb, _ = self.emb_gru(emb)
633
+ # emb shape: [batch_size, conv_channels, time_steps, freq_dim // 4]
634
+ emb = emb.view(b, t, f8, -1).permute(0, 3, 1, 2)
635
+ e3 = self.convt3(self.conv3p(e3) + emb)
636
+ # e3 shape: [batch_size, conv_channels, time_steps, freq_dim // 4]
637
+ e2 = self.convt2(self.conv2p(e2) + e3)
638
+ # e2 shape: [batch_size, conv_channels, time_steps, freq_dim // 2]
639
+ e1 = self.convt1(self.conv1p(e1) + e2)
640
+ # e1 shape: [batch_size, conv_channels, time_steps, freq_dim]
641
+ mask = self.conv0_out(self.conv0p(e0) + e1)
642
+ # mask shape: [batch_size, 1, time_steps, freq_dim]
643
+ return mask
644
+
645
+
646
+ class DfDecoder(nn.Module):
647
+ def __init__(self, config: NXDfNetConfig):
648
+ super(DfDecoder, self).__init__()
649
+
650
+ self.embedding_input_size = config.conv_channels * config.freq_bins // 4
651
+ self.df_decoder_hidden_size = config.df_decoder_hidden_size
652
+ self.df_num_layers = config.df_num_layers
653
+
654
+ self.df_order = config.df_order
655
+
656
+ self.df_bins = config.df_bins
657
+ self.df_out_ch = config.df_order * 2
658
+
659
+ self.df_convp = CausalConv2d(
660
+ config.conv_channels,
661
+ self.df_out_ch,
662
+ fstride=1,
663
+ kernel_size=(config.df_pathway_kernel_size_t, 1),
664
+ separable=True,
665
+ bias=False,
666
+ )
667
+ self.df_gru = SqueezedGRU_S(
668
+ self.embedding_input_size,
669
+ self.df_decoder_hidden_size,
670
+ num_layers=self.df_num_layers,
671
+ batch_first=True,
672
+ skip_op="none",
673
+ activation_layer="relu",
674
+ )
675
+
676
+ if config.df_gru_skip == "none":
677
+ self.df_skip = None
678
+ elif config.df_gru_skip == "identity":
679
+ if config.embedding_hidden_size != config.df_decoder_hidden_size:
680
+ raise AssertionError("Dimensions do not match")
681
+ self.df_skip = nn.Identity()
682
+ elif config.df_gru_skip == "grouped_linear":
683
+ self.df_skip = GroupedLinear(
684
+ self.embedding_input_size,
685
+ self.df_decoder_hidden_size,
686
+ groups=config.df_decoder_linear_groups
687
+ )
688
+ else:
689
+ raise NotImplementedError()
690
+
691
+ self.df_out: nn.Module
692
+ out_dim = self.df_bins * self.df_out_ch
693
+
694
+ self.df_out = nn.Sequential(
695
+ GroupedLinear(
696
+ input_size=self.df_decoder_hidden_size,
697
+ hidden_size=out_dim,
698
+ groups=config.df_decoder_linear_groups
699
+ ),
700
+ nn.Tanh()
701
+ )
702
+ self.df_fc_a = nn.Sequential(
703
+ nn.Linear(self.df_decoder_hidden_size, 1),
704
+ nn.Sigmoid()
705
+ )
706
+
707
+ def forward(self, emb: torch.Tensor, c0: torch.Tensor) -> torch.Tensor:
708
+ # emb shape: [batch_size, time_steps, df_bins // 4 * channels]
709
+ b, t, _ = emb.shape
710
+ df_coefs, _ = self.df_gru(emb)
711
+ if self.df_skip is not None:
712
+ df_coefs = df_coefs + self.df_skip(emb)
713
+ # df_coefs shape: [batch_size, time_steps, df_decoder_hidden_size]
714
+
715
+ # c0 shape: [batch_size, channels, time_steps, df_bins]
716
+ c0 = self.df_convp(c0)
717
+ # c0 shape: [batch_size, df_order * 2, time_steps, df_bins]
718
+ c0 = c0.permute(0, 2, 3, 1)
719
+ # c0 shape: [batch_size, time_steps, df_bins, df_order * 2]
720
+
721
+ df_coefs = self.df_out(df_coefs) # [B, T, F*O*2], O: df_order
722
+ # df_coefs shape: [batch_size, time_steps, df_bins * df_order * 2]
723
+ df_coefs = df_coefs.view(b, t, self.df_bins, self.df_out_ch)
724
+ # df_coefs shape: [batch_size, time_steps, df_bins, df_order * 2]
725
+ df_coefs = df_coefs + c0
726
+ # df_coefs shape: [batch_size, time_steps, df_bins, df_order * 2]
727
+ return df_coefs
728
+
729
+
730
+ class DfOutputReshapeMF(nn.Module):
731
+ """Coefficients output reshape for multiframe/MultiFrameModule
732
+
733
+ Requires input of shape B, C, T, F, 2.
734
+ """
735
+
736
+ def __init__(self, df_order: int, df_bins: int):
737
+ super().__init__()
738
+ self.df_order = df_order
739
+ self.df_bins = df_bins
740
+
741
+ def forward(self, coefs: torch.Tensor) -> torch.Tensor:
742
+ # [B, T, F, O*2] -> [B, O, T, F, 2]
743
+ new_shape = list(coefs.shape)
744
+ new_shape[-1] = -1
745
+ new_shape.append(2)
746
+ coefs = coefs.view(new_shape)
747
+ coefs = coefs.permute(0, 3, 1, 2, 4)
748
+ return coefs
749
+
750
+
751
+ class Mask(nn.Module):
752
+ def __init__(self, use_post_filter: bool = False, eps: float = 1e-12):
753
+ super().__init__()
754
+ self.use_post_filter = use_post_filter
755
+ self.eps = eps
756
+
757
+ def post_filter(self, mask: torch.Tensor, beta: float = 0.02) -> torch.Tensor:
758
+ """
759
+ Post-Filter
760
+
761
+ A Perceptually-Motivated Approach for Low-Complexity, Real-Time Enhancement of Fullband Speech.
762
+ https://arxiv.org/abs/2008.04259
763
+
764
+ :param mask: Real valued mask, typically of shape [B, C, T, F].
765
+ :param beta: Global gain factor.
766
+ :return:
767
+ """
768
+ mask_sin = mask * torch.sin(np.pi * mask / 2)
769
+ mask_pf = (1 + beta) * mask / (1 + beta * mask.div(mask_sin.clamp_min(self.eps)).pow(2))
770
+ return mask_pf
771
+
772
+ def forward(self, spec: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
773
+ # spec shape: [batch_size, 1, time_steps, freq_bins, 2]
774
+
775
+ if not self.training and self.use_post_filter:
776
+ mask = self.post_filter(mask)
777
+
778
+ # mask shape: [batch_size, 1, time_steps, freq_bins]
779
+ mask = mask.unsqueeze(4)
780
+ # mask shape: [batch_size, 1, time_steps, freq_bins, 1]
781
+ return spec * mask
782
+
783
+
784
+ class DeepFiltering(nn.Module):
785
+ def __init__(self,
786
+ df_bins: int,
787
+ df_order: int,
788
+ lookahead: int = 0,
789
+ ):
790
+ super(DeepFiltering, self).__init__()
791
+ self.df_bins = df_bins
792
+ self.df_order = df_order
793
+ self.need_unfold = df_order > 1
794
+ self.lookahead = lookahead
795
+
796
+ self.pad = nn.ConstantPad2d((0, 0, df_order - 1 - lookahead, lookahead), 0.0)
797
+
798
+ def spec_unfold(self, spec: torch.Tensor):
799
+ """
800
+ Pads and unfolds the spectrogram according to frame_size.
801
+ :param spec: complex Tensor, Spectrogram of shape [B, C, T, F].
802
+ :return: Tensor, Unfolded spectrogram of shape [B, C, T, F, N], where N: frame_size.
803
+ """
804
+ if self.need_unfold:
805
+ # spec shape: [batch_size, freq_bins, time_steps]
806
+ spec_pad = self.pad(spec)
807
+ # spec_pad shape: [batch_size, 1, time_steps_pad, freq_bins]
808
+ spec_unfold = spec_pad.unfold(2, self.df_order, 1)
809
+ # spec_unfold shape: [batch_size, 1, time_steps, freq_bins, df_order]
810
+ return spec_unfold
811
+ else:
812
+ return spec.unsqueeze(-1)
813
+
814
+ def forward(self,
815
+ spec: torch.Tensor,
816
+ coefs: torch.Tensor,
817
+ ):
818
+ # spec shape: [batch_size, 1, time_steps, freq_bins, 2]
819
+ spec = spec.contiguous()
820
+ spec_u = self.spec_unfold(torch.view_as_complex(spec))
821
+ # spec_u shape: [batch_size, 1, time_steps, freq_bins, df_order]
822
+
823
+ # coefs shape: [batch_size, df_order, time_steps, df_bins, 2]
824
+ coefs = torch.view_as_complex(coefs)
825
+ # coefs shape: [batch_size, df_order, time_steps, df_bins]
826
+ spec_f = spec_u.narrow(-2, 0, self.df_bins)
827
+ # spec_f shape: [batch_size, 1, time_steps, df_bins, df_order]
828
+
829
+ coefs = coefs.view(coefs.shape[0], -1, self.df_order, *coefs.shape[2:])
830
+ # coefs shape: [batch_size, 1, df_order, time_steps, df_bins]
831
+
832
+ spec_f = self.df(spec_f, coefs)
833
+ # spec_f shape: [batch_size, 1, time_steps, df_bins]
834
+
835
+ if self.training:
836
+ spec = spec.clone()
837
+ spec[..., :self.df_bins, :] = torch.view_as_real(spec_f)
838
+ # spec shape: [batch_size, 1, time_steps, freq_bins, 2]
839
+ return spec
840
+
841
+ @staticmethod
842
+ def df(spec: torch.Tensor, coefs: torch.Tensor) -> torch.Tensor:
843
+ """
844
+ Deep filter implementation using `torch.einsum`. Requires unfolded spectrogram.
845
+ :param spec: (complex Tensor). Spectrogram of shape [B, C, T, F, N].
846
+ :param coefs: (complex Tensor). Coefficients of shape [B, C, N, T, F].
847
+ :return: (complex Tensor). Spectrogram of shape [B, C, T, F].
848
+ """
849
+ return torch.einsum("...tfn,...ntf->...tf", spec, coefs)
850
+
851
+
852
+ class NXDfNet(nn.Module):
853
+ def __init__(self, config: NXDfNetConfig):
854
+ super(NXDfNet, self).__init__()
855
+ self.config = config
856
+
857
+ self.stft = DeepSTFT(win_size=config.win_size, freq_bins=config.freq_bins)
858
+ self.istft = DeepISTFT(win_size=config.win_size, freq_bins=config.freq_bins)
859
+
860
+ self.encoder = Encoder(config)
861
+ self.decoder = Decoder(config)
862
+
863
+ self.df_decoder = DfDecoder(config)
864
+ self.df_out_transform = DfOutputReshapeMF(config.df_order, config.df_bins)
865
+ self.df_op = DeepFiltering(
866
+ df_bins=config.df_bins,
867
+ df_order=config.df_order,
868
+ lookahead=config.df_lookahead,
869
+ )
870
+
871
+ self.mask = Mask(use_post_filter=config.use_post_filter)
872
+
873
+ def forward(self,
874
+ noisy: torch.Tensor,
875
+ ):
876
+ """
877
+ :param noisy: Tensor, shape: [batch_size, num_samples]
878
+ :return:
879
+ """
880
+ spec = self.stft.forward(noisy)
881
+ # spec shape: [batch_size, freq_bins, time_steps, 2]
882
+ power_spec = torch.sum(torch.square(spec), dim=-1)
883
+ power_spec = power_spec.unsqueeze(1).permute(0, 1, 3, 2)
884
+ # power_spec shape: [batch_size, freq_bins, time_steps]
885
+ # power_spec shape: [batch_size, 1, freq_bins, time_steps]
886
+ # power_spec shape: [batch_size, 1, time_steps, freq_bins]
887
+
888
+ df_spec = spec.permute(0, 3, 2, 1)
889
+ # df_spec shape: [batch_size, 2, time_steps, freq_bins]
890
+ df_spec = df_spec[..., :self.df_decoder.df_bins]
891
+ # df_spec shape: [batch_size, 2, time_steps, df_bins]
892
+
893
+ # spec shape: [batch_size, freq_bins, time_steps, 2]
894
+ spec = torch.transpose(spec, dim0=1, dim1=2)
895
+ # spec shape: [batch_size, time_steps, freq_bins, 2]
896
+ spec = torch.unsqueeze(spec, dim=1)
897
+ # spec shape: [batch_size, 1, time_steps, freq_bins, 2]
898
+
899
+ e0, e1, e2, e3, emb, c0, _, h = self.encoder.forward(power_spec, df_spec)
900
+
901
+ mask = self.decoder.forward(emb, e3, e2, e1, e0)
902
+ # mask shape: [batch_size, 1, time_steps, freq_bins]
903
+ if torch.any(mask > 1) or torch.any(mask < 0):
904
+ raise AssertionError
905
+
906
+ spec_m = self.mask.forward(spec, mask)
907
+
908
+ # lsnr shape: [batch_size, time_steps, 1]
909
+ # lsnr = torch.transpose(lsnr, dim0=2, dim1=1)
910
+ # lsnr shape: [batch_size, 1, time_steps]
911
+
912
+ df_coefs = self.df_decoder.forward(emb, c0)
913
+ df_coefs = self.df_out_transform(df_coefs)
914
+ # df_coefs shape: [batch_size, df_order, time_steps, df_bins, 2]
915
+
916
+ spec_e = self.df_op.forward(spec.clone(), df_coefs)
917
+ # spec_e shape: [batch_size, 1, time_steps, freq_bins, 2]
918
+
919
+ spec_e[..., self.df_decoder.df_bins:, :] = spec_m[..., self.df_decoder.df_bins:, :]
920
+
921
+ spec_e = torch.squeeze(spec_e, dim=1)
922
+ spec_e = spec_e.permute(0, 2, 1, 3)
923
+ # spec_e shape: [batch_size, freq_bins, time_steps, 2]
924
+
925
+ denoise = self.istft.forward(spec_e)
926
+ # spec_e shape: [batch_size, freq_bins, time_steps, 2]
927
+ return denoise
928
+
929
+
930
+ class NXDfNetPretrainedModel(NXDfNet):
931
+ def __init__(self,
932
+ config: NXDfNetConfig,
933
+ ):
934
+ super(NXDfNetPretrainedModel, self).__init__(
935
+ config=config,
936
+ )
937
+
938
+ @classmethod
939
+ def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
940
+ config = NXDfNetConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
941
+
942
+ model = cls(config)
943
+
944
+ if os.path.isdir(pretrained_model_name_or_path):
945
+ ckpt_file = os.path.join(pretrained_model_name_or_path, MODEL_FILE)
946
+ else:
947
+ ckpt_file = pretrained_model_name_or_path
948
+
949
+ with open(ckpt_file, "rb") as f:
950
+ state_dict = torch.load(f, map_location="cpu", weights_only=True)
951
+ model.load_state_dict(state_dict, strict=True)
952
+ return model
953
+
954
+ def save_pretrained(self,
955
+ save_directory: Union[str, os.PathLike],
956
+ state_dict: Optional[dict] = None,
957
+ ):
958
+
959
+ model = self
960
+
961
+ if state_dict is None:
962
+ state_dict = model.state_dict()
963
+
964
+ os.makedirs(save_directory, exist_ok=True)
965
+
966
+ # save state dict
967
+ model_file = os.path.join(save_directory, MODEL_FILE)
968
+ torch.save(state_dict, model_file)
969
+
970
+ # save config
971
+ config_file = os.path.join(save_directory, CONFIG_FILE)
972
+ self.config.to_yaml_file(config_file)
973
+ return save_directory
974
+
975
+
976
+ def main():
977
+
978
+ config = NXDfNetConfig()
979
+ model = NXDfNet(config=config)
980
+
981
+ inputs = torch.randn(size=(1, 16000), dtype=torch.float32)
982
+
983
+ denoise = model.forward(inputs)
984
+ print(denoise.shape)
985
+ return
986
+
987
+
988
+ if __name__ == "__main__":
989
+ main()
toolbox/torchaudio/models/nx_dfnet/utils.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ https://github.com/kaituoxu/Conv-TasNet/blob/master/src/utils.py
5
+ """
6
+ import math
7
+ import torch
8
+
9
+
10
+ def overlap_and_add(signal: torch.Tensor, frame_step: int):
11
+ """
12
+ Reconstructs a signal from a framed representation.
13
+
14
+ Adds potentially overlapping frames of a signal with shape
15
+ `[..., frames, frame_length]`, offsetting subsequent frames by `frame_step`.
16
+ The resulting tensor has shape `[..., output_size]` where
17
+
18
+ output_size = (frames - 1) * frame_step + frame_length
19
+
20
+ Based on https://github.com/tensorflow/tensorflow/blob/r1.12/tensorflow/contrib/signal/python/ops/reconstruction_ops.py
21
+
22
+ :param signal: Tensor, shape: [..., frames, frame_length]. All dimensions may be unknown, and rank must be at least 2.
23
+ :param frame_step: int, overlap offsets. Must be less than or equal to frame_length.
24
+ :return: Tensor, shape: [..., output_size].
25
+ containing the overlap-added frames of signal's inner-most two dimensions.
26
+ output_size = (frames - 1) * frame_step + frame_length
27
+ """
28
+ outer_dimensions = signal.size()[:-2]
29
+ frames, frame_length = signal.size()[-2:]
30
+
31
+ subframe_length = math.gcd(frame_length, frame_step) # gcd=Greatest Common Divisor
32
+ subframe_step = frame_step // subframe_length
33
+ subframes_per_frame = frame_length // subframe_length
34
+
35
+ output_size = frame_step * (frames - 1) + frame_length
36
+ output_subframes = output_size // subframe_length
37
+
38
+ subframe_signal = signal.view(*outer_dimensions, -1, subframe_length)
39
+
40
+ frame = torch.arange(0, output_subframes).unfold(0, subframes_per_frame, subframe_step)
41
+
42
+ frame = frame.clone().detach()
43
+ frame = frame.to(signal.device)
44
+ frame = frame.long()
45
+
46
+ frame = frame.contiguous().view(-1)
47
+
48
+ result = signal.new_zeros(*outer_dimensions, output_subframes, subframe_length)
49
+ result.index_add_(-2, frame, subframe_signal)
50
+ result = result.view(*outer_dimensions, -1)
51
+ return result
52
+
53
+
54
+ if __name__ == "__main__":
55
+ pass